pde_opt.numerics.functions

Function representations for PDEs.

class pde_opt.numerics.functions.PeriodicCNN(*args: Any, **kwargs: Any)[source]

Stack of periodic conv blocks; final conv returns requested channels.

Translation-equivariant on a torus so long as stride=1 and only pointwise nonlinearities are used. Accepts (C,H,W) or (B,C,H,W); returns same spatial size.

__init__(in_channels: int, hidden_channels: Sequence[int] = (32, 64, 64), out_channels: int | None = None, kernel_size: int = 3, act: Callable[[jax.Array], jax.Array] = jax.nn.gelu, *, key: jax.Array)[source]
layers: Tuple[equinox.Module, ...]
class pde_opt.numerics.functions.LegendrePolynomialExpansion(*args: Any, **kwargs: Any)[source]

Legendre polynomial expansion.

__init__(params: jax.Array)[source]
params: jax.Array
max_degree: int
class pde_opt.numerics.functions.DiffusionLegendrePolynomials(*args: Any, **kwargs: Any)[source]

Diffusion Legendre polynomials.

Uses exp to ensure positivity.

__init__(params: jax.Array)[source]
expansion: LegendrePolynomialExpansion
class pde_opt.numerics.functions.ChemicalPotentialLegendrePolynomials(*args: Any, **kwargs: Any)[source]

Chemical potential Legendre polynomials.

__init__(params: jax.Array, prior_fn: Callable | None = None)[source]
expansion: LegendrePolynomialExpansion
prior_fn: Callable
class pde_opt.numerics.functions.Mixer2d(*args: Any, **kwargs: Any)[source]
__init__(img_size, patch_size, hidden_size, mix_patch_size, mix_hidden_size, num_blocks, *, key)[source]
conv_in: equinox.nn.Conv2d
conv_out: equinox.nn.ConvTranspose2d
blocks: list
norm: equinox.nn.LayerNorm

Modules

cnn

This module contains a periodic CNN class for representing functions in PDEs.

legendre

This module contains the Legendre polynomial expansion class for representing functions in PDEs.

mixer_mlp

This module contains a Mixer MLP architecture for representing functions in PDEs.