Source code for pde_opt.numerics.equations.allen_cahn

"""
This module contains various Allen-Cahn equation classes.
"""

import dataclasses
from typing import Callable, Union
import jax
import jax.numpy as jnp
import equinox as eqx

from ..domains import Domain
from .base_eq import BaseEquation
from ..utils.derivatives import _lap_2nd_2D
from ..utils.derivatives import (
    _gradx_c,
    _grady_c,
    _avgx_c2f,
    _avgy_c2f,
    _divx_f2c,
    _divy_f2c,
    _gradx_c2f,
    _grady_c2f,
)


[docs] @dataclasses.dataclass class AllenCahn2DPeriodic(BaseEquation): """Allen-Cahn equation in 2D with periodic boundary conditions. The Allen-Cahn equation describes phase transitions and interface dynamics. The equation is: .. math:: \\frac{\\partial u}{\\partial t} = -R(u) \\mu where u is the concentration, R(u) is the reaction term, μ is the chemical potential, and κ is a parameter (the gradient energy coefficient). The chemical potential is given by: .. math:: \\mu = \\mu_h(u) - \\kappa \\nabla^2 u """ domain: Domain # The computational domain for the equation """Domain of the equation""" kappa: float """Gradient energy coefficient""" mu: Union[Callable, eqx.Module] # Can be a callable or Equinox module """Function for the chemical potential""" R: Union[Callable, eqx.Module] # Can be a callable or Equinox module """Function for the reaction term""" derivs: str = "fd" """Type of derivative computation"""
[docs] def rhs(self, state, t): raise NotImplementedError("rhs method not implemented")
def __post_init__(self): self.kx, self.ky = self.domain.fft_mesh() self.two_pi_i_kx = 2j * jnp.pi * self.kx self.two_pi_i_ky = 2j * jnp.pi * self.ky self.two_pi_i_kx_2 = (self.two_pi_i_kx) ** 2 self.two_pi_i_ky_2 = (self.two_pi_i_ky) ** 2 self.two_pi_i_k_2 = self.two_pi_i_kx_2 + self.two_pi_i_ky_2 self.fft = jnp.fft.fftn self.ifft = jnp.fft.ifftn if self.derivs == "fourier": self.rhs = jax.jit(self.rhs_fourier) elif self.derivs == "fd": self.rhs = jax.jit(self.rhs_fd) else: raise ValueError(f"Invalid derivative type: {self.derivs}")
[docs] def rhs_fourier(self, state, t): state_hat = self.fft(state) mu = self.ifft( self.fft(self.mu(state)) - self.kappa * (self.two_pi_i_k_2) * state_hat ).real return -self.R(state) * mu
[docs] def rhs_fd(self, state, t): hx, hy = self.domain.dx mu = self.mu(state) - self.kappa * _lap_2nd_2D(state, hx, hy) return -self.R(state) * mu
[docs] @dataclasses.dataclass class AllenCahn2DSmoothedBoundary(BaseEquation): """Allen-Cahn equation with smoothed boundary method for arbitrary geometries. This class implements the Allen-Cahn equation using the smoothed boundary method, which allows for complex domain geometries through a smooth level-set function ψ. The equation is: .. math:: \\frac{\\partial u}{\\partial t} = -R(u) \\mu where the chemical potential includes boundary effects: .. math:: \\mu = \\mu_h(u) - \\frac{\\kappa}{\\psi} \\nabla \\cdot (\\psi \\nabla u) - \\sqrt{\\kappa} \\frac{|\\nabla \\psi|}{\\psi} \\sqrt{2f} \\cos(\\theta) """ domain: Domain """Domain of the equation""" kappa: float """Gradient energy coefficient""" f: Union[Callable, eqx.Module] # Can be a callable or Equinox module """Function for the free energy density""" mu: Union[Callable, eqx.Module] # Can be a callable or Equinox module """Function for the chemical potential""" R: Union[Callable, eqx.Module] # Can be a callable or Equinox module """Function for the reaction term""" theta: Union[Callable, eqx.Module] # Can be a callable or Equinox module """Function for the contact angle""" derivs: str = "fd" """Type of derivative computation"""
[docs] def rhs(self, state, t): raise NotImplementedError("rhs method not implemented")
def __post_init__(self): self.psi = self.domain.geometry.smooth self.sqrt_kappa = jnp.sqrt(self.kappa) self.hx, self.hy = self.domain.dx self.norm_grad_psi = ( jnp.sqrt( _gradx_c(self.psi, self.hx) ** 2 + _grady_c(self.psi, self.hy) ** 2 ) / self.psi ) self.left_half = jnp.zeros_like(self.psi) self.left_half = self.left_half.at[:, :100].set(1.0) if self.derivs == "fd": self.rhs = jax.jit(self.rhs_fd) else: raise ValueError(f"Invalid derivative type: {self.derivs}")
[docs] def rhs_fd(self, state, t): f = self.f(state) mu = self.mu(state) mask_avgx = _avgx_c2f(self.psi) mask_avgy = _avgy_c2f(self.psi) mu += ( -(self.kappa / self.psi) * ( _divx_f2c(mask_avgx * _gradx_c2f(state, self.hx), self.hx) + _divy_f2c(mask_avgy * _grady_c2f(state, self.hy), self.hy) ) - self.sqrt_kappa * self.norm_grad_psi * jnp.sqrt(2.0 * f) * jnp.cos(self.theta(t)) * self.left_half ) return -self.R(state) * mu
[docs] @dataclasses.dataclass class AllenCahn2DPeriodicButlerVolmer(BaseEquation): domain: Domain # The computational domain for the equation """Domain of the equation""" kappa: float """Gradient energy coefficient""" mu: Union[Callable, eqx.Module] # Can be a callable or Equinox module """Function for the chemical potential""" j0: Union[Callable, eqx.Module] # Can be a callable or Equinox module """Function for the exchange current""" alpha: float """Symmetry factor""" derivs: str = "fd" """Type of derivative computation"""
[docs] def rhs(self, state, t): raise NotImplementedError("rhs method not implemented")
def __post_init__(self): self.kx, self.ky = self.domain.fft_mesh() self.two_pi_i_kx = 2j * jnp.pi * self.kx self.two_pi_i_ky = 2j * jnp.pi * self.ky self.two_pi_i_kx_2 = (self.two_pi_i_kx) ** 2 self.two_pi_i_ky_2 = (self.two_pi_i_ky) ** 2 self.two_pi_i_k_2 = self.two_pi_i_kx_2 + self.two_pi_i_ky_2 self.fft = jnp.fft.fftn self.ifft = jnp.fft.ifftn if self.derivs == "fourier": self.rhs = jax.jit(self.rhs_fourier) elif self.derivs == "fd": self.rhs = jax.jit(self.rhs_fd) else: raise ValueError(f"Invalid derivative type: {self.derivs}")
[docs] def rhs_fourier(self, state, t): state_hat = self.fft(state) mu = self.ifft( self.fft(self.mu(state)) - self.kappa * (self.two_pi_i_k_2) * state_hat ).real return -self.R(state) * mu
[docs] def rhs_fd(self, state, t, v): hx, hy = self.domain.dx mu = self.mu(state) - self.kappa * _lap_2nd_2D(state, hx, hy) eta = mu + v return self.j0(state) * ( jnp.exp(-self.alpha * eta) - jnp.exp((1.0 - self.alpha) * eta) )
[docs] @dataclasses.dataclass class AllenCahn2DPeriodicButlerVolmerConstantCurrent(BaseEquation): domain: Domain # The computational domain for the equation """Domain of the equation""" kappa: float """Gradient energy coefficient""" mu: Union[Callable, eqx.Module] # Can be a callable or Equinox module """Function for the chemical potential""" j0: Union[Callable, eqx.Module] # Can be a callable or Equinox module """Function for the exchange current""" alpha: float """Symmetry factor""" Crate: float """Current""" derivs: str = "fd" """Type of derivative computation"""
[docs] def rhs(self, state, t): raise NotImplementedError("rhs method not implemented")
def __post_init__(self): self.kx, self.ky = self.domain.fft_mesh() self.two_pi_i_kx = 2j * jnp.pi * self.kx self.two_pi_i_ky = 2j * jnp.pi * self.ky self.two_pi_i_kx_2 = (self.two_pi_i_kx) ** 2 self.two_pi_i_ky_2 = (self.two_pi_i_ky) ** 2 self.two_pi_i_k_2 = self.two_pi_i_kx_2 + self.two_pi_i_ky_2 self.fft = jnp.fft.fftn self.ifft = jnp.fft.ifftn if self.derivs == "fourier": self.rhs = jax.jit(self.rhs_fourier) elif self.derivs == "fd": self.rhs = jax.jit(self.rhs_fd) else: raise ValueError(f"Invalid derivative type: {self.derivs}")
[docs] def rhs_fourier(self, state, t): state_hat = self.fft(state) mu = self.ifft( self.fft(self.mu(state)) - self.kappa * (self.two_pi_i_k_2) * state_hat ).real return -self.R(state) * mu
[docs] def rhs_fd(self, state, t): hx, hy = self.domain.dx mu = self.mu(state) - self.kappa * _lap_2nd_2D(state, hx, hy) int_plus = jnp.sum(self.j0(state) * jnp.exp(0.5 * mu)) * hx * hy int_minus = jnp.sum(self.j0(state) * jnp.exp(-0.5 * mu)) * hx * hy y = (-self.Crate + jnp.sqrt(self.Crate**2 + 4.0 * int_plus * int_minus)) / ( 2.0 * int_plus ) # Compute v to satisfy constant current constraint v = 2.0 * jnp.log(y) eta = mu + v return self.j0(state) * ( jnp.exp(-self.alpha * eta) - jnp.exp((1.0 - self.alpha) * eta) )
[docs] def get_voltage(self, state): hx, hy = self.domain.dx mu = self.mu(state) - self.kappa * _lap_2nd_2D(state, hx, hy) int_plus = jnp.sum(self.j0(state) * jnp.exp(0.5 * mu)) * hx * hy int_minus = jnp.sum(self.j0(state) * jnp.exp(-0.5 * mu)) * hx * hy y = (-self.Crate + jnp.sqrt(self.Crate**2 + 4.0 * int_plus * int_minus)) / ( 2.0 * int_plus ) # Compute v to satisfy constant current constraint return 2.0 * jnp.log(y)
[docs] @dataclasses.dataclass class AllenCahn2DSmoothedBoundaryButlerVolmerConstantCurrent(BaseEquation): domain: Domain """Domain of the equation""" kappa: float """Gradient energy coefficient""" f: Union[Callable, eqx.Module] # Can be a callable or Equinox module """Function for the free energy density""" mu: Union[Callable, eqx.Module] # Can be a callable or Equinox module """Function for the chemical potential""" j0: Union[Callable, eqx.Module] # Can be a callable or Equinox module """Function for the exchange current""" alpha: float """Symmetry factor""" Crate: float """Current""" derivs: str = "fd" """Type of derivative computation"""
[docs] def rhs(self, state, t): raise NotImplementedError("rhs method not implemented")
def __post_init__(self): self.psi = self.domain.geometry.smooth self.sqrt_kappa = jnp.sqrt(self.kappa) self.hx, self.hy = self.domain.dx self.norm_grad_psi = ( jnp.sqrt( _gradx_c(self.psi, self.hx) ** 2 + _grady_c(self.psi, self.hy) ** 2 ) / self.psi ) self.left_half = jnp.zeros_like(self.psi) self.left_half = self.left_half.at[:, :100].set(1.0) if self.derivs == "fd": self.rhs = jax.jit(self.rhs_fd) else: raise ValueError(f"Invalid derivative type: {self.derivs}")
[docs] def rhs_fd(self, state, t): # f = self.f(state) mu = self.mu(state) mask_avgx = _avgx_c2f(self.psi) mask_avgy = _avgy_c2f(self.psi) mu += ( -(self.kappa / self.psi) * ( _divx_f2c(mask_avgx * _gradx_c2f(state, self.hx), self.hx) + _divy_f2c(mask_avgy * _grady_c2f(state, self.hy), self.hy) ) # - self.sqrt_kappa # * self.norm_grad_psi # * jnp.sqrt(2.0 * f) # * jnp.cos(self.theta(t)) # * self.left_half ) int_plus = ( jnp.sum(self.j0(state) * jnp.exp(0.5 * mu) * self.psi) * self.hx * self.hy ) int_minus = ( jnp.sum(self.j0(state) * jnp.exp(-0.5 * mu) * self.psi) * self.hx * self.hy ) y = (-self.Crate + jnp.sqrt(self.Crate**2 + 4.0 * int_plus * int_minus)) / ( 2.0 * int_plus ) # Compute v to satisfy constant current constraint v = 2.0 * jnp.log(y) eta = mu + v return self.j0(state) * ( jnp.exp(-self.alpha * eta) - jnp.exp((1.0 - self.alpha) * eta) )
[docs] def get_voltage(self, state): # f = self.f(state) mu = self.mu(state) mask_avgx = _avgx_c2f(self.psi) mask_avgy = _avgy_c2f(self.psi) mu += ( -(self.kappa / self.psi) * ( _divx_f2c(mask_avgx * _gradx_c2f(state, self.hx), self.hx) + _divy_f2c(mask_avgy * _grady_c2f(state, self.hy), self.hy) ) # - self.sqrt_kappa # * self.norm_grad_psi # * jnp.sqrt(2.0 * f) # * jnp.cos(self.theta(t)) # * self.left_half ) int_plus = ( jnp.sum(self.j0(state) * jnp.exp(0.5 * mu) * self.psi) * self.hx * self.hy ) int_minus = ( jnp.sum(self.j0(state) * jnp.exp(-0.5 * mu) * self.psi) * self.hx * self.hy ) y = (-self.Crate + jnp.sqrt(self.Crate**2 + 4.0 * int_plus * int_minus)) / ( 2.0 * int_plus ) # Compute v to satisfy constant current constraint return 2.0 * jnp.log(y)