waymax.dataloader.womd_utils
#
General settings and utility functions specifically for the WOMD data.
These functions are mainly intended to be iternal into the Waymax data library.
WOMD (go/womd) represents the data we typically use for simulation in our environment. See https://waymo.com/open/data/motion/tfexample for definitions on many of the data fields mentioned in this file.
Module Contents#
Functions#
|
Combines all past/current/future fields into an 'all' field. |
|
Returns a dictionary of all features to be extracted. |
|
Tensorflow version of the simulator state to WOMD dict converter. |
|
Converts a simulator state into the WOMD tensor format. |
|
Gets the roadgrpah mpdata from the simulator state. |
|
Generates the mpdata fields for Trajectory data. |
|
Generates the corresponding mpdata for TrafficLights fields. |
|
Converts object metadata to the original tf.Example format. |
|
Gets an invalid trajectory representing future. |
|
Gets an invalid traffic light representing future. |
Attributes#
- waymax.dataloader.womd_utils.DEFAULT_FLOAT#
- waymax.dataloader.womd_utils.DEFAULT_INT#
- waymax.dataloader.womd_utils.DEFAULT_BOOL = False#
- waymax.dataloader.womd_utils._TF_TO_JNP_DTYPE#
- waymax.dataloader.womd_utils.TL_TIMESTAMP_STEP_AXIS#
- waymax.dataloader.womd_utils.TL_STEP_AXIS#
- waymax.dataloader.womd_utils.aggregate_time_tensors(decoded_tensors: dict[str, tensorflow.Tensor]) dict[str, tensorflow.Tensor] #
Combines all past/current/future fields into an ‘all’ field.
Note the original past/current/future keys are removed in the returned dict.
- Parameters:
decoded_tensors – input dict of tensors keyed by string.
- Returns:
past/current/future are merged into all.
- Return type:
A new dict of tensors keyed by updated string
- waymax.dataloader.womd_utils.get_features_description(max_num_objects: int = 128, max_num_rg_points: int = 30000, include_sdc_paths: bool = False, num_paths: int | None = 45, num_points_per_path: int | None = 800, num_tls: int | None = 16) dict[str, tensorflow.io.FixedLenFeature] #
Returns a dictionary of all features to be extracted.
- Parameters:
max_num_objects – Max number of objects.
max_num_rg_points – Max number of sampled roadgraph points.
include_sdc_paths – Whether to include roadgraph traversal paths for the SDC.
num_paths – Optional number of SDC paths. Must be defined if include_sdc_paths is True.
num_points_per_path – Optional number of points per SDC path. Must be defined if include_sdc_paths is True.
num_tls – Maximum number of traffic lights.
- Returns:
Dictionary of all features to be extracted.
- Raises:
ValueError – If include_sdc_paths is True but either num_paths or num_points_per_path is None.
- waymax.dataloader.womd_utils.simulator_state_to_womd_dict_tensorflow(state: waymax.datatypes.simulator_state.SimulatorState, feature_description: dict[str, tensorflow.io.FixedLenFeature], validate: bool = False) dict[str, tensorflow.Tensor] #
Tensorflow version of the simulator state to WOMD dict converter.
- waymax.dataloader.womd_utils.simulator_state_to_womd_dict(state: waymax.datatypes.simulator_state.SimulatorState, feature_description: dict[str, tensorflow.io.FixedLenFeature], validate: bool = False) dict[str, jax.Array] #
Converts a simulator state into the WOMD tensor format.
See https://waymo.com/open/data/motion/tfexample for the tf.Example format which will be returned from this function. Note: This function is compatible with jax2tf.
- Parameters:
state – State of the simulator from the environment. Should contain at least num_history + 1 elements in the time dimension for all temporal components of the simulated trajectory.
feature_description – Feature description expected out of the dictionary. This is used to understand the shape of the fields expected such as number of agents and amount of history.
validate – Validate whether the simulation has progressed far enough to ensure that an adequate amount of history is present.
- Returns:
A dictionary matching fields as if they were read from the WOMD dataset.
- Raises:
ValueError – If validate is set to True and the number of history stored in the observations is not num_history + 1.
- waymax.dataloader.womd_utils._roadgraph_to_dict(rg: waymax.datatypes.roadgraph.RoadgraphPoints, prefix: str = 'roadgraph_samples') dict[str, jax.Array] #
Gets the roadgrpah mpdata from the simulator state.
- waymax.dataloader.womd_utils._trajectory_to_dict(trajectory: waymax.datatypes.object_state.Trajectory, time_prefix: str) dict[str, jax.Array] #
Generates the mpdata fields for Trajectory data.
- waymax.dataloader.womd_utils._traffic_light_to_dict(tls: waymax.datatypes.traffic_lights.TrafficLights, time_prefix: str, timestamp_micros: jax.Array) dict[str, jax.Array] #
Generates the corresponding mpdata for TrafficLights fields.
- waymax.dataloader.womd_utils._object_metadata_to_dict(metadata: waymax.datatypes.object_state.ObjectMetadata) dict[str, jax.Array] #
Converts object metadata to the original tf.Example format.
- waymax.dataloader.womd_utils._get_invalid_future_trajectory(feature_description: dict[str, tensorflow.io.FixedLenFeature]) dict[str, jax.Array] #
Gets an invalid trajectory representing future.
- waymax.dataloader.womd_utils._get_invalid_future_traffic_light(feature_description: dict[str, tensorflow.io.FixedLenFeature]) dict[str, jax.Array] #
Gets an invalid traffic light representing future.