"""
This module contains the Shape class, which is used to set up a geometry/shape for solving PDE on with smoothed boundary method.
"""
import dataclasses
import jax
import jax.numpy as jnp
from typing import Tuple, Optional
import diffrax as dfx
import scipy
import numpy as np
from scipy.sparse import coo_matrix, csr_matrix
from .utils.derivatives import _gradx_c, _grady_c, _grad2x_c, _grad2y_c, _grad2xy_c
Array = jax.Array
[docs]
@dataclasses.dataclass
class Shape:
"""Sets up a geometry/shape for solving PDE on with smoothed boundary method.
The user creates a shape by providing a binary representation and an optional smoothing parameter.
"""
binary: Array
refine_factor: Optional[float] = None
refine_edge: Optional[float] = None
dx: Optional[Tuple[float, float]] = (1.0, 1.0)
smooth_epsilon: float = 1.0
smooth_curvature: float = 0.0
smooth_dt: float = 0.1
smooth_tf: float = 1.0
def __post_init__(self):
# Refine binary mask if parameters are provided
if self.refine_factor is not None or self.refine_edge is not None:
self.binary = self.refine_binary_mask()
self.smooth = self.smooth_shape()
self.smooth = jnp.where(self.smooth < 0.001, 0.001, self.smooth)
self.smooth = jnp.where(self.smooth > 0.99, 1.0, self.smooth)
[docs]
def refine_binary_mask(self) -> Array:
"""Refine the binary mask by upsampling and adding edge padding.
Args:
refine_factor: Factor by which to upsample the binary mask (e.g., 2.0 for 2x upsampling)
refine_edge: Percentage to increase the edges (e.g., 0.5 for 50% increase)
Returns:
Refined binary mask with upsampling and edge padding applied
"""
binary = self.binary
# Step 1: Upsample by refine_factor if specified
if self.refine_factor is not None and self.refine_factor != 1.0:
# Convert to numpy for scipy interpolation
binary_np = np.array(binary)
h, w = binary_np.shape
# Calculate new dimensions
new_h = int(h * self.refine_factor)
new_w = int(w * self.refine_factor)
# Create coordinate grids
y_old = np.linspace(0, h-1, h)
x_old = np.linspace(0, w-1, w)
y_new = np.linspace(0, h-1, new_h)
x_new = np.linspace(0, w-1, new_w)
# Create meshgrids
X_old, Y_old = np.meshgrid(x_old, y_old)
X_new, Y_new = np.meshgrid(x_new, y_new)
# Interpolate using nearest neighbor to preserve binary nature
from scipy.interpolate import griddata
points = np.column_stack((Y_old.ravel(), X_old.ravel()))
values = binary_np.ravel()
binary = griddata(points, values, (Y_new, X_new), method='nearest', fill_value=0.0)
binary = jnp.array(binary)
# Step 2: Add edge padding if specified
if self.refine_edge is not None and self.refine_edge > 0.0:
h, w = binary.shape
# Calculate padding size based on percentage increase
pad_h = int(h * self.refine_edge)
pad_w = int(w * self.refine_edge)
# Create new array with padding (filled with zeros/black)
new_h = h + 2 * pad_h
new_w = w + 2 * pad_w
padded_binary = jnp.zeros((new_h, new_w))
# Place the upsampled binary in the center
padded_binary = padded_binary.at[pad_h:pad_h+h, pad_w:pad_w+w].set(binary)
binary = padded_binary
return binary
[docs]
def smooth_shape(self) -> Array:
"""Smooths the shape using the Allen-Cahn equation with curvature minimization."""
def potential(u):
return 18.0 / self.smooth_epsilon * u * (1.0 - u) * (1.0 - 2.0 * u)
@jax.jit
def rhs(t, u, args):
gradx = _gradx_c(u, self.dx[0])
grady = _grady_c(u, self.dx[1])
grad2x = _grad2x_c(u, self.dx[0])
grad2y = _grad2y_c(u, self.dx[1])
grad2xy = _grad2xy_c(u, self.dx[0], self.dx[1])
grad_norm_sq = gradx**2 + grady**2
grad_norm_sq = jnp.where(grad_norm_sq < 1e-7, 1.0, grad_norm_sq)
norm_laplace = (
grad2x * gradx**2 + 2.0 * grad2xy * gradx * grady + grad2y * grady**2
) / grad_norm_sq
laplace = grad2x + grad2y
return (
2.0
* (
self.smooth_curvature * laplace
+ (1.0 - self.smooth_curvature) * norm_laplace
)
- potential(u) / self.smooth_epsilon
)
solution = dfx.diffeqsolve(
dfx.ODETerm(rhs),
dfx.Tsit5(),
t0=0.0,
t1=self.smooth_tf,
dt0=self.smooth_dt,
y0=self.binary,
stepsize_controller=dfx.PIDController(rtol=1e-4, atol=1e-6),
saveat=dfx.SaveAt(t1=True),
max_steps=1000000,
)
return solution.ys[-1]
[docs]
def laplacian_from_mask(self, periodic: bool = False):
"""
Unnormalized graph Laplacian (4-neighbour) from a 0/1 mask.
Nodes are entries where mask==1. Two nodes connect if they are
up/down/left/right neighbours and both are 1.
Returns:
L : (n_nodes, n_nodes) CSR Laplacian
ids: (H, W) array, node index in [0, n_nodes) or -1 if not a node
"""
mask = self.binary > 0
H, W = mask.shape
ids = -np.ones((H, W), dtype=np.int64)
ids[mask] = np.arange(mask.sum(), dtype=np.int64)
n = int(mask.sum())
if n == 0:
return csr_matrix((0, 0)), ids
def undirected_edges(dy, dx):
"""Return endpoints (u,v) for each undirected edge, listed once."""
if periodic:
m_both = mask & np.roll(mask, (dy, dx), axis=(0, 1))
if not m_both.any():
return np.empty(0, np.int64), np.empty(0, np.int64)
u = ids[m_both]
v = np.roll(ids, (dy, dx), axis=(0, 1))[m_both]
return u, v
else:
y0, y1 = max(0, dy), H + min(0, dy)
x0, x1 = max(0, dx), W + min(0, dx)
m1 = mask[y0:y1, x0:x1]
m2 = mask[y0 - dy : y1 - dy, x0 - dx : x1 - dx]
both = m1 & m2
if not both.any():
return np.empty(0, np.int64), np.empty(0, np.int64)
u = ids[y0:y1, x0:x1][both]
v = ids[y0 - dy : y1 - dy, x0 - dx : x1 - dx][both]
return u, v
# Build edges once using right and down neighbours, then symmetrize
ur, vr = undirected_edges(0, +1) # right
ud, vd = undirected_edges(+1, 0) # down
u_one = np.concatenate([ur, ud])
v_one = np.concatenate([vr, vd])
# Degree from unique undirected edges: each endpoint counted once
deg = np.bincount(np.concatenate([u_one, v_one]), minlength=n).astype(
np.float64
)
# Off-diagonals: symmetrize edges (u,v) and (v,u)
rows_off = np.concatenate([u_one, v_one])
cols_off = np.concatenate([v_one, u_one])
data_off = -np.ones(rows_off.shape[0], dtype=np.float64)
# Diagonal
rows = np.concatenate([rows_off, np.arange(n)])
cols = np.concatenate([cols_off, np.arange(n)])
data = np.concatenate([data_off, deg])
L = coo_matrix((data, (rows, cols)), shape=(n, n)).tocsr()
return L, ids
[docs]
def get_shape_modes(self, N: Optional[int] = None):
"""Get the first N eigenvectors of the graph Laplacian of the binary mask.
Creates a graph where nodes are the 1-valued pixels, with edges between
adjacent pixels (left, right, top, bottom neighbors).
Args:
N: Number of eigenvectors to return. If None, returns all eigenvectors.
downsampling_factor: If provided, downsample binary by this factor before
computing modes, then upsample results back to original size.
This can significantly reduce memory usage and computation time
for large binary masks.
Returns:
Array of shape (num_nodes, N) containing the first N eigenvectors
"""
laplacian, node_ids = self.laplacian_from_mask()
# return laplacian, node_ids
n = laplacian.shape[0]
# Check if Laplacian matrix is symmetric
is_symmetric = (laplacian != laplacian.T).nnz == 0
if not is_symmetric:
raise ValueError("Laplacian matrix is not symmetric")
# A scale-aware tiny shift: ~ 1e-8 times a typical diagonal magnitude
diag_mean = float(laplacian.diagonal().mean()) if n > 0 else 1.0
sigma = max(diag_mean, 1.0) * 1e-8
# Get only the first N eigenvectors (much faster than computing all)
eigenvals, eigenvecs = scipy.sparse.linalg.eigsh(
laplacian,
k=N,
which="LM",
sigma=sigma,
tol=1e-8,
maxiter=None,
)
# Initialize output array with zeros
shape = self.binary.shape
output = np.zeros((shape[0], shape[1], N))
# Vectorized assignment using advanced indexing
# Get valid node positions (where node_ids >= 0)
valid_mask = node_ids >= 0
valid_node_ids = node_ids[valid_mask]
# print(valid_mask)
# print(valid_node_ids)
# Fill in eigenvector values at node locations
for i in range(N):
eigenvec = eigenvecs[:, i]
output[valid_mask, i] = eigenvec[valid_node_ids]
self.shape_basis = jnp.array(output)
self.shape_basis_evals = eigenvals