55import glob
66import h5py
77import asyncio
8+ import random
9+ import multiprocessing as mp
10+ import time
11+ import logging
812
913# flatten the data such that all data starts with root level tree (observation and action)
1014def _flatten (data , parent_key = '' , sep = '/' ):
@@ -27,33 +31,64 @@ def recursively_read_hdf5_group(group):
2731
2832
2933class HDF5Loader (BaseLoader ):
30- def __init__ (self , path , batch_size = 1 ):
34+ def __init__ (self , path , batch_size = 1 , buffer_size = 100 , num_workers = 4 ):
3135 super (HDF5Loader , self ).__init__ (path )
32- self .index = 0
3336 self .files = glob .glob (self .path , recursive = True )
3437 self .batch_size = batch_size
35- async def _read_hdf5_async (self , data_path ):
36- return await asyncio .to_thread (self ._read_hdf5 , data_path )
37-
38- async def get_batch (self ):
39- tasks = []
40- for _ in range (self .batch_size ):
41- if self .index < len (self .files ):
42- file_path = self .files [self .index ]
43- self .index += 1
44- tasks .append (self ._read_hdf5_async (file_path ))
45- else :
38+ self .buffer_size = buffer_size
39+ self .buffer = mp .Queue (maxsize = buffer_size )
40+ self .num_workers = num_workers
41+ self .processes = []
42+ random .shuffle (self .files )
43+ self ._start_workers ()
44+
45+ def _worker (self ):
46+ while True :
47+ if not self .files :
48+ logging .info ("Worker finished" )
49+ break
50+ file_path = random .choice (self .files )
51+ data = self ._read_hdf5 (file_path )
52+ self .buffer .put (data )
53+
54+ def _start_workers (self ):
55+ for _ in range (self .num_workers ):
56+ p = mp .Process (target = self ._worker )
57+ p .start ()
58+ logging .debug (f"Started worker { p .pid } " )
59+ self .processes .append (p )
60+
61+ def get_batch (self ):
62+ batch = []
63+ timeout = 5
64+ start_time = time .time ()
65+
66+ while len (batch ) < self .batch_size :
67+ if time .time () - start_time > timeout :
68+ logging .warning (f"Timeout reached while getting batch. Batch size: { len (batch )} " )
4669 break
47- return await asyncio .gather (* tasks )
70+
71+ try :
72+ item = self .buffer .get (timeout = 1 )
73+ batch .append (item )
74+ except mp .queues .Empty :
75+ if all (not p .is_alive () for p in self .processes ) and self .buffer .empty ():
76+ if len (batch ) == 0 :
77+ return None
78+ else :
79+ break
80+
81+ return batch
4882
4983 def __next__ (self ):
50- if self .index >= len (self .files ):
51- self .index = 0
84+ batch = self .get_batch ()
85+ if batch is None :
86+ random .shuffle (self .files )
87+ self ._start_workers ()
5288 raise StopIteration
53- return asyncio . run ( self . get_batch ())
89+ return batch
5490
5591 def _read_hdf5 (self , data_path ):
56-
5792 with h5py .File (data_path , "r" ) as f :
5893 data_unflattened = recursively_read_hdf5_group (f )
5994
@@ -69,6 +104,16 @@ def __iter__(self):
69104 def __len__ (self ):
70105 return len (self .files )
71106
107+ def peek (self ):
108+ if self .buffer .empty ():
109+ return None
110+ return self .buffer .get ()
111+
112+ def __del__ (self ):
113+ for p in self .processes :
114+ p .terminate ()
115+ p .join ()
116+
72117class HDF5IterableDataset (IterableDataset ):
73118 def __init__ (self , path , batch_size = 1 ):
74119 self .hdf5_loader = HDF5Loader (path , batch_size )
0 commit comments