diff --git a/python/data_processing_pytorch.py b/python/data_processing_pytorch.py index 55fcc585a..c47b55fe5 100644 --- a/python/data_processing_pytorch.py +++ b/python/data_processing_pytorch.py @@ -8,6 +8,9 @@ import modelconfigs +import threading +import concurrent.futures + def read_npz_training_data( npz_files, batch_size: int, @@ -24,7 +27,11 @@ def read_npz_training_data( num_global_features = modelconfigs.get_num_global_input_features(model_config) (h_base,h_builder) = build_history_matrices(model_config, device) - for npz_file in npz_files: + #create loading file thread + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + future = None + + def load_npz_file(npz_file): with np.load(npz_file) as npz: binaryInputNCHWPacked = npz["binaryInputNCHWPacked"] globalInputNC = npz["globalInputNC"] @@ -47,6 +54,19 @@ def read_npz_training_data( assert binaryInputNCHW.shape[1] == num_bin_features assert globalInputNC.shape[1] == num_global_features + return [binaryInputNCHW, globalInputNC, policyTargetsNCMove, globalTargetsNC, scoreDistrN, valueTargetsNCHW, metadataInputNC if include_meta else None] + + #read the first file + future = executor.submit(load_npz_file, npz_files[0]) + npz_files.append("") + npz_files=npz_files[1:] + for npz_file in npz_files: + binaryInputNCHW, globalInputNC, policyTargetsNCMove, globalTargetsNC, scoreDistrN, valueTargetsNCHW, metadataInputNC = future.result() + + if npz_file != "": + future = executor.submit(load_npz_file, npz_file) + + num_samples = binaryInputNCHW.shape[0] # Just discard stuff that doesn't divide evenly num_whole_steps = num_samples // (batch_size * world_size)