4
4
from typing import Any , Dict , List , Optional , Tuple
5
5
6
6
import numpy as np
7
+ import polars
8
+ import pandas
7
9
8
10
from fog_x .database import (
9
11
DatabaseConnector ,
@@ -36,6 +38,22 @@ def __init__(
36
38
step_data_connector : DatabaseConnector = None ,
37
39
storage : Optional [str ] = None ,
38
40
) -> None :
41
+ """
42
+
43
+ Args:
44
+ name (str): Name of this dataset. Used as the directory name when exporting.
45
+ path (str): Required. Local path of where this dataset should be stored.
46
+ features (optional Dict[str, FeatureType]): Description of `param1`.
47
+ enable_feature_inference (bool): enable inferring additional FeatureTypes
48
+
49
+ Example:
50
+ ```
51
+ >>> dataset = fog_x.Dataset('my_dataset', path='~/fog_x/my_dataset`)
52
+ ```
53
+
54
+ TODO:
55
+ * is replace_existing actually used anywhere?
56
+ """
39
57
self .name = name
40
58
path = os .path .expanduser (path )
41
59
self .path = path
@@ -55,23 +73,24 @@ def __init__(
55
73
if not os .path .exists (f"{ path } /{ name } " ):
56
74
os .makedirs (f"{ path } /{ name } " )
57
75
step_data_connector = LazyFrameConnector (f"{ path } /{ name } " )
58
- self .db_manager = DatabaseManager (
59
- episode_info_connector , step_data_connector
60
- )
76
+ self .db_manager = DatabaseManager (episode_info_connector , step_data_connector )
61
77
self .db_manager .initialize_dataset (self .name , features )
62
78
63
79
self .storage = storage
64
80
self .obs_keys = []
65
81
self .act_keys = []
66
82
self .step_keys = []
67
83
68
- def new_episode (
69
- self , metadata : Optional [Dict [str , Any ]] = None
70
- ) -> Episode :
84
+ def new_episode (self , metadata : Optional [Dict [str , Any ]] = None ) -> Episode :
71
85
"""
72
86
Create a new episode / trajectory.
73
- TODO #1: support multiple processes writing to the same episode
74
- TODO #2: close the previous episode if not closed
87
+
88
+ Returns:
89
+ Episode
90
+
91
+ TODO:
92
+ * support multiple processes writing to the same episode
93
+ * close the previous episode if not closed
75
94
"""
76
95
return Episode (
77
96
metadata = metadata ,
@@ -113,6 +132,10 @@ def export(
113
132
) -> None :
114
133
"""
115
134
Export the dataset.
135
+
136
+ Args:
137
+ export_path (optional str): location of exported data. Uses dataset.path/export by default.
138
+ format (str): Supported formats are `rtx`, `open-x`, and `rlds`.
116
139
"""
117
140
if format == "rtx" or format == "open-x" or format == "rlds" :
118
141
if export_path == None :
@@ -207,20 +230,14 @@ def export(
207
230
and feature_spec .shape != ()
208
231
):
209
232
# reverse the process
210
- value = np .load (io .BytesIO (v )).astype (
211
- feature_spec .np_dtype
212
- )
233
+ value = np .load (io .BytesIO (v )).astype (feature_spec .np_dtype )
213
234
elif (
214
235
isinstance (feature_spec , tfds .core .features .Tensor )
215
236
and feature_spec .shape == ()
216
237
):
217
238
value = np .array (v , dtype = feature_spec .np_dtype )
218
- elif isinstance (
219
- feature_spec , tfds .core .features .Image
220
- ):
221
- value = np .load (io .BytesIO (v )).astype (
222
- feature_spec .np_dtype
223
- )
239
+ elif isinstance (feature_spec , tfds .core .features .Image ):
240
+ value = np .load (io .BytesIO (v )).astype (feature_spec .np_dtype )
224
241
else :
225
242
value = v
226
243
@@ -265,7 +282,18 @@ def load_rtx_episodes(
265
282
additional_metadata : Optional [Dict [str , Any ]] = None ,
266
283
):
267
284
"""
268
- Load the dataset.
285
+ Load robot data from Tensorflow Datasets.
286
+
287
+ Args:
288
+ name (str): Name of RT-X episodes, which can be found at [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog) under the Robotics category
289
+ split (optional str): the portion of data to load, see [Tensorflow Split API](https://www.tensorflow.org/datasets/splits)
290
+ additional_metadata (optional Dict[str, Any]): additional metadata to be associated with the loaded episodes
291
+
292
+ Example:
293
+ ```
294
+ >>> dataset.load_rtx_episodes(name="berkeley_autolab_ur5)
295
+ >>> dataset.load_rtx_episodes(name="berkeley_autolab_ur5", split="train[:10]", additional_metadata={"data_collector": "Alice", "custom_tag": "sample"})
296
+ ```
269
297
"""
270
298
271
299
# this is only required if rtx format is used
@@ -325,26 +353,36 @@ def load_rtx_episodes(
325
353
fog_epsiode .add (
326
354
feature = str (k ),
327
355
value = v .numpy (),
328
- feature_type = FeatureType (
329
- tf_feature_spec = data_type [k ]
330
- ),
356
+ feature_type = FeatureType (tf_feature_spec = data_type [k ]),
331
357
)
332
358
self .step_keys .append (k )
333
359
fog_epsiode .close ()
334
360
335
- def get_episode_info (self ):
361
+ def get_episode_info (self ) -> pandas . DataFrame :
336
362
"""
337
- Return the metadata as pandas dataframe.
363
+ Returns:
364
+ metadata of all episodes as `pandas.DataFrame`
338
365
"""
339
366
return self .db_manager .get_episode_info_table ()
340
367
341
- def get_step_data (self ):
368
+ def get_step_data (self ) -> polars . LazyFrame :
342
369
"""
343
- Return the all step data as lazy dataframe.
370
+ Returns:
371
+ step data of all episodes
344
372
"""
345
373
return self .db_manager .get_step_table_all ()
346
374
347
- def get_step_data_by_episode_ids (self , episode_ids : List [int ], as_lazy_frame = True ):
375
+ def get_step_data_by_episode_ids (
376
+ self , episode_ids : List [int ], as_lazy_frame = True
377
+ ) -> List [polars .LazyFrame ] | List [polars .DataFrame ]:
378
+ """
379
+ Args:
380
+ episode_ids (List[int]): list of episode ids
381
+ as_lazy_frame (bool): whether to return polars.LazyFrame or polars.DataFrame
382
+
383
+ Returns:
384
+ step data of each episode
385
+ """
348
386
episodes = []
349
387
for episode_id in episode_ids :
350
388
if episode_id == None :
@@ -354,8 +392,17 @@ def get_step_data_by_episode_ids(self, episode_ids: List[int], as_lazy_frame = T
354
392
else :
355
393
episodes .append (self .db_manager .get_step_table (episode_id ).collect ())
356
394
return episodes
357
-
358
- def read_by (self , episode_info : Any = None ):
395
+
396
+ def read_by (self , episode_info : Any = None ) -> List [polars .LazyFrame ]:
397
+ """
398
+ To be used with `Dataset.get_episode_info`.
399
+
400
+ Args:
401
+ episode_info (pandas.DataFrame): episode metadata information to determine which episodes to read
402
+
403
+ Returns:
404
+ episodes filtered by `episode_info`
405
+ """
359
406
episode_ids = list (episode_info ["episode_id" ])
360
407
logger .info (f"Reading episodes as order: { episode_ids } " )
361
408
episodes = []
@@ -375,6 +422,11 @@ def get_episodes_from_metadata(self, metadata: Any = None):
375
422
return episodes
376
423
377
424
def pytorch_dataset_builder (self , metadata = None , ** kwargs ):
425
+ """
426
+ Used for loading current dataset as a PyTorch dataset.
427
+ To be used with `torch.utils.data.DataLoader`.
428
+ """
429
+
378
430
import torch
379
431
from torch .utils .data import Dataset
380
432
@@ -414,15 +466,22 @@ def __getitem__(self, idx):
414
466
return pytorch_dataset
415
467
416
468
def get_as_huggingface_dataset (self ):
469
+ """
470
+ Load current dataset as a HuggingFace dataset.
471
+
472
+ TODO:
473
+ * currently the support for huggingg face dataset is limited.
474
+ it only shows its capability of easily returning a hf dataset
475
+ * add features from the episode metadata
476
+ * allow selecting episodes based on queries.
477
+ doing so requires creating a new copy of the dataset on disk
478
+ """
417
479
import datasets
418
480
419
- # TODO: currently the support for huggingg face dataset is limited
420
- # it only shows its capability of easily returning a hf dataset
421
- # TODO #1: add features from the episode metadata
422
- # TODO #2: allow selecting episodes based on queries
423
- # doing so requires creating a new copy of the dataset on disk
424
481
dataset_path = self .path + "/" + self .name
425
- parquet_files = [os .path .join (dataset_path , f ) for f in os .listdir (dataset_path )]
482
+ parquet_files = [
483
+ os .path .join (dataset_path , f ) for f in os .listdir (dataset_path )
484
+ ]
426
485
427
- hf_dataset = datasets .load_dataset (' parquet' , data_files = parquet_files )
486
+ hf_dataset = datasets .load_dataset (" parquet" , data_files = parquet_files )
428
487
return hf_dataset
0 commit comments