Skip to content

Commit 77989cf

Browse files
authored
Update MIGraphX resnet50 example (#440)
* Add FP32 mode to resnet50 migraphx example * Fix JSON dump error in migraphx resnet50 example * Fix fail in migraphx resnet50 example with RGBA and CMYK images Also fix issue where returned batch size differs from the expected
1 parent a69f52c commit 77989cf

File tree

1 file changed

+30
-17
lines changed

1 file changed

+30
-17
lines changed

quantization/image_classification/migraphx/resnet50/e2e_migraphx_resnet_example.py

+30-17
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@ def parse_input_args():
2121
help='Perform fp16 quantizaton in addition to int8',
2222
)
2323

24+
parser.add_argument(
25+
"--fp32",
26+
action="store_true",
27+
required=False,
28+
default=False,
29+
help='Perform no quantization',
30+
)
31+
2432
parser.add_argument(
2533
"--image_dir",
2634
required=False,
@@ -56,10 +64,10 @@ def __init__(self,
5664
'''
5765
:param image_folder: image dataset folder
5866
:param width: image width
59-
:param height: image height
67+
:param height: image height
6068
:param start_index: start index of images
6169
:param end_index: end index of images
62-
:param stride: image size of each data get
70+
:param stride: image size of each data get
6371
:param batch_size: batch size of inference
6472
:param model_path: model name and path
6573
:param input_name: model input name
@@ -153,12 +161,14 @@ def preprocess_imagenet(self, images_folder, height, width, start_index=0, size_
153161
parameter images_folder: path to folder storing images
154162
parameter height: image height in pixels
155163
parameter width: image width in pixels
156-
parameter start_index: image index to start with
164+
parameter start_index: image index to start with
157165
parameter size_limit: number of images to load. Default is 0 which means all images are picked.
158166
return: list of matrices characterizing multiple images
159167
'''
160168
def preprocess_images(input, channels=3, height=224, width=224):
161169
image = input.resize((width, height), Image.Resampling.LANCZOS)
170+
if image.mode in ["CMYK", "RGBA"]:
171+
image = image.convert("RGB")
162172
input_data = np.asarray(image).astype(np.float32)
163173
if len(input_data.shape) != 2:
164174
input_data = input_data.transpose([2, 0, 1])
@@ -249,7 +259,7 @@ def __init__(self,
249259
providers=["MIGraphXExecutionProvider"]):
250260
'''
251261
:param model_path: ONNX model to validate
252-
:param synset_id: ILSVRC2012 synset id
262+
:param synset_id: ILSVRC2012 synset id
253263
:param data_reader: user implemented object to read in and preprocess calibration dataset
254264
based on CalibrationDataReader Interface
255265
:param providers: ORT execution provider type
@@ -281,9 +291,8 @@ def predict(self):
281291
self.prediction_result_list = inference_outputs_list
282292

283293
def top_k_accuracy(self, truth, prediction, k=1):
284-
'''From https://github.com/chainer/chainer/issues/606
294+
'''From https://github.com/chainer/chainer/issues/606
285295
'''
286-
287296
y = np.argsort(prediction)[:, -k:]
288297
return np.any(y.T == truth.argmax(axis=1), axis=0).mean()
289298

@@ -293,7 +302,7 @@ def evaluate(self, prediction_results):
293302
y_prediction = np.empty((total_val_images, 1000), dtype=np.float32)
294303
i = 0
295304
for res in prediction_results:
296-
y_prediction[i:i + batch_size, :] = res[0]
305+
y_prediction[i:i + res[0].shape[0], :] = res[0]
297306
i = i + batch_size
298307
print("top 1: ", self.top_k_accuracy(self.synset_id, y_prediction, k=1))
299308
print("top 5: ", self.top_k_accuracy(self.synset_id, y_prediction, k=5))
@@ -344,8 +353,9 @@ def get_dataset_size(dataset_path, calibration_dataset_size):
344353
2. Download ILSVRC2012 validation dataset and development kit from http://www.image-net.org/challenges/LSVRC/2012/downloads.
345354
3. Extract validation dataset JPEG files to 'ILSVRC2012/val'.
346355
4. Extract development kit to 'ILSVRC2012/devkit'. Two files in the development kit are used, 'ILSVRC2012_validation_ground_truth.txt' and 'meta.mat'.
356+
These are also available to download at https://github.com/miraclewkf/MobileNetV2-PyTorch/tree/master/ImageNet/ILSVRC2012_devkit_t12/data
347357
5. Download 'synset_words.txt' from https://github.com/HoldenCaulfieldRye/caffe/blob/master/data/ilsvrc12/synset_words.txt into 'ILSVRC2012/'.
348-
358+
349359
Please download Resnet50 model from ONNX model zoo https://github.com/onnx/models/blob/master/vision/classification/resnet/model/resnet50-v2-7.tar.gz
350360
Untar the model into the workspace
351361
'''
@@ -356,15 +366,18 @@ def get_dataset_size(dataset_path, calibration_dataset_size):
356366
ilsvrc2012_dataset_path = flags.image_dir
357367
augmented_model_path = "./augmented_model.onnx"
358368
batch_size = flags.batch
359-
calibration_dataset_size = flags.cal_size # Size of dataset for calibration
369+
calibration_dataset_size = 0 if flags.fp32 else flags.cal_size # Size of dataset for calibration
370+
371+
calibration_table_generation_enable = False
372+
if not flags.fp32:
373+
# INT8 calibration setting
374+
calibration_table_generation_enable = True # Enable/Disable INT8 calibration
360375

361-
# INT8 calibration setting
362-
calibration_table_generation_enable = True # Enable/Disable INT8 calibration
376+
# MIGraphX EP INT8 settings
377+
os.environ["ORT_MIGRAPHX_INT8_ENABLE"] = "1" # Enable INT8 precision
378+
os.environ["ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"] = "calibration.flatbuffers" # Calibration table name
379+
os.environ["ORT_MIGRAPHX_INT8_NATIVE_CALIBRATION_TABLE"] = "0" # Calibration table name
363380

364-
# MIGraphX EP INT8 settings
365-
os.environ["ORT_MIGRAPHX_INT8_ENABLE"] = "1" # Enable INT8 precision
366-
os.environ["ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"] = "calibration.flatbuffers" # Calibration table name
367-
os.environ["ORT_MIGRAPHX_INT8_NATIVE_CALIBRATION_TABLE"] = "0" # Calibration table name
368381
execution_provider = ["MIGraphXExecutionProvider"]
369382

370383
# Convert static batch to dynamic batch
@@ -378,7 +391,7 @@ def get_dataset_size(dataset_path, calibration_dataset_size):
378391
if calibration_table_generation_enable:
379392
print("Generating Calibration Table")
380393
calibrator = create_calibrator(new_model_path, [], augmented_model_path=augmented_model_path)
381-
calibrator.set_execution_providers(["ROCMExecutionProvider"])
394+
calibrator.set_execution_providers(["ROCMExecutionProvider"])
382395
data_reader = ImageNetDataReader(ilsvrc2012_dataset_path,
383396
start_index=0,
384397
end_index=calibration_dataset_size,
@@ -391,7 +404,7 @@ def get_dataset_size(dataset_path, calibration_dataset_size):
391404

392405
serial_cal_tensors = {}
393406
for keys, values in cal_tensors.data.items():
394-
serial_cal_tensors[keys] = values.range_value
407+
serial_cal_tensors[keys] = [float(x[0]) for x in values.range_value]
395408

396409
print("Writing calibration table")
397410
write_calibration_table(serial_cal_tensors)

0 commit comments

Comments
 (0)