pde_opt.numerics.functions.cnn
This module contains a periodic CNN class for representing functions in PDEs.
Classes
|
Stack of periodic conv blocks; final conv returns requested channels. |
|
Conv2d -> activation with periodic padding. |
- class pde_opt.numerics.functions.cnn.PeriodicConvBlock(*args: Any, **kwargs: Any)[source]
Conv2d -> activation with periodic padding. Translation-equivariant.
- __init__(in_channels: int, out_channels: int, kernel_size: int = 3, act: Callable[[jax.Array], jax.Array] = jax.nn.gelu, *, key: jax.Array)[source]
- conv: equinox.nn.Conv2d
- act: Callable[[jax.Array], jax.Array]
- class pde_opt.numerics.functions.cnn.PeriodicCNN(*args: Any, **kwargs: Any)[source]
Stack of periodic conv blocks; final conv returns requested channels.
Translation-equivariant on a torus so long as stride=1 and only pointwise nonlinearities are used. Accepts (C,H,W) or (B,C,H,W); returns same spatial size.
- __init__(in_channels: int, hidden_channels: Sequence[int] = (32, 64, 64), out_channels: int | None = None, kernel_size: int = 3, act: Callable[[jax.Array], jax.Array] = jax.nn.gelu, *, key: jax.Array)[source]
- layers: Tuple[equinox.Module, ...]