@@ -67,13 +67,18 @@ def get_fens(self):
67
67
return strings
68
68
69
69
FenBatchPtr = ctypes .POINTER (FenBatch )
70
- # EXPORT FenBatchStream* CDECL create_fen_batch_stream(int concurrency, const char* filename , int batch_size, bool cyclic, bool filtered, int random_fen_skipping, bool wld_filtered, int param_index)
70
+ # EXPORT FenBatchStream* CDECL create_fen_batch_stream(int concurrency, int num_files, const char* const* filenames , int batch_size, bool cyclic, bool filtered, int random_fen_skipping, bool wld_filtered, int early_fen_skipping , int param_index)
71
71
create_fen_batch_stream = dll .create_fen_batch_stream
72
72
create_fen_batch_stream .restype = ctypes .c_void_p
73
- create_fen_batch_stream .argtypes = [ctypes .c_int , ctypes .c_char_p , ctypes .c_int , ctypes .c_bool , ctypes .c_bool , ctypes .c_int , ctypes .c_bool , ctypes .c_int , ctypes .c_int ]
73
+ create_fen_batch_stream .argtypes = [ctypes .c_int , ctypes .c_int , ctypes . POINTER ( ctypes . c_char_p ) , ctypes .c_int , ctypes .c_bool , ctypes .c_bool , ctypes .c_int , ctypes .c_bool , ctypes .c_int , ctypes .c_int ]
74
74
destroy_fen_batch_stream = dll .destroy_fen_batch_stream
75
75
destroy_fen_batch_stream .argtypes = [ctypes .c_void_p ]
76
76
77
+ def make_fen_batch_stream (concurrency , filenames , batch_size , cyclic , filtered , random_fen_skipping , wld_filtered , early_fen_skipping , param_index ):
78
+ filenames_ = (ctypes .c_char_p * len (filenames ))()
79
+ filenames_ [:] = [filename .encode ('utf-8' ) for filename in filenames ]
80
+ return create_fen_batch_stream (concurrency , len (filenames ), filenames_ , batch_size , cyclic , filtered , random_fen_skipping , wld_filtered , early_fen_skipping , param_index )
81
+
77
82
fetch_next_fen_batch = dll .fetch_next_fen_batch
78
83
fetch_next_fen_batch .restype = FenBatchPtr
79
84
fetch_next_fen_batch .argtypes = [ctypes .c_void_p ]
@@ -103,9 +108,9 @@ def __init__(
103
108
self .param_index = param_index
104
109
105
110
if batch_size :
106
- self .stream = create_fen_batch_stream (self .num_workers , self .filename , batch_size , cyclic , filtered , random_fen_skipping , wld_filtered , early_fen_skipping , param_index )
111
+ self .stream = make_fen_batch_stream (self .num_workers , [ self .filename ] , batch_size , cyclic , filtered , random_fen_skipping , wld_filtered , early_fen_skipping , param_index )
107
112
else :
108
- self .stream = create_fen_batch_stream (self .num_workers , self .filename , cyclic , filtered , random_fen_skipping , wld_filtered , early_fen_skipping , param_index )
113
+ self .stream = make_fen_batch_stream (self .num_workers , [ self .filename ] , cyclic , filtered , random_fen_skipping , wld_filtered , early_fen_skipping , param_index )
109
114
110
115
def __iter__ (self ):
111
116
return self
@@ -131,7 +136,7 @@ def __init__(
131
136
destroy_stream ,
132
137
fetch_next ,
133
138
destroy_part ,
134
- filename ,
139
+ filenames ,
135
140
cyclic ,
136
141
num_workers ,
137
142
batch_size = None ,
@@ -147,7 +152,7 @@ def __init__(
147
152
self .destroy_stream = destroy_stream
148
153
self .fetch_next = fetch_next
149
154
self .destroy_part = destroy_part
150
- self .filename = filename . encode ( 'utf-8' )
155
+ self .filenames = filenames
151
156
self .cyclic = cyclic
152
157
self .num_workers = num_workers
153
158
self .batch_size = batch_size
@@ -158,9 +163,9 @@ def __init__(
158
163
self .device = device
159
164
160
165
if batch_size :
161
- self .stream = self .create_stream (self .feature_set , self .num_workers , self .filename , batch_size , cyclic , filtered , random_fen_skipping , wld_filtered , early_fen_skipping , param_index )
166
+ self .stream = self .create_stream (self .feature_set , self .num_workers , self .filenames , batch_size , cyclic , filtered , random_fen_skipping , wld_filtered , early_fen_skipping , param_index )
162
167
else :
163
- self .stream = self .create_stream (self .feature_set , self .num_workers , self .filename , cyclic , filtered , random_fen_skipping , wld_filtered , early_fen_skipping , param_index )
168
+ self .stream = self .create_stream (self .feature_set , self .num_workers , self .filenames , cyclic , filtered , random_fen_skipping , wld_filtered , early_fen_skipping , param_index )
164
169
165
170
def __iter__ (self ):
166
171
return self
@@ -178,14 +183,19 @@ def __next__(self):
178
183
def __del__ (self ):
179
184
self .destroy_stream (self .stream )
180
185
181
- # EXPORT Stream<SparseBatch>* CDECL create_sparse_batch_stream(const char* feature_set_c, int concurrency, const char* filename , int batch_size, bool cyclic,
186
+ # EXPORT Stream<SparseBatch>* CDECL create_sparse_batch_stream(const char* feature_set_c, int concurrency, int num_files, const char* const* filenames , int batch_size, bool cyclic,
182
187
# bool filtered, int random_fen_skipping, bool wld_filtered, int early_fen_skipping, int param_index)
183
188
create_sparse_batch_stream = dll .create_sparse_batch_stream
184
189
create_sparse_batch_stream .restype = ctypes .c_void_p
185
- create_sparse_batch_stream .argtypes = [ctypes .c_char_p , ctypes .c_int , ctypes .c_char_p , ctypes .c_int , ctypes .c_bool , ctypes .c_bool , ctypes .c_int , ctypes .c_bool , ctypes .c_int , ctypes .c_int ]
190
+ create_sparse_batch_stream .argtypes = [ctypes .c_char_p , ctypes .c_int , ctypes .c_int , ctypes . POINTER ( ctypes . c_char_p ) , ctypes .c_int , ctypes .c_bool , ctypes .c_bool , ctypes .c_int , ctypes .c_bool , ctypes .c_int , ctypes .c_int ]
186
191
destroy_sparse_batch_stream = dll .destroy_sparse_batch_stream
187
192
destroy_sparse_batch_stream .argtypes = [ctypes .c_void_p ]
188
193
194
+ def make_sparse_batch_stream (feature_set , concurrency , filenames , batch_size , cyclic , filtered , random_fen_skipping , wld_filtered , early_fen_skipping , param_index ):
195
+ filenames_ = (ctypes .c_char_p * len (filenames ))()
196
+ filenames_ [:] = [filename .encode ('utf-8' ) for filename in filenames ]
197
+ return create_sparse_batch_stream (feature_set , concurrency , len (filenames ), filenames_ , batch_size , cyclic , filtered , random_fen_skipping , wld_filtered , early_fen_skipping , param_index )
198
+
189
199
fetch_next_sparse_batch = dll .fetch_next_sparse_batch
190
200
fetch_next_sparse_batch .restype = SparseBatchPtr
191
201
fetch_next_sparse_batch .argtypes = [ctypes .c_void_p ]
@@ -211,14 +221,14 @@ def make_sparse_batch_from_fens(feature_set, fens, scores, plies, results):
211
221
return b
212
222
213
223
class SparseBatchProvider (TrainingDataProvider ):
214
- def __init__ (self , feature_set , filename , batch_size , cyclic = True , num_workers = 1 , filtered = False , random_fen_skipping = 0 , wld_filtered = False , early_fen_skipping = - 1 , param_index = 0 , device = 'cpu' ):
224
+ def __init__ (self , feature_set , filenames , batch_size , cyclic = True , num_workers = 1 , filtered = False , random_fen_skipping = 0 , wld_filtered = False , early_fen_skipping = - 1 , param_index = 0 , device = 'cpu' ):
215
225
super (SparseBatchProvider , self ).__init__ (
216
226
feature_set ,
217
- create_sparse_batch_stream ,
227
+ make_sparse_batch_stream ,
218
228
destroy_sparse_batch_stream ,
219
229
fetch_next_sparse_batch ,
220
230
destroy_sparse_batch ,
221
- filename ,
231
+ filenames ,
222
232
cyclic ,
223
233
num_workers ,
224
234
batch_size ,
@@ -230,10 +240,10 @@ def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1
230
240
device )
231
241
232
242
class SparseBatchDataset (torch .utils .data .IterableDataset ):
233
- def __init__ (self , feature_set , filename , batch_size , cyclic = True , num_workers = 1 , filtered = False , random_fen_skipping = 0 , wld_filtered = False , early_fen_skipping = - 1 , param_index = 0 , device = 'cpu' ):
243
+ def __init__ (self , feature_set , filenames , batch_size , cyclic = True , num_workers = 1 , filtered = False , random_fen_skipping = 0 , wld_filtered = False , early_fen_skipping = - 1 , param_index = 0 , device = 'cpu' ):
234
244
super (SparseBatchDataset ).__init__ ()
235
245
self .feature_set = feature_set
236
- self .filename = filename
246
+ self .filenames = filenames
237
247
self .batch_size = batch_size
238
248
self .cyclic = cyclic
239
249
self .num_workers = num_workers
@@ -245,7 +255,7 @@ def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1
245
255
self .device = device
246
256
247
257
def __iter__ (self ):
248
- return SparseBatchProvider (self .feature_set , self .filename , self .batch_size , cyclic = self .cyclic , num_workers = self .num_workers ,
258
+ return SparseBatchProvider (self .feature_set , self .filenames , self .batch_size , cyclic = self .cyclic , num_workers = self .num_workers ,
249
259
filtered = self .filtered , random_fen_skipping = self .random_fen_skipping , wld_filtered = self .wld_filtered , early_fen_skipping = self .early_fen_skipping , param_index = self .param_index , device = self .device )
250
260
251
261
class FixedNumBatchesDataset (Dataset ):
0 commit comments