"""
This module contains the Legendre polynomial expansion class for representing functions in PDEs.
"""
import equinox as eqx
import jax
import jax.numpy as jnp
from typing import Callable
import dataclasses
[docs]
class LegendrePolynomialExpansion(eqx.Module):
"""Legendre polynomial expansion."""
params: jax.Array # shape (max_degree+1,)
max_degree: int
[docs]
def __init__(self, params: jax.Array):
super().__init__()
self.params = params
self.max_degree = len(params) - 1
def __call__(self, inputs):
# Inputs are assumed to be in [-1, 1]
result = self.params[0] * jnp.ones_like(inputs)
if self.max_degree >= 1:
result += self.params[1] * inputs
p_prev = jnp.ones_like(inputs)
p_curr = inputs
for n in range(2, self.max_degree + 1):
p_next = ((2 * n - 1) * inputs * p_curr - (n - 1) * p_prev) / n
result += self.params[n] * p_next
p_prev, p_curr = p_curr, p_next
return result
[docs]
class DiffusionLegendrePolynomials(eqx.Module):
"""Diffusion Legendre polynomials.
Uses exp to ensure positivity.
"""
expansion: LegendrePolynomialExpansion
[docs]
def __init__(self, params: jax.Array):
super().__init__()
self.expansion = LegendrePolynomialExpansion(params)
@eqx.filter_jit
def __call__(self, inputs):
# Scale inputs to [-1, 1] and apply exp to ensure positivity
scaled_inputs = 2.0 * inputs - 1.0
return jnp.exp(self.expansion(scaled_inputs))
[docs]
class ChemicalPotentialLegendrePolynomials(eqx.Module):
"""Chemical potential Legendre polynomials."""
expansion: LegendrePolynomialExpansion
prior_fn: Callable
[docs]
def __init__(self, params: jax.Array, prior_fn: Callable = None):
super().__init__()
self.expansion = LegendrePolynomialExpansion(params)
self.prior_fn = prior_fn
@eqx.filter_jit
def __call__(self, inputs):
# Scale inputs to [-1, 1]
scaled_inputs = 2.0 * inputs - 1.0
result = self.expansion(scaled_inputs)
if self.prior_fn is not None:
result += self.prior_fn(inputs)
return result
[docs]
class FixedDegreeChemicalPotential(eqx.Module):
"""Chemical potential with fixed high degree (1000) but zeros for unused parameters.
This class always uses degree 1000 Legendre polynomials, but zeros out parameters
beyond the specified number. This ensures constant forward pass time while allowing
gradient computation scaling tests.
NOTE: THIS IS ONLY USED FOR TESTING PURPOSES.
"""
expansion: LegendrePolynomialExpansion
prior_fn: Callable
num_active_params: int
[docs]
def __init__(self, params: jax.Array, prior_fn: Callable = None):
super().__init__()
# Always create expansion with degree 1000
full_params = jnp.zeros(1001) # 1000 degree + 1 for constant term
full_params = full_params.at[:len(params)].set(params)
self.expansion = LegendrePolynomialExpansion(full_params)
self.prior_fn = prior_fn
self.num_active_params = len(params)
@eqx.filter_jit
def __call__(self, inputs):
# Scale inputs to [-1, 1]
scaled_inputs = 2.0 * inputs - 1.0
result = self.expansion(scaled_inputs)
if self.prior_fn is not None:
result += self.prior_fn(inputs)
return result
[docs]
class TestChemicalPotential(eqx.Module):
"""Test chemical potential."""
params: jax.Array
prior_fn: Callable
[docs]
def __init__(self, params: jax.Array, prior_fn: Callable = None):
super().__init__()
self.params = params
self.prior_fn = prior_fn
@eqx.filter_jit
def __call__(self, inputs):
result = jnp.sum(self.params) * jnp.ones_like(inputs)
if self.prior_fn is not None:
result += self.prior_fn(inputs)
return result
[docs]
@dataclasses.dataclass
class LegendrePolynomials:
max_degree: int
def __post_init__(self):
if self.max_degree == 0:
self.func = jax.jit(lambda p, x: p[0] * self.T0(x))
elif self.max_degree == 1:
self.func = jax.jit(lambda p, x: p[0] * self.T0(x) + p[1] * self.T1(x))
elif self.max_degree == 2:
self.func = jax.jit(
lambda p, x: p[0] * self.T0(x) + p[1] * self.T1(x) + p[2] * self.T2(x)
)
elif self.max_degree == 3:
self.func = jax.jit(
lambda p, x: p[0] * self.T0(x)
+ p[1] * self.T1(x)
+ p[2] * self.T2(x)
+ p[3] * self.T3(x)
)
elif self.max_degree == 4:
self.func = jax.jit(
lambda p, x: p[0] * self.T0(x)
+ p[1] * self.T1(x)
+ p[2] * self.T2(x)
+ p[3] * self.T3(x)
+ p[4] * self.T4(x)
)
elif self.max_degree == 5:
self.func = jax.jit(
lambda p, x: p[0] * self.T0(x)
+ p[1] * self.T1(x)
+ p[2] * self.T2(x)
+ p[3] * self.T3(x)
+ p[4] * self.T4(x)
+ p[5] * self.T5(x)
)
elif self.max_degree == 6:
self.func = jax.jit(
lambda p, x: p[0] * self.T0(x)
+ p[1] * self.T1(x)
+ p[2] * self.T2(x)
+ p[3] * self.T3(x)
+ p[4] * self.T4(x)
+ p[5] * self.T5(x)
+ p[6] * self.T6(x)
)
elif self.max_degree == 7:
self.func = jax.jit(
lambda p, x: p[0] * self.T0(x)
+ p[1] * self.T1(x)
+ p[2] * self.T2(x)
+ p[3] * self.T3(x)
+ p[4] * self.T4(x)
+ p[5] * self.T5(x)
+ p[6] * self.T6(x)
+ p[7] * self.T7(x)
)
elif self.max_degree == 8:
self.func = jax.jit(
lambda p, x: p[0] * self.T0(x)
+ p[1] * self.T1(x)
+ p[2] * self.T2(x)
+ p[3] * self.T3(x)
+ p[4] * self.T4(x)
+ p[5] * self.T5(x)
+ p[6] * self.T6(x)
+ p[7] * self.T7(x)
+ p[8] * self.T8(x)
)
elif self.max_degree == 9:
self.func = jax.jit(
lambda p, x: p[0] * self.T0(x)
+ p[1] * self.T1(x)
+ p[2] * self.T2(x)
+ p[3] * self.T3(x)
+ p[4] * self.T4(x)
+ p[5] * self.T5(x)
+ p[6] * self.T6(x)
+ p[7] * self.T7(x)
+ p[8] * self.T8(x)
+ p[9] * self.T9(x)
)
elif self.max_degree == 10:
self.func = jax.jit(
lambda p, x: p[0] * self.T0(x)
+ p[1] * self.T1(x)
+ p[2] * self.T2(x)
+ p[3] * self.T3(x)
+ p[4] * self.T4(x)
+ p[5] * self.T5(x)
+ p[6] * self.T6(x)
+ p[7] * self.T7(x)
+ p[8] * self.T8(x)
+ p[9] * self.T9(x)
+ p[10] * self.T10(x)
)
def __call__(self, params, inputs):
return self.func(params, inputs)
[docs]
def T0(self, x):
return 1.0 * jnp.ones_like(x)
[docs]
def T1(self, x):
return x
[docs]
def T2(self, x):
return 0.5 * (3 * x**2 - 1.0)
[docs]
def T3(self, x):
return 0.5 * (5 * x**3 - 3 * x)
[docs]
def T4(self, x):
return 0.125 * (35 * x**4 - 30 * x**2 + 3)
[docs]
def T5(self, x):
return 0.125 * (63 * x**5 - 70 * x**3 + 15 * x)
[docs]
def T6(self, x):
return 0.0625 * (231 * x**6 - 315 * x**4 + 105 * x**2 - 5)
[docs]
def T7(self, x):
return 0.0625 * (429 * x**7 - 693 * x**5 + 315 * x**3 - 35 * x)
[docs]
def T8(self, x):
return 0.0078125 * (6435 * x**8 - 12012 * x**6 + 6930 * x**4 - 1260 * x**2 + 35)
[docs]
def T9(self, x):
return 0.0078125 * (
12155 * x**9 - 25740 * x**7 + 18018 * x**5 - 4620 * x**3 + 315 * x
)
[docs]
def T10(self, x):
return 0.00390625 * (
46189 * x**10
- 109395 * x**8
+ 90090 * x**6
- 30030 * x**4
+ 3465 * x**2
- 63
)