"""
This module contains a symbolic equation class for the Allen-Cahn equation.
"""
from dataclasses import dataclass
from typing import Callable
import sympy as sp
from sympy.utilities.lambdify import lambdify
import jax.numpy as jnp # only to return jnp arrays if you like; optional
[docs]
@dataclass
class SymbolicAllenCahn2DPeriodic:
"""Build exact RHS for Allen–Cahn equation, used only in tests."""
domain: object
kappa: float
mu_sym: Callable[[sp.Expr], sp.Expr] # e.g., lambda u: u**3 - u
R_sym: Callable[[sp.Expr], sp.Expr] # e.g., lambda u: 1
u_star: sp.Expr # test solution u*(x,y,t)
def __post_init__(self):
x, y, t = sp.symbols("x y t", real=True)
u = self.u_star
u_xx = sp.diff(u, x, 2)
u_yy = sp.diff(u, y, 2)
mu_expr = self.mu_sym(u) - self.kappa * (u_xx + u_yy)
rhs_expr = -self.R_sym(u) * mu_expr
# Cache fast array-callables
self._u_fn = lambdify((x, y, t), sp.simplify(u), "numpy")
self._rhs_fn = lambdify((x, y, t), sp.simplify(rhs_expr), "numpy")
# ---- Public evaluators for tests ----
[docs]
def u_exact(self, t: float):
"""Exact solution for the equation"""
X, Y = self.domain.mesh()
return jnp.asarray(self._u_fn(X, Y, float(t)))
[docs]
def rhs_exact(self, t: float):
"""Exact RHS for the equation"""
X, Y = self.domain.mesh()
return jnp.asarray(self._rhs_fn(X, Y, float(t)))