pde_opt.pde_model

Classes

PDEModel(equation_type, domain, solver_type)

Manage the solving and optimization of partial differential equations (PDEs).

class pde_opt.pde_model.PDEModel(equation_type: Type[BaseEquation], domain: Domain, solver_type: Type[diffrax.AbstractSolver])[source]

Manage the solving and optimization of partial differential equations (PDEs).

The PDEModel class provides a unified interface for solving PDEs and optimizing their parameters using gradient-based methods. It supports both forward simulation and parameter estimation.

The class is designed to work with JAX-based PDE implementations and leverages automatic differentiation for efficient gradient computation during optimization.

equation_type

The equation class to optimize. Must be a subclass of BaseEquation from numerics.equations.

Type:

Type[BaseEquation]

domain

The spatial domain for solving the equation. Contains grid information, boundary conditions, and coordinate systems.

Type:

domains.Domain

solver_type

The numerical solver class for time integration. Must be a subclass of dfx.AbstractSolver and can be existing diffrax solvers like Tsit5 or custom solvers like defined in numerics.solvers.

Type:

Type[dfx.AbstractSolver]

For examples, see the documentation notebooks.

__init__(equation_type: Type[BaseEquation], domain: Domain, solver_type: Type[diffrax.AbstractSolver])[source]

Initialize the PDE optimization model.

Parameters:
  • equation_type (Type[BaseEquation]) – The equation class to optimize. Must be a subclass of BaseEquation.

  • domain (domains.Domain) – The spatial domain for the PDE. Defines the grid resolution, spatial bounds, and coordinate system.

  • solver_type (Type[dfx.AbstractSolver]) – The numerical solver for time integration. Must be a subclass of dfx.AbstractSolver and can be existing diffrax solvers like Tsit5 or custom solvers like defined in numerics.solvers.

Raises:

ValueError – If the equation and solver are incompatible (e.g., solver requires attributes that the equation doesn’t provide).

Note

The solver and equation compatibility is automatically checked during initialization. Some solvers require specific attributes from equations (e.g., fourier_symbol, fft, ifft for semi-implicit spectral methods).

solve(parameters: Dict[str, Any], y0, ts, solver_parameters: Dict[str, Any] = {}, adjoint=diffrax.ForwardMode, dt0=1e-06, max_steps=1000000, stepsize_controller=diffrax.ConstantStepSize)[source]

Solve the PDE with given parameters and initial conditions.

This method performs forward simulation of the PDE using the specified solver and parameters. The solution is computed at the requested time points and returned.

Parameters:
  • parameters (Dict[str, Any]) – Dictionary of equation parameters. Keys should match the parameter names expected by the equation class.

  • y0 – Initial condition array. Shape should match the spatial dimensions of the domain.

  • ts – Time points at which to save the solution. Should be a 1D array of increasing time values. The solver will integrate from ts[0] to ts[-1] and return solutions at all specified time points.

  • solver_parameters (Dict[str, Any], optional) – Additional parameters for the solver. These are passed directly to the solver constructor.

  • adjoint (dfx.AbstractAdjoint, optional) – Adjoint mode for automatic differentiation. Defaults to ForwardMode() for forward-mode AD. Use RecursiveCheckpointAdjoint() for reverse-mode AD when the number of parameters is large.

  • dt0 (float, optional) – Initial time step for the solver. Defaults to 1e-6. The solver may adapt this step size during integration.

  • max_steps (int, optional) – Maximum number of integration steps. Defaults to 1,000,000.

  • stepsize_controller (dfx.AbstractStepSizeController, optional) – Controller for adaptive step sizing. Defaults to ConstantStepSize().

Returns:

Solution array with shape (len(ts), *y0.shape).

residual_single(parameters, solver_parameters, y0, values, ts, adjoint=diffrax.ForwardMode)[source]

Compute residuals for a single trajectory.

This method computes the difference between model predictions and observed data for a single initial condition and trajectory. It’s used internally by the batched residuals computation.

Parameters:
  • parameters (Dict[str, Any]) – Current parameter values for the equation.

  • solver_parameters (Dict[str, Any]) – Parameters for the numerical solver.

  • y0 (jax.Array) – Initial condition.

  • values (jax.Array) – Observed values for computing residuals.

  • ts (jax.Array) – Time points with shape (timepoints,).

  • adjoint (dfx.AbstractAdjoint, optional) – Adjoint mode for automatic differentiation. Defaults to ForwardMode().

Returns:

Residuals array with shape (timepoints, *y0.shape).

The residuals are computed as: values - predicted[1:] (values should not include the initial condition).

regularization(parameters, weights, lambda_reg)[source]

Compute parameter regularization term.

This method computes a weighted L2 regularization term for the parameters to prevent overfitting and improve parameter stability during optimization. The regularization is computed as:

\[\text{reg} = \lambda \sum_i w_i p_i^2\]

where λ is the regularization coefficient, w_i are the weights, and p_i are the parameter values.

Parameters:
  • parameters (Dict[str, Any]) – Current parameter values for the equation. Can contain nested structures (pytrees) of parameters.

  • weights (Dict[str, Any]) – Regularization weights with the same structure as parameters. Higher weights penalize large parameter values more strongly. None values in weights are ignored.

  • lambda_reg (float) – Regularization coefficient controlling the overall strength of the regularization term.

Returns:

Scalar regularization term to be added to the loss function.

residuals(parameters, y0s__values, solver_parameters, ts, weights, lambda_reg, adjoint=diffrax.ForwardMode)[source]

Compute batched residuals and regularization for parameter optimization.

This method computes the difference between model predictions and observed data for multiple trajectories simultaneously, along with parameter regularization. It’s used internally by the optimization algorithms in the train() method.

Parameters:
  • parameters (Dict[str, Any]) – Current parameter values for the equation.

  • y0s__values (Tuple[jax.Array, jax.Array]) – Tuple containing: - y0s: Batch of initial conditions - values: Batch of observed values

  • solver_parameters (Dict[str, Any]) – Parameters for the numerical solver.

  • ts (jax.Array) – Time points for the simulation.

  • weights (Dict[str, jax.Array]) – Regularization weights for each parameter. Should match the structure of parameters.

  • lambda_reg (float) – Regularization coefficient controlling the strength of parameter penalties.

  • adjoint (dfx.AbstractAdjoint, optional) – Adjoint mode for automatic differentiation. Defaults to ForwardMode().

Returns:

Tuple containing:
  • batch_residuals: Residuals array with shape (batch_size, timepoints, *y0.shape)

  • reg: Scalar regularization term

Return type:

Tuple[jax.Array, float]

mse(parameters, y0s__values, solver_parameters, ts, weights, lambda_reg, adjoint=diffrax.RecursiveCheckpointAdjoint)[source]

Compute the mean squared error loss for parameter optimization.

This method computes the mean squared error between model predictions and observed data, plus a regularization term. It’s used as the objective function for the “mse” optimization method in the train() method.

The loss function is:

\[\text{MSE} = \text{mean}((\text{predicted} - \text{observed})^2) + \lambda \cdot \text{regularization}\]
Parameters:
  • parameters (Dict[str, Any]) – Current parameter values for the equation.

  • y0s__values (Tuple[jax.Array, jax.Array]) – Tuple containing: - y0s: Batch of initial conditions - values: Batch of observed values

  • solver_parameters (Dict[str, Any]) – Parameters for the numerical solver.

  • ts (jax.Array) – Time points for the simulation.

  • weights (Dict[str, jax.Array]) – Regularization weights for each parameter. Should match the structure of parameters.

  • lambda_reg (float) – Regularization coefficient controlling the strength of parameter penalties.

  • adjoint (dfx.AbstractAdjoint, optional) – Adjoint mode for automatic differentiation. Defaults to RecursiveCheckpointAdjoint().

Returns:

Mean squared error loss including regularization term.

Return type:

float

train(data, inds, opt_parameters, other_parameters, solver_parameters, weights, lambda_reg, method='least_squares', max_steps=100)[source]

Train the model by optimizing parameters to fit observed data.

This method performs parameter estimation by minimizing the difference between model predictions and observed data. It supports two optimization approaches: least-squares (which uses the Levenberg-Marquardt algorithm) and mean squared error minimization (which uses the BFGS algorithm).

Parameters:
  • data (Dict[str, List]) – Training data dictionary with keys: - “ys”: List of solution snapshots at different times - “ts”: List of corresponding time points Example: {“ys”: [y0, y1, y2, …], “ts”: [0, 0.1, 0.2, …]}

  • inds (List[List[int]]) – Indices specifying which data points to use for each training trajectory. Each inner list represents a trajectory: [initial_time_index, …intermediate_indices…]. Example: [[0, 1, 2], [0, 1, 2]] for two trajectories using time points 0, 1, 2.

  • opt_parameters (Dict[str, jax.Array]) – Parameters to optimize.

  • other_parameters (Dict[str, Any]) – Fixed parameters that won’t be optimized.

  • solver_parameters (Dict[str, Any]) – Parameters for the numerical solver during optimization. Passed to the solver constructor.

  • weights (Dict[str, jax.Array]) – Regularization weights for each parameter. Should have the same structure as opt_parameters.

  • lambda_reg (float) – Regularization coefficient. Controls the strength of parameter regularization.

  • method (str, optional) –

    Optimization method. Options: - “least_squares”: Uses Levenberg-Marquardt algorithm with ForwardMode

    adjoint. Best when parameter number is small (not using neural networks).

    • ”mse”: Uses BFGS algorithm with RecursiveCheckpointAdjoint. Better when parameter number is large (using neural networks).

  • max_steps (int, optional) – Maximum number of optimization iterations. Defaults to 100.

Returns:

Optimized parameters combined with fixed parameters. The returned dictionary contains both the optimized parameters and the other_parameters, ready for use in the solve() method.

Return type:

Dict[str, Any]

optimize(objective_function, y0, ts, opt_parameters, other_parameters, solver_parameters, weights, lambda_reg, max_steps=100)[source]

Optimize parameters by minimizing a scalar function of the solution.

This method performs parameter optimization by minimizing a user-provided scalar function that takes the solution as input. It uses the BFGS algorithm for optimization and supports parameter regularization.

Parameters:
  • objective_function (Callable) – A callable function that takes the solution array (shape: (len(ts), *y0.shape)) and returns a scalar value to minimize. The function should be JAX-compatible for automatic differentiation.

  • y0 – Initial condition array. Shape should match the spatial dimensions of the domain.

  • ts – Time points at which to save the solution. Should be a 1D array of increasing time values.

  • opt_parameters (Dict[str, jax.Array]) – Parameters to optimize.

  • other_parameters (Dict[str, Any]) – Fixed parameters that won’t be optimized.

  • solver_parameters (Dict[str, Any]) – Parameters for the numerical solver during optimization. Passed to the solver constructor.

  • weights (Dict[str, jax.Array]) – Regularization weights for each parameter. Should have the same structure as opt_parameters.

  • lambda_reg (float) – Regularization coefficient. Controls the strength of parameter regularization.

  • max_steps (int, optional) – Maximum number of optimization iterations. Defaults to 100.

Returns:

Optimized parameters combined with fixed parameters. The returned dictionary contains both the optimized parameters and the other_parameters, ready for use in the solve() method.

Return type:

Dict[str, Any]