Source code for pde_opt.numerics.domains

"""
This module contains the Domain class, which is used to set up a simulation domain for the model.
"""

import dataclasses
from typing import Optional, Tuple
import jax
import jax.numpy as jnp

from .shapes import Shape


[docs] @dataclasses.dataclass class Domain: """Sets up a simulation domain for the model. The following information is stored in a Domain: -- `points[i]` is the number of collocation points in the i'th dimension -- `dx[i]` is the spacing between each collocation point in the i'th dimension -- `box[i]` is the bounds of the simulation box in the i'th dimension -- `units` are the length units these. values are stored in """ points: Tuple[int, ...] box: Tuple[Tuple[float, float], ...] units: str geometry: Optional[Shape] = None def __post_init__(self): self.dx = tuple( (up_bound - low_bound) / points for (low_bound, up_bound), points in zip(self.box, self.points) ) self.L = tuple((up_bound - low_bound) for (low_bound, up_bound) in self.box)
[docs] def axes(self) -> Tuple[jax.Array, ...]: return tuple( jnp.linspace(low_bound + step / 2, up_bound - step / 2, num=points) for (low_bound, up_bound), points, step in zip( self.box, self.points, self.dx ) )
[docs] def fft_axes(self) -> Tuple[jax.Array, ...]: return tuple( jnp.fft.fftfreq(points, step) for points, step in zip(self.points, self.dx) )
[docs] def rfft_axes(self) -> Tuple[jax.Array, ...]: return tuple( jnp.fft.rfftfreq(points, step) for points, step in zip(self.points, self.dx) )
[docs] def mesh(self) -> Tuple[jax.Array, ...]: axes = self.axes() return tuple(jnp.meshgrid(*axes, indexing="ij"))
[docs] def fft_mesh(self) -> Tuple[jax.Array, ...]: fft_axes = self.fft_axes() return tuple(jnp.meshgrid(*fft_axes, indexing="ij"))
[docs] def rfft_mesh(self) -> Tuple[jax.Array, ...]: rfft_axes = self.rfft_axes() return tuple(jnp.meshgrid(*rfft_axes, indexing="ij"))
def __str__(self): return f"Domain with bounds {self.box} with units of {self.units} and {self.points} collocation points."