waymax.env.wrappers.brax_wrapper#

Library for wrapping Waymax environments in a Brax-like interface.

For more information on the Brax interface see: https://github.com/google/brax.

The Waymax/Brax interface primarily differs from the Google/Brax interface in the reset function. Because Waymax uses data to instantiate a new episode, the reset function requires a SimulatorState argument, whereas the Google/Brax interface requires only a random key.

Module Contents#

Classes#

TimeStep

Container class for Waymax transitions.

BraxWrapper

Brax-like interface wrapper for the Waymax environment.

class waymax.env.wrappers.brax_wrapper.TimeStep#

Container class for Waymax transitions.

state#

The current simulation state of shape (…).

observation#

The current observation of shape (..,).

reward#

The reward obtained in the current transition of shape (…, num_objects).

done#

A boolean array denoting the end of an episode of shape (…).

discount#

An array of discount values of shape (…).

metrics#

Optional dictionary of metrics.

info#

Optional dictionary of arbitrary logging information.

property shape: tuple[int, Ellipsis]#

Shape of TimeStep.

state: waymax.datatypes.SimulatorState#
observation: waymax.env.typedefs.Observation#
reward: jax.Array#
done: jax.Array#
discount: jax.Array#
metrics: waymax.env.typedefs.Metrics#
info: dict[str, Any]#
__eq__(other: Any) bool#

Return self==value.

class waymax.env.wrappers.brax_wrapper.BraxWrapper(wrapped_env: waymax.env.abstract_environment.AbstractEnvironment, dynamics_model: waymax.dynamics.DynamicsModel, config: waymax.config.EnvironmentConfig)#

Brax-like interface wrapper for the Waymax environment.

metrics(state: waymax.datatypes.SimulatorState) waymax.env.typedefs.Metrics#

Computes metrics (lower is better) from state.

reset(state: waymax.datatypes.SimulatorState) TimeStep#

Resets the environment and initializes the simulation state.

This initializer sets the initial timestep and fills the initial simulation trajectory with invalid values.

Parameters:

state – An uninitialized state.

Returns:

The initialized simulation state.

observe(state: waymax.datatypes.SimulatorState) waymax.env.typedefs.Observation#

Computes the observation for the given simulation state.

step(timestep: TimeStep, action: waymax.datatypes.Action) TimeStep#

Advances simulation by one timestep using the dynamics model.

Parameters:
  • timestep – The timestep containing the current state.

  • action – The action to apply, of shape (…, num_objects). The actions.valid field is used to denote which objects are being controlled - objects whose valid is False will fallback to default behavior specified by self.dynamics.

Returns:

The timestep corresponding to the transition taken.

reward(state: waymax.datatypes.SimulatorState, action: waymax.datatypes.Action) jax.Array#

Computes the reward for a transition.

Parameters:
  • state – The state used to compute the reward at state.timestep.

  • action – The action applied to state.

Returns:

A (…, num_objects) tensor of rewards.

termination(state: waymax.datatypes.SimulatorState) jax.Array#

Returns whether the current state is an episode termination.

A termination marks the end of an episode where the cost-to-go from this state is 0.

The equivalent step type in DMEnv is dm_env.termination.

Parameters:

state – The current simulator state.

Returns:

A boolean (…) tensor indicating whether the current state is the end

of an episode as a termination.

truncation(state: waymax.datatypes.SimulatorState) jax.Array#

Returns whether the current state should truncate the episode.

A truncation denotes that an episode has ended due to reaching the step limit of an episode. In these cases dynamic programming methods (e.g. Q-learning) should still compute cost-to-go assuming the episode will continue running.

The equivalent step type in DMEnv is dm_env.truncation.

Parameters:

state – The current simulator state.

Returns:

A boolean (…) tensor indicating whether the current state is the end of

an episode as a truncation.

action_spec() waymax.datatypes.Action#

Action spec of the environment.

reward_spec() dm_env.specs.Array#

Reward spec of the environment.

discount_spec() dm_env.specs.BoundedArray#

Discount spec of the environment.

observation_spec() waymax.env.typedefs.PyTree#

Observation spec of the environment.