"""
This module contains a Mixer MLP architecture for representing functions in PDEs.
Based on the code from https://docs.kidger.site/equinox/examples/score_based_diffusion/
"""
import equinox as eqx
import jax
import jax.random as jr
import einops
[docs]
class MixerBlock(eqx.Module):
patch_mixer: eqx.nn.MLP
hidden_mixer: eqx.nn.MLP
norm1: eqx.nn.LayerNorm
norm2: eqx.nn.LayerNorm
[docs]
def __init__(
self, num_patches, hidden_size, mix_patch_size, mix_hidden_size, *, key
):
tkey, ckey = jr.split(key, 2)
self.patch_mixer = eqx.nn.MLP(
num_patches, num_patches, mix_patch_size, depth=1, key=tkey
)
self.hidden_mixer = eqx.nn.MLP(
hidden_size, hidden_size, mix_hidden_size, depth=1, key=ckey
)
self.norm1 = eqx.nn.LayerNorm((hidden_size, num_patches))
self.norm2 = eqx.nn.LayerNorm((num_patches, hidden_size))
def __call__(self, y):
y = y + jax.vmap(self.patch_mixer)(self.norm1(y))
y = einops.rearrange(y, "c p -> p c")
y = y + jax.vmap(self.hidden_mixer)(self.norm2(y))
y = einops.rearrange(y, "p c -> c p")
return y
[docs]
class Mixer2d(eqx.Module):
conv_in: eqx.nn.Conv2d
conv_out: eqx.nn.ConvTranspose2d
blocks: list
norm: eqx.nn.LayerNorm
[docs]
def __init__(
self,
img_size,
patch_size,
hidden_size,
mix_patch_size,
mix_hidden_size,
num_blocks,
*,
key,
):
input_size, height, width = img_size
assert (height % patch_size) == 0
assert (width % patch_size) == 0
num_patches = (height // patch_size) * (width // patch_size)
inkey, outkey, *bkeys = jr.split(key, 2 + num_blocks)
self.conv_in = eqx.nn.Conv2d(
input_size, hidden_size, patch_size, stride=patch_size, key=inkey
)
self.conv_out = eqx.nn.ConvTranspose2d(
hidden_size, input_size, patch_size, stride=patch_size, key=outkey
)
self.blocks = [
MixerBlock(
num_patches, hidden_size, mix_patch_size, mix_hidden_size, key=bkey
)
for bkey in bkeys
]
self.norm = eqx.nn.LayerNorm((hidden_size, num_patches))
def __call__(self, y):
y = y[None]
y = self.conv_in(y)
_, patch_height, patch_width = y.shape
y = einops.rearrange(y, "c h w -> c (h w)")
for block in self.blocks:
y = block(y)
y = self.norm(y)
y = einops.rearrange(y, "c (h w) -> c h w", h=patch_height, w=patch_width)
return self.conv_out(y)[0]