Skip to content

Commit e585698

Browse files
committed
add docstrings
1 parent 9f211de commit e585698

File tree

2 files changed

+139
-36
lines changed

2 files changed

+139
-36
lines changed

fog_x/dataset.py

Lines changed: 94 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from typing import Any, Dict, List, Optional, Tuple
55

66
import numpy as np
7+
import polars
8+
import pandas
79

810
from fog_x.database import (
911
DatabaseConnector,
@@ -36,6 +38,22 @@ def __init__(
3638
step_data_connector: DatabaseConnector = None,
3739
storage: Optional[str] = None,
3840
) -> None:
41+
"""
42+
43+
Args:
44+
name (str): Name of this dataset. Used as the directory name when exporting.
45+
path (str): Required. Local path of where this dataset should be stored.
46+
features (optional Dict[str, FeatureType]): Description of `param1`.
47+
enable_feature_inference (bool): enable inferring additional FeatureTypes
48+
49+
Example:
50+
```
51+
>>> dataset = fog_x.Dataset('my_dataset', path='~/fog_x/my_dataset`)
52+
```
53+
54+
TODO:
55+
* is replace_existing actually used anywhere?
56+
"""
3957
self.name = name
4058
path = os.path.expanduser(path)
4159
self.path = path
@@ -55,23 +73,24 @@ def __init__(
5573
if not os.path.exists(f"{path}/{name}"):
5674
os.makedirs(f"{path}/{name}")
5775
step_data_connector = LazyFrameConnector(f"{path}/{name}")
58-
self.db_manager = DatabaseManager(
59-
episode_info_connector, step_data_connector
60-
)
76+
self.db_manager = DatabaseManager(episode_info_connector, step_data_connector)
6177
self.db_manager.initialize_dataset(self.name, features)
6278

6379
self.storage = storage
6480
self.obs_keys = []
6581
self.act_keys = []
6682
self.step_keys = []
6783

68-
def new_episode(
69-
self, metadata: Optional[Dict[str, Any]] = None
70-
) -> Episode:
84+
def new_episode(self, metadata: Optional[Dict[str, Any]] = None) -> Episode:
7185
"""
7286
Create a new episode / trajectory.
73-
TODO #1: support multiple processes writing to the same episode
74-
TODO #2: close the previous episode if not closed
87+
88+
Returns:
89+
Episode
90+
91+
TODO:
92+
* support multiple processes writing to the same episode
93+
* close the previous episode if not closed
7594
"""
7695
return Episode(
7796
metadata=metadata,
@@ -113,6 +132,10 @@ def export(
113132
) -> None:
114133
"""
115134
Export the dataset.
135+
136+
Args:
137+
export_path (optional str): location of exported data. Uses dataset.path/export by default.
138+
format (str): Supported formats are `rtx`, `open-x`, and `rlds`.
116139
"""
117140
if format == "rtx" or format == "open-x" or format == "rlds":
118141
if export_path == None:
@@ -207,20 +230,14 @@ def export(
207230
and feature_spec.shape != ()
208231
):
209232
# reverse the process
210-
value = np.load(io.BytesIO(v)).astype(
211-
feature_spec.np_dtype
212-
)
233+
value = np.load(io.BytesIO(v)).astype(feature_spec.np_dtype)
213234
elif (
214235
isinstance(feature_spec, tfds.core.features.Tensor)
215236
and feature_spec.shape == ()
216237
):
217238
value = np.array(v, dtype=feature_spec.np_dtype)
218-
elif isinstance(
219-
feature_spec, tfds.core.features.Image
220-
):
221-
value = np.load(io.BytesIO(v)).astype(
222-
feature_spec.np_dtype
223-
)
239+
elif isinstance(feature_spec, tfds.core.features.Image):
240+
value = np.load(io.BytesIO(v)).astype(feature_spec.np_dtype)
224241
else:
225242
value = v
226243

@@ -265,7 +282,18 @@ def load_rtx_episodes(
265282
additional_metadata: Optional[Dict[str, Any]] = None,
266283
):
267284
"""
268-
Load the dataset.
285+
Load robot data from Tensorflow Datasets.
286+
287+
Args:
288+
name (str): Name of RT-X episodes, which can be found at [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog) under the Robotics category
289+
split (optional str): the portion of data to load, see [Tensorflow Split API](https://www.tensorflow.org/datasets/splits)
290+
additional_metadata (optional Dict[str, Any]): additional metadata to be associated with the loaded episodes
291+
292+
Example:
293+
```
294+
>>> dataset.load_rtx_episodes(name="berkeley_autolab_ur5)
295+
>>> dataset.load_rtx_episodes(name="berkeley_autolab_ur5", split="train[:10]", additional_metadata={"data_collector": "Alice", "custom_tag": "sample"})
296+
```
269297
"""
270298

271299
# this is only required if rtx format is used
@@ -325,26 +353,36 @@ def load_rtx_episodes(
325353
fog_epsiode.add(
326354
feature=str(k),
327355
value=v.numpy(),
328-
feature_type=FeatureType(
329-
tf_feature_spec=data_type[k]
330-
),
356+
feature_type=FeatureType(tf_feature_spec=data_type[k]),
331357
)
332358
self.step_keys.append(k)
333359
fog_epsiode.close()
334360

335-
def get_episode_info(self):
361+
def get_episode_info(self) -> pandas.DataFrame:
336362
"""
337-
Return the metadata as pandas dataframe.
363+
Returns:
364+
metadata of all episodes as `pandas.DataFrame`
338365
"""
339366
return self.db_manager.get_episode_info_table()
340367

341-
def get_step_data(self):
368+
def get_step_data(self) -> polars.LazyFrame:
342369
"""
343-
Return the all step data as lazy dataframe.
370+
Returns:
371+
step data of all episodes
344372
"""
345373
return self.db_manager.get_step_table_all()
346374

347-
def get_step_data_by_episode_ids(self, episode_ids: List[int], as_lazy_frame = True):
375+
def get_step_data_by_episode_ids(
376+
self, episode_ids: List[int], as_lazy_frame=True
377+
) -> List[polars.LazyFrame] | List[polars.DataFrame]:
378+
"""
379+
Args:
380+
episode_ids (List[int]): list of episode ids
381+
as_lazy_frame (bool): whether to return polars.LazyFrame or polars.DataFrame
382+
383+
Returns:
384+
step data of each episode
385+
"""
348386
episodes = []
349387
for episode_id in episode_ids:
350388
if episode_id == None:
@@ -354,8 +392,17 @@ def get_step_data_by_episode_ids(self, episode_ids: List[int], as_lazy_frame = T
354392
else:
355393
episodes.append(self.db_manager.get_step_table(episode_id).collect())
356394
return episodes
357-
358-
def read_by(self, episode_info: Any = None):
395+
396+
def read_by(self, episode_info: Any = None) -> List[polars.LazyFrame]:
397+
"""
398+
To be used with `Dataset.get_episode_info`.
399+
400+
Args:
401+
episode_info (pandas.DataFrame): episode metadata information to determine which episodes to read
402+
403+
Returns:
404+
episodes filtered by `episode_info`
405+
"""
359406
episode_ids = list(episode_info["episode_id"])
360407
logger.info(f"Reading episodes as order: {episode_ids}")
361408
episodes = []
@@ -375,6 +422,11 @@ def get_episodes_from_metadata(self, metadata: Any = None):
375422
return episodes
376423

377424
def pytorch_dataset_builder(self, metadata=None, **kwargs):
425+
"""
426+
Used for loading current dataset as a PyTorch dataset.
427+
To be used with `torch.utils.data.DataLoader`.
428+
"""
429+
378430
import torch
379431
from torch.utils.data import Dataset
380432

@@ -414,15 +466,22 @@ def __getitem__(self, idx):
414466
return pytorch_dataset
415467

416468
def get_as_huggingface_dataset(self):
469+
"""
470+
Load current dataset as a HuggingFace dataset.
471+
472+
TODO:
473+
* currently the support for huggingg face dataset is limited.
474+
it only shows its capability of easily returning a hf dataset
475+
* add features from the episode metadata
476+
* allow selecting episodes based on queries.
477+
doing so requires creating a new copy of the dataset on disk
478+
"""
417479
import datasets
418480

419-
# TODO: currently the support for huggingg face dataset is limited
420-
# it only shows its capability of easily returning a hf dataset
421-
# TODO #1: add features from the episode metadata
422-
# TODO #2: allow selecting episodes based on queries
423-
# doing so requires creating a new copy of the dataset on disk
424481
dataset_path = self.path + "/" + self.name
425-
parquet_files = [os.path.join(dataset_path, f) for f in os.listdir(dataset_path)]
482+
parquet_files = [
483+
os.path.join(dataset_path, f) for f in os.listdir(dataset_path)
484+
]
426485

427-
hf_dataset = datasets.load_dataset('parquet', data_files=parquet_files)
486+
hf_dataset = datasets.load_dataset("parquet", data_files=parquet_files)
428487
return hf_dataset

fog_x/episode.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,22 @@ def add(
2828
timestamp: Optional[int] = None,
2929
feature_type: Optional[FeatureType] = None,
3030
) -> None:
31+
"""
32+
Add one feature step data.
33+
To add multiple features at each step, call this multiple times or use
34+
`add_by_dict` to ensure the same timestamp is used for each feature.
35+
36+
Args:
37+
feature (str): name of the feature
38+
value (Any): value associated with the feature
39+
timestamp (optional int): nanoseconds since the Epoch.
40+
If not provided, the current time is used.
41+
42+
Examples:
43+
>>> episode.add('feature1', 'image1.jpg')
44+
>>> episode.add('feature2', 100)
45+
"""
46+
3147
if timestamp is None:
3248
timestamp = time.time_ns()
3349

@@ -45,18 +61,46 @@ def add_by_dict(
4561
self, data: Dict[str, Any], timestamp: Optional[int] = None
4662
) -> None:
4763
"""
48-
add the same timestamp for all features
64+
Add multiple features as step data.
65+
Ensures that the same timestamp is used for each feature.
66+
67+
Args:
68+
data (Dict[str, Any]): feature names and their values
69+
timestamp (optional int): nanoseconds since the Epoch.
70+
If not provided, the current time is used.
71+
72+
Examples:
73+
>>> episode.add_by_dict({'feature1': 'image1.jpg', 'feature2': 100})
4974
"""
5075
if timestamp is None:
5176
timestamp = time.time_ns()
5277
for feature, value in data.items():
5378
self.add(feature, value, timestamp)
5479

5580
def compact(self) -> None:
81+
"""
82+
Creates a table for the compacted data.
83+
84+
TODO:
85+
* compact should not be run more than once?
86+
* expand docstring description
87+
"""
5688
self.db_manager.compact()
5789

5890
def get_steps(self) -> List[Dict[str, Any]]:
91+
"""
92+
Retrieves the episode's step data.
93+
94+
Returns:
95+
the step data
96+
97+
TODO:
98+
* get_steps not in db_manager; db_manager.get_step_table_all returns a `LazyFrame`, not `List[Dict[str, Any]]`
99+
"""
59100
return self.db_manager.get_steps()
60101

61102
def close(self) -> None:
103+
"""
104+
Saves the episode object.
105+
"""
62106
self.db_manager.close()

0 commit comments

Comments
 (0)