Skip to content

Commit

Permalink
Refactor evaluation script for improved code organization and perform…
Browse files Browse the repository at this point in the history
…ance
  • Loading branch information
KeplerC committed Sep 2, 2024
1 parent fcd8f2d commit 3516491
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 65 deletions.
10 changes: 5 additions & 5 deletions benchmarks/Visualization.ipynb

Large diffs are not rendered by default.

60 changes: 28 additions & 32 deletions benchmarks/openx.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,14 @@ def measure_average_trajectory_size(self):
def clear_cache(self):
"""Clears the cache directory."""
if os.path.exists(CACHE_DIR):
logger.info(f"Clearing cache directory: {CACHE_DIR}")
subprocess.run(["rm", "-rf", CACHE_DIR], check=True)

def clear_os_cache(self):
"""Clears the OS cache."""
subprocess.run(["sync"], check=True)
subprocess.run(["sudo", "sh", "-c", "echo 3 > /proc/sys/vm/drop_caches"], check=True)
logger.info(f"Cleared OS cache")

def _recursively_load_data(self, data):
logger.debug(f"Data summary for loader {self.dataset_type.upper()}")
Expand Down Expand Up @@ -135,19 +137,21 @@ def write_result(self, format_name, elapsed_time, index):
def measure_random_loading_time(self):
start_time = time.time()
loader = self.get_loader()

last_batch_time = time.time()
for batch_num, data in enumerate(loader):
if batch_num >= self.num_batches:
break
self._recursively_load_data(data)
current_batch_time = time.time()
elapsed_time = current_batch_time - last_batch_time
last_batch_time = current_batch_time

elapsed_time = time.time() - start_time
self.write_result(
f"{self.dataset_type.upper()}", elapsed_time, batch_num
)
if batch_num % self.log_frequency == 0:
logger.info(
f"{self.dataset_type.upper()} - Loaded {batch_num} random {self.batch_size} batches from {self.dataset_name}, Time: {elapsed_time:.2f} s, Average Time: {elapsed_time / (batch_num + 1):.2f} s"
f"{self.dataset_type.upper()} - Loaded {batch_num} random {self.batch_size} batches from {self.dataset_name}, Time: {elapsed_time:.2f} s, Total Average Time: {(current_batch_time - start_time) / (batch_num + 1):.2f} s, Batch Average Time: {elapsed_time / self.batch_size:.2f} s"
)

return time.time() - start_time
Expand Down Expand Up @@ -276,12 +280,6 @@ def get_loader(self):
return LeRobotLoader(path, self.dataset_name, batch_size=self.batch_size)


def prepare(args):
# Clear the cache directory
if os.path.exists(CACHE_DIR):
subprocess.run(["rm", "-rf", CACHE_DIR], check=True)


def evaluation(args):

csv_file = "format_comparison_results.csv"
Expand All @@ -296,34 +294,34 @@ def evaluation(args):
logger.debug(f"Evaluating dataset: {dataset_name}")

handlers = [
VLAHandler(
args.exp_dir,
dataset_name,
args.num_batches,
args.batch_size,
args.log_frequency,
),
# VLAHandler(
# args.exp_dir,
# dataset_name,
# args.num_batches,
# args.batch_size,
# args.log_frequency,
# ),
HDF5Handler(
args.exp_dir,
dataset_name,
args.num_batches,
args.batch_size,
args.log_frequency,
),
LeRobotHandler(
args.exp_dir,
dataset_name,
args.num_batches,
args.batch_size,
args.log_frequency,
),
RLDSHandler(
args.exp_dir,
dataset_name,
args.num_batches,
args.batch_size,
args.log_frequency,
),
# LeRobotHandler(
# args.exp_dir,
# dataset_name,
# args.num_batches,
# args.batch_size,
# args.log_frequency,
# ),
# RLDSHandler(
# args.exp_dir,
# dataset_name,
# args.num_batches,
# args.batch_size,
# args.log_frequency,
# ),
]

for handler in handlers:
Expand Down Expand Up @@ -389,6 +387,4 @@ def evaluation(args):
)
args = parser.parse_args()

if args.prepare:
prepare(args)
evaluation(args)
10 changes: 5 additions & 5 deletions evaluation.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@ sudo echo "Use sudo access for clearning cache"
rm *.csv

# Define a list of batch sizes to iterate through
batch_sizes=(1 8 16 32)
batch_sizes=(1 8)
# batch_sizes=(1 2)

num_batches=10
num_batches=1000

# Iterate through each batch size
for batch_size in "${batch_sizes[@]}"
do
echo "Running benchmarks with batch size: $batch_size"

python3 benchmarks/openx.py --dataset_names nyu_door_opening_surprising_effectiveness --num_batches $num_batches --batch_size $batch_size
# python3 benchmarks/openx.py --dataset_names nyu_door_opening_surprising_effectiveness --num_batches $num_batches --batch_size $batch_size
python3 benchmarks/openx.py --dataset_names berkeley_autolab_ur5 --num_batches $num_batches --batch_size $batch_size
python3 benchmarks/openx.py --dataset_names berkeley_cable_routing --num_batches $num_batches --batch_size $batch_size
python3 benchmarks/openx.py --dataset_names bridge --num_batches $num_batches --batch_size $batch_size
# python3 benchmarks/openx.py --dataset_names berkeley_cable_routing --num_batches $num_batches --batch_size $batch_size
# python3 benchmarks/openx.py --dataset_names bridge --num_batches $num_batches --batch_size $batch_size
done
41 changes: 23 additions & 18 deletions fog_x/loader/hdf5.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import torch
from torch.utils.data import IterableDataset, DataLoader
from . import BaseLoader
import numpy as np
import glob
import numpy as np
import glob
import h5py
import asyncio
import random
import multiprocessing as mp
import time
import logging


# flatten the data such that all data starts with root level tree (observation and action)
def _flatten(data, parent_key='', sep='/'):
def _flatten(data, parent_key="", sep="/"):
items = {}
for k, v in data.items():
new_key = parent_key + sep + k if parent_key else k
Expand All @@ -20,15 +21,16 @@ def _flatten(data, parent_key='', sep='/'):
else:
items[new_key] = v
return items



def recursively_read_hdf5_group(group):
if isinstance(group, h5py.Dataset):
return np.array(group)
elif isinstance(group, h5py.Group):
return {key: recursively_read_hdf5_group(value) for key, value in group.items()}
else:
raise TypeError("Unsupported HDF5 group type")


class HDF5Loader(BaseLoader):
def __init__(self, path, batch_size=1, buffer_size=100, num_workers=4):
Expand Down Expand Up @@ -65,19 +67,23 @@ def get_batch(self):

while len(batch) < self.batch_size:
if time.time() - start_time > timeout:
logging.warning(f"Timeout reached while getting batch. Batch size: {len(batch)}")
logging.warning(
f"Timeout reached while getting batch. Batch size: {len(batch)}"
)
break

try:
item = self.buffer.get(timeout=1)
batch.append(item)
except mp.queues.Empty:
if all(not p.is_alive() for p in self.processes) and self.buffer.empty():
if (
all(not p.is_alive() for p in self.processes)
and self.buffer.empty()
):
if len(batch) == 0:
return None
else:
break

return batch

def __next__(self):
Expand All @@ -100,7 +106,7 @@ def _read_hdf5(self, data_path):

def __iter__(self):
return self

def __len__(self):
return len(self.files)

Expand All @@ -114,9 +120,11 @@ def __del__(self):
p.terminate()
p.join()


class HDF5IterableDataset(IterableDataset):
def __init__(self, path, batch_size=1):
self.hdf5_loader = HDF5Loader(path, batch_size)
# Note: batch size = 1 is to bypass the dataloader without pytorch dataloader
self.hdf5_loader = HDF5Loader(path, 1)

def __iter__(self):
return self
Expand All @@ -128,20 +136,17 @@ def __next__(self):
except StopIteration:
raise StopIteration


def hdf5_collate_fn(batch):
# Convert data to PyTorch tensors
return batch
return batch

def get_hdf5_dataloader(
path: str,
batch_size: int = 1,
num_workers: int = 0
):

def get_hdf5_dataloader(path: str, batch_size: int = 1, num_workers: int = 0):
dataset = HDF5IterableDataset(path, batch_size)
return DataLoader(
dataset,
batch_size=batch_size,
collate_fn=hdf5_collate_fn,
num_workers=num_workers
num_workers=num_workers,
)

14 changes: 9 additions & 5 deletions fog_x/loader/rlds.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
from . import BaseLoader
import numpy as np


class RLDSLoader(BaseLoader):
def __init__(self, path, split, batch_size=1, shuffle_buffer=50):
super(RLDSLoader, self).__init__(path)

try:
import tensorflow as tf
import tensorflow_datasets as tfds
except ImportError:
raise ImportError("Please install tensorflow and tensorflow_datasets to use rlds loader")
raise ImportError(
"Please install tensorflow and tensorflow_datasets to use rlds loader"
)

self.batch_size = batch_size
builder = tfds.builder_from_directory(path)
self.ds = builder.as_dataset(split)
self.length = len(self.ds)
self.ds = self.ds.shuffle(shuffle_buffer)
self.ds = self.ds.repeat()
self.ds = self.ds.shuffle(shuffle_buffer)
self.iterator = iter(self.ds)

self.split = split
Expand All @@ -27,11 +30,12 @@ def __len__(self):
import tensorflow as tf
except ImportError:
raise ImportError("Please install tensorflow to use rlds loader")

return self.length

def __iter__(self):
return self

def get_batch(self):
batch = self.ds.take(self.batch_size)
self.index += self.batch_size
Expand Down

0 comments on commit 3516491

Please sign in to comment.