Solving a PDE
A basic functionality of this package is to solve time-dependent PDEs. In this example, we show how to solve the Cahn-Hilliard equation in 2D.
[2]:
import jax.numpy as jnp
from jax import random
import diffrax as dfx
import matplotlib.pyplot as plt
from IPython.display import HTML
import matplotlib.animation as animation
from pde_opt.numerics.domains import Domain
from pde_opt.numerics.equations.cahn_hilliard import CahnHilliard2DPeriodic
from pde_opt.numerics.solvers import SemiImplicitFourierSpectral
from pde_opt.pde_model import PDEModel
First, we create the domain that we will solve the PDE on.
[3]:
Nx, Ny = 128, 128
Lx = 0.01 * Nx
Ly = 0.01 * Ny
domain = Domain((Nx, Ny), ((-Lx / 2, Lx / 2), (-Ly / 2, Ly / 2)), "dimensionless")
Now, let’s create our PDE model with this domain and the Cahn-Hilliard equation. We will use a semi-impicit Fourier spectral method for timestepping.
[4]:
model = PDEModel(CahnHilliard2DPeriodic, domain, SemiImplicitFourierSpectral)
Next, we specify the starting time, ending time, initial step size, and time points we want to save the solution at.
[6]:
t_start = 0.0
t_final = 0.2
dt = 0.000001
ts_save = jnp.linspace(t_start, t_final, 200)
We also need to specify the parameters in the PDE and solver
[7]:
pde_parameters = {
"kappa": 0.002,
"mu": lambda c: jnp.log(c / (1.0 - c)) + 3.0 * (1.0 - 2.0 * c),
"D": lambda c: (1.0 - c) * c,
"derivs": "fd"
}
solver_parameters = {
"A": 0.5,
}
Finally, let’s create an initial condition for the simulation.
[8]:
key = random.PRNGKey(0)
u0 = 0.5 * jnp.ones((Nx, Ny)) + 0.01 * random.normal(key, (Nx, Ny))
[ ]:
solution = model.solve(
pde_parameters,
u0,
ts_save,
solver_parameters,
)
[11]:
fig, ax = plt.subplots(figsize=(4,4))
ims = []
for i in range(0, len(solution), 2):
im = ax.imshow(solution[i], animated=True, cmap='RdBu',
vmin=0.0, vmax=1.0,
extent=[domain.box[0][0], domain.box[0][1],
domain.box[1][0], domain.box[1][1]])
ims.append([im])
ani = animation.ArtistAnimation(fig, ims, interval=100, blit=True)
plt.title('Cahn-Hilliard Evolution')
plt.xlabel('x')
plt.ylabel('y')
plt.close()
HTML(ani.to_jshtml())
[11]: