Source code for pde_opt.numerics.equations.cahn_hilliard

"""
This module contains various Cahn-Hilliard equation classes.
"""

import dataclasses
from typing import Callable, Union
import jax
import jax.numpy as jnp
import equinox as eqx

from ..domains import Domain
from .base_eq import BaseEquation
from ..utils.derivatives import (
    _lap_2nd_2D,
    _lap_2nd_3D,
    _gradx_c2f,
    _grady_c2f,
    _gradz_c2f,
    _avgx_c2f,
    _avgy_c2f,
    _avgz_c2f,
    _divx_f2c,
    _divy_f2c,
    _divz_f2c,
    _gradx_c,
    _grady_c,
)


[docs] @dataclasses.dataclass class CahnHilliard2DPeriodic(BaseEquation): """Cahn–Hilliard equation in 2D with periodic boundary conditions. The Cahn-Hilliard equation describes phase separation and coarsening dynamics. The equation is: .. math:: \\frac{\\partial u}{\\partial t} = \\nabla \\cdot (D(u) \\nabla \\mu) where u is the concentration, D(u) is the mobility, and μ is the chemical potential. The chemical potential is given by: .. math:: \\mu = \\mu_h(u) - \\kappa \\nabla^2 u """ domain: Domain # The computational domain for the equation """Domain of the equation""" kappa: float """Gradient energy coefficient""" mu: Union[Callable, eqx.Module] # Can be a callable or Equinox module """Function for the chemical potential""" D: Union[Callable, eqx.Module] # Can be a callable or Equinox module """Function for the mobility""" derivs: str = "fd" """Type of derivative computation""" fft = None ifft = None fourier_symbol = None
[docs] def rhs(self, state, t): raise NotImplementedError("rhs method not implemented")
def __post_init__(self): self.kx, self.ky = self.domain.fft_mesh() self.two_pi_i_kx = 2j * jnp.pi * self.kx self.two_pi_i_ky = 2j * jnp.pi * self.ky self.two_pi_i_kx_2 = (self.two_pi_i_kx) ** 2 self.two_pi_i_ky_2 = (self.two_pi_i_ky) ** 2 self.two_pi_i_k_2 = self.two_pi_i_kx_2 + self.two_pi_i_ky_2 self.two_pi_i_k_4 = (self.two_pi_i_k_2) ** 2 self.fft = jnp.fft.fftn self.ifft = jnp.fft.ifftn self.fourier_symbol = self.kappa * self.two_pi_i_k_4 if self.derivs == "fourier": self.rhs = jax.jit(self.rhs_fourier) elif self.derivs == "fd": self.rhs = jax.jit(self.rhs_fd) else: raise ValueError(f"Invalid derivative type: {self.derivs}")
[docs] def rhs_fourier(self, state, t): state_hat = self.fft(state) tmp = self.fft(self.mu(state)) - self.kappa * (self.two_pi_i_k_2) * state_hat tmpx = self.fft(self.D(state) * self.ifft(self.two_pi_i_kx * tmp)) tmpy = self.fft(self.D(state) * self.ifft(self.two_pi_i_ky * tmp)) return self.ifft(self.two_pi_i_kx * tmpx + self.two_pi_i_ky * tmpy).real
[docs] def rhs_fd(self, state, t): hx, hy = self.domain.dx # chemical potential: μ = μ_nl(u) - κ Δu mu = self.mu(state) - self.kappa * _lap_2nd_2D(state, hx, hy) # gradients of μ at faces mux_f = _gradx_c2f(mu, hx) muy_f = _grady_c2f(mu, hy) # mobility at faces Du = self.D(state) Dx_f = _avgx_c2f(Du) Dy_f = _avgy_c2f(Du) # fluxes at faces Fx = Dx_f * mux_f Fy = Dy_f * muy_f # divergence back to centers return _divx_f2c(Fx, hx) + _divy_f2c(Fy, hy)
[docs] @dataclasses.dataclass class CahnHilliard3DPeriodic(BaseEquation): """Cahn–Hilliard equation in 3D with periodic boundary conditions. The Cahn-Hilliard equation describes phase separation and coarsening dynamics. The equation is: .. math:: \\frac{\\partial u}{\\partial t} = \\nabla \\cdot (D(u) \\nabla \\mu) where u is the concentration, D(u) is the mobility, and μ is the chemical potential. The chemical potential is given by: .. math:: \\mu = \\mu_h(u) - \\kappa \\nabla^2 u """ domain: Domain # The computational domain for the equation """Domain of the equation""" kappa: float """Gradient energy coefficient""" mu: Union[Callable, eqx.Module] # Can be a callable or Equinox module """Function for the chemical potential""" D: Union[Callable, eqx.Module] # Can be a callable or Equinox module """Function for the mobility""" derivs: str = "fd" """Type of derivative computation""" fft = None ifft = None fourier_symbol = None
[docs] def rhs(self, state, t): raise NotImplementedError("rhs method not implemented")
def __post_init__(self): self.kx, self.ky, self.kz = self.domain.fft_mesh() self.two_pi_i_kx = 2j * jnp.pi * self.kx self.two_pi_i_ky = 2j * jnp.pi * self.ky self.two_pi_i_kz = 2j * jnp.pi * self.kz self.two_pi_i_kx_2 = (self.two_pi_i_kx) ** 2 self.two_pi_i_ky_2 = (self.two_pi_i_ky) ** 2 self.two_pi_i_kz_2 = (self.two_pi_i_kz) ** 2 self.two_pi_i_k_2 = self.two_pi_i_kx_2 + self.two_pi_i_ky_2 + self.two_pi_i_kz_2 self.two_pi_i_k_4 = (self.two_pi_i_k_2) ** 2 self.fft = jnp.fft.fftn self.ifft = jnp.fft.ifftn self.fourier_symbol = self.kappa * self.two_pi_i_k_4 if self.derivs == "fourier": self.rhs = jax.jit(self.rhs_fourier) elif self.derivs == "fd": self.rhs = jax.jit(self.rhs_fd) else: raise ValueError(f"Invalid derivative type: {self.derivs}")
[docs] def rhs_fourier(self, state, t): state_hat = self.fft(state) tmp = self.fft(self.mu(state)) - self.kappa * (self.two_pi_i_k_2) * state_hat tmpx = self.fft(self.D(state) * self.ifft(self.two_pi_i_kx * tmp)) tmpy = self.fft(self.D(state) * self.ifft(self.two_pi_i_ky * tmp)) tmpz = self.fft(self.D(state) * self.ifft(self.two_pi_i_kz * tmp)) return self.ifft( self.two_pi_i_kx * tmpx + self.two_pi_i_ky * tmpy + self.two_pi_i_kz * tmpz ).real
[docs] def rhs_fd(self, state, t): hx, hy, hz = self.domain.dx # chemical potential: μ = μ_nl(u) - κ Δu mu = self.mu(state) - self.kappa * _lap_2nd_3D(state, hx, hy, hz) # gradients of μ at faces mux_f = _gradx_c2f(mu, hx) muy_f = _grady_c2f(mu, hy) muz_f = _gradz_c2f(mu, hz) # mobility at faces Du = self.D(state) Dx_f = _avgx_c2f(Du) Dy_f = _avgy_c2f(Du) Dz_f = _avgz_c2f(Du) # fluxes at faces Fx = Dx_f * mux_f Fy = Dy_f * muy_f Fz = Dz_f * muz_f # divergence back to centers return _divx_f2c(Fx, hx) + _divy_f2c(Fy, hy) + _divz_f2c(Fz, hz)
[docs] @dataclasses.dataclass class CahnHilliard2DSmoothedBoundary(BaseEquation): """Cahn–Hilliard equation with smoothed boundary method for arbitrary geometries. This class implements the Cahn-Hilliard equation using the smoothed boundary method, which allows for complex domain geometries through a smooth level-set function ψ. The equation is: .. math:: \\frac{\\partial u}{\\partial t} = \\frac{1}{\\psi} \\nabla \\cdot (\\psi D(u) \\nabla \\mu) + \\frac{|\\nabla \\psi|}{\\psi} J_n where the chemical potential includes boundary effects: .. math:: \\mu = \\mu_h(u) - \\frac{\\kappa}{\\psi} \\nabla \\cdot (\\psi \\nabla u) - \\sqrt{\\kappa} \\frac{|\\nabla \\psi|}{\\psi} \\sqrt{2f} \\cos(\\theta) """ domain: Domain """Domain of the equation""" kappa: float """Gradient energy coefficient""" f: Union[Callable, eqx.Module] # Can be a callable or Equinox module """Function for the free energy density""" mu: Union[Callable, eqx.Module] # Can be a callable or Equinox module """Function for the chemical potential""" D: Union[Callable, eqx.Module] # Can be a callable or Equinox module """Function for the mobility""" theta: Union[Callable, eqx.Module] # Can be a callable or Equinox module """Function for the contact angle""" flux: Union[Callable, eqx.Module] # Can be a callable or Equinox module """Function for the normal flux""" derivs: str = "fd" """Type of derivative computation"""
[docs] def rhs(self, state, t): raise NotImplementedError("rhs method not implemented")
def __post_init__(self): self.psi = self.domain.geometry.smooth self.sqrt_kappa = jnp.sqrt(self.kappa) self.hx, self.hy = self.domain.dx self.norm_grad_psi = ( jnp.sqrt( _gradx_c(self.psi, self.hx) ** 2 + _grady_c(self.psi, self.hy) ** 2 ) / self.psi ) self.left_half = jnp.zeros_like(self.psi) self.left_half = self.left_half.at[:50, :].set(1.0) if self.derivs == "fd": self.rhs = jax.jit(self.rhs_fd) else: raise ValueError(f"Invalid derivative type: {self.derivs}")
[docs] def rhs_fd(self, state, t): f = self.f(state) mu = self.mu(state) mask_avgx = _avgx_c2f(self.psi) mask_avgy = _avgy_c2f(self.psi) inner_term = ( mu - (self.kappa / self.psi) * ( _divx_f2c(mask_avgx * _gradx_c2f(state, self.hx), self.hx) + _divy_f2c(mask_avgy * _grady_c2f(state, self.hy), self.hy) ) - self.sqrt_kappa * self.norm_grad_psi * jnp.sqrt(2.0 * f) * ( jnp.cos(self.theta(t)) * self.left_half + jnp.cos(jnp.pi - self.theta(t)) * (1.0 - self.left_half) ) ) gradx_inner = _gradx_c2f(inner_term, self.hx) grady_inner = _grady_c2f(inner_term, self.hy) Du = self.D(state) Dx_f = _avgx_c2f(Du) Dy_f = _avgy_c2f(Du) Fx = mask_avgx * Dx_f * gradx_inner Fy = mask_avgy * Dy_f * grady_inner return ( _divx_f2c(Fx, self.hx) + _divy_f2c(Fy, self.hy) ) / self.psi + self.norm_grad_psi * self.flux(t)