waymax.env#

Reinforcement learning environment interfaces.

Subpackages#

Submodules#

Package Contents#

Classes#

AbstractEnvironment

A stateless environment interface for Waymax.

BaseEnvironment

Waymax environment for multi-agent scenarios.

PlanningAgentDynamics

A dynamics wrapper for converting multi-agent dynamics to single-agent.

PlanningAgentEnvironment

An environment wrapper allowing for controlling a single agent.

PlanningAgentSimulatorState

Simulator state for the planning agent environment.

RolloutOutput

Rollout output datatypes.structure for using as output of rollout function.

BraxWrapper

Brax-like interface wrapper for the Waymax environment.

DMEnvWrapper

A stateful environment wrapper implementing the DMEnv interface.

Functions#

rollout

Utilities for fast jittable rollout of environments in Waymax.

rollout_log_by_expert_sdc(→ RolloutOutput)

Rollouts state using logged expert actions specified by dynamics_model.

Attributes#

class waymax.env.AbstractEnvironment#

Bases: abc.ABC

A stateless environment interface for Waymax.

abstract reset(scenario: waymax.env.typedefs.GenericScenario, rng: jax.Array | None = None) waymax.env.typedefs.GenericState#

Initializes a simulation state.

This method allows the environment to perform optional postprocessing on the state before the episode begins. By default this method is a no-op.

Parameters:
  • scenario – Scenario used to generate the initial state.

  • rng – Optional random number generator for stochastic environments.

Returns:

The initialized simulation state.

abstract step(state: waymax.env.typedefs.GenericState, actions: waymax.env.typedefs.GenericAction, rng: jax.Array | None = None) waymax.env.typedefs.GenericState#

Advances the simulation by one timestep.

Parameters:
  • state – The current state of the simulator.

  • actions – Action to apply to the state to produce the updated simulator state.

  • rng – Optional random number generator for stochastic environments.

Returns:

The next simulation state after taking an action.

abstract reward(state: waymax.env.typedefs.GenericState, action: waymax.env.typedefs.GenericAction) jax.Array#

Computes the reward for a transition.

Parameters:
  • state – The state used to compute the reward.

  • action – The action applied to state.

Returns:

A (…, num_objects) tensor of rewards.

abstract metrics(state: waymax.env.typedefs.GenericState) waymax.env.typedefs.Metrics#

Computes a set of metrics which score a given simulator state.

Parameters:

state – The state used to compute the metrics.

Returns:

A mapping from metric name to metrics which evaluate a simulator state at

state.timestep where all of the metrics are of shape (…, num_objects).

abstract observe(state: waymax.env.typedefs.GenericState) waymax.env.typedefs.Observation#

Computes the observation of the simulator for the actor.

Parameters:

state – The state used to compute the observation.

Returns:

An observation of the simulator state for the given timestep of shape

(…).

abstract action_spec() waymax.env.typedefs.GenericAction#

Returns the action specs of the environment without batch dimension.

Returns:

The action specs represented as a PyTree where the leaves

are instances of specs.Array.

abstract reward_spec() dm_env.specs.Array#

Returns the reward specs of the environment without batch dimension.

abstract discount_spec() dm_env.specs.BoundedArray#

Returns the discount specs of the environment without batch dimension.

abstract observation_spec() waymax.env.typedefs.PyTree#

Returns the observation specs of the environment without batch dimension.

Returns:

The observation specs represented as a PyTree where the

leaves are instances of specs.Array.

termination(state: waymax.env.typedefs.GenericState) jax.Array#

Returns whether the current state is an episode termination.

A termination marks the end of an episode where the cost-to-go from this state is 0.

The equivalent step type in DMEnv is dm_env.termination.

Parameters:

state – The current simulator state.

Returns:

A boolean (…) tensor indicating whether the current state is the end of

an episode as a termination.

truncation(state: waymax.env.typedefs.GenericState) jax.Array#

Returns whether the current state should truncate the episode.

A truncation denotes that an episode has ended due to reaching the step limit of an episode. In these cases dynamic programming methods (e.g. Q-learning) should still compute cost-to-go assuming the episode will continue running.

The equivalent step type in DMEnv is dm_env.truncation.

Parameters:

state – The current simulator state.

Returns:

A boolean (…) tensor indicating whether the current state is the end of

an episode as a truncation.

class waymax.env.BaseEnvironment(dynamics_model: waymax.dynamics.DynamicsModel, config: waymax.config.EnvironmentConfig)#

Bases: waymax.env.abstract_environment.AbstractEnvironment

Waymax environment for multi-agent scenarios.

property dynamics: waymax.dynamics.DynamicsModel#
metrics(state: waymax.datatypes.SimulatorState) waymax.env.typedefs.Metrics#

Computes metrics (lower is better) from state.

reset(state: waymax.datatypes.SimulatorState, rng: jax.Array | None = None) waymax.datatypes.SimulatorState#

Initializes the simulation state.

This initializer sets the initial timestep and fills the initial simulation trajectory with invalid values.

Parameters:
  • state – An uninitialized state of shape (…).

  • rng – Optional random number generator for stochastic environments.

Returns:

The initialized simulation state of shape (…).

observe(state: waymax.datatypes.SimulatorState) waymax.env.typedefs.Observation#

Computes the observation for the given simulation state.

Here we assume that the default observation is just the simulator state. We leave this for the user to override in order to provide a user-specific observation function. A user can use this to move some of their model specific post-processing into the environment rollout in the actor nodes. If they want this post-processing on the accelertor, they can keep this the same and implement it on the learner side. We provide some helper functions at datatypes.observation.py to help write your own observation functions.

Parameters:

state – Current state of the simulator of shape (…).

Returns:

Simulator state as an observation without modifications of shape (…).

step(state: waymax.datatypes.SimulatorState, action: waymax.datatypes.Action, rng: jax.Array | None = None) waymax.datatypes.SimulatorState#

Advances simulation by one timestep using the dynamics model.

Parameters:
  • state – The current state of the simulator of shape (…).

  • action – The action to apply, of shape (…, num_objects). The actions.valid field is used to denote which objects are being controlled - objects whose valid is False will fallback to default behavior specified by self.dynamics.

  • rng – Optional random number generator for stochastic environments.

Returns:

The next simulation state after taking an action of shape (…).

reward(state: waymax.datatypes.SimulatorState, action: waymax.datatypes.Action) jax.Array#

Computes the reward for a transition.

Parameters:
  • state – The state used to compute the reward at state.timestep of shape (…).

  • action – The action applied to state of shape (…, num_objects, dim).

Returns:

An array of rewards of shape (…, num_objects).

action_spec() waymax.datatypes.Action#

Returns the action specs of the environment without batch dimension.

Returns:

The action specs represented as a PyTree where the leaves

are instances of specs.Array.

reward_spec() dm_env.specs.Array#

Returns the reward specs of the environment without batch dimension.

discount_spec() dm_env.specs.BoundedArray#

Returns the discount specs of the environment without batch dimension.

abstract observation_spec() waymax.env.typedefs.Observation#

Returns the observation specs of the environment without batch dimension.

Returns:

The observation specs represented as a PyTree where the

leaves are instances of specs.Array.

waymax.env.MultiAgentEnvironment#
exception waymax.env.EpisodeAlreadyFinishedError#

Bases: RuntimeError

Error thrown when attempting to advance an episode that is finished.

exception waymax.env.SimulationNotInitializedError#

Bases: RuntimeError

Error thrown when attempting to advance an episode before reset.

class waymax.env.PlanningAgentDynamics(multi_agent_dynamics: waymax.dynamics.DynamicsModel)#

Bases: waymax.dynamics.DynamicsModel

A dynamics wrapper for converting multi-agent dynamics to single-agent.

action_spec() dm_env.specs.BoundedArray#

Action spec of the action containing the bounds.

compute_update(action: waymax.datatypes.Action, trajectory: waymax.datatypes.Trajectory) waymax.datatypes.TrajectoryUpdate#

Computes the pose and velocity updates at timestep.

forward(action: waymax.datatypes.Action, trajectory: waymax.datatypes.Trajectory, log_trajectory: waymax.datatypes.Trajectory, is_controlled: jax.Array, timestep: int, allow_new_objects: bool = True) waymax.datatypes.Trajectory#

Updates a simulated trajectory to the next timestep given an update.

Runs the forward model for the planning agent by taking in a single object’s action and tiling it for all others and then running the wrapped action.

Parameters:
  • action – Actions to be applied to the trajectory to produce updates at the next timestep of shape (…, dim).

  • trajectory – Simulated trajectory up to the current timestep. This trajectory will be updated by this function updated with the trajectory update. It is expected that this trajectory will have been updated up to timestep. This is of shape: (…, num_objects, num_timesteps).

  • log_trajectory – Logged trajectory for all objects over the entire run segment. Certain fields such as valid are optionally taken from this trajectory. This is of shape: (…, num_objects, num_timesteps).

  • is_controlled – Boolean array specifying which objects are to be controlled by the trajectory update of shape (…, num_objects).

  • timestep – Timestep of the current simulation.

  • allow_new_objects – Whether to allow new objects to enter the secene. If this is set to False, all objects that are not valid at the current timestep will not be valid at the next timestep and visa versa.

Returns:

Updated trajectory given update from a dynamics model at timestep + 1

of shape (…, num_objects, num_timesteps).

inverse(trajectory: waymax.datatypes.Trajectory, metadata: waymax.datatypes.ObjectMetadata, timestep: int) waymax.datatypes.Action#

Computes actions converting traj[timestep] to traj[timestep+1].

Runs the wrapped dynamics inverse and slices out the sdc’s action specifically.

Parameters:
  • trajectory – Full trajectory to compute the inverse actions from of shape (…, num_objects, num_timesteps). This trajectory is for the entire simulation so that dynamics models can use sophisticated otpimization techniques to find the best fitting actions.

  • metadata – Metadata on all objects in the scene which contains information about what types of objects are in the scene of shape (…, num_objects).

  • timestep – Current timestpe of the simulation.

Returns:

Action which will take a set of objects from trajectory[timestep] to

trajectory[timestep + 1] of shape (…, num_objects, dim).

class waymax.env.PlanningAgentEnvironment(dynamics_model: waymax.dynamics.DynamicsModel, config: waymax.config.EnvironmentConfig, sim_agent_actors: Sequence[waymax.agents.actor_core.WaymaxActorCore] = (), sim_agent_params: Sequence[waymax.agents.actor_core.Params] = ())#

Bases: waymax.env.abstract_environment.AbstractEnvironment

An environment wrapper allowing for controlling a single agent.

The PlanningAgentEnvironment inherits from a multi-agent BaseEnvironment to build a single-agent environment by returning only the observations and rewards corresponding to the ego-agent (i.e. ADV).

Note that while the action and reward no longer have an obj dimension as expected for a single agent env, the observation retains the obj dimension set to 1 to conform with the observation datastructure.

property dynamics: waymax.dynamics.DynamicsModel#
reset(state: waymax.datatypes.SimulatorState, rng: jax.Array | None = None) PlanningAgentSimulatorState#

Initializes the simulation state.

This initializer sets the initial timestep and fills the initial simulation trajectory with invalid values.

Parameters:
  • state – An uninitialized state of shape (…).

  • rng – Optional random number generator for stochastic environments.

Returns:

The initialized simulation state of shape (…).

observe(state: PlanningAgentSimulatorState) waymax.env.typedefs.Observation#

Computes the observation for the given simulation state.

Here we assume that the default observation is just the simulator state. We leave this for the user to override in order to provide a user-specific observation function. A user can use this to move some of their model specific post-processing into the environment rollout in the actor nodes. If they want this post-processing on the accelerator, they can keep this the same and implement it on the learner side. We provide some helper functions at datatypes.observation.py to help write your own observation functions.

Parameters:

state – Current state of the simulator of shape (…).

Returns:

Simulator state as an observation without modifications of shape (…).

metrics(state: PlanningAgentSimulatorState) waymax.env.typedefs.Metrics#

Computes the metrics for the single agent wrapper.

The metrics to be computed are based on those specified by the configuration passed into the environment. This runs metrics that may be specific to the planning agent case.

Parameters:

state – State of simulation to compute the metrics for. This will compute metrics for the timestep corresponding to state.timestep of shape (…).

Returns:

Dictionary from metric name to metrics.MetricResult which represents the

metrics calculated at state.timestep. All metrics assumed to be shaped (…, num_objects=1) unless specified in the metrics implementation.

reward(state: PlanningAgentSimulatorState, action: waymax.datatypes.Action) jax.Array#

Computes the reward for a transition.

Parameters:
  • state – State of simulation to compute the metrics for. This will compute reward for the timestep corresponding to state.timestep of shape (…).

  • action – The action applied for the state.

Returns:

A float (…) tensor of rewards for the single agent.

action_spec() waymax.datatypes.Action#

Returns the action specs of the environment without batch dimension.

Returns:

The action specs represented as a PyTree where the leaves

are instances of specs.Array.

step(state: PlanningAgentSimulatorState, action: waymax.datatypes.Action, rng: jax.Array | None = None) PlanningAgentSimulatorState#

Advances simulation by one timestep using the dynamics model.

Parameters:
  • state – The current state of the simulator of shape (…).

  • action – The action to apply, of shape (…, num_objects). The actions.valid field is used to denote which objects are being controlled - objects whose valid is False will fallback to default behavior specified by self.dynamics.

  • rng – Optional random number generator for stochastic environments.

Returns:

The next simulation state after taking an action of shape (…).

reward_spec() dm_env.specs.Array#

Specify the reward spec as just for one object.

discount_spec() dm_env.specs.BoundedArray#

Returns the discount specs of the environment without batch dimension.

abstract observation_spec() waymax.env.typedefs.Observation#

Returns the observation specs of the environment without batch dimension.

Returns:

The observation specs represented as a PyTree where the

leaves are instances of specs.Array.

class waymax.env.PlanningAgentSimulatorState#

Bases: waymax.datatypes.SimulatorState

Simulator state for the planning agent environment.

sim_agent_actor_states#

State of the sim agents that are being run inside of the environment step function. If sim agents state is provided, this will be updated. The list of sim agent states should be as long as and in the same order as the number of sim agents run in the environment.

sim_agent_actor_states: Sequence[waymax.agents.actor_core.ActorState] = ()#
waymax.env.rollout(scenario: waymax.env.typedefs.GenericScenario, actor: waymax.agents.actor_core.WaymaxActorCore, env: waymax.env.abstract_environment.AbstractEnvironment, rng: jax.Array, rollout_num_steps: int = 1, actor_params: waymax.agents.actor_core.Params | None = None) RolloutOutput#

Performs a rollout from the beginning of a run segment.

Parameters:
  • scenario – initial SimulatorState to start the rollout of shape (…).

  • actor – The action function used to select actions during the rollout.

  • env – A stateless Waymax environment used for computing steps, observations, and rewards.

  • rng – Random key used for generate stochastic actions if needed.

  • rollout_num_steps – number of rollout steps.

  • actor_params – Parameters used by actor to select actions. It can be None if the actor does not require parameters.

Returns:

Stacked rollout output of shape (rollout_num_steps + 1, …) from the

simulator when taking an action given the action_fn. There is one extra in the time dimension compared to rollout_num_steps. This is because we prepend the initial timetep to the timestep field and append an invalid action into the action field.

waymax.env.rollout_log_by_expert_sdc(scenario: waymax.env.typedefs.GenericScenario, env: waymax.env.abstract_environment.AbstractEnvironment, dynamics_model: waymax.dynamics.DynamicsModel, rollout_num_steps: int = 1) RolloutOutput#

Rollouts state using logged expert actions specified by dynamics_model.

class waymax.env.RolloutOutput#

Rollout output datatypes.structure for using as output of rollout function.

action#

Action produced by a functional corresponding to ActionFuncType which after taking by calling some environment.step(action) produces the timestep information. This is aggregated over a number of timesteps and so the shape is (num_timesteps, …, num_objects, action_dim). The of the shapes correspond to any kind of prefix for batching that might be applied.

state#

Temporally aggregated information of the output of the simulation after calling environment.step(action). This information represents the important information from the simulation aggregated through the rollout of shape (num_timesteps, …). The first element of state corresponds to the initial simulation state.

observation#

Temporally aggregated information of the output of the simulation after calling observe(environment.step(action)). This information represents the observation of the agent of the simulator state aggregated through the rollout of shape (num_timesteps, …). The first element of observation corresponds to the initial simulation state.

metrics#

Mapping from metric name to metric which contains metrics computed on the simulator states aggregated in time of shape (num_timestpes, …). These functions are defined in the env.metrics(state) function. As this is a mapping, these metrics could be empty if the environment decides not to produce metrics. This could be due to speed reasons during the rollout.

reward#

Scalar value of shape (num_timesteps, …, num_objects) which represents the reward achieved at a certain simulator state at the given state.timestep.

property shape: tuple[int, Ellipsis]#

Returns the shape prefix for the rollout type.

action: waymax.env.typedefs.GenericAction#
state: waymax.env.typedefs.GenericState#
observation: waymax.env.typedefs.Observation#
metrics: waymax.env.typedefs.Metrics#
reward: jax.Array#
validate()#

Validates the shape prefix of the actions and timesteps.

waymax.env.Metrics#
waymax.env.Observation#
waymax.env.ObservationFn#
waymax.env.PyTree#
waymax.env.RewardFn#
class waymax.env.BraxWrapper(wrapped_env: waymax.env.abstract_environment.AbstractEnvironment, dynamics_model: waymax.dynamics.DynamicsModel, config: waymax.config.EnvironmentConfig)#

Brax-like interface wrapper for the Waymax environment.

metrics(state: waymax.datatypes.SimulatorState) waymax.env.typedefs.Metrics#

Computes metrics (lower is better) from state.

reset(state: waymax.datatypes.SimulatorState) TimeStep#

Resets the environment and initializes the simulation state.

This initializer sets the initial timestep and fills the initial simulation trajectory with invalid values.

Parameters:

state – An uninitialized state.

Returns:

The initialized simulation state.

observe(state: waymax.datatypes.SimulatorState) waymax.env.typedefs.Observation#

Computes the observation for the given simulation state.

step(timestep: TimeStep, action: waymax.datatypes.Action) TimeStep#

Advances simulation by one timestep using the dynamics model.

Parameters:
  • timestep – The timestep containing the current state.

  • action – The action to apply, of shape (…, num_objects). The actions.valid field is used to denote which objects are being controlled - objects whose valid is False will fallback to default behavior specified by self.dynamics.

Returns:

The timestep corresponding to the transition taken.

reward(state: waymax.datatypes.SimulatorState, action: waymax.datatypes.Action) jax.Array#

Computes the reward for a transition.

Parameters:
  • state – The state used to compute the reward at state.timestep.

  • action – The action applied to state.

Returns:

A (…, num_objects) tensor of rewards.

termination(state: waymax.datatypes.SimulatorState) jax.Array#

Returns whether the current state is an episode termination.

A termination marks the end of an episode where the cost-to-go from this state is 0.

The equivalent step type in DMEnv is dm_env.termination.

Parameters:

state – The current simulator state.

Returns:

A boolean (…) tensor indicating whether the current state is the end

of an episode as a termination.

truncation(state: waymax.datatypes.SimulatorState) jax.Array#

Returns whether the current state should truncate the episode.

A truncation denotes that an episode has ended due to reaching the step limit of an episode. In these cases dynamic programming methods (e.g. Q-learning) should still compute cost-to-go assuming the episode will continue running.

The equivalent step type in DMEnv is dm_env.truncation.

Parameters:

state – The current simulator state.

Returns:

A boolean (…) tensor indicating whether the current state is the end of

an episode as a truncation.

action_spec() waymax.datatypes.Action#

Action spec of the environment.

reward_spec() dm_env.specs.Array#

Reward spec of the environment.

discount_spec() dm_env.specs.BoundedArray#

Discount spec of the environment.

observation_spec() waymax.env.typedefs.PyTree#

Observation spec of the environment.

class waymax.env.DMEnvWrapper(data_generator: Iterator[waymax.datatypes.SimulatorState], stateless_env: waymax.env.abstract_environment.AbstractEnvironment, squeeze_scalar_actions: bool = True)#

Bases: dm_env.Environment

A stateful environment wrapper implementing the DMEnv interface.

property config: waymax.config.EnvironmentConfig#
property simulation_state: waymax.datatypes.SimulatorState#

The current simulation state.

property stateless_env: waymax.env.abstract_environment.AbstractEnvironment#

The underlying stateless Waymax environment.

observe(state: waymax.datatypes.SimulatorState) waymax.env.typedefs.Observation#

Runs the stateless environment observation function.

reset() dm_env.TimeStep#

Resets the environment and returns the initial TimeStep.

step(action: jax.Array) dm_env.TimeStep#

Advances the state given an action.

Parameters:

action – An action with shape compatible with self.action_spec()

Returns:

The TimeStep corresponding to the transition taken by applying

action to the current state.

Raises:
action_spec() dm_env.specs.BoundedArray#

The action specs of this environment, without batch dimension.

discount_spec() dm_env.specs.BoundedArray#

The discount specs of this environment, without batch dimension.

observation_spec() waymax.env.typedefs.PyTree#

The observation specs of this environment, without batch dimension.

reward_spec() dm_env.specs.Array#

The reward specs of this environment, without batch dimension.