waymax.env.wrappers.brax_wrapper
#
Library for wrapping Waymax environments in a Brax-like interface.
For more information on the Brax interface see: https://github.com/google/brax.
The Waymax/Brax interface primarily differs from the Google/Brax interface in the reset function. Because Waymax uses data to instantiate a new episode, the reset function requires a SimulatorState argument, whereas the Google/Brax interface requires only a random key.
Module Contents#
Classes#
Container class for Waymax transitions. |
|
Brax-like interface wrapper for the Waymax environment. |
- class waymax.env.wrappers.brax_wrapper.TimeStep#
Container class for Waymax transitions.
- state#
The current simulation state of shape (…).
- observation#
The current observation of shape (..,).
- reward#
The reward obtained in the current transition of shape (…, num_objects).
- done#
A boolean array denoting the end of an episode of shape (…).
- discount#
An array of discount values of shape (…).
- metrics#
Optional dictionary of metrics.
- info#
Optional dictionary of arbitrary logging information.
- property shape: tuple[int, Ellipsis]#
Shape of TimeStep.
- observation: waymax.env.typedefs.Observation#
- reward: jax.Array#
- done: jax.Array#
- discount: jax.Array#
- metrics: waymax.env.typedefs.Metrics#
- info: dict[str, Any]#
- __eq__(other: Any) bool #
Return self==value.
- class waymax.env.wrappers.brax_wrapper.BraxWrapper(wrapped_env: waymax.env.abstract_environment.AbstractEnvironment, dynamics_model: waymax.dynamics.DynamicsModel, config: waymax.config.EnvironmentConfig)#
Brax-like interface wrapper for the Waymax environment.
- metrics(state: waymax.datatypes.SimulatorState) waymax.env.typedefs.Metrics #
Computes metrics (lower is better) from state.
- reset(state: waymax.datatypes.SimulatorState) TimeStep #
Resets the environment and initializes the simulation state.
This initializer sets the initial timestep and fills the initial simulation trajectory with invalid values.
- Parameters:
state – An uninitialized state.
- Returns:
The initialized simulation state.
- observe(state: waymax.datatypes.SimulatorState) waymax.env.typedefs.Observation #
Computes the observation for the given simulation state.
- step(timestep: TimeStep, action: waymax.datatypes.Action) TimeStep #
Advances simulation by one timestep using the dynamics model.
- Parameters:
timestep – The timestep containing the current state.
action – The action to apply, of shape (…, num_objects). The actions.valid field is used to denote which objects are being controlled - objects whose valid is False will fallback to default behavior specified by self.dynamics.
- Returns:
The timestep corresponding to the transition taken.
- reward(state: waymax.datatypes.SimulatorState, action: waymax.datatypes.Action) jax.Array #
Computes the reward for a transition.
- Parameters:
state – The state used to compute the reward at state.timestep.
action – The action applied to state.
- Returns:
A (…, num_objects) tensor of rewards.
- termination(state: waymax.datatypes.SimulatorState) jax.Array #
Returns whether the current state is an episode termination.
A termination marks the end of an episode where the cost-to-go from this state is 0.
The equivalent step type in DMEnv is dm_env.termination.
- Parameters:
state – The current simulator state.
- Returns:
- A boolean (…) tensor indicating whether the current state is the end
of an episode as a termination.
- truncation(state: waymax.datatypes.SimulatorState) jax.Array #
Returns whether the current state should truncate the episode.
A truncation denotes that an episode has ended due to reaching the step limit of an episode. In these cases dynamic programming methods (e.g. Q-learning) should still compute cost-to-go assuming the episode will continue running.
The equivalent step type in DMEnv is dm_env.truncation.
- Parameters:
state – The current simulator state.
- Returns:
- A boolean (…) tensor indicating whether the current state is the end of
an episode as a truncation.
- action_spec() waymax.datatypes.Action #
Action spec of the environment.
- reward_spec() dm_env.specs.Array #
Reward spec of the environment.
- discount_spec() dm_env.specs.BoundedArray #
Discount spec of the environment.
- observation_spec() waymax.env.typedefs.PyTree #
Observation spec of the environment.