diff --git a/preprocess_imagenet_using_opencv/code_axs.py b/preprocess_imagenet_using_opencv/code_axs.py index 6b24eeb..0abdcfa 100755 --- a/preprocess_imagenet_using_opencv/code_axs.py +++ b/preprocess_imagenet_using_opencv/code_axs.py @@ -6,9 +6,22 @@ import sys import shutil -def generate_file_list(supported_extensions, calibration_dir=None, index_file=None, first_n=None, first_n_insert=None, images_directory=None,): - original_file_list = os.listdir(images_directory) - sorted_filenames = [filename for filename in sorted(original_file_list) if any(filename.lower().endswith(extension) for extension in supported_extensions) ] +def get_files_from_subdirs(directory, supported_extensions): + all_files = [] + for entry in os.listdir(directory): + entry_path = os.path.join(directory, entry) + if os.path.isdir(entry_path): + for filename in os.listdir(entry_path): + if any(filename.lower().endswith(ext) for ext in supported_extensions): + all_files.append(os.path.join(entry, filename)) + return sorted(all_files) + +def generate_file_list(supported_extensions, flat_dataset_structure, calibration_dir=None, index_file=None, first_n=None, first_n_insert=None, images_directory=None,): + if flat_dataset_structure: + original_file_list = os.listdir(images_directory) + sorted_filenames = [filename for filename in sorted(original_file_list) if any(filename.lower().endswith(extension) for extension in supported_extensions) ] + else: + sorted_filenames = get_files_from_subdirs(images_directory, supported_extensions) if index_file: index_file = os.path.join(calibration_dir, index_file) @@ -124,6 +137,7 @@ def preprocess_files(selected_filenames, images_directory, destination_dir, crop elif normalayout: image_data = norma_layout(image_data, data_type, data_layout, subtract_mean, given_channel_means, normalize_symmetric) + input_filename = os.path.basename(input_filename) output_filename = input_filename.rsplit('.', 1)[0] + '.' + new_file_extension if new_file_extension else input_filename output_filename_calib = input_filename.rsplit('.', 1)[0] + '.' + new_file_extension + '.raw' diff --git a/preprocess_imagenet_using_opencv/data_axs.json b/preprocess_imagenet_using_opencv/data_axs.json index 49aa7a1..28ac334 100644 --- a/preprocess_imagenet_using_opencv/data_axs.json +++ b/preprocess_imagenet_using_opencv/data_axs.json @@ -95,5 +95,6 @@ "convert_to_unsigned": true, "interpolation_method": "INTER_AREA", "dataset_name": "imagenet", + "flat_dataset_structure": true, "file_name": "preprocessed" }