Source code for pde_opt.numerics.solvers

"""Custom numerical solvers for partial differential equations.

This module provides specialized numerical solvers that extend the diffrax library
for solving specific types of PDEs. The solvers are designed to work with the
equation classes defined in the numerics.equations module.

Available Solvers:
    SemiImplicitFourierSpectral: Semi-implicit Fourier spectral method.

    StrangSplitting: Strang splitting method for equations with separable operators
        (e.g., Gross-Pitaevskii equation).

All solvers inherit from diffrax.AbstractSolver and are compatible with the
diffrax integration framework.
"""

import diffrax as dfx
import jax
import jax.numpy as jnp
from typing import Callable


[docs] class SemiImplicitFourierSpectral(dfx.AbstractSolver): """Semi-implicit Fourier spectral method. This solver implements a semi-implicit Fourier spectral method for phase-field simulations with variable mobility. Required Equation Attributes: fourier_symbol: Fourier space representation of the highest order differential operator. fft: Forward Fourier transform function. ifft: Inverse Fourier transform function. Parameters: A (float): Constant for splitting the mobility term. References: Zhu, Jingzhi, et al. "Coarsening kinetics from a variable-mobility Cahn-Hilliard equation: Application of a semi-implicit Fourier spectral method." Physical Review E 60.4 (1999): 3564. """ required_equation_attrs = ["fourier_symbol", "fft", "ifft"] A: float fourier_symbol: jax.Array fft: Callable ifft: Callable term_structure = dfx.ODETerm interpolation_cls = dfx.LocalLinearInterpolation
[docs] def order(self, terms): return 1
[docs] def init(self, terms, t0, t1, y0, args): return None
[docs] def step(self, terms, t0, t1, y0, args, solver_state, made_jump): del solver_state, made_jump δt = t1 - t0 f0 = terms.vf(t0, y0, args) euler_y1 = y0 + δt * f0 tmp = 1.0 + self.A * δt * self.fourier_symbol y1 = y0 + δt * self.ifft(self.fft(f0) / tmp).real y_error = y1 - euler_y1 dense_info = dict(y0=y0, y1=y1) solver_state = None result = dfx.RESULTS.successful return y1, y_error, dense_info, solver_state, result
[docs] def func(self, terms, t0, y0, args): return terms.vf(t0, y0, args)
[docs] class StrangSplitting(dfx.AbstractSolver): """Strang splitting method for time-dependent PDEs with separable operators. References: Bao, Weizhu, and Yongyong Cai. "Mathematical theory and numerical methods for Bose-Einstein condensation." arXiv preprint arXiv:1212.5341 (2012). """ required_equation_attrs = ["A_term", "dx", "fft", "ifft"] A_term: jax.Array dx: float fft: Callable ifft: Callable time_scale: float term_structure = dfx.ODETerm interpolation_cls = dfx.LocalLinearInterpolation
[docs] def order(self, terms): return 1
[docs] def init(self, terms, t0, t1, y0, args): return None
[docs] def step(self, terms, t0, t1, y0, args, solver_state, made_jump): del solver_state, made_jump δt = (t1 - t0) * self.time_scale y0_ = y0[..., 0] + 1j * y0[..., 1] exp_A_term = jnp.exp(self.A_term * 0.5 * δt) tmp = self.fft(y0_) * exp_A_term # 1 tmp = self.ifft(tmp) # 2 b_term = terms.vf(t0, y0, args) tmp = tmp * jnp.exp((b_term[..., 0] + 1j * b_term[..., 1]) * δt) # 3 tmp /= jnp.sqrt(jnp.sum(jnp.abs(tmp) ** 2) * self.dx**2) # 4 tmp = self.fft(tmp) # 5 tmp = tmp * exp_A_term # 6 y1_ = self.ifft(tmp) # 7 y1 = jnp.stack([y1_.real, y1_.imag], axis=-1) # TODO: I should be able to change order and reduce number of ffts dense_info = dict(y0=y0, y1=y1) solver_state = None result = dfx.RESULTS.successful return y1, None, dense_info, solver_state, result
[docs] def func(self, terms, t0, y0, args): return NotImplementedError