Skip to content

Commit

Permalink
RLDS frame slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
Lenoplus42 committed Sep 24, 2024
1 parent e573046 commit 219c7e4
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 25 deletions.
47 changes: 24 additions & 23 deletions benchmarks/openx_by_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import fog_x
import csv
import stat
from fog_x.loader.rlds import RLDSLoader_ByFrame
from fog_x.loader.lerobot import LeRobotLoader_ByFrame
from fog_x.loader.vla import get_vla_dataloader
from fog_x.loader.hdf5 import get_hdf5_dataloader
Expand Down Expand Up @@ -194,7 +195,7 @@ def __init__(
self.file_extension = ".tfrecord"

def get_loader(self):
return RLDSLoader(self.dataset_dir, split="train", batch_size=self.batch_size)
return RLDSLoader_ByFrame(self.dataset_dir, split="train", batch_size=1, slice_length=self.batch_size)

def _recursively_load_data(self, data):
log_level = self.log_level
Expand Down Expand Up @@ -359,34 +360,34 @@ def evaluation(args):
logger.debug(f"Evaluating dataset: {dataset_name}")

handlers = [
VLAHandler(
args.exp_dir,
dataset_name,
args.num_batches,
args.batch_size,
args.log_frequency,
),
HDF5Handler(
args.exp_dir,
dataset_name,
args.num_batches,
args.batch_size,
args.log_frequency,
),
LeRobotHandler(
args.exp_dir,
dataset_name,
args.num_batches,
args.batch_size,
args.log_frequency,
),
# RLDSHandler(
# VLAHandler(
# args.exp_dir,
# dataset_name,
# args.num_batches,
# args.batch_size,
# args.log_frequency,
# ),
# HDF5Handler(
# args.exp_dir,
# dataset_name,
# args.num_batches,
# args.batch_size,
# args.log_frequency,
# ),
# LeRobotHandler(
# args.exp_dir,
# dataset_name,
# args.num_batches,
# args.batch_size,
# args.log_frequency,
# ),
RLDSHandler(
args.exp_dir,
dataset_name,
args.num_batches,
args.batch_size,
args.log_frequency,
),
# FFV1Handler(
# args.exp_dir,
# dataset_name,
Expand Down
4 changes: 2 additions & 2 deletions evaluation.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ do
echo "Running benchmarks with batch size: $batch_size"

# python3 benchmarks/openx.py --dataset_names nyu_door_opening_surprising_effectiveness --num_batches $num_batches --batch_size $batch_size
# python3 benchmarks/openx_by_frame.py --dataset_names berkeley_cable_routing --num_batches $num_batches --batch_size $batch_size
python3 benchmarks/openx_by_frame.py --dataset_names bridge --num_batches $num_batches --batch_size $batch_size
python3 benchmarks/openx_by_frame.py --dataset_names berkeley_cable_routing --num_batches $num_batches --batch_size $batch_size
# python3 benchmarks/openx_by_frame.py --dataset_names bridge --num_batches $num_batches --batch_size $batch_size
# python3 benchmarks/openx.py --dataset_names berkeley_autolab_ur5 --num_batches $num_batches --batch_size $batch_size
done
83 changes: 83 additions & 0 deletions fog_x/loader/rlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,89 @@ def __next__(self):
raise StopIteration
return data

def __getitem__(self, idx):
batch = next(iter(self.ds.skip(idx).take(1)))
return self._convert_traj_to_numpy(batch)

class RLDSLoader_ByFrame(BaseLoader):
def __init__(self, path, split, batch_size=1, shuffle_buffer=10, shuffling = True, slice_length=16):
super(RLDSLoader_ByFrame, self).__init__(path)

try:
import tensorflow as tf
import tensorflow_datasets as tfds
except ImportError:
raise ImportError(
"Please install tensorflow and tensorflow_datasets to use rlds loader"
)

self.batch_size = batch_size
builder = tfds.builder_from_directory(path)
self.ds = builder.as_dataset(split)
self.length = len(self.ds)
self.shuffling = shuffling
if shuffling:
self.ds = self.ds.repeat()
self.ds = self.ds.shuffle(shuffle_buffer)
self.slice_length = slice_length
self.iterator = iter(self.ds)

self.split = split
self.index = 0

def __len__(self):
try:
import tensorflow as tf
except ImportError:
raise ImportError("Please install tensorflow to use rlds loader")

return self.length

def __iter__(self):
return self

def get_batch(self):
batch = self.ds.take(self.batch_size)
self.index += self.batch_size
if not self.shuffling and self.index >= self.length:
raise StopIteration
data = []
for b in batch:
data.append(self._convert_traj_to_numpy(b))
return data

def _convert_traj_to_numpy(self, traj):
import tensorflow as tf

def to_numpy(step_data):
step = {}
for key in step_data:
val = step_data[key]
if isinstance(val, dict):
step[key] = {k: np.array(v) for k, v in val.items()}
else:
step[key] = np.array(val)
return step

# Random step / frame slicing
trajectory = []
num_frames = len(traj["steps"])
if num_frames >= self.slice_length:
random_from = np.random.randint(0, num_frames - self.slice_length + 1)
trajs = traj["steps"].skip(random_from).take(self.slice_length)
else:
trajs = traj["steps"]
for step in trajs:
trajectory.append(to_numpy(step))
return trajectory

def __next__(self):
data = [self._convert_traj_to_numpy(next(self.iterator))]
self.index += 1
if self.index >= self.length:
raise StopIteration
return data

def __getitem__(self, idx):
batch = next(iter(self.ds.skip(idx).take(1)))
return self._convert_traj_to_numpy(batch)

0 comments on commit 219c7e4

Please sign in to comment.