Source code for pde_opt.numerics.functions.mixer_mlp

"""
This module contains a Mixer MLP architecture for representing functions in PDEs.

Based on the code from https://docs.kidger.site/equinox/examples/score_based_diffusion/
"""

import equinox as eqx
import jax
import jax.random as jr
import einops


[docs] class MixerBlock(eqx.Module): patch_mixer: eqx.nn.MLP hidden_mixer: eqx.nn.MLP norm1: eqx.nn.LayerNorm norm2: eqx.nn.LayerNorm
[docs] def __init__( self, num_patches, hidden_size, mix_patch_size, mix_hidden_size, *, key ): tkey, ckey = jr.split(key, 2) self.patch_mixer = eqx.nn.MLP( num_patches, num_patches, mix_patch_size, depth=1, key=tkey ) self.hidden_mixer = eqx.nn.MLP( hidden_size, hidden_size, mix_hidden_size, depth=1, key=ckey ) self.norm1 = eqx.nn.LayerNorm((hidden_size, num_patches)) self.norm2 = eqx.nn.LayerNorm((num_patches, hidden_size))
def __call__(self, y): y = y + jax.vmap(self.patch_mixer)(self.norm1(y)) y = einops.rearrange(y, "c p -> p c") y = y + jax.vmap(self.hidden_mixer)(self.norm2(y)) y = einops.rearrange(y, "p c -> c p") return y
[docs] class Mixer2d(eqx.Module): conv_in: eqx.nn.Conv2d conv_out: eqx.nn.ConvTranspose2d blocks: list norm: eqx.nn.LayerNorm
[docs] def __init__( self, img_size, patch_size, hidden_size, mix_patch_size, mix_hidden_size, num_blocks, *, key, ): input_size, height, width = img_size assert (height % patch_size) == 0 assert (width % patch_size) == 0 num_patches = (height // patch_size) * (width // patch_size) inkey, outkey, *bkeys = jr.split(key, 2 + num_blocks) self.conv_in = eqx.nn.Conv2d( input_size, hidden_size, patch_size, stride=patch_size, key=inkey ) self.conv_out = eqx.nn.ConvTranspose2d( hidden_size, input_size, patch_size, stride=patch_size, key=outkey ) self.blocks = [ MixerBlock( num_patches, hidden_size, mix_patch_size, mix_hidden_size, key=bkey ) for bkey in bkeys ] self.norm = eqx.nn.LayerNorm((hidden_size, num_patches))
def __call__(self, y): y = y[None] y = self.conv_in(y) _, patch_height, patch_width = y.shape y = einops.rearrange(y, "c h w -> c (h w)") for block in self.blocks: y = block(y) y = self.norm(y) y = einops.rearrange(y, "c (h w) -> c h w", h=patch_height, w=patch_width) return self.conv_out(y)[0]