From 219c7e4e8baacc3e98c7a6d84aadaa68bad948bc Mon Sep 17 00:00:00 2001 From: LennoxFu Date: Mon, 23 Sep 2024 23:18:48 -0700 Subject: [PATCH] RLDS frame slicing --- benchmarks/openx_by_frame.py | 47 ++++++++++---------- evaluation.sh | 4 +- fog_x/loader/rlds.py | 83 ++++++++++++++++++++++++++++++++++++ 3 files changed, 109 insertions(+), 25 deletions(-) diff --git a/benchmarks/openx_by_frame.py b/benchmarks/openx_by_frame.py index 94d3715..eb2df9e 100644 --- a/benchmarks/openx_by_frame.py +++ b/benchmarks/openx_by_frame.py @@ -9,6 +9,7 @@ import fog_x import csv import stat +from fog_x.loader.rlds import RLDSLoader_ByFrame from fog_x.loader.lerobot import LeRobotLoader_ByFrame from fog_x.loader.vla import get_vla_dataloader from fog_x.loader.hdf5 import get_hdf5_dataloader @@ -194,7 +195,7 @@ def __init__( self.file_extension = ".tfrecord" def get_loader(self): - return RLDSLoader(self.dataset_dir, split="train", batch_size=self.batch_size) + return RLDSLoader_ByFrame(self.dataset_dir, split="train", batch_size=1, slice_length=self.batch_size) def _recursively_load_data(self, data): log_level = self.log_level @@ -359,34 +360,34 @@ def evaluation(args): logger.debug(f"Evaluating dataset: {dataset_name}") handlers = [ - VLAHandler( - args.exp_dir, - dataset_name, - args.num_batches, - args.batch_size, - args.log_frequency, - ), - HDF5Handler( - args.exp_dir, - dataset_name, - args.num_batches, - args.batch_size, - args.log_frequency, - ), - LeRobotHandler( - args.exp_dir, - dataset_name, - args.num_batches, - args.batch_size, - args.log_frequency, - ), - # RLDSHandler( + # VLAHandler( # args.exp_dir, # dataset_name, # args.num_batches, # args.batch_size, # args.log_frequency, # ), + # HDF5Handler( + # args.exp_dir, + # dataset_name, + # args.num_batches, + # args.batch_size, + # args.log_frequency, + # ), + # LeRobotHandler( + # args.exp_dir, + # dataset_name, + # args.num_batches, + # args.batch_size, + # args.log_frequency, + # ), + RLDSHandler( + args.exp_dir, + dataset_name, + args.num_batches, + args.batch_size, + args.log_frequency, + ), # FFV1Handler( # args.exp_dir, # dataset_name, diff --git a/evaluation.sh b/evaluation.sh index 28ee235..ed2d91c 100755 --- a/evaluation.sh +++ b/evaluation.sh @@ -16,7 +16,7 @@ do echo "Running benchmarks with batch size: $batch_size" # python3 benchmarks/openx.py --dataset_names nyu_door_opening_surprising_effectiveness --num_batches $num_batches --batch_size $batch_size - # python3 benchmarks/openx_by_frame.py --dataset_names berkeley_cable_routing --num_batches $num_batches --batch_size $batch_size - python3 benchmarks/openx_by_frame.py --dataset_names bridge --num_batches $num_batches --batch_size $batch_size + python3 benchmarks/openx_by_frame.py --dataset_names berkeley_cable_routing --num_batches $num_batches --batch_size $batch_size + # python3 benchmarks/openx_by_frame.py --dataset_names bridge --num_batches $num_batches --batch_size $batch_size # python3 benchmarks/openx.py --dataset_names berkeley_autolab_ur5 --num_batches $num_batches --batch_size $batch_size done \ No newline at end of file diff --git a/fog_x/loader/rlds.py b/fog_x/loader/rlds.py index 9390308..8403580 100644 --- a/fog_x/loader/rlds.py +++ b/fog_x/loader/rlds.py @@ -73,6 +73,89 @@ def __next__(self): raise StopIteration return data + def __getitem__(self, idx): + batch = next(iter(self.ds.skip(idx).take(1))) + return self._convert_traj_to_numpy(batch) + +class RLDSLoader_ByFrame(BaseLoader): + def __init__(self, path, split, batch_size=1, shuffle_buffer=10, shuffling = True, slice_length=16): + super(RLDSLoader_ByFrame, self).__init__(path) + + try: + import tensorflow as tf + import tensorflow_datasets as tfds + except ImportError: + raise ImportError( + "Please install tensorflow and tensorflow_datasets to use rlds loader" + ) + + self.batch_size = batch_size + builder = tfds.builder_from_directory(path) + self.ds = builder.as_dataset(split) + self.length = len(self.ds) + self.shuffling = shuffling + if shuffling: + self.ds = self.ds.repeat() + self.ds = self.ds.shuffle(shuffle_buffer) + self.slice_length = slice_length + self.iterator = iter(self.ds) + + self.split = split + self.index = 0 + + def __len__(self): + try: + import tensorflow as tf + except ImportError: + raise ImportError("Please install tensorflow to use rlds loader") + + return self.length + + def __iter__(self): + return self + + def get_batch(self): + batch = self.ds.take(self.batch_size) + self.index += self.batch_size + if not self.shuffling and self.index >= self.length: + raise StopIteration + data = [] + for b in batch: + data.append(self._convert_traj_to_numpy(b)) + return data + + def _convert_traj_to_numpy(self, traj): + import tensorflow as tf + + def to_numpy(step_data): + step = {} + for key in step_data: + val = step_data[key] + if isinstance(val, dict): + step[key] = {k: np.array(v) for k, v in val.items()} + else: + step[key] = np.array(val) + return step + + # Random step / frame slicing + trajectory = [] + num_frames = len(traj["steps"]) + if num_frames >= self.slice_length: + random_from = np.random.randint(0, num_frames - self.slice_length + 1) + trajs = traj["steps"].skip(random_from).take(self.slice_length) + else: + trajs = traj["steps"] + for step in trajs: + trajectory.append(to_numpy(step)) + return trajectory + + def __next__(self): + data = [self._convert_traj_to_numpy(next(self.iterator))] + self.index += 1 + if self.index >= self.length: + raise StopIteration + return data + def __getitem__(self, idx): batch = next(iter(self.ds.skip(idx).take(1))) return self._convert_traj_to_numpy(batch) \ No newline at end of file