
Util functions for general dataloading.

Module Contents#


generate_sharded_filenames(→ Sequence[str])

Returns the filenames of individual sharded files.

tf_examples_dataset(, num_shards, deterministic, ...)

Returns a dataset of Open Motion dataset TFExamples.

get_data_generator(→ Iterator[T])

Iterator that yields the desired object returned by postprocess_fn.


waymax.dataloader.dataloader_utils.generate_sharded_filenames(path: str) Sequence[str]#

Returns the filenames of individual sharded files.

A sharded file is a set of files of the format filename-XXXXX-of-YYYYY, where XXXXX is a placeholder for the index of the shard, and YYYYY is the total number of shards. These files are collectively referred to by a sharded path filename@YYYYY.

For example, the sharded path myfile@100 refers to the set of files
  • myfile-00000-of-00100

  • myfile-00001-of-00100

  • myfile-00098-of-00100

  • myfile-00099-of-00100


path – A path to a sharded file, with format filename@shards, where shards is an integer denoting the number of total shards.


An iterator through the complete set of filenames that the path refers to, with each filename having the format filename-XXXXX-of-YYYYY

waymax.dataloader.dataloader_utils.tf_examples_dataset(path: str, data_format: waymax.config.DataFormat, preprocess_fn: Callable[[bytes], dict[str, tensorflow.Tensor]], shuffle_seed: int | None = None, shuffle_buffer_size: int = 100, repeat: int | None = None, batch_dims: Sequence[int] = (), num_shards: int = 1, deterministic: bool = True, drop_remainder: bool = True, tf_data_service_address: str | None = None, batch_by_scenario: bool = True) tensorflow.data.Dataset#

Returns a dataset of Open Motion dataset TFExamples.

Each TFExample contains data for the trajectory of all objects, the roadgraph, and traffic light states. See https://waymo.com/open/data/motion/tfexample for the data format definition.

  • path – The path to the dataset.

  • data_format – Data format of the dataset.

  • preprocess_fn – Function for parsing and preprocessing individual examples.

  • shuffle_seed – Seed for shuffling. If left default (None), will not shuffle the dataset.

  • shuffle_buffer_size – The size of the shuffle buffer.

  • repeat – Number of times to repeat the dataset. Default (None) will repeat infinitely.

  • batch_dims – List of size of batch dimensions. Multiple batch dimension can be used to provide inputs for multiple devices. E.g. [jax.local_device_count(), batch_size_per_device].

  • num_shards – Number of shards for parallel loading, no effect on data returned.

  • deterministic – Whether to use deterministic parallel processing.

  • drop_remainder – Arg for tf.data.Dataset.batch. Set True to drop remainder if the last batch does not contains enough examples.

  • tf_data_service_address – Set to use tf data service.

  • batch_by_scenario – If True, one example in a returned batch is the entire scenario containing all objects; if False, the dataset will treat individual object trajectories as a training example rather than an entire scenario.


A tf.data.Dataset of Open Motion Dataset tf.Example elements.

waymax.dataloader.dataloader_utils.get_data_generator(config: waymax.config.DatasetConfig, preprocess_fn: Callable[[bytes], dict[str, tensorflow.Tensor | dict[str, tensorflow.Tensor]]] | None, postprocess_fn: Callable[[dict[str, jax.Array]], T] | None = None) Iterator[T]#

Iterator that yields the desired object returned by postprocess_fn.

It parses data using preprocess_fn and returns a generator of data whose data structure is defined by postprocess_fn function.

  • config – config for dataset and preprocessing.

  • preprocess_fn – preprocess the serialized data into a dictionary of str to tf Tensor.

  • postprocess_fn – a function that converts dict of jnp array to desired data class. Note for distributed training, this function will be pmap-ed and executed in the main process.


Iterator of desired data class.