waymax.visualization.viz#

Visualization functions for Waymax data structures.

Module Contents#

Functions#

_plot_bounding_boxes(→ None)

Helper function to plot multiple bounding boxes across time.

_index_pytree(→ Any)

Helper function to get idx-th example in a batch.

plot_trajectory(→ None)

Plots a Trajectory with different color for controlled and context.

plot_roadgraph_points(→ None)

Plots road graph as points.

plot_traffic_light_signals_as_points(→ None)

Plots traffic lights for timestep.

_plot_path_points(→ None)

Plots on/off route paths.

plot_simulator_state(→ numpy.ndarray)

Plots np array image for SimulatorState.

plot_observation(→ numpy.ndarray)

Plots np array image for an Observation.

plot_single_agent_brax_timestep(→ numpy.ndarray)

Plots np array image for Brax TimeStep with metrics.

Attributes#

waymax.visualization.viz._RoadGraphShown = (1, 2, 3, 15, 16, 17, 18, 19)#
waymax.visualization.viz._RoadGraphDefaultColor = (0.9, 0.9, 0.9)#
waymax.visualization.viz._plot_bounding_boxes(ax: matplotlib.axes.Axes, traj_5dof: numpy.ndarray, time_idx: int, is_controlled: numpy.ndarray, valid: numpy.ndarray, add_label: bool = False) None#

Helper function to plot multiple bounding boxes across time.

waymax.visualization.viz._index_pytree(inputs: Any, idx: int) Any#

Helper function to get idx-th example in a batch.

waymax.visualization.viz.plot_trajectory(ax: matplotlib.axes.Axes, traj: waymax.datatypes.Trajectory, is_controlled: numpy.ndarray, time_idx: int | None = None, indices: numpy.ndarray | None = None, add_label: bool = False) None#

Plots a Trajectory with different color for controlled and context.

Plots the full bounding_boxes only for time_idx step, overlap is highlighted.

Notation: A: number of agents; T: numbe of time steps; 5 degree of freedom: center x, center y, length, width, yaw.

Parameters:
  • ax – matplotlib axes.

  • traj – a Trajectory with shape (A, T).

  • is_controlled – binary mask for controlled object, shape (A,).

  • time_idx – step index to highlight bbox, -1 for last step. Default(None) for not showing bbox.

  • indices – ids to show for each agents if not None, shape (A,).

  • add_label – a boolean that indicates whether or not to plot labels that indicates different agent types, including ‘controlled’, ‘overlap’, ‘history’, ‘context’.

waymax.visualization.viz.plot_roadgraph_points(ax: matplotlib.axes.Axes, rg_pts: waymax.datatypes.RoadgraphPoints, verbose: bool = False) None#

Plots road graph as points.

Parameters:
  • ax – matplotlib axes.

  • rg_pts – a RoadgraphPoints with shape (1,)

  • verbose – print roadgraph points count if set to True.

waymax.visualization.viz.plot_traffic_light_signals_as_points(ax: matplotlib.axes.Axes, tls: waymax.datatypes.TrafficLights, timestep: int = 0, verbose: bool = False) None#

Plots traffic lights for timestep.

Parameters:
  • ax – matplotlib axes.

  • tls – a TrafficLightStates to show.

  • timestep – draw traffi lights at this given timestep.

  • verbose – print traffic lights count if set to True.

waymax.visualization.viz._plot_path_points(ax: matplotlib.axes.Axes, paths: waymax.datatypes.Paths) None#

Plots on/off route paths.

waymax.visualization.viz.plot_simulator_state(state: waymax.datatypes.SimulatorState, use_log_traj: bool = True, viz_config: dict[str, Any] | None = None, batch_idx: int = -1, highlight_obj: waymax.config.ObjectType = waymax_config.ObjectType.SDC) numpy.ndarray#

Plots np array image for SimulatorState.

Parameters:
  • state – A SimulatorState instance.

  • use_log_traj – Set True to use logged trajectory, o/w uses simulated trajectory.

  • viz_config – dict for optional config.

  • batch_idx – optional batch index.

  • highlight_obj – Represents the type of objects that will be highlighted with color.COLOR_DICT[‘controlled’] color.

Returns:

np image.

waymax.visualization.viz.plot_observation(obs: waymax.datatypes.Observation, obj_idx: int, viz_config: dict[str, Any] | None = None, batch_idx: int = -1, highlight_obj: waymax.config.ObjectType = waymax_config.ObjectType.SDC) numpy.ndarray#

Plots np array image for an Observation.

Parameters:
  • obs – An Observation instance, with shape (…, obs_A), where obs_A represents the number of objects that have observation view over things including other objects, roadgraph, and traffic lights.

  • obj_idx – The object index in obs_A.

  • viz_config – Dict for optional config.

  • batch_idx – Optional batch index.

  • highlight_obj – Represents the type of objects that will be highlighted with color.COLOR_DICT[‘controlled’] color.

Returns:

np image.

waymax.visualization.viz.plot_single_agent_brax_timestep(waymax_ts: waymax.env.wrappers.brax_wrapper.TimeStep, use_log_traj: bool = False, viz_config: dict[str, Any] | None = None, batch_idx: int = -1) numpy.ndarray#

Plots np array image for Brax TimeStep with metrics.

Currently only for single-agent env outputs.

Parameters:
  • waymax_ts – Timestep returned from Waymax env step or reset.

  • use_log_traj – Set True to use logged trajectory, o/w uses simulated trajectory.

  • viz_config – dict for optional config.

  • batch_idx – optional batch index.

Returns:

np image.