waymax.agents
#
Waymax agent interfaces and sim agent implementations.
Submodules#
Package Contents#
Classes#
Interface that defines actor functionality for inference. |
|
Output of the Waymax actor including an action and its internal state. |
|
A sim agent policy that does not update object trajectories. |
|
Base class for simulated agents. |
|
A policy implementing the intelligent driver model (IDM). |
|
A base class for all waypoint-following sim agents. |
Functions#
|
Creates a WaymaxActorCore from pure functions. |
|
Combines multiple actor_outputs into one action instance. |
Constructs sim agent WaymaxActorCore objects from a config. |
|
Creates an actor with constant speed without changing objects' heading. |
|
|
Creates an expert agent using the WaymaxActorCore interface. |
- waymax.agents.actor_core_factory(init: Callable[[jax.Array, waymax.datatypes.SimulatorState], ActorState], select_action: Callable[[Params, waymax.datatypes.SimulatorState, ActorState, jax.Array], WaymaxActorOutput], name: str = 'WaymaxActorCore') WaymaxActorCore #
Creates a WaymaxActorCore from pure functions.
- Parameters:
init – A function that initializes the actor’s internal state. This is a generic type which can contain anything that the agent needs to pass through to the next call. The init function takes a random key to help randomize initialization and the initial timestep. It should return its specific internal state.
select_action – A function that selects an action given the current simulator state of the environment, the previous actor state and an optional random key. Returns the action and the updated internal actor state.
name – Name of the agent used for inspection and logging.
- Returns:
An actor core instance defined by init and select_action.
- waymax.agents.merge_actions(actor_outputs: Sequence[WaymaxActorOutput]) waymax.datatypes.Action #
Combines multiple actor_outputs into one action instance.
- Parameters:
actor_outputs – A sequence of WaymaxActorOutput to be combined, each corresponds to a different actor. Note different actor should not be controlling the same object (i.e. is_controlled flags from different actors should be disjoint). Note all actors must use the same dynamics model.
- Returns:
An Action consists of information from all actor outputs.
- class waymax.agents.WaymaxActorCore#
Bases:
abc.ABC
Interface that defines actor functionality for inference.
- abstract property name: str#
Name of the agent used for inspection and logging.
- abstract init(rng: jax.Array, state: waymax.datatypes.SimulatorState) ActorState #
Initializes the actor’s internal state.
ActorState is a generic type which can contain anything that the agent needs to pass through to the next call, e.g. for recurrent state or batch normalization. The init function takes a random key to help randomize initialization and the initial timestep.
- Parameters:
rng – A random key.
state – The initial simulator state.
- Returns:
The actor’s initial state.
- abstract select_action(params: Params, state: waymax.datatypes.SimulatorState, actor_state: ActorState, rng: jax.Array) WaymaxActorOutput #
Selects an action given the current simulator state.
- Parameters:
params – Actor parameters, e.g. neural network weights.
state – The current simulator state.
actor_state – The actor state, e.g. recurrent state or batch normalization.
rng – A random key.
- Returns:
An actor output containing the next action and actor state.
- class waymax.agents.WaymaxActorOutput#
Output of the Waymax actor including an action and its internal state.
- actor_state#
Internal state for whatever the agent needs to keep as its state. This can be recurrent embeddings or accounting information.
- action#
Action of shape (…, num_objects) predicted by the Waymax actor at the most recent simulation step given the inputs in the select_action function of WaymaxActorCore.
- is_controlled#
A binary indicator of shape (…, num_objects) representing which objects are controlled by the actor.
- actor_state: ActorState#
- is_controlled: jax.Array#
- validate()#
Validates shapes.
- waymax.agents.create_sim_agents_from_config(config: waymax.config.SimAgentConfig) waymax.agents.actor_core.WaymaxActorCore #
Constructs sim agent WaymaxActorCore objects from a config.
- Parameters:
config – Waymax sim agent config specifying agent type and controlled objects’ type.
- Returns:
Constructed sim agents.
- waymax.agents.create_constant_speed_actor(dynamics_model: waymax.dynamics.DynamicsModel, is_controlled_func: Callable[[waymax.datatypes.SimulatorState], jax.Array], speed: float | None = None) waymax.agents.actor_core.WaymaxActorCore #
Creates an actor with constant speed without changing objects’ heading.
Note the difference against ConstantSpeedPolicy is that an actor requires input of a dynamics model, while a policy does not (it assumes to use StateDynamics).
- Parameters:
dynamics_model – The dynamics model the actor is using that defines the action output by the actor.
is_controlled_func – Defines which objects are controlled by this actor.
speed – Speed of the actor, if None, speed from previous step is used.
- Returns:
An statelss actor that drives the controlled objects with constant speed.
- waymax.agents.create_expert_actor(dynamics_model: waymax.dynamics.DynamicsModel, is_controlled_func: Callable[[waymax.datatypes.SimulatorState], jax.Array] = _IS_SDC_FUNC) waymax.agents.actor_core.WaymaxActorCore #
Creates an expert agent using the WaymaxActorCore interface.
This agent infers an action from the expert by inferring an action using the logged data. It does this by calling the inverse function on the passed in dynamics parameter. It will return an action in the format returned by the dynamics parameter.
- Parameters:
dynamics_model – Dynamics model whose inverse function will be used to infer the expert action given the logged states.
is_controlled_func – A function that maps state to a controlled objects mask of shape (…, num_objects).
- Returns:
A Stateless Waymax actor which returns an expert action for all controlled objects (defined by is_controlled_func) by inferring the best-fit action given the logged state.
- class waymax.agents.FrozenSimPolicy(is_controlled_func: Callable[[waymax.datatypes.SimulatorState], jax.Array] | None = None)#
Bases:
SimAgentActor
A sim agent policy that does not update object trajectories.
This class is primarily intended to be used for testing or debugging purposes.
- update_trajectory(state: waymax.datatypes.SimulatorState) waymax.datatypes.TrajectoryUpdate #
Returns the current sim trajectory as the next update.
- class waymax.agents.SimAgentActor(is_controlled_func: Callable[[waymax.datatypes.SimulatorState], jax.Array] | None = None)#
Bases:
waymax.agents.actor_core.WaymaxActorCore
Base class for simulated agents.
Subclasses should implement the update_trajectory method. As SimAgentActor outputs TrajectoryUpdate actions, it is primarily intended to be used with the StateDynamics dynamics model.
- property name: str#
Name of the agent used for inspection and logging.
- abstract update_trajectory(state: waymax.datatypes.SimulatorState) waymax.datatypes.TrajectoryUpdate #
Updates the trajectory for all simulated agents.
- Parameters:
state – The current simulator state.
- Returns:
A trajectory update of shape (…, num_objects, num_timesteps=1) that contains the updated positions and velocities for all simulated agents for the next timestep.
- init(rng: jax.Array, state: waymax.datatypes.SimulatorState)#
Returns an empty initial state.
- select_action(params: waymax.agents.actor_core.Params, state: waymax.datatypes.SimulatorState, actor_state: Any, rng: jax.Array) waymax.agents.actor_core.WaymaxActorOutput #
Selects an action given the current simulator state.
- Parameters:
params – Actor parameters, e.g. neural network weights.
state – The current simulator state.
actor_state – The actor state, e.g. recurrent state or batch normalization.
rng – A random key.
- Returns:
An actor output containing the next action and actor state.
- class waymax.agents.IDMRoutePolicy(is_controlled_func: Callable[[waymax.datatypes.SimulatorState], jax.Array] | None = None, desired_vel: float = 30.0, min_spacing: float = 2.0, safe_time_headway: float = 2.0, max_accel: float = 2.0, max_decel: float = 4.0, delta: float = 4.0, max_lookahead: int = 10, lookahead_from_current_position: bool = True, additional_lookahead_points: int = 10, additional_lookahead_distance: float = 10.0, invalidate_on_end: bool = False)#
Bases:
WaypointFollowingPolicy
A policy implementing the intelligent driver model (IDM).
This policy uses IDM to compute the acceleration and velocities for the agent while it follows its own logged future.
- update_speed(state: waymax.datatypes.SimulatorState, dt: float = _DEFAULT_TIME_DELTA) tuple[jax.Array, jax.Array] #
Returns the new speed for each agent in the current simulation step.
- Parameters:
state – The simulator state of shape (…).
dt – Delta between timesteps of the simulator state.
- Returns:
A (…, num_objects) float array of speeds. valids: A (…, num_objects) bool array of valids.
- Return type:
speeds
- _get_accel(log_waypoints: waymax.datatypes.Trajectory, cur_position: jax.Array, cur_speed: jax.Array, obj_curr_traj: waymax.datatypes.Trajectory) jax.Array #
Computes vehicle accelerations according to IDM for a single vehicle.
Note log_waypoints and obj_curr_traj contain the same set of objects, thus need to remove collision against oneself when computing pairwise collision.
- Parameters:
log_waypoints – A trajectory of the agents’ future of shape (…, num_objects, num_timesteps).
cur_position – Current positions for the agents of shape (…, num_objects, 3).
cur_speed – Current speeds for the agents of shape (…, num_objects).
obj_curr_traj – Trajectory containing the state for all current objects of shape (…, num_objects, num_timesteps=1).
- Returns:
- A vector of all vehicles’ accelerations after solving for them of shape
(…, num_objects).
- _compute_lead_velocity(future_speeds: jax.Array, collisions_per_agent: jax.Array, future_speeds_valid: jax.Array | None = None) jax.Array #
Computes the velocity of the object at the closest collision.
- Parameters:
future_speeds – Future speeds per agent of shape (…, num_objects, num_timesteps).
collisions_per_agent – Future collision indications of shape (…, num_objects, num_timesteps).
future_speeds_valid – Boolean mask for future speeds of shape (…, num_objects, num_timesteps).
- Returns:
- An array containing the velocity of the colliding object at the
closest collision of shape (…).
- _compute_lead_distance(agent_future: jax.Array, collision_indicator: jax.Array, agent_future_valid: jax.Array | None = None, current_position: jax.Array | None = None, use_arclength=False) jax.Array #
Computes the distance between the agent and the nearest collision.
- Parameters:
agent_future – Agent’s future positions {x, y, z} of shape (…, num_timesteps, 3).
collision_indicator – Collision indications of shape (…, num_timesteps).
agent_future_valid – Boolean mask for agent’s future positions of shape (…, num_timesteps).
current_position – Array of the vehicle’s current positions {x, y, z} of shape (…, 1, 3). If None, will use the first element of agent_future as the current position.
use_arclength – Whether to use arclength for computing collisions. Arclength is more accurate but is not robust to futures with mixed valids.
- Returns:
An array of distances to the agent’s closest collision of shape (…).
- class waymax.agents.WaypointFollowingPolicy(is_controlled_func: Callable[[waymax.datatypes.SimulatorState], jax.Array] | None = None, invalidate_on_end: bool = False)#
Bases:
waymax.agents.sim_agent.SimAgentActor
A base class for all waypoint-following sim agents.
The WaypointFollowingPolicy will force sim agents to travel along a pre-defined path (the agent’s future in the log trajectory). The behavior of the vehicle is determined by setting its speed via the update_speed() method, which will update the velocity of the vehicle.
- update_trajectory(state: waymax.datatypes.SimulatorState) waymax.datatypes.TrajectoryUpdate #
Returns a trajectory update of shape (…, num_objects, 1).
- _get_next_trajectory_by_projection(log_traj: waymax.datatypes.Trajectory, cur_sim_traj: waymax.datatypes.Trajectory, new_speed: jax.Array, new_speed_valid: jax.Array, dt: float = _DEFAULT_TIME_DELTA) waymax.datatypes.Trajectory #
Computes the next trajectory.
- Parameters:
log_traj – Logged trajectory for the simulation of shape (…, num_objects, num_timesteps).
cur_sim_traj – Current simulated trajectory for the simulation of shape (…, num_objects, num_timesteps=1).
new_speed – Updated speed for the agents after solving for velocity of shape (…, num_objects).
new_speed_valid – Updated validity for the speed updates of the agents after (…, num_objects).
dt – Delta between timesteps of the simulator state.
- Returns:
- The next Trajectory projected onto log_traj of shape
(…, num_objects, num_timesteps=1).
- abstract update_speed(state: waymax.datatypes.SimulatorState, dt: float = _DEFAULT_TIME_DELTA) tuple[jax.Array, jax.Array] #
Updates the speed for each agent in the current simulation step.
- Parameters:
state – The simulator state of shape (…).
dt – Delta between timesteps of the simulator state.
- Returns:
A (…, num_objects) float array of speeds. valids: A (…, num_objects) bool array of valids.
- Return type:
speeds