Skip to content

Commit 33d2706

Browse files
committed
feat: separate builder
1 parent 1c6e2f8 commit 33d2706

File tree

1 file changed

+44
-3
lines changed

1 file changed

+44
-3
lines changed

fog_x/dataset.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -381,12 +381,53 @@ def load_rtx_episodes(
381381
import tensorflow_datasets as tfds
382382

383383
from fog_x.rlds.utils import dataset2path
384-
385384
b = tfds.builder_from_directory(builder_dir=dataset2path(name))
385+
self._build_rtx_episodes_from_tfds_builder(
386+
b,
387+
split=split,
388+
additional_metadata=additional_metadata,
389+
)
390+
391+
def load_rtx_episodes_local(
392+
self,
393+
path: str,
394+
split: str = "all",
395+
additional_metadata: Optional[Dict[str, Any]] = dict(),
396+
):
397+
"""
398+
Load robot data from Tensorflow Datasets.
386399
387-
ds = b.as_dataset(split=split)
400+
Args:
401+
path (str): Path to the RT-X episodes
402+
split (optional str): the portion of data to load, see [Tensorflow Split API](https://www.tensorflow.org/datasets/splits)
403+
additional_metadata (optional Dict[str, Any]): additional metadata to be associated with the loaded episodes
388404
389-
data_type = b.info.features["steps"]
405+
Example:
406+
```
407+
>>> dataset.load_rtx_episodes_local(path="~/Downloads/berkeley_autolab_ur5")
408+
>>> dataset.load_rtx_episodes_local(path="~/Downloads/berkeley_autolab_ur5", split="train[:10]", additional_metadata={"data_collector": "Alice", "custom_tag": "sample"})
409+
```
410+
"""
411+
412+
# this is only required if rtx format is used
413+
import tensorflow_datasets as tfds
414+
415+
b = tfds.builder(path)
416+
self._build_rtx_episodes_from_tfds_builder(
417+
b,
418+
split=split,
419+
additional_metadata=additional_metadata,
420+
)
421+
422+
def _build_rtx_episodes_from_tfds_builder(
423+
builder,
424+
):
425+
"""
426+
construct the dataset from the tfds builder
427+
"""
428+
ds = builder.as_dataset(split=split)
429+
430+
data_type = builder.info.features["steps"]
390431

391432
for tf_episode in ds:
392433
logger.info(tf_episode)

0 commit comments

Comments
 (0)