Skip to content

Commit

Permalink
Refactor evaluation script for improved code organization and perform…
Browse files Browse the repository at this point in the history
…ance
  • Loading branch information
KeplerC committed Sep 2, 2024
1 parent c4d7150 commit fcd8f2d
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 33 deletions.
87 changes: 81 additions & 6 deletions benchmarks/Visualization.ipynb

Large diffs are not rendered by default.

23 changes: 14 additions & 9 deletions benchmarks/openx.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
"berkeley_autolab_ur5",
"bridge",
]
DEFAULT_DATASET_NAMES = ["bridge"]
CACHE_DIR = "/tmp/fog_x/cache/"
# DEFAULT_DATASET_NAMES = ["bridge"]
# CACHE_DIR = "/tmp/fog_x/cache/"
CACHE_DIR = "/mnt/data/fog_x/cache/"
DEFAULT_LOG_FREQUENCY = 20

# suppress tensorflow warnings
Expand Down Expand Up @@ -117,6 +118,7 @@ def write_result(self, format_name, elapsed_time, index):
"Format": format_name,
"AverageTrajectorySize(MB)": self.measure_average_trajectory_size(),
"LoadingTime(s)": elapsed_time,
"AverageLoadingTime(s)": elapsed_time / (index + 1),
"Index": index,
"BatchSize": self.batch_size,
}
Expand All @@ -141,11 +143,11 @@ def measure_random_loading_time(self):

elapsed_time = time.time() - start_time
self.write_result(
f"{self.dataset_type.upper()}-RandomLoad", elapsed_time, batch_num
f"{self.dataset_type.upper()}", elapsed_time, batch_num
)
if batch_num % self.log_frequency == 0:
logger.debug(
f"{self.dataset_type.upper()}-RandomLoad - Loaded {batch_num} random batches, Time: {elapsed_time:.2f} s"
logger.info(
f"{self.dataset_type.upper()} - Loaded {batch_num} random {self.batch_size} batches from {self.dataset_name}, Time: {elapsed_time:.2f} s, Average Time: {elapsed_time / (batch_num + 1):.2f} s"
)

return time.time() - start_time
Expand Down Expand Up @@ -333,13 +335,16 @@ def evaluation(args):
new_results.append(
{
"Dataset": dataset_name,
"Format": f"{handler.dataset_type.upper()}-RandomLoad",
"Format": f"{handler.dataset_type.upper()}",
"AverageTrajectorySize(MB)": avg_traj_size,
"LoadingTime(s)": random_load_time,
"AverageLoadingTime(s)": random_load_time / (args.num_batches + 1),
"Index": args.num_batches,
"BatchSize": args.batch_size,
}
)
logger.debug(
f"{handler.dataset_type.upper()}-RandomLoad - Average Trajectory Size: {avg_traj_size:.2f} MB, Loading Time: {random_load_time:.2f} s"
f"{handler.dataset_type.upper()} - Average Trajectory Size: {avg_traj_size:.2f} MB, Loading Time: {random_load_time:.2f} s"
)

# Combine existing and new results
Expand Down Expand Up @@ -376,11 +381,11 @@ def evaluation(args):
parser.add_argument(
"--num_batches",
type=int,
default=1,
default=1000,
help="Number of batches to load for each loader.",
)
parser.add_argument(
"--batch_size", type=int, default=8, help="Batch size for loaders."
"--batch_size", type=int, default=16, help="Batch size for loaders."
)
args = parser.parse_args()

Expand Down
21 changes: 21 additions & 0 deletions evaluation.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# ask for sudo access
sudo echo "Use sudo access for clearning cache"

rm *.csv

# Define a list of batch sizes to iterate through
batch_sizes=(1 8 16 32)
# batch_sizes=(1 2)

num_batches=10

# Iterate through each batch size
for batch_size in "${batch_sizes[@]}"
do
echo "Running benchmarks with batch size: $batch_size"

python3 benchmarks/openx.py --dataset_names nyu_door_opening_surprising_effectiveness --num_batches $num_batches --batch_size $batch_size
python3 benchmarks/openx.py --dataset_names berkeley_autolab_ur5 --num_batches $num_batches --batch_size $batch_size
python3 benchmarks/openx.py --dataset_names berkeley_cable_routing --num_batches $num_batches --batch_size $batch_size
python3 benchmarks/openx.py --dataset_names bridge --num_batches $num_batches --batch_size $batch_size
done
81 changes: 63 additions & 18 deletions fog_x/loader/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
import glob
import h5py
import asyncio
import random
import multiprocessing as mp
import time
import logging

# flatten the data such that all data starts with root level tree (observation and action)
def _flatten(data, parent_key='', sep='/'):
Expand All @@ -27,33 +31,64 @@ def recursively_read_hdf5_group(group):


class HDF5Loader(BaseLoader):
def __init__(self, path, batch_size=1):
def __init__(self, path, batch_size=1, buffer_size=100, num_workers=4):
super(HDF5Loader, self).__init__(path)
self.index = 0
self.files = glob.glob(self.path, recursive=True)
self.batch_size = batch_size
async def _read_hdf5_async(self, data_path):
return await asyncio.to_thread(self._read_hdf5, data_path)

async def get_batch(self):
tasks = []
for _ in range(self.batch_size):
if self.index < len(self.files):
file_path = self.files[self.index]
self.index += 1
tasks.append(self._read_hdf5_async(file_path))
else:
self.buffer_size = buffer_size
self.buffer = mp.Queue(maxsize=buffer_size)
self.num_workers = num_workers
self.processes = []
random.shuffle(self.files)
self._start_workers()

def _worker(self):
while True:
if not self.files:
logging.info("Worker finished")
break
file_path = random.choice(self.files)
data = self._read_hdf5(file_path)
self.buffer.put(data)

def _start_workers(self):
for _ in range(self.num_workers):
p = mp.Process(target=self._worker)
p.start()
logging.debug(f"Started worker {p.pid}")
self.processes.append(p)

def get_batch(self):
batch = []
timeout = 5
start_time = time.time()

while len(batch) < self.batch_size:
if time.time() - start_time > timeout:
logging.warning(f"Timeout reached while getting batch. Batch size: {len(batch)}")
break
return await asyncio.gather(*tasks)

try:
item = self.buffer.get(timeout=1)
batch.append(item)
except mp.queues.Empty:
if all(not p.is_alive() for p in self.processes) and self.buffer.empty():
if len(batch) == 0:
return None
else:
break

return batch

def __next__(self):
if self.index >= len(self.files):
self.index = 0
batch = self.get_batch()
if batch is None:
random.shuffle(self.files)
self._start_workers()
raise StopIteration
return asyncio.run(self.get_batch())
return batch

def _read_hdf5(self, data_path):

with h5py.File(data_path, "r") as f:
data_unflattened = recursively_read_hdf5_group(f)

Expand All @@ -69,6 +104,16 @@ def __iter__(self):
def __len__(self):
return len(self.files)

def peek(self):
if self.buffer.empty():
return None
return self.buffer.get()

def __del__(self):
for p in self.processes:
p.terminate()
p.join()

class HDF5IterableDataset(IterableDataset):
def __init__(self, path, batch_size=1):
self.hdf5_loader = HDF5Loader(path, batch_size)
Expand Down

0 comments on commit fcd8f2d

Please sign in to comment.