waymax.dataloader.womd_utils#

General settings and utility functions specifically for the WOMD data.

These functions are mainly intended to be iternal into the Waymax data library.

WOMD (go/womd) represents the data we typically use for simulation in our environment. See https://waymo.com/open/data/motion/tfexample for definitions on many of the data fields mentioned in this file.

Module Contents#

Functions#

aggregate_time_tensors(→ dict[str, tensorflow.Tensor])

Combines all past/current/future fields into an 'all' field.

get_features_description(→ dict[str, ...)

Returns a dictionary of all features to be extracted.

simulator_state_to_womd_dict_tensorflow(→ dict[str, ...)

Tensorflow version of the simulator state to WOMD dict converter.

simulator_state_to_womd_dict(→ dict[str, jax.Array])

Converts a simulator state into the WOMD tensor format.

_roadgraph_to_dict(→ dict[str, jax.Array])

Gets the roadgrpah mpdata from the simulator state.

_trajectory_to_dict(→ dict[str, jax.Array])

Generates the mpdata fields for Trajectory data.

_traffic_light_to_dict(→ dict[str, jax.Array])

Generates the corresponding mpdata for TrafficLights fields.

_object_metadata_to_dict(→ dict[str, jax.Array])

Converts object metadata to the original tf.Example format.

_get_invalid_future_trajectory(→ dict[str, jax.Array])

Gets an invalid trajectory representing future.

_get_invalid_future_traffic_light(→ dict[str, jax.Array])

Gets an invalid traffic light representing future.

Attributes#

waymax.dataloader.womd_utils.DEFAULT_FLOAT#
waymax.dataloader.womd_utils.DEFAULT_INT#
waymax.dataloader.womd_utils.DEFAULT_BOOL = False#
waymax.dataloader.womd_utils._TF_TO_JNP_DTYPE#
waymax.dataloader.womd_utils.TL_TIMESTAMP_STEP_AXIS#
waymax.dataloader.womd_utils.TL_STEP_AXIS#
waymax.dataloader.womd_utils.aggregate_time_tensors(decoded_tensors: dict[str, tensorflow.Tensor]) dict[str, tensorflow.Tensor]#

Combines all past/current/future fields into an ‘all’ field.

Note the original past/current/future keys are removed in the returned dict.

Parameters:

decoded_tensors – input dict of tensors keyed by string.

Returns:

past/current/future are merged into all.

Return type:

A new dict of tensors keyed by updated string

waymax.dataloader.womd_utils.get_features_description(max_num_objects: int = 128, max_num_rg_points: int = 30000, include_sdc_paths: bool = False, num_paths: int | None = 45, num_points_per_path: int | None = 800, num_tls: int | None = 16) dict[str, tensorflow.io.FixedLenFeature]#

Returns a dictionary of all features to be extracted.

Parameters:
  • max_num_objects – Max number of objects.

  • max_num_rg_points – Max number of sampled roadgraph points.

  • include_sdc_paths – Whether to include roadgraph traversal paths for the SDC.

  • num_paths – Optional number of SDC paths. Must be defined if include_sdc_paths is True.

  • num_points_per_path – Optional number of points per SDC path. Must be defined if include_sdc_paths is True.

  • num_tls – Maximum number of traffic lights.

Returns:

Dictionary of all features to be extracted.

Raises:

ValueError – If include_sdc_paths is True but either num_paths or num_points_per_path is None.

waymax.dataloader.womd_utils.simulator_state_to_womd_dict_tensorflow(state: waymax.datatypes.simulator_state.SimulatorState, feature_description: dict[str, tensorflow.io.FixedLenFeature], validate: bool = False) dict[str, tensorflow.Tensor]#

Tensorflow version of the simulator state to WOMD dict converter.

waymax.dataloader.womd_utils.simulator_state_to_womd_dict(state: waymax.datatypes.simulator_state.SimulatorState, feature_description: dict[str, tensorflow.io.FixedLenFeature], validate: bool = False) dict[str, jax.Array]#

Converts a simulator state into the WOMD tensor format.

See https://waymo.com/open/data/motion/tfexample for the tf.Example format which will be returned from this function. Note: This function is compatible with jax2tf.

Parameters:
  • state – State of the simulator from the environment. Should contain at least num_history + 1 elements in the time dimension for all temporal components of the simulated trajectory.

  • feature_description – Feature description expected out of the dictionary. This is used to understand the shape of the fields expected such as number of agents and amount of history.

  • validate – Validate whether the simulation has progressed far enough to ensure that an adequate amount of history is present.

Returns:

A dictionary matching fields as if they were read from the WOMD dataset.

Raises:

ValueError – If validate is set to True and the number of history stored in the observations is not num_history + 1.

waymax.dataloader.womd_utils._roadgraph_to_dict(rg: waymax.datatypes.roadgraph.RoadgraphPoints, prefix: str = 'roadgraph_samples') dict[str, jax.Array]#

Gets the roadgrpah mpdata from the simulator state.

waymax.dataloader.womd_utils._trajectory_to_dict(trajectory: waymax.datatypes.object_state.Trajectory, time_prefix: str) dict[str, jax.Array]#

Generates the mpdata fields for Trajectory data.

waymax.dataloader.womd_utils._traffic_light_to_dict(tls: waymax.datatypes.traffic_lights.TrafficLights, time_prefix: str, timestamp_micros: jax.Array) dict[str, jax.Array]#

Generates the corresponding mpdata for TrafficLights fields.

waymax.dataloader.womd_utils._object_metadata_to_dict(metadata: waymax.datatypes.object_state.ObjectMetadata) dict[str, jax.Array]#

Converts object metadata to the original tf.Example format.

waymax.dataloader.womd_utils._get_invalid_future_trajectory(feature_description: dict[str, tensorflow.io.FixedLenFeature]) dict[str, jax.Array]#

Gets an invalid trajectory representing future.

waymax.dataloader.womd_utils._get_invalid_future_traffic_light(feature_description: dict[str, tensorflow.io.FixedLenFeature]) dict[str, jax.Array]#

Gets an invalid traffic light representing future.