pde_opt.numerics.functions
Function representations for PDEs.
- class pde_opt.numerics.functions.PeriodicCNN(*args: Any, **kwargs: Any)[source]
Stack of periodic conv blocks; final conv returns requested channels.
Translation-equivariant on a torus so long as stride=1 and only pointwise nonlinearities are used. Accepts (C,H,W) or (B,C,H,W); returns same spatial size.
- __init__(in_channels: int, hidden_channels: Sequence[int] = (32, 64, 64), out_channels: int | None = None, kernel_size: int = 3, act: Callable[[jax.Array], jax.Array] = jax.nn.gelu, *, key: jax.Array)[source]
- layers: Tuple[equinox.Module, ...]
- class pde_opt.numerics.functions.LegendrePolynomialExpansion(*args: Any, **kwargs: Any)[source]
Legendre polynomial expansion.
- params: jax.Array
- max_degree: int
- class pde_opt.numerics.functions.DiffusionLegendrePolynomials(*args: Any, **kwargs: Any)[source]
Diffusion Legendre polynomials.
Uses exp to ensure positivity.
- expansion: LegendrePolynomialExpansion
- class pde_opt.numerics.functions.ChemicalPotentialLegendrePolynomials(*args: Any, **kwargs: Any)[source]
Chemical potential Legendre polynomials.
- expansion: LegendrePolynomialExpansion
- prior_fn: Callable
- class pde_opt.numerics.functions.Mixer2d(*args: Any, **kwargs: Any)[source]
- __init__(img_size, patch_size, hidden_size, mix_patch_size, mix_hidden_size, num_blocks, *, key)[source]
- conv_in: equinox.nn.Conv2d
- conv_out: equinox.nn.ConvTranspose2d
- blocks: list
- norm: equinox.nn.LayerNorm
Modules
This module contains a periodic CNN class for representing functions in PDEs. |
|
This module contains the Legendre polynomial expansion class for representing functions in PDEs. |
|
This module contains a Mixer MLP architecture for representing functions in PDEs. |