Skip to content

Commit cf579af

Browse files
authored
Merge pull request official-stockfish#259 from Sopel97/builtin_interleave
Add support for multiple datasets in the trainer. The datasets are in…
2 parents 2aa01ec + 32b6986 commit cf579af

7 files changed

+181
-124
lines changed

lib/nnue_training_data_formats.h

+52-12
Original file line numberDiff line numberDiff line change
@@ -6726,6 +6726,12 @@ namespace binpack
67266726
}
67276727
}
67286728

6729+
inline std::ifstream::pos_type filesize(const char* filename)
6730+
{
6731+
std::ifstream in(filename, std::ifstream::ate | std::ifstream::binary);
6732+
return in.tellg();
6733+
}
6734+
67296735
struct CompressedTrainingDataFile
67306736
{
67316737
struct Header
@@ -6737,12 +6743,15 @@ namespace binpack
67376743
m_path(std::move(path)),
67386744
m_file(m_path, std::ios_base::binary | std::ios_base::in | std::ios_base::out | om)
67396745
{
6746+
// Racey but who cares
6747+
m_sizeBytes = filesize(m_path.c_str());
67406748
}
67416749

67426750
void append(const char* data, std::uint32_t size)
67436751
{
67446752
writeChunkHeader({size});
67456753
m_file.write(data, size);
6754+
m_sizeBytes += size + 8;
67466755
}
67476756

67486757
[[nodiscard]] bool hasNextChunk()
@@ -6756,6 +6765,11 @@ namespace binpack
67566765
return !m_file.eof();
67576766
}
67586767

6768+
void seek_to_start()
6769+
{
6770+
m_file.seekg(0);
6771+
}
6772+
67596773
[[nodiscard]] std::vector<unsigned char> readNextChunk()
67606774
{
67616775
auto size = readChunkHeader().chunkSize;
@@ -6764,9 +6778,15 @@ namespace binpack
67646778
return data;
67656779
}
67666780

6781+
[[nodiscard]] std::size_t sizeBytes() const
6782+
{
6783+
return m_sizeBytes;
6784+
}
6785+
67676786
private:
67686787
std::string m_path;
67696788
std::fstream m_file;
6789+
std::size_t m_sizeBytes;
67706790

67716791
void writeChunkHeader(Header h)
67726792
{
@@ -7558,21 +7578,32 @@ namespace binpack
75587578

75597579
CompressedTrainingDataEntryParallelReader(
75607580
int concurrency,
7561-
std::string path,
7581+
std::vector<std::string> paths,
75627582
std::ios_base::openmode om = std::ios_base::app,
7583+
bool cyclic = false,
75637584
std::function<bool(const TrainingDataEntry&)> skipPredicate = nullptr
75647585
) :
75657586
m_concurrency(concurrency),
7566-
m_inputFile(path, om),
75677587
m_bufferOffset(0),
7588+
m_cyclic(cyclic),
75687589
m_skipPredicate(std::move(skipPredicate))
75697590
{
75707591
m_numRunningWorkers.store(0);
7571-
if (!m_inputFile.hasNextChunk())
7592+
std::vector<double> sizes; // discrete distribution wants double weights
7593+
for (const auto& path : paths)
75727594
{
7573-
return;
7595+
auto& file = m_inputFiles.emplace_back(path, om);
7596+
7597+
if (!file.hasNextChunk())
7598+
{
7599+
return;
7600+
}
7601+
7602+
sizes.emplace_back(static_cast<double>(file.sizeBytes()));
75747603
}
75757604

7605+
m_inputFileDistribution = std::discrete_distribution<>(sizes.begin(), sizes.end());
7606+
75767607
m_stopFlag.store(false);
75777608

75787609
auto worker = [this]()
@@ -7742,8 +7773,10 @@ namespace binpack
77427773

77437774
private:
77447775
int m_concurrency;
7745-
CompressedTrainingDataFile m_inputFile;
7776+
std::vector<CompressedTrainingDataFile> m_inputFiles;
7777+
std::discrete_distribution<> m_inputFileDistribution;
77467778
std::atomic_int m_numRunningWorkers;
7779+
bool m_cyclic;
77477780

77487781
static constexpr int threadBufferSize = 256 * 256 * 16;
77497782

@@ -7763,17 +7796,24 @@ namespace binpack
77637796
{
77647797
if (m_offset + sizeof(PackedTrainingDataEntry) + 2 > m_chunk.size())
77657798
{
7799+
auto& prng = rng::get_thread_local_rng();
7800+
const std::size_t fileId = m_inputFileDistribution(prng);
7801+
auto& inputFile = m_inputFiles[fileId];
7802+
77667803
std::unique_lock lock(m_fileMutex);
77677804

7768-
if (!m_inputFile.hasNextChunk())
7769-
{
7770-
return true;
7771-
}
7772-
else
7805+
if (!inputFile.hasNextChunk())
77737806
{
7774-
m_chunk = m_inputFile.readNextChunk();
7775-
m_offset = 0;
7807+
if (m_cyclic)
7808+
{
7809+
inputFile.seek_to_start();
7810+
}
7811+
else
7812+
return true;
77767813
}
7814+
7815+
m_chunk = inputFile.readNextChunk();
7816+
m_offset = 0;
77777817
}
77787818

77797819
return false;

lib/nnue_training_data_stream.h

+10-41
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,9 @@ namespace training_data {
183183
static constexpr auto openmode = std::ios::in | std::ios::binary;
184184
static inline const std::string extension = "binpack";
185185

186-
BinpackSfenInputParallelStream(int concurrency, std::string filename, bool cyclic, std::function<bool(const TrainingDataEntry&)> skipPredicate) :
187-
m_stream(std::make_unique<binpack::CompressedTrainingDataEntryParallelReader>(concurrency, filename, openmode, skipPredicate)),
188-
m_filename(filename),
186+
BinpackSfenInputParallelStream(int concurrency, const std::vector<std::string>& filenames, bool cyclic, std::function<bool(const TrainingDataEntry&)> skipPredicate) :
187+
m_stream(std::make_unique<binpack::CompressedTrainingDataEntryParallelReader>(concurrency, filenames, openmode, cyclic, skipPredicate)),
188+
m_filenames(filenames),
189189
m_concurrency(concurrency),
190190
m_eof(false),
191191
m_cyclic(cyclic),
@@ -199,12 +199,6 @@ namespace training_data {
199199
auto v = m_stream->next();
200200
if (!v.has_value())
201201
{
202-
if (m_cyclic)
203-
{
204-
m_stream = std::make_unique<binpack::CompressedTrainingDataEntryParallelReader>(m_concurrency, m_filename, openmode, m_skipPredicate);
205-
return m_stream->next();
206-
}
207-
208202
m_eof = true;
209203
return std::nullopt;
210204
}
@@ -217,32 +211,7 @@ namespace training_data {
217211
auto k = m_stream->fill(v, n);
218212
if (n != k)
219213
{
220-
if (m_cyclic)
221-
{
222-
m_stream = std::make_unique<binpack::CompressedTrainingDataEntryParallelReader>(m_concurrency, m_filename, openmode, m_skipPredicate);
223-
n -= k;
224-
k = m_stream->fill(v, n);
225-
if (k == 0)
226-
{
227-
// No data in the file
228-
m_eof = true;
229-
return;
230-
}
231-
else if (k == n)
232-
{
233-
// We're done
234-
return;
235-
}
236-
else
237-
{
238-
// We need to read again
239-
this->fill(v, n - k);
240-
}
241-
}
242-
else
243-
{
244-
m_eof = true;
245-
}
214+
m_eof = true;
246215
}
247216
}
248217

@@ -255,7 +224,7 @@ namespace training_data {
255224

256225
private:
257226
std::unique_ptr<binpack::CompressedTrainingDataEntryParallelReader> m_stream;
258-
std::string m_filename;
227+
std::vector<std::string> m_filenames;
259228
int m_concurrency;
260229
bool m_eof;
261230
bool m_cyclic;
@@ -272,13 +241,13 @@ namespace training_data {
272241
return nullptr;
273242
}
274243

275-
inline std::unique_ptr<BasicSfenInputStream> open_sfen_input_file_parallel(int concurrency, const std::string& filename, bool cyclic, std::function<bool(const TrainingDataEntry&)> skipPredicate = nullptr)
244+
inline std::unique_ptr<BasicSfenInputStream> open_sfen_input_file_parallel(int concurrency, const std::vector<std::string>& filenames, bool cyclic, std::function<bool(const TrainingDataEntry&)> skipPredicate = nullptr)
276245
{
277246
// TODO (low priority): optimize and parallelize .bin reading.
278-
if (has_extension(filename, BinSfenInputStream::extension))
279-
return std::make_unique<BinSfenInputStream>(filename, cyclic, std::move(skipPredicate));
280-
else if (has_extension(filename, BinpackSfenInputParallelStream::extension))
281-
return std::make_unique<BinpackSfenInputParallelStream>(concurrency, filename, cyclic, std::move(skipPredicate));
247+
if (has_extension(filenames[0], BinSfenInputStream::extension))
248+
return std::make_unique<BinSfenInputStream>(filenames[0], cyclic, std::move(skipPredicate));
249+
else if (has_extension(filenames[0], BinpackSfenInputParallelStream::extension))
250+
return std::make_unique<BinpackSfenInputParallelStream>(concurrency, filenames, cyclic, std::move(skipPredicate));
282251

283252
return nullptr;
284253
}

nnue_dataset.py

+26-16
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,18 @@ def get_fens(self):
6767
return strings
6868

6969
FenBatchPtr = ctypes.POINTER(FenBatch)
70-
# EXPORT FenBatchStream* CDECL create_fen_batch_stream(int concurrency, const char* filename, int batch_size, bool cyclic, bool filtered, int random_fen_skipping, bool wld_filtered, int param_index)
70+
# EXPORT FenBatchStream* CDECL create_fen_batch_stream(int concurrency, int num_files, const char* const* filenames, int batch_size, bool cyclic, bool filtered, int random_fen_skipping, bool wld_filtered, int early_fen_skipping, int param_index)
7171
create_fen_batch_stream = dll.create_fen_batch_stream
7272
create_fen_batch_stream.restype = ctypes.c_void_p
73-
create_fen_batch_stream.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_bool, ctypes.c_bool, ctypes.c_int, ctypes.c_bool, ctypes.c_int, ctypes.c_int]
73+
create_fen_batch_stream.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_char_p), ctypes.c_int, ctypes.c_bool, ctypes.c_bool, ctypes.c_int, ctypes.c_bool, ctypes.c_int, ctypes.c_int]
7474
destroy_fen_batch_stream = dll.destroy_fen_batch_stream
7575
destroy_fen_batch_stream.argtypes = [ctypes.c_void_p]
7676

77+
def make_fen_batch_stream(concurrency, filenames, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index):
78+
filenames_ = (ctypes.c_char_p * len(filenames))()
79+
filenames_[:] = [filename.encode('utf-8') for filename in filenames]
80+
return create_fen_batch_stream(concurrency, len(filenames), filenames_, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
81+
7782
fetch_next_fen_batch = dll.fetch_next_fen_batch
7883
fetch_next_fen_batch.restype = FenBatchPtr
7984
fetch_next_fen_batch.argtypes = [ctypes.c_void_p]
@@ -103,9 +108,9 @@ def __init__(
103108
self.param_index = param_index
104109

105110
if batch_size:
106-
self.stream = create_fen_batch_stream(self.num_workers, self.filename, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
111+
self.stream = make_fen_batch_stream(self.num_workers, [self.filename], batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
107112
else:
108-
self.stream = create_fen_batch_stream(self.num_workers, self.filename, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
113+
self.stream = make_fen_batch_stream(self.num_workers, [self.filename], cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
109114

110115
def __iter__(self):
111116
return self
@@ -131,7 +136,7 @@ def __init__(
131136
destroy_stream,
132137
fetch_next,
133138
destroy_part,
134-
filename,
139+
filenames,
135140
cyclic,
136141
num_workers,
137142
batch_size=None,
@@ -147,7 +152,7 @@ def __init__(
147152
self.destroy_stream = destroy_stream
148153
self.fetch_next = fetch_next
149154
self.destroy_part = destroy_part
150-
self.filename = filename.encode('utf-8')
155+
self.filenames = filenames
151156
self.cyclic = cyclic
152157
self.num_workers = num_workers
153158
self.batch_size = batch_size
@@ -158,9 +163,9 @@ def __init__(
158163
self.device = device
159164

160165
if batch_size:
161-
self.stream = self.create_stream(self.feature_set, self.num_workers, self.filename, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
166+
self.stream = self.create_stream(self.feature_set, self.num_workers, self.filenames, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
162167
else:
163-
self.stream = self.create_stream(self.feature_set, self.num_workers, self.filename, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
168+
self.stream = self.create_stream(self.feature_set, self.num_workers, self.filenames, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
164169

165170
def __iter__(self):
166171
return self
@@ -178,14 +183,19 @@ def __next__(self):
178183
def __del__(self):
179184
self.destroy_stream(self.stream)
180185

181-
# EXPORT Stream<SparseBatch>* CDECL create_sparse_batch_stream(const char* feature_set_c, int concurrency, const char* filename, int batch_size, bool cyclic,
186+
# EXPORT Stream<SparseBatch>* CDECL create_sparse_batch_stream(const char* feature_set_c, int concurrency, int num_files, const char* const* filenames, int batch_size, bool cyclic,
182187
# bool filtered, int random_fen_skipping, bool wld_filtered, int early_fen_skipping, int param_index)
183188
create_sparse_batch_stream = dll.create_sparse_batch_stream
184189
create_sparse_batch_stream.restype = ctypes.c_void_p
185-
create_sparse_batch_stream.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_bool, ctypes.c_bool, ctypes.c_int, ctypes.c_bool, ctypes.c_int, ctypes.c_int]
190+
create_sparse_batch_stream.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_char_p), ctypes.c_int, ctypes.c_bool, ctypes.c_bool, ctypes.c_int, ctypes.c_bool, ctypes.c_int, ctypes.c_int]
186191
destroy_sparse_batch_stream = dll.destroy_sparse_batch_stream
187192
destroy_sparse_batch_stream.argtypes = [ctypes.c_void_p]
188193

194+
def make_sparse_batch_stream(feature_set, concurrency, filenames, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index):
195+
filenames_ = (ctypes.c_char_p * len(filenames))()
196+
filenames_[:] = [filename.encode('utf-8') for filename in filenames]
197+
return create_sparse_batch_stream(feature_set, concurrency, len(filenames), filenames_, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
198+
189199
fetch_next_sparse_batch = dll.fetch_next_sparse_batch
190200
fetch_next_sparse_batch.restype = SparseBatchPtr
191201
fetch_next_sparse_batch.argtypes = [ctypes.c_void_p]
@@ -211,14 +221,14 @@ def make_sparse_batch_from_fens(feature_set, fens, scores, plies, results):
211221
return b
212222

213223
class SparseBatchProvider(TrainingDataProvider):
214-
def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1, filtered=False, random_fen_skipping=0, wld_filtered=False, early_fen_skipping=-1, param_index=0, device='cpu'):
224+
def __init__(self, feature_set, filenames, batch_size, cyclic=True, num_workers=1, filtered=False, random_fen_skipping=0, wld_filtered=False, early_fen_skipping=-1, param_index=0, device='cpu'):
215225
super(SparseBatchProvider, self).__init__(
216226
feature_set,
217-
create_sparse_batch_stream,
227+
make_sparse_batch_stream,
218228
destroy_sparse_batch_stream,
219229
fetch_next_sparse_batch,
220230
destroy_sparse_batch,
221-
filename,
231+
filenames,
222232
cyclic,
223233
num_workers,
224234
batch_size,
@@ -230,10 +240,10 @@ def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1
230240
device)
231241

232242
class SparseBatchDataset(torch.utils.data.IterableDataset):
233-
def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1, filtered=False, random_fen_skipping=0, wld_filtered=False, early_fen_skipping=-1, param_index=0, device='cpu'):
243+
def __init__(self, feature_set, filenames, batch_size, cyclic=True, num_workers=1, filtered=False, random_fen_skipping=0, wld_filtered=False, early_fen_skipping=-1, param_index=0, device='cpu'):
234244
super(SparseBatchDataset).__init__()
235245
self.feature_set = feature_set
236-
self.filename = filename
246+
self.filenames = filenames
237247
self.batch_size = batch_size
238248
self.cyclic = cyclic
239249
self.num_workers = num_workers
@@ -245,7 +255,7 @@ def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1
245255
self.device = device
246256

247257
def __iter__(self):
248-
return SparseBatchProvider(self.feature_set, self.filename, self.batch_size, cyclic=self.cyclic, num_workers=self.num_workers,
258+
return SparseBatchProvider(self.feature_set, self.filenames, self.batch_size, cyclic=self.cyclic, num_workers=self.num_workers,
249259
filtered=self.filtered, random_fen_skipping=self.random_fen_skipping, wld_filtered=self.wld_filtered, early_fen_skipping = self.early_fen_skipping, param_index=self.param_index, device=self.device)
250260

251261
class FixedNumBatchesDataset(Dataset):

0 commit comments

Comments
 (0)