
Library for wrapping Waymax environments in a Brax-like interface.

For more information on the Brax interface see: https://github.com/google/brax.

The Waymax/Brax interface primarily differs from the Google/Brax interface in the reset function. Because Waymax uses data to instantiate a new episode, the reset function requires a SimulatorState argument, whereas the Google/Brax interface requires only a random key.

Module Contents#



Container class for Waymax transitions.


Brax-like interface wrapper for the Waymax environment.

class waymax.env.wrappers.brax_wrapper.TimeStep#

Container class for Waymax transitions.


The current simulation state of shape (…).


The current observation of shape (..,).


The reward obtained in the current transition of shape (…, num_objects).


A boolean array denoting the end of an episode of shape (…).


An array of discount values of shape (…).


Optional dictionary of metrics.


Optional dictionary of arbitrary logging information.

property shape: tuple[int, Ellipsis]#

Shape of TimeStep.

state: waymax.datatypes.SimulatorState#
observation: waymax.env.typedefs.Observation#
reward: jax.Array#
done: jax.Array#
discount: jax.Array#
metrics: waymax.env.typedefs.Metrics#
info: dict[str, Any]#
__eq__(other: Any) bool#

Return self==value.

class waymax.env.wrappers.brax_wrapper.BraxWrapper(wrapped_env: waymax.env.abstract_environment.AbstractEnvironment, dynamics_model: waymax.dynamics.DynamicsModel, config: waymax.config.EnvironmentConfig)#

Brax-like interface wrapper for the Waymax environment.

metrics(state: waymax.datatypes.SimulatorState) waymax.env.typedefs.Metrics#

Computes metrics (lower is better) from state.

reset(state: waymax.datatypes.SimulatorState) TimeStep#

Resets the environment and initializes the simulation state.

This initializer sets the initial timestep and fills the initial simulation trajectory with invalid values.


state – An uninitialized state.


The initialized simulation state.

observe(state: waymax.datatypes.SimulatorState) waymax.env.typedefs.Observation#

Computes the observation for the given simulation state.

step(timestep: TimeStep, action: waymax.datatypes.Action) TimeStep#

Advances simulation by one timestep using the dynamics model.

  • timestep – The timestep containing the current state.

  • action – The action to apply, of shape (…, num_objects). The actions.valid field is used to denote which objects are being controlled - objects whose valid is False will fallback to default behavior specified by self.dynamics.


The timestep corresponding to the transition taken.

reward(state: waymax.datatypes.SimulatorState, action: waymax.datatypes.Action) jax.Array#

Computes the reward for a transition.

  • state – The state used to compute the reward at state.timestep.

  • action – The action applied to state.


A (…, num_objects) tensor of rewards.

termination(state: waymax.datatypes.SimulatorState) jax.Array#

Returns whether the current state is an episode termination.

A termination marks the end of an episode where the cost-to-go from this state is 0.

The equivalent step type in DMEnv is dm_env.termination.


state – The current simulator state.


A boolean (…) tensor indicating whether the current state is the end

of an episode as a termination.

truncation(state: waymax.datatypes.SimulatorState) jax.Array#

Returns whether the current state should truncate the episode.

A truncation denotes that an episode has ended due to reaching the step limit of an episode. In these cases dynamic programming methods (e.g. Q-learning) should still compute cost-to-go assuming the episode will continue running.

The equivalent step type in DMEnv is dm_env.truncation.


state – The current simulator state.


A boolean (…) tensor indicating whether the current state is the end of

an episode as a truncation.

action_spec() waymax.datatypes.Action#

Action spec of the environment.

reward_spec() dm_env.specs.Array#

Reward spec of the environment.

discount_spec() dm_env.specs.BoundedArray#

Discount spec of the environment.

observation_spec() waymax.env.typedefs.PyTree#

Observation spec of the environment.