Solving a PDE

A basic functionality of this package is to solve time-dependent PDEs. In this example, we show how to solve the Cahn-Hilliard equation in 2D.

[2]:
import jax.numpy as jnp
from jax import random
import diffrax as dfx
import matplotlib.pyplot as plt
from IPython.display import HTML
import matplotlib.animation as animation

from pde_opt.numerics.domains import Domain
from pde_opt.numerics.equations.cahn_hilliard import CahnHilliard2DPeriodic
from pde_opt.numerics.solvers import SemiImplicitFourierSpectral
from pde_opt.pde_model import PDEModel

First, we create the domain that we will solve the PDE on.

[3]:
Nx, Ny = 128, 128
Lx = 0.01 * Nx
Ly = 0.01 * Ny
domain = Domain((Nx, Ny), ((-Lx / 2, Lx / 2), (-Ly / 2, Ly / 2)), "dimensionless")

Now, let’s create our PDE model with this domain and the Cahn-Hilliard equation. We will use a semi-impicit Fourier spectral method for timestepping.

[4]:
model = PDEModel(CahnHilliard2DPeriodic, domain, SemiImplicitFourierSpectral)

Next, we specify the starting time, ending time, initial step size, and time points we want to save the solution at.

[6]:
t_start = 0.0
t_final = 0.2
dt = 0.000001
ts_save = jnp.linspace(t_start, t_final, 200)

We also need to specify the parameters in the PDE and solver

[7]:
pde_parameters = {
    "kappa": 0.002,
    "mu": lambda c: jnp.log(c / (1.0 - c)) + 3.0 * (1.0 - 2.0 * c),
    "D": lambda c: (1.0 - c) * c,
    "derivs": "fd"
}

solver_parameters = {
    "A": 0.5,
}

Finally, let’s create an initial condition for the simulation.

[8]:
key = random.PRNGKey(0)
u0 = 0.5 * jnp.ones((Nx, Ny)) + 0.01 * random.normal(key, (Nx, Ny))
[ ]:
solution = model.solve(
    pde_parameters,
    u0,
    ts_save,
    solver_parameters,
)
[11]:
fig, ax = plt.subplots(figsize=(4,4))

ims = []
for i in range(0, len(solution), 2):
    im = ax.imshow(solution[i], animated=True, cmap='RdBu',
                   vmin=0.0, vmax=1.0,
                   extent=[domain.box[0][0], domain.box[0][1],
                          domain.box[1][0], domain.box[1][1]])
    ims.append([im])

ani = animation.ArtistAnimation(fig, ims, interval=100, blit=True)

plt.title('Cahn-Hilliard Evolution')
plt.xlabel('x')
plt.ylabel('y')

plt.close()

HTML(ani.to_jshtml())

[11]: