pde_opt

Public API for pde_opt.

class pde_opt.PDEModel(equation_type: Type[BaseEquation], domain: Domain, solver_type: Type[diffrax.AbstractSolver])[source]

Manage the solving and optimization of partial differential equations (PDEs).

The PDEModel class provides a unified interface for solving PDEs and optimizing their parameters using gradient-based methods. It supports both forward simulation and parameter estimation.

The class is designed to work with JAX-based PDE implementations and leverages automatic differentiation for efficient gradient computation during optimization.

equation_type

The equation class to optimize. Must be a subclass of BaseEquation from numerics.equations.

Type:

Type[BaseEquation]

domain

The spatial domain for solving the equation. Contains grid information, boundary conditions, and coordinate systems.

Type:

domains.Domain

solver_type

The numerical solver class for time integration. Must be a subclass of dfx.AbstractSolver and can be existing diffrax solvers like Tsit5 or custom solvers like defined in numerics.solvers.

Type:

Type[dfx.AbstractSolver]

For examples, see the documentation notebooks.

__init__(equation_type: Type[BaseEquation], domain: Domain, solver_type: Type[diffrax.AbstractSolver])[source]

Initialize the PDE optimization model.

Parameters:
  • equation_type (Type[BaseEquation]) – The equation class to optimize. Must be a subclass of BaseEquation.

  • domain (domains.Domain) – The spatial domain for the PDE. Defines the grid resolution, spatial bounds, and coordinate system.

  • solver_type (Type[dfx.AbstractSolver]) – The numerical solver for time integration. Must be a subclass of dfx.AbstractSolver and can be existing diffrax solvers like Tsit5 or custom solvers like defined in numerics.solvers.

Raises:

ValueError – If the equation and solver are incompatible (e.g., solver requires attributes that the equation doesn’t provide).

Note

The solver and equation compatibility is automatically checked during initialization. Some solvers require specific attributes from equations (e.g., fourier_symbol, fft, ifft for semi-implicit spectral methods).

solve(parameters: Dict[str, Any], y0, ts, solver_parameters: Dict[str, Any] = {}, adjoint=diffrax.ForwardMode, dt0=1e-06, max_steps=1000000, stepsize_controller=diffrax.ConstantStepSize)[source]

Solve the PDE with given parameters and initial conditions.

This method performs forward simulation of the PDE using the specified solver and parameters. The solution is computed at the requested time points and returned.

Parameters:
  • parameters (Dict[str, Any]) – Dictionary of equation parameters. Keys should match the parameter names expected by the equation class.

  • y0 – Initial condition array. Shape should match the spatial dimensions of the domain.

  • ts – Time points at which to save the solution. Should be a 1D array of increasing time values. The solver will integrate from ts[0] to ts[-1] and return solutions at all specified time points.

  • solver_parameters (Dict[str, Any], optional) – Additional parameters for the solver. These are passed directly to the solver constructor.

  • adjoint (dfx.AbstractAdjoint, optional) – Adjoint mode for automatic differentiation. Defaults to ForwardMode() for forward-mode AD. Use RecursiveCheckpointAdjoint() for reverse-mode AD when the number of parameters is large.

  • dt0 (float, optional) – Initial time step for the solver. Defaults to 1e-6. The solver may adapt this step size during integration.

  • max_steps (int, optional) – Maximum number of integration steps. Defaults to 1,000,000.

  • stepsize_controller (dfx.AbstractStepSizeController, optional) – Controller for adaptive step sizing. Defaults to ConstantStepSize().

Returns:

Solution array with shape (len(ts), *y0.shape).

residual_single(parameters, solver_parameters, y0, values, ts, adjoint=diffrax.ForwardMode)[source]

Compute residuals for a single trajectory.

This method computes the difference between model predictions and observed data for a single initial condition and trajectory. It’s used internally by the batched residuals computation.

Parameters:
  • parameters (Dict[str, Any]) – Current parameter values for the equation.

  • solver_parameters (Dict[str, Any]) – Parameters for the numerical solver.

  • y0 (jax.Array) – Initial condition.

  • values (jax.Array) – Observed values for computing residuals.

  • ts (jax.Array) – Time points with shape (timepoints,).

  • adjoint (dfx.AbstractAdjoint, optional) – Adjoint mode for automatic differentiation. Defaults to ForwardMode().

Returns:

Residuals array with shape (timepoints, *y0.shape).

The residuals are computed as: values - predicted[1:] (values should not include the initial condition).

regularization(parameters, weights, lambda_reg)[source]

Compute parameter regularization term.

This method computes a weighted L2 regularization term for the parameters to prevent overfitting and improve parameter stability during optimization. The regularization is computed as:

\[\text{reg} = \lambda \sum_i w_i p_i^2\]

where λ is the regularization coefficient, w_i are the weights, and p_i are the parameter values.

Parameters:
  • parameters (Dict[str, Any]) – Current parameter values for the equation. Can contain nested structures (pytrees) of parameters.

  • weights (Dict[str, Any]) – Regularization weights with the same structure as parameters. Higher weights penalize large parameter values more strongly. None values in weights are ignored.

  • lambda_reg (float) – Regularization coefficient controlling the overall strength of the regularization term.

Returns:

Scalar regularization term to be added to the loss function.

residuals(parameters, y0s__values, solver_parameters, ts, weights, lambda_reg, adjoint=diffrax.ForwardMode)[source]

Compute batched residuals and regularization for parameter optimization.

This method computes the difference between model predictions and observed data for multiple trajectories simultaneously, along with parameter regularization. It’s used internally by the optimization algorithms in the train() method.

Parameters:
  • parameters (Dict[str, Any]) – Current parameter values for the equation.

  • y0s__values (Tuple[jax.Array, jax.Array]) – Tuple containing: - y0s: Batch of initial conditions - values: Batch of observed values

  • solver_parameters (Dict[str, Any]) – Parameters for the numerical solver.

  • ts (jax.Array) – Time points for the simulation.

  • weights (Dict[str, jax.Array]) – Regularization weights for each parameter. Should match the structure of parameters.

  • lambda_reg (float) – Regularization coefficient controlling the strength of parameter penalties.

  • adjoint (dfx.AbstractAdjoint, optional) – Adjoint mode for automatic differentiation. Defaults to ForwardMode().

Returns:

Tuple containing:
  • batch_residuals: Residuals array with shape (batch_size, timepoints, *y0.shape)

  • reg: Scalar regularization term

Return type:

Tuple[jax.Array, float]

mse(parameters, y0s__values, solver_parameters, ts, weights, lambda_reg, adjoint=diffrax.RecursiveCheckpointAdjoint)[source]

Compute the mean squared error loss for parameter optimization.

This method computes the mean squared error between model predictions and observed data, plus a regularization term. It’s used as the objective function for the “mse” optimization method in the train() method.

The loss function is:

\[\text{MSE} = \text{mean}((\text{predicted} - \text{observed})^2) + \lambda \cdot \text{regularization}\]
Parameters:
  • parameters (Dict[str, Any]) – Current parameter values for the equation.

  • y0s__values (Tuple[jax.Array, jax.Array]) – Tuple containing: - y0s: Batch of initial conditions - values: Batch of observed values

  • solver_parameters (Dict[str, Any]) – Parameters for the numerical solver.

  • ts (jax.Array) – Time points for the simulation.

  • weights (Dict[str, jax.Array]) – Regularization weights for each parameter. Should match the structure of parameters.

  • lambda_reg (float) – Regularization coefficient controlling the strength of parameter penalties.

  • adjoint (dfx.AbstractAdjoint, optional) – Adjoint mode for automatic differentiation. Defaults to RecursiveCheckpointAdjoint().

Returns:

Mean squared error loss including regularization term.

Return type:

float

train(data, inds, opt_parameters, other_parameters, solver_parameters, weights, lambda_reg, method='least_squares', max_steps=100)[source]

Train the model by optimizing parameters to fit observed data.

This method performs parameter estimation by minimizing the difference between model predictions and observed data. It supports two optimization approaches: least-squares (which uses the Levenberg-Marquardt algorithm) and mean squared error minimization (which uses the BFGS algorithm).

Parameters:
  • data (Dict[str, List]) – Training data dictionary with keys: - “ys”: List of solution snapshots at different times - “ts”: List of corresponding time points Example: {“ys”: [y0, y1, y2, …], “ts”: [0, 0.1, 0.2, …]}

  • inds (List[List[int]]) – Indices specifying which data points to use for each training trajectory. Each inner list represents a trajectory: [initial_time_index, …intermediate_indices…]. Example: [[0, 1, 2], [0, 1, 2]] for two trajectories using time points 0, 1, 2.

  • opt_parameters (Dict[str, jax.Array]) – Parameters to optimize.

  • other_parameters (Dict[str, Any]) – Fixed parameters that won’t be optimized.

  • solver_parameters (Dict[str, Any]) – Parameters for the numerical solver during optimization. Passed to the solver constructor.

  • weights (Dict[str, jax.Array]) – Regularization weights for each parameter. Should have the same structure as opt_parameters.

  • lambda_reg (float) – Regularization coefficient. Controls the strength of parameter regularization.

  • method (str, optional) –

    Optimization method. Options: - “least_squares”: Uses Levenberg-Marquardt algorithm with ForwardMode

    adjoint. Best when parameter number is small (not using neural networks).

    • ”mse”: Uses BFGS algorithm with RecursiveCheckpointAdjoint. Better when parameter number is large (using neural networks).

  • max_steps (int, optional) – Maximum number of optimization iterations. Defaults to 100.

Returns:

Optimized parameters combined with fixed parameters. The returned dictionary contains both the optimized parameters and the other_parameters, ready for use in the solve() method.

Return type:

Dict[str, Any]

optimize(objective_function, y0, ts, opt_parameters, other_parameters, solver_parameters, weights, lambda_reg, max_steps=100)[source]

Optimize parameters by minimizing a scalar function of the solution.

This method performs parameter optimization by minimizing a user-provided scalar function that takes the solution as input. It uses the BFGS algorithm for optimization and supports parameter regularization.

Parameters:
  • objective_function (Callable) – A callable function that takes the solution array (shape: (len(ts), *y0.shape)) and returns a scalar value to minimize. The function should be JAX-compatible for automatic differentiation.

  • y0 – Initial condition array. Shape should match the spatial dimensions of the domain.

  • ts – Time points at which to save the solution. Should be a 1D array of increasing time values.

  • opt_parameters (Dict[str, jax.Array]) – Parameters to optimize.

  • other_parameters (Dict[str, Any]) – Fixed parameters that won’t be optimized.

  • solver_parameters (Dict[str, Any]) – Parameters for the numerical solver during optimization. Passed to the solver constructor.

  • weights (Dict[str, jax.Array]) – Regularization weights for each parameter. Should have the same structure as opt_parameters.

  • lambda_reg (float) – Regularization coefficient. Controls the strength of parameter regularization.

  • max_steps (int, optional) – Maximum number of optimization iterations. Defaults to 100.

Returns:

Optimized parameters combined with fixed parameters. The returned dictionary contains both the optimized parameters and the other_parameters, ready for use in the solve() method.

Return type:

Dict[str, Any]

class pde_opt.PDEEnv(*args: Any, **kwargs: Any)[source]

Reinforcement learning environment for controlling partial differential equations.

The PDEEnv class provides a Gymnasium-compatible environment for reinforcement learning control of PDEs. It allows agents to learn control policies by taking actions that modify equation parameters and receiving rewards based on the resulting system behavior.

equation_type

The PDE equation class to control.

Type:

Type[BaseEquation]

domain

Spatial domain for the PDE simulation.

Type:

domains.Domain

solver_type

Numerical solver for time integration.

Type:

Type[diffrax.AbstractSolver]

end_time

Maximum simulation time for each episode.

Type:

float

step_dt

Time duration for each RL step.

Type:

float

numeric_dt

Internal time step for numerical integration.

Type:

float

observation_space

Gymnasium observation space.

Type:

gym.Space

action_space

Gymnasium action space.

Type:

gym.Space

For an example, see the documentation notebooks.

__init__(equation_type: Type[BaseEquation], domain: Domain, solver_type: Type[diffrax.AbstractSolver], end_time: float, step_dt: float, numeric_dt: float, state_to_observation_func: Callable, reward_function: Callable, reset_func: Callable, reset_control_value, update_control_value: Callable, update_control_parameter: Callable, action_space_config: Dict[str, Any], static_equation_parameters: Dict[str, Any], control_equation_parameter_name: str, solver_parameters: Dict[str, Any])[source]

Initialize the PDE reinforcement learning environment.

Parameters:
  • equation_type (Type[BaseEquation]) – The PDE equation class to control. Must be a subclass of BaseEquation from numerics.equations.

  • domain (domains.Domain) – Spatial domain for the PDE simulation. Defines the grid resolution, spatial bounds, and coordinate system.

  • solver_type (Type[diffrax.AbstractSolver]) – Numerical solver for time integration. Must be compatible with the equation type.

  • end_time (float) – Maximum simulation time for each episode. Episodes terminate when this time is reached.

  • step_dt (float) – Time duration for each RL step. This is the time interval between consecutive actions taken by the agent.

  • numeric_dt (float) – Internal time step for numerical integration. Likely to be smaller than step_dt for numerical stability.

  • state_to_observation_func (Callable) – Function that converts the PDE state to an observation for the RL agent.

  • reward_function (Callable) – Function that computes the reward from the current PDE state.

  • reset_func (Callable) – Function that generates initial conditions for the PDE. Should accept (domain, seed=None) and return initial state.

  • reset_control_value – Initial value for the control parameter.

  • update_control_value (Callable) – Function that updates the control value based on the agent’s action. Should accept (action, old_value) and return the new control value.

  • update_control_parameter (Callable) – Function that converts control values to equation parameters. Should accept (old_value, new_value) and return the parameter value for the equation.

  • action_space_config (Dict[str, Any]) – Configuration for the action space. Should contain: - “type”: “continuous” or “discrete” - For continuous: “shape”, “low”, “high” - For discrete: “num_actions”, “action_mapping”

  • static_equation_parameters (Dict[str, Any]) – Fixed parameters for the equation that are not controlled by the agent. Keys should match equation parameter names.

  • control_equation_parameter_name (str) – Name of the equation parameter that will be controlled by the agent’s actions.

  • solver_parameters (Dict[str, Any]) – Additional parameters for the numerical solver. Passed directly to the solver constructor.

reset(seed: int | None = None, options: dict | None = None)[source]

Reset the environment to initial state.

This method resets the PDE environment to its initial state, generating new initial conditions and resetting the control parameters. It follows the standard Gymnasium reset interface.

Parameters:
  • seed (Optional[int]) – Random seed for reproducible initial conditions.

  • options (Optional[dict]) – Additional options for reset.

Returns:

A tuple containing:
  • observation: Initial observation from the reset state

  • info: Additional information dictionary (currently empty)

Return type:

Tuple[np.ndarray, dict]

step(action: numpy.ndarray)[source]

Take a step in the environment.

This method executes one step of the RL environment by: 1. Converting the agent’s action to a control parameter update 2. Updating the PDE equation with the new control parameter 3. Simulating the PDE forward in time for one step_dt 4. Computing the reward and next observation 5. Checking for episode termination

Parameters:

action (np.ndarray) – The action taken by the RL agent. Should be compatible with the environment’s action space. For continuous actions, this is typically a numpy array. For discrete actions, this is an integer.

Returns:

A tuple containing:
  • observation: Next observation from the updated state

  • reward: Reward value computed from the current state

  • terminated: Whether the episode has ended (reached end_time)

  • truncated: Whether the episode was truncated (always False for now)

  • info: Additional information dictionary (currently empty)

Return type:

Tuple[np.ndarray, float, bool, bool, dict]

class pde_opt.BaseEquation[source]

Base class for time-dependent PDE equations.

Abstract base class for time-dependent PDE equations of the form

\[\frac{d}{dt} \text{state} = F(\text{state}, t)\]

where state is the state of the system and t is the time.

Subclasses should implement the rhs method, which returns the right hand side of the equation.

abstract rhs(state: State, t: float) State[source]

Right hand side of the equation.

class pde_opt.AllenCahn2DPeriodic(domain: Domain, kappa: float, mu: Callable | equinox.Module, R: Callable | equinox.Module, derivs: str = 'fd')[source]

Allen-Cahn equation in 2D with periodic boundary conditions.

The Allen-Cahn equation describes phase transitions and interface dynamics. The equation is:

\[\frac{\partial u}{\partial t} = -R(u) \mu\]

where u is the concentration, R(u) is the reaction term, μ is the chemical potential, and κ is a parameter (the gradient energy coefficient). The chemical potential is given by:

\[\mu = \mu_h(u) - \kappa \nabla^2 u\]
domain: Domain

Domain of the equation

kappa: float

Gradient energy coefficient

mu: Callable | equinox.Module

Function for the chemical potential

R: Callable | equinox.Module

Function for the reaction term

derivs: str = 'fd'

Type of derivative computation

rhs(state, t)[source]

Right hand side of the equation.

rhs_fourier(state, t)[source]
rhs_fd(state, t)[source]
__init__(domain: Domain, kappa: float, mu: Callable | equinox.Module, R: Callable | equinox.Module, derivs: str = 'fd') None
class pde_opt.AllenCahn2DSmoothedBoundary(domain: Domain, kappa: float, f: Callable | equinox.Module, mu: Callable | equinox.Module, R: Callable | equinox.Module, theta: Callable | equinox.Module, derivs: str = 'fd')[source]

Allen-Cahn equation with smoothed boundary method for arbitrary geometries.

This class implements the Allen-Cahn equation using the smoothed boundary method, which allows for complex domain geometries through a smooth level-set function ψ.

The equation is:

\[\frac{\partial u}{\partial t} = -R(u) \mu\]

where the chemical potential includes boundary effects:

\[\mu = \mu_h(u) - \frac{\kappa}{\psi} \nabla \cdot (\psi \nabla u) - \sqrt{\kappa} \frac{|\nabla \psi|}{\psi} \sqrt{2f} \cos(\theta)\]
domain: Domain

Domain of the equation

kappa: float

Gradient energy coefficient

f: Callable | equinox.Module

Function for the free energy density

mu: Callable | equinox.Module

Function for the chemical potential

R: Callable | equinox.Module

Function for the reaction term

theta: Callable | equinox.Module

Function for the contact angle

derivs: str = 'fd'

Type of derivative computation

rhs(state, t)[source]

Right hand side of the equation.

rhs_fd(state, t)[source]
__init__(domain: Domain, kappa: float, f: Callable | equinox.Module, mu: Callable | equinox.Module, R: Callable | equinox.Module, theta: Callable | equinox.Module, derivs: str = 'fd') None
class pde_opt.CahnHilliard2DPeriodic(domain: Domain, kappa: float, mu: Callable | equinox.Module, D: Callable | equinox.Module, derivs: str = 'fd')[source]

Cahn–Hilliard equation in 2D with periodic boundary conditions.

The Cahn-Hilliard equation describes phase separation and coarsening dynamics. The equation is:

\[\frac{\partial u}{\partial t} = \nabla \cdot (D(u) \nabla \mu)\]

where u is the concentration, D(u) is the mobility, and μ is the chemical potential. The chemical potential is given by:

\[\mu = \mu_h(u) - \kappa \nabla^2 u\]
domain: Domain

Domain of the equation

kappa: float

Gradient energy coefficient

mu: Callable | equinox.Module

Function for the chemical potential

D: Callable | equinox.Module

Function for the mobility

derivs: str = 'fd'

Type of derivative computation

fft = None
ifft = None
fourier_symbol = None
rhs(state, t)[source]

Right hand side of the equation.

rhs_fourier(state, t)[source]
rhs_fd(state, t)[source]
__init__(domain: Domain, kappa: float, mu: Callable | equinox.Module, D: Callable | equinox.Module, derivs: str = 'fd') None
class pde_opt.CahnHilliard3DPeriodic(domain: Domain, kappa: float, mu: Callable | equinox.Module, D: Callable | equinox.Module, derivs: str = 'fd')[source]

Cahn–Hilliard equation in 3D with periodic boundary conditions.

The Cahn-Hilliard equation describes phase separation and coarsening dynamics. The equation is:

\[\frac{\partial u}{\partial t} = \nabla \cdot (D(u) \nabla \mu)\]

where u is the concentration, D(u) is the mobility, and μ is the chemical potential. The chemical potential is given by:

\[\mu = \mu_h(u) - \kappa \nabla^2 u\]
domain: Domain

Domain of the equation

kappa: float

Gradient energy coefficient

mu: Callable | equinox.Module

Function for the chemical potential

D: Callable | equinox.Module

Function for the mobility

derivs: str = 'fd'

Type of derivative computation

fft = None
ifft = None
fourier_symbol = None
rhs(state, t)[source]

Right hand side of the equation.

rhs_fourier(state, t)[source]
rhs_fd(state, t)[source]
__init__(domain: Domain, kappa: float, mu: Callable | equinox.Module, D: Callable | equinox.Module, derivs: str = 'fd') None
class pde_opt.CahnHilliard2DSmoothedBoundary(domain: Domain, kappa: float, f: Callable | equinox.Module, mu: Callable | equinox.Module, D: Callable | equinox.Module, theta: Callable | equinox.Module, flux: Callable | equinox.Module, derivs: str = 'fd')[source]

Cahn–Hilliard equation with smoothed boundary method for arbitrary geometries.

This class implements the Cahn-Hilliard equation using the smoothed boundary method, which allows for complex domain geometries through a smooth level-set function ψ.

The equation is:

\[\frac{\partial u}{\partial t} = \frac{1}{\psi} \nabla \cdot (\psi D(u) \nabla \mu) + \frac{|\nabla \psi|}{\psi} J_n\]

where the chemical potential includes boundary effects:

\[\mu = \mu_h(u) - \frac{\kappa}{\psi} \nabla \cdot (\psi \nabla u) - \sqrt{\kappa} \frac{|\nabla \psi|}{\psi} \sqrt{2f} \cos(\theta)\]
__init__(domain: Domain, kappa: float, f: Callable | equinox.Module, mu: Callable | equinox.Module, D: Callable | equinox.Module, theta: Callable | equinox.Module, flux: Callable | equinox.Module, derivs: str = 'fd') None
domain: Domain

Domain of the equation

kappa: float

Gradient energy coefficient

f: Callable | equinox.Module

Function for the free energy density

mu: Callable | equinox.Module

Function for the chemical potential

D: Callable | equinox.Module

Function for the mobility

theta: Callable | equinox.Module

Function for the contact angle

flux: Callable | equinox.Module

Function for the normal flux

derivs: str = 'fd'

Type of derivative computation

rhs(state, t)[source]

Right hand side of the equation.

rhs_fd(state, t)[source]
class pde_opt.GPE2DTSControl(domain: Domain, k: float, e: float, lights: Callable, trap_factor: float = 1.0)[source]

Gross-Pitaevskii equation in 2D with time-splitting and control.

The Gross-Pitaevskii equation describes the dynamics of Bose-Einstein condensates. The equation is:

\[i\hbar \frac{\partial \psi}{\partial t} = \left[-\frac{\hbar^2}{2m}\nabla^2 + V(\mathbf{r}, t) + g|\psi|^2\right]\psi\]

where ψ is the wave function, V is the external potential, and g is the interaction strength. The external potential includes a harmonic trap and control field:

\[V(\mathbf{r}, t) = \frac{1}{2}m\omega^2\left[(1+\epsilon)x^2 + (1-\epsilon)y^2\right] + V_{control}(\mathbf{r}, t)\]
domain: Domain

Domain of the equation

k: float

Interaction strength parameter

e: float

Trap ellipticity parameter

lights: Callable

Function for the control field

trap_factor: float = 1.0

Scaling factor for the harmonic trap

fft = None
ifft = None
A_term = None
dx = None
A_terms(state, t)[source]

A terms of the equation.

B_terms(state, t)[source]

B terms of the equation.

rhs(state, t)[source]

Right hand side of the equation.

__init__(domain: Domain, k: float, e: float, lights: Callable, trap_factor: float = 1.0) None
class pde_opt.GPE2DTSRot(domain: Domain, k: float, e: float, omega: float)[source]

Gross-Pitaevskii equation in 2D with time-splitting and rotation.

The Gross-Pitaevskii equation describes the dynamics of Bose-Einstein condensates. The equation is:

\[i\hbar \frac{\partial \psi}{\partial t} = \left[-\frac{\hbar^2}{2m}\nabla^2 + V(\mathbf{r}) + g|\psi|^2 - \Omega L_z\right]\psi\]

where ψ is the wave function, V is the external potential, g is the interaction strength, and Ω is the rotation frequency with L_z being the angular momentum operator. The external potential includes a harmonic trap:

\[V(\mathbf{r}) = \frac{1}{2}m\omega^2\left[(1+\epsilon)x^2 + (1-\epsilon)y^2\right]\]
domain: Domain

Domain of the equation

k: float

Interaction strength parameter

e: float

Trap ellipticity parameter

omega: float

Rotation frequency

A_terms(state_hat, t)[source]

A terms of the equation.

B_terms(state, t)[source]

B terms of the equation.

__init__(domain: Domain, k: float, e: float, omega: float) None
class pde_opt.Domain(points: Tuple[int, ...], box: Tuple[Tuple[float, float], ...], units: str, geometry: Shape | None = None)[source]

Sets up a simulation domain for the model.

The following information is stored in a Domain: – points[i] is the number of collocation points in the i’th dimension – dx[i] is the spacing between each collocation point in the i’th dimension – box[i] is the bounds of the simulation box in the i’th dimension – units are the length units these. values are stored in

points: Tuple[int, ...]
box: Tuple[Tuple[float, float], ...]
units: str
geometry: Shape | None = None
axes() Tuple[jax.Array, ...][source]
fft_axes() Tuple[jax.Array, ...][source]
rfft_axes() Tuple[jax.Array, ...][source]
mesh() Tuple[jax.Array, ...][source]
fft_mesh() Tuple[jax.Array, ...][source]
rfft_mesh() Tuple[jax.Array, ...][source]
__init__(points: Tuple[int, ...], box: Tuple[Tuple[float, float], ...], units: str, geometry: Shape | None = None) None
class pde_opt.Shape(binary: jax.Array, refine_factor: float | None = None, refine_edge: float | None = None, dx: Tuple[float, float] | None = (1.0, 1.0), smooth_epsilon: float = 1.0, smooth_curvature: float = 0.0, smooth_dt: float = 0.1, smooth_tf: float = 1.0)[source]

Sets up a geometry/shape for solving PDE on with smoothed boundary method.

The user creates a shape by providing a binary representation and an optional smoothing parameter.

binary: jax.Array
refine_factor: float | None = None
refine_edge: float | None = None
dx: Tuple[float, float] | None = (1.0, 1.0)
smooth_epsilon: float = 1.0
smooth_curvature: float = 0.0
smooth_dt: float = 0.1
smooth_tf: float = 1.0
refine_binary_mask() jax.Array[source]

Refine the binary mask by upsampling and adding edge padding.

Parameters:
  • refine_factor – Factor by which to upsample the binary mask (e.g., 2.0 for 2x upsampling)

  • refine_edge – Percentage to increase the edges (e.g., 0.5 for 50% increase)

Returns:

Refined binary mask with upsampling and edge padding applied

smooth_shape() jax.Array[source]

Smooths the shape using the Allen-Cahn equation with curvature minimization.

laplacian_from_mask(periodic: bool = False)[source]

Unnormalized graph Laplacian (4-neighbour) from a 0/1 mask. Nodes are entries where mask==1. Two nodes connect if they are up/down/left/right neighbours and both are 1.

Returns:

(n_nodes, n_nodes) CSR Laplacian ids: (H, W) array, node index in [0, n_nodes) or -1 if not a node

Return type:

L

get_shape_modes(N: int | None = None)[source]

Get the first N eigenvectors of the graph Laplacian of the binary mask.

Creates a graph where nodes are the 1-valued pixels, with edges between adjacent pixels (left, right, top, bottom neighbors).

Parameters:
  • N – Number of eigenvectors to return. If None, returns all eigenvectors.

  • downsampling_factor – If provided, downsample binary by this factor before computing modes, then upsample results back to original size. This can significantly reduce memory usage and computation time for large binary masks.

Returns:

Array of shape (num_nodes, N) containing the first N eigenvectors

__init__(binary: jax.Array, refine_factor: float | None = None, refine_edge: float | None = None, dx: Tuple[float, float] | None = (1.0, 1.0), smooth_epsilon: float = 1.0, smooth_curvature: float = 0.0, smooth_dt: float = 0.1, smooth_tf: float = 1.0) None
class pde_opt.PeriodicCNN(*args: Any, **kwargs: Any)[source]

Stack of periodic conv blocks; final conv returns requested channels.

Translation-equivariant on a torus so long as stride=1 and only pointwise nonlinearities are used. Accepts (C,H,W) or (B,C,H,W); returns same spatial size.

__init__(in_channels: int, hidden_channels: Sequence[int] = (32, 64, 64), out_channels: int | None = None, kernel_size: int = 3, act: Callable[[jax.Array], jax.Array] = jax.nn.gelu, *, key: jax.Array)[source]
layers: Tuple[equinox.Module, ...]
class pde_opt.LegendrePolynomialExpansion(*args: Any, **kwargs: Any)[source]

Legendre polynomial expansion.

__init__(params: jax.Array)[source]
params: jax.Array
max_degree: int
class pde_opt.DiffusionLegendrePolynomials(*args: Any, **kwargs: Any)[source]

Diffusion Legendre polynomials.

Uses exp to ensure positivity.

__init__(params: jax.Array)[source]
expansion: LegendrePolynomialExpansion
class pde_opt.ChemicalPotentialLegendrePolynomials(*args: Any, **kwargs: Any)[source]

Chemical potential Legendre polynomials.

__init__(params: jax.Array, prior_fn: Callable | None = None)[source]
expansion: LegendrePolynomialExpansion
prior_fn: Callable
class pde_opt.Mixer2d(*args: Any, **kwargs: Any)[source]
__init__(img_size, patch_size, hidden_size, mix_patch_size, mix_hidden_size, num_blocks, *, key)[source]
conv_in: equinox.nn.Conv2d
conv_out: equinox.nn.ConvTranspose2d
blocks: list
norm: equinox.nn.LayerNorm
class pde_opt.SemiImplicitFourierSpectral(*args: Any, **kwargs: Any)[source]

Semi-implicit Fourier spectral method.

This solver implements a semi-implicit Fourier spectral method for phase-field simulations with variable mobility.

Required Equation Attributes:

fourier_symbol: Fourier space representation of the highest order differential operator. fft: Forward Fourier transform function. ifft: Inverse Fourier transform function.

Parameters:

A (float) – Constant for splitting the mobility term.

References

Zhu, Jingzhi, et al. “Coarsening kinetics from a variable-mobility Cahn-Hilliard equation: Application of a semi-implicit Fourier spectral method.” Physical Review E 60.4 (1999): 3564.

required_equation_attrs = ['fourier_symbol', 'fft', 'ifft']
A: float
fourier_symbol: jax.Array
fft: Callable
ifft: Callable
order(terms)[source]
init(terms, t0, t1, y0, args)[source]
step(terms, t0, t1, y0, args, solver_state, made_jump)[source]
func(terms, t0, y0, args)[source]
class pde_opt.StrangSplitting(*args: Any, **kwargs: Any)[source]

Strang splitting method for time-dependent PDEs with separable operators.

References

Bao, Weizhu, and Yongyong Cai. “Mathematical theory and numerical methods for Bose-Einstein condensation.” arXiv preprint arXiv:1212.5341 (2012).

required_equation_attrs = ['A_term', 'dx', 'fft', 'ifft']
A_term: jax.Array
dx: float
fft: Callable
ifft: Callable
time_scale: float
order(terms)[source]
init(terms, t0, t1, y0, args)[source]
step(terms, t0, t1, y0, args, solver_state, made_jump)[source]
func(terms, t0, y0, args)[source]

Modules

numerics

Numerical methods for PDEs.

pde_env

pde_model

utils

This module contains general utility functions.