Source code for pde_opt.utils

"""
This module contains general utility functions.
"""


[docs] def check_equation_solver_compatibility(solver_type, equation_type): """Check that equation type has all required attributes specified by solver. This is a check to ensure that the equation and solver are compatible. Raises: ValueError: If equation is missing any required attributes from solver. """ # Get required attributes from solver if they exist if not hasattr(solver_type, "required_equation_attrs"): return solver_required_attrs = solver_type.required_equation_attrs # Check each required attribute exists in equation class missing_attrs = [] for attr in solver_required_attrs: if not hasattr(equation_type, attr): missing_attrs.append(attr) if missing_attrs: raise ValueError( f"Equation type {equation_type.__name__} is missing required " f"attributes for solver {solver_type.__name__}: {missing_attrs}" )
[docs] def prepare_solver_params(solver_type, solver_parameters, equation): """Prepare solver parameters by extracting required equation attributes. Some solvers require attributes from the equation to be passed to them. This function prepares the solver parameters by extracting the required attributes from the equation. Args: solver_parameters (Dict[str, Any]): The solver parameters to use for the equation equation (BaseEquation): The equation to solve Returns: Dict[str, Any]: The prepared solver parameters """ full_solver_params = solver_parameters.copy() if hasattr(solver_type, "required_equation_attrs"): for attr_name in solver_type.required_equation_attrs: full_solver_params[attr_name] = getattr(equation, attr_name) return full_solver_params