waymax.dataloader.dataloader_utils#

Util functions for general dataloading.

Module Contents#

Functions#

generate_sharded_filenames(→ Sequence[str])

Returns the filenames of individual sharded files.

tf_examples_dataset(, num_shards, deterministic, ...)

Returns a dataset of Open Motion dataset TFExamples.

get_data_generator(→ Iterator[T])

Iterator that yields the desired object returned by postprocess_fn.

Attributes#

waymax.dataloader.dataloader_utils.T#
waymax.dataloader.dataloader_utils.AUTOTUNE#
waymax.dataloader.dataloader_utils.generate_sharded_filenames(path: str) Sequence[str]#

Returns the filenames of individual sharded files.

A sharded file is a set of files of the format filename-XXXXX-of-YYYYY, where XXXXX is a placeholder for the index of the shard, and YYYYY is the total number of shards. These files are collectively referred to by a sharded path filename@YYYYY.

For example, the sharded path myfile@100 refers to the set of files
  • myfile-00000-of-00100

  • myfile-00001-of-00100

  • myfile-00098-of-00100

  • myfile-00099-of-00100

Parameters:

path – A path to a sharded file, with format filename@shards, where shards is an integer denoting the number of total shards.

Returns:

An iterator through the complete set of filenames that the path refers to, with each filename having the format filename-XXXXX-of-YYYYY

waymax.dataloader.dataloader_utils.tf_examples_dataset(path: str, data_format: waymax.config.DataFormat, preprocess_fn: Callable[[bytes], dict[str, tensorflow.Tensor]], shuffle_seed: int | None = None, shuffle_buffer_size: int = 100, repeat: int | None = None, batch_dims: Sequence[int] = (), num_shards: int = 1, deterministic: bool = True, drop_remainder: bool = True, tf_data_service_address: str | None = None, batch_by_scenario: bool = True) tensorflow.data.Dataset#

Returns a dataset of Open Motion dataset TFExamples.

Each TFExample contains data for the trajectory of all objects, the roadgraph, and traffic light states. See https://waymo.com/open/data/motion/tfexample for the data format definition.

Parameters:
  • path – The path to the dataset.

  • data_format – Data format of the dataset.

  • preprocess_fn – Function for parsing and preprocessing individual examples.

  • shuffle_seed – Seed for shuffling. If left default (None), will not shuffle the dataset.

  • shuffle_buffer_size – The size of the shuffle buffer.

  • repeat – Number of times to repeat the dataset. Default (None) will repeat infinitely.

  • batch_dims – List of size of batch dimensions. Multiple batch dimension can be used to provide inputs for multiple devices. E.g. [jax.local_device_count(), batch_size_per_device].

  • num_shards – Number of shards for parallel loading, no effect on data returned.

  • deterministic – Whether to use deterministic parallel processing.

  • drop_remainder – Arg for tf.data.Dataset.batch. Set True to drop remainder if the last batch does not contains enough examples.

  • tf_data_service_address – Set to use tf data service.

  • batch_by_scenario – If True, one example in a returned batch is the entire scenario containing all objects; if False, the dataset will treat individual object trajectories as a training example rather than an entire scenario.

Returns:

A tf.data.Dataset of Open Motion Dataset tf.Example elements.

waymax.dataloader.dataloader_utils.get_data_generator(config: waymax.config.DatasetConfig, preprocess_fn: Callable[[bytes], dict[str, tensorflow.Tensor | dict[str, tensorflow.Tensor]]] | None, postprocess_fn: Callable[[dict[str, jax.Array]], T] | None = None) Iterator[T]#

Iterator that yields the desired object returned by postprocess_fn.

It parses data using preprocess_fn and returns a generator of data whose data structure is defined by postprocess_fn function.

Parameters:
  • config – config for dataset and preprocessing.

  • preprocess_fn – preprocess the serialized data into a dictionary of str to tf Tensor.

  • postprocess_fn – a function that converts dict of jnp array to desired data class. Note for distributed training, this function will be pmap-ed and executed in the main process.

Yields:

Iterator of desired data class.