Skip to content

Commit fcd8f2d

Browse files
committed
Refactor evaluation script for improved code organization and performance
1 parent c4d7150 commit fcd8f2d

File tree

4 files changed

+179
-33
lines changed

4 files changed

+179
-33
lines changed

benchmarks/Visualization.ipynb

Lines changed: 81 additions & 6 deletions
Large diffs are not rendered by default.

benchmarks/openx.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
"berkeley_autolab_ur5",
2323
"bridge",
2424
]
25-
DEFAULT_DATASET_NAMES = ["bridge"]
26-
CACHE_DIR = "/tmp/fog_x/cache/"
25+
# DEFAULT_DATASET_NAMES = ["bridge"]
26+
# CACHE_DIR = "/tmp/fog_x/cache/"
27+
CACHE_DIR = "/mnt/data/fog_x/cache/"
2728
DEFAULT_LOG_FREQUENCY = 20
2829

2930
# suppress tensorflow warnings
@@ -117,6 +118,7 @@ def write_result(self, format_name, elapsed_time, index):
117118
"Format": format_name,
118119
"AverageTrajectorySize(MB)": self.measure_average_trajectory_size(),
119120
"LoadingTime(s)": elapsed_time,
121+
"AverageLoadingTime(s)": elapsed_time / (index + 1),
120122
"Index": index,
121123
"BatchSize": self.batch_size,
122124
}
@@ -141,11 +143,11 @@ def measure_random_loading_time(self):
141143

142144
elapsed_time = time.time() - start_time
143145
self.write_result(
144-
f"{self.dataset_type.upper()}-RandomLoad", elapsed_time, batch_num
146+
f"{self.dataset_type.upper()}", elapsed_time, batch_num
145147
)
146148
if batch_num % self.log_frequency == 0:
147-
logger.debug(
148-
f"{self.dataset_type.upper()}-RandomLoad - Loaded {batch_num} random batches, Time: {elapsed_time:.2f} s"
149+
logger.info(
150+
f"{self.dataset_type.upper()} - Loaded {batch_num} random {self.batch_size} batches from {self.dataset_name}, Time: {elapsed_time:.2f} s, Average Time: {elapsed_time / (batch_num + 1):.2f} s"
149151
)
150152

151153
return time.time() - start_time
@@ -333,13 +335,16 @@ def evaluation(args):
333335
new_results.append(
334336
{
335337
"Dataset": dataset_name,
336-
"Format": f"{handler.dataset_type.upper()}-RandomLoad",
338+
"Format": f"{handler.dataset_type.upper()}",
337339
"AverageTrajectorySize(MB)": avg_traj_size,
338340
"LoadingTime(s)": random_load_time,
341+
"AverageLoadingTime(s)": random_load_time / (args.num_batches + 1),
342+
"Index": args.num_batches,
343+
"BatchSize": args.batch_size,
339344
}
340345
)
341346
logger.debug(
342-
f"{handler.dataset_type.upper()}-RandomLoad - Average Trajectory Size: {avg_traj_size:.2f} MB, Loading Time: {random_load_time:.2f} s"
347+
f"{handler.dataset_type.upper()} - Average Trajectory Size: {avg_traj_size:.2f} MB, Loading Time: {random_load_time:.2f} s"
343348
)
344349

345350
# Combine existing and new results
@@ -376,11 +381,11 @@ def evaluation(args):
376381
parser.add_argument(
377382
"--num_batches",
378383
type=int,
379-
default=1,
384+
default=1000,
380385
help="Number of batches to load for each loader.",
381386
)
382387
parser.add_argument(
383-
"--batch_size", type=int, default=8, help="Batch size for loaders."
388+
"--batch_size", type=int, default=16, help="Batch size for loaders."
384389
)
385390
args = parser.parse_args()
386391

evaluation.sh

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# ask for sudo access
2+
sudo echo "Use sudo access for clearning cache"
3+
4+
rm *.csv
5+
6+
# Define a list of batch sizes to iterate through
7+
batch_sizes=(1 8 16 32)
8+
# batch_sizes=(1 2)
9+
10+
num_batches=10
11+
12+
# Iterate through each batch size
13+
for batch_size in "${batch_sizes[@]}"
14+
do
15+
echo "Running benchmarks with batch size: $batch_size"
16+
17+
python3 benchmarks/openx.py --dataset_names nyu_door_opening_surprising_effectiveness --num_batches $num_batches --batch_size $batch_size
18+
python3 benchmarks/openx.py --dataset_names berkeley_autolab_ur5 --num_batches $num_batches --batch_size $batch_size
19+
python3 benchmarks/openx.py --dataset_names berkeley_cable_routing --num_batches $num_batches --batch_size $batch_size
20+
python3 benchmarks/openx.py --dataset_names bridge --num_batches $num_batches --batch_size $batch_size
21+
done

fog_x/loader/hdf5.py

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
import glob
66
import h5py
77
import asyncio
8+
import random
9+
import multiprocessing as mp
10+
import time
11+
import logging
812

913
# flatten the data such that all data starts with root level tree (observation and action)
1014
def _flatten(data, parent_key='', sep='/'):
@@ -27,33 +31,64 @@ def recursively_read_hdf5_group(group):
2731

2832

2933
class HDF5Loader(BaseLoader):
30-
def __init__(self, path, batch_size=1):
34+
def __init__(self, path, batch_size=1, buffer_size=100, num_workers=4):
3135
super(HDF5Loader, self).__init__(path)
32-
self.index = 0
3336
self.files = glob.glob(self.path, recursive=True)
3437
self.batch_size = batch_size
35-
async def _read_hdf5_async(self, data_path):
36-
return await asyncio.to_thread(self._read_hdf5, data_path)
37-
38-
async def get_batch(self):
39-
tasks = []
40-
for _ in range(self.batch_size):
41-
if self.index < len(self.files):
42-
file_path = self.files[self.index]
43-
self.index += 1
44-
tasks.append(self._read_hdf5_async(file_path))
45-
else:
38+
self.buffer_size = buffer_size
39+
self.buffer = mp.Queue(maxsize=buffer_size)
40+
self.num_workers = num_workers
41+
self.processes = []
42+
random.shuffle(self.files)
43+
self._start_workers()
44+
45+
def _worker(self):
46+
while True:
47+
if not self.files:
48+
logging.info("Worker finished")
49+
break
50+
file_path = random.choice(self.files)
51+
data = self._read_hdf5(file_path)
52+
self.buffer.put(data)
53+
54+
def _start_workers(self):
55+
for _ in range(self.num_workers):
56+
p = mp.Process(target=self._worker)
57+
p.start()
58+
logging.debug(f"Started worker {p.pid}")
59+
self.processes.append(p)
60+
61+
def get_batch(self):
62+
batch = []
63+
timeout = 5
64+
start_time = time.time()
65+
66+
while len(batch) < self.batch_size:
67+
if time.time() - start_time > timeout:
68+
logging.warning(f"Timeout reached while getting batch. Batch size: {len(batch)}")
4669
break
47-
return await asyncio.gather(*tasks)
70+
71+
try:
72+
item = self.buffer.get(timeout=1)
73+
batch.append(item)
74+
except mp.queues.Empty:
75+
if all(not p.is_alive() for p in self.processes) and self.buffer.empty():
76+
if len(batch) == 0:
77+
return None
78+
else:
79+
break
80+
81+
return batch
4882

4983
def __next__(self):
50-
if self.index >= len(self.files):
51-
self.index = 0
84+
batch = self.get_batch()
85+
if batch is None:
86+
random.shuffle(self.files)
87+
self._start_workers()
5288
raise StopIteration
53-
return asyncio.run(self.get_batch())
89+
return batch
5490

5591
def _read_hdf5(self, data_path):
56-
5792
with h5py.File(data_path, "r") as f:
5893
data_unflattened = recursively_read_hdf5_group(f)
5994

@@ -69,6 +104,16 @@ def __iter__(self):
69104
def __len__(self):
70105
return len(self.files)
71106

107+
def peek(self):
108+
if self.buffer.empty():
109+
return None
110+
return self.buffer.get()
111+
112+
def __del__(self):
113+
for p in self.processes:
114+
p.terminate()
115+
p.join()
116+
72117
class HDF5IterableDataset(IterableDataset):
73118
def __init__(self, path, batch_size=1):
74119
self.hdf5_loader = HDF5Loader(path, batch_size)

0 commit comments

Comments
 (0)