Source code for pde_opt.numerics.functions.legendre

"""
This module contains the Legendre polynomial expansion class for representing functions in PDEs.
"""

import equinox as eqx
import jax
import jax.numpy as jnp
from typing import Callable
import dataclasses


[docs] class LegendrePolynomialExpansion(eqx.Module): """Legendre polynomial expansion.""" params: jax.Array # shape (max_degree+1,) max_degree: int
[docs] def __init__(self, params: jax.Array): super().__init__() self.params = params self.max_degree = len(params) - 1
def __call__(self, inputs): # Inputs are assumed to be in [-1, 1] result = self.params[0] * jnp.ones_like(inputs) if self.max_degree >= 1: result += self.params[1] * inputs p_prev = jnp.ones_like(inputs) p_curr = inputs for n in range(2, self.max_degree + 1): p_next = ((2 * n - 1) * inputs * p_curr - (n - 1) * p_prev) / n result += self.params[n] * p_next p_prev, p_curr = p_curr, p_next return result
[docs] class DiffusionLegendrePolynomials(eqx.Module): """Diffusion Legendre polynomials. Uses exp to ensure positivity. """ expansion: LegendrePolynomialExpansion
[docs] def __init__(self, params: jax.Array): super().__init__() self.expansion = LegendrePolynomialExpansion(params)
@eqx.filter_jit def __call__(self, inputs): # Scale inputs to [-1, 1] and apply exp to ensure positivity scaled_inputs = 2.0 * inputs - 1.0 return jnp.exp(self.expansion(scaled_inputs))
[docs] class ChemicalPotentialLegendrePolynomials(eqx.Module): """Chemical potential Legendre polynomials.""" expansion: LegendrePolynomialExpansion prior_fn: Callable
[docs] def __init__(self, params: jax.Array, prior_fn: Callable = None): super().__init__() self.expansion = LegendrePolynomialExpansion(params) self.prior_fn = prior_fn
@eqx.filter_jit def __call__(self, inputs): # Scale inputs to [-1, 1] scaled_inputs = 2.0 * inputs - 1.0 result = self.expansion(scaled_inputs) if self.prior_fn is not None: result += self.prior_fn(inputs) return result
[docs] class FixedDegreeChemicalPotential(eqx.Module): """Chemical potential with fixed high degree (1000) but zeros for unused parameters. This class always uses degree 1000 Legendre polynomials, but zeros out parameters beyond the specified number. This ensures constant forward pass time while allowing gradient computation scaling tests. NOTE: THIS IS ONLY USED FOR TESTING PURPOSES. """ expansion: LegendrePolynomialExpansion prior_fn: Callable num_active_params: int
[docs] def __init__(self, params: jax.Array, prior_fn: Callable = None): super().__init__() # Always create expansion with degree 1000 full_params = jnp.zeros(1001) # 1000 degree + 1 for constant term full_params = full_params.at[:len(params)].set(params) self.expansion = LegendrePolynomialExpansion(full_params) self.prior_fn = prior_fn self.num_active_params = len(params)
@eqx.filter_jit def __call__(self, inputs): # Scale inputs to [-1, 1] scaled_inputs = 2.0 * inputs - 1.0 result = self.expansion(scaled_inputs) if self.prior_fn is not None: result += self.prior_fn(inputs) return result
[docs] class TestChemicalPotential(eqx.Module): """Test chemical potential.""" params: jax.Array prior_fn: Callable
[docs] def __init__(self, params: jax.Array, prior_fn: Callable = None): super().__init__() self.params = params self.prior_fn = prior_fn
@eqx.filter_jit def __call__(self, inputs): result = jnp.sum(self.params) * jnp.ones_like(inputs) if self.prior_fn is not None: result += self.prior_fn(inputs) return result
[docs] @dataclasses.dataclass class LegendrePolynomials: max_degree: int def __post_init__(self): if self.max_degree == 0: self.func = jax.jit(lambda p, x: p[0] * self.T0(x)) elif self.max_degree == 1: self.func = jax.jit(lambda p, x: p[0] * self.T0(x) + p[1] * self.T1(x)) elif self.max_degree == 2: self.func = jax.jit( lambda p, x: p[0] * self.T0(x) + p[1] * self.T1(x) + p[2] * self.T2(x) ) elif self.max_degree == 3: self.func = jax.jit( lambda p, x: p[0] * self.T0(x) + p[1] * self.T1(x) + p[2] * self.T2(x) + p[3] * self.T3(x) ) elif self.max_degree == 4: self.func = jax.jit( lambda p, x: p[0] * self.T0(x) + p[1] * self.T1(x) + p[2] * self.T2(x) + p[3] * self.T3(x) + p[4] * self.T4(x) ) elif self.max_degree == 5: self.func = jax.jit( lambda p, x: p[0] * self.T0(x) + p[1] * self.T1(x) + p[2] * self.T2(x) + p[3] * self.T3(x) + p[4] * self.T4(x) + p[5] * self.T5(x) ) elif self.max_degree == 6: self.func = jax.jit( lambda p, x: p[0] * self.T0(x) + p[1] * self.T1(x) + p[2] * self.T2(x) + p[3] * self.T3(x) + p[4] * self.T4(x) + p[5] * self.T5(x) + p[6] * self.T6(x) ) elif self.max_degree == 7: self.func = jax.jit( lambda p, x: p[0] * self.T0(x) + p[1] * self.T1(x) + p[2] * self.T2(x) + p[3] * self.T3(x) + p[4] * self.T4(x) + p[5] * self.T5(x) + p[6] * self.T6(x) + p[7] * self.T7(x) ) elif self.max_degree == 8: self.func = jax.jit( lambda p, x: p[0] * self.T0(x) + p[1] * self.T1(x) + p[2] * self.T2(x) + p[3] * self.T3(x) + p[4] * self.T4(x) + p[5] * self.T5(x) + p[6] * self.T6(x) + p[7] * self.T7(x) + p[8] * self.T8(x) ) elif self.max_degree == 9: self.func = jax.jit( lambda p, x: p[0] * self.T0(x) + p[1] * self.T1(x) + p[2] * self.T2(x) + p[3] * self.T3(x) + p[4] * self.T4(x) + p[5] * self.T5(x) + p[6] * self.T6(x) + p[7] * self.T7(x) + p[8] * self.T8(x) + p[9] * self.T9(x) ) elif self.max_degree == 10: self.func = jax.jit( lambda p, x: p[0] * self.T0(x) + p[1] * self.T1(x) + p[2] * self.T2(x) + p[3] * self.T3(x) + p[4] * self.T4(x) + p[5] * self.T5(x) + p[6] * self.T6(x) + p[7] * self.T7(x) + p[8] * self.T8(x) + p[9] * self.T9(x) + p[10] * self.T10(x) ) def __call__(self, params, inputs): return self.func(params, inputs)
[docs] def T0(self, x): return 1.0 * jnp.ones_like(x)
[docs] def T1(self, x): return x
[docs] def T2(self, x): return 0.5 * (3 * x**2 - 1.0)
[docs] def T3(self, x): return 0.5 * (5 * x**3 - 3 * x)
[docs] def T4(self, x): return 0.125 * (35 * x**4 - 30 * x**2 + 3)
[docs] def T5(self, x): return 0.125 * (63 * x**5 - 70 * x**3 + 15 * x)
[docs] def T6(self, x): return 0.0625 * (231 * x**6 - 315 * x**4 + 105 * x**2 - 5)
[docs] def T7(self, x): return 0.0625 * (429 * x**7 - 693 * x**5 + 315 * x**3 - 35 * x)
[docs] def T8(self, x): return 0.0078125 * (6435 * x**8 - 12012 * x**6 + 6930 * x**4 - 1260 * x**2 + 35)
[docs] def T9(self, x): return 0.0078125 * ( 12155 * x**9 - 25740 * x**7 + 18018 * x**5 - 4620 * x**3 + 315 * x )
[docs] def T10(self, x): return 0.00390625 * ( 46189 * x**10 - 109395 * x**8 + 90090 * x**6 - 30030 * x**4 + 3465 * x**2 - 63 )