pde_opt.numerics.functions.cnn

This module contains a periodic CNN class for representing functions in PDEs.

Classes

PeriodicCNN(*args, **kwargs)

Stack of periodic conv blocks; final conv returns requested channels.

PeriodicConvBlock(*args, **kwargs)

Conv2d -> activation with periodic padding.

class pde_opt.numerics.functions.cnn.PeriodicConvBlock(*args: Any, **kwargs: Any)[source]

Conv2d -> activation with periodic padding. Translation-equivariant.

__init__(in_channels: int, out_channels: int, kernel_size: int = 3, act: Callable[[jax.Array], jax.Array] = jax.nn.gelu, *, key: jax.Array)[source]
conv: equinox.nn.Conv2d
act: Callable[[jax.Array], jax.Array]
class pde_opt.numerics.functions.cnn.PeriodicCNN(*args: Any, **kwargs: Any)[source]

Stack of periodic conv blocks; final conv returns requested channels.

Translation-equivariant on a torus so long as stride=1 and only pointwise nonlinearities are used. Accepts (C,H,W) or (B,C,H,W); returns same spatial size.

__init__(in_channels: int, hidden_channels: Sequence[int] = (32, 64, 64), out_channels: int | None = None, kernel_size: int = 3, act: Callable[[jax.Array], jax.Array] = jax.nn.gelu, *, key: jax.Array)[source]
layers: Tuple[equinox.Module, ...]