pde_opt.pde_env
Classes
|
Reinforcement learning environment for controlling partial differential equations. |
- class pde_opt.pde_env.PDEEnv(*args: Any, **kwargs: Any)[source]
Reinforcement learning environment for controlling partial differential equations.
The PDEEnv class provides a Gymnasium-compatible environment for reinforcement learning control of PDEs. It allows agents to learn control policies by taking actions that modify equation parameters and receiving rewards based on the resulting system behavior.
- equation_type
The PDE equation class to control.
- Type:
Type[BaseEquation]
- domain
Spatial domain for the PDE simulation.
- Type:
- solver_type
Numerical solver for time integration.
- Type:
Type[diffrax.AbstractSolver]
- end_time
Maximum simulation time for each episode.
- Type:
float
- step_dt
Time duration for each RL step.
- Type:
float
- numeric_dt
Internal time step for numerical integration.
- Type:
float
- observation_space
Gymnasium observation space.
- Type:
gym.Space
- action_space
Gymnasium action space.
- Type:
gym.Space
For an example, see the documentation notebooks.
- __init__(equation_type: Type[BaseEquation], domain: Domain, solver_type: Type[diffrax.AbstractSolver], end_time: float, step_dt: float, numeric_dt: float, state_to_observation_func: Callable, reward_function: Callable, reset_func: Callable, reset_control_value, update_control_value: Callable, update_control_parameter: Callable, action_space_config: Dict[str, Any], static_equation_parameters: Dict[str, Any], control_equation_parameter_name: str, solver_parameters: Dict[str, Any])[source]
Initialize the PDE reinforcement learning environment.
- Parameters:
equation_type (Type[BaseEquation]) – The PDE equation class to control. Must be a subclass of BaseEquation from numerics.equations.
domain (domains.Domain) – Spatial domain for the PDE simulation. Defines the grid resolution, spatial bounds, and coordinate system.
solver_type (Type[diffrax.AbstractSolver]) – Numerical solver for time integration. Must be compatible with the equation type.
end_time (float) – Maximum simulation time for each episode. Episodes terminate when this time is reached.
step_dt (float) – Time duration for each RL step. This is the time interval between consecutive actions taken by the agent.
numeric_dt (float) – Internal time step for numerical integration. Likely to be smaller than step_dt for numerical stability.
state_to_observation_func (Callable) – Function that converts the PDE state to an observation for the RL agent.
reward_function (Callable) – Function that computes the reward from the current PDE state.
reset_func (Callable) – Function that generates initial conditions for the PDE. Should accept (domain, seed=None) and return initial state.
reset_control_value – Initial value for the control parameter.
update_control_value (Callable) – Function that updates the control value based on the agent’s action. Should accept (action, old_value) and return the new control value.
update_control_parameter (Callable) – Function that converts control values to equation parameters. Should accept (old_value, new_value) and return the parameter value for the equation.
action_space_config (Dict[str, Any]) – Configuration for the action space. Should contain: - “type”: “continuous” or “discrete” - For continuous: “shape”, “low”, “high” - For discrete: “num_actions”, “action_mapping”
static_equation_parameters (Dict[str, Any]) – Fixed parameters for the equation that are not controlled by the agent. Keys should match equation parameter names.
control_equation_parameter_name (str) – Name of the equation parameter that will be controlled by the agent’s actions.
solver_parameters (Dict[str, Any]) – Additional parameters for the numerical solver. Passed directly to the solver constructor.
- reset(seed: int | None = None, options: dict | None = None)[source]
Reset the environment to initial state.
This method resets the PDE environment to its initial state, generating new initial conditions and resetting the control parameters. It follows the standard Gymnasium reset interface.
- Parameters:
seed (Optional[int]) – Random seed for reproducible initial conditions.
options (Optional[dict]) – Additional options for reset.
- Returns:
- A tuple containing:
observation: Initial observation from the reset state
info: Additional information dictionary (currently empty)
- Return type:
Tuple[np.ndarray, dict]
- step(action: numpy.ndarray)[source]
Take a step in the environment.
This method executes one step of the RL environment by: 1. Converting the agent’s action to a control parameter update 2. Updating the PDE equation with the new control parameter 3. Simulating the PDE forward in time for one step_dt 4. Computing the reward and next observation 5. Checking for episode termination
- Parameters:
action (np.ndarray) – The action taken by the RL agent. Should be compatible with the environment’s action space. For continuous actions, this is typically a numpy array. For discrete actions, this is an integer.
- Returns:
- A tuple containing:
observation: Next observation from the updated state
reward: Reward value computed from the current state
terminated: Whether the episode has ended (reached end_time)
truncated: Whether the episode was truncated (always False for now)
info: Additional information dictionary (currently empty)
- Return type:
Tuple[np.ndarray, float, bool, bool, dict]