@@ -381,12 +381,53 @@ def load_rtx_episodes(
381
381
import tensorflow_datasets as tfds
382
382
383
383
from fog_x .rlds .utils import dataset2path
384
-
385
384
b = tfds .builder_from_directory (builder_dir = dataset2path (name ))
385
+ self ._build_rtx_episodes_from_tfds_builder (
386
+ b ,
387
+ split = split ,
388
+ additional_metadata = additional_metadata ,
389
+ )
390
+
391
+ def load_rtx_episodes_local (
392
+ self ,
393
+ path : str ,
394
+ split : str = "all" ,
395
+ additional_metadata : Optional [Dict [str , Any ]] = dict (),
396
+ ):
397
+ """
398
+ Load robot data from Tensorflow Datasets.
386
399
387
- ds = b .as_dataset (split = split )
400
+ Args:
401
+ path (str): Path to the RT-X episodes
402
+ split (optional str): the portion of data to load, see [Tensorflow Split API](https://www.tensorflow.org/datasets/splits)
403
+ additional_metadata (optional Dict[str, Any]): additional metadata to be associated with the loaded episodes
388
404
389
- data_type = b .info .features ["steps" ]
405
+ Example:
406
+ ```
407
+ >>> dataset.load_rtx_episodes_local(path="~/Downloads/berkeley_autolab_ur5")
408
+ >>> dataset.load_rtx_episodes_local(path="~/Downloads/berkeley_autolab_ur5", split="train[:10]", additional_metadata={"data_collector": "Alice", "custom_tag": "sample"})
409
+ ```
410
+ """
411
+
412
+ # this is only required if rtx format is used
413
+ import tensorflow_datasets as tfds
414
+
415
+ b = tfds .builder (path )
416
+ self ._build_rtx_episodes_from_tfds_builder (
417
+ b ,
418
+ split = split ,
419
+ additional_metadata = additional_metadata ,
420
+ )
421
+
422
+ def _build_rtx_episodes_from_tfds_builder (
423
+ builder ,
424
+ ):
425
+ """
426
+ construct the dataset from the tfds builder
427
+ """
428
+ ds = builder .as_dataset (split = split )
429
+
430
+ data_type = builder .info .features ["steps" ]
390
431
391
432
for tf_episode in ds :
392
433
logger .info (tf_episode )
0 commit comments