diff --git a/tests/sweep_framework/sweeps/normalization/batch_norm/batch_norm.py b/tests/sweep_framework/sweeps/normalization/batch_norm/batch_norm.py index 4483c385fff..3079d177577 100644 --- a/tests/sweep_framework/sweeps/normalization/batch_norm/batch_norm.py +++ b/tests/sweep_framework/sweeps/normalization/batch_norm/batch_norm.py @@ -7,7 +7,9 @@ import torch import random +import itertools import ttnn +import pytest from tests.sweep_framework.sweep_utils.utils import gen_shapes from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt @@ -37,7 +39,7 @@ } -def run( +def run_batch_norm( input_shape, input_dtype, input_layout, @@ -49,9 +51,8 @@ def run( bias, eps, momentum, - *, device, -) -> list: +): data_seed = random.randint(0, 20000000) torch.manual_seed(data_seed) @@ -110,14 +111,6 @@ def run( output_tensor = ttnn.to_torch(result) e2e_perf = stop_measuring_time(start_time) - tt_updated_mean = None - tt_updated_var = None - if training: - if check_mean: - tt_updated_mean = ttnn.to_torch(mean_tensor) - if check_var: - tt_updated_var = ttnn.to_torch(var_tensor) - torch_result = torch.nn.functional.batch_norm( input=in_data, running_mean=mean_data, @@ -129,4 +122,97 @@ def run( momentum=momentum, ) - return [check_with_pcc(torch_result, output_tensor, 0.99), e2e_perf] + passed = [] + output_string = "" + passed_, output_string_ = check_with_pcc(torch_result, output_tensor, 0.99) + passed.append(passed_) + output_string += output_string_ + ", " + + if training: + channels = input_shape[1] + if check_mean: + tt_updated_mean = ttnn.to_torch(mean_tensor) + passed_, output_string_ = check_with_pcc(tt_updated_mean, mean_data.view(1, channels, 1, 1), 0.99) + passed.append(passed_) + output_string += output_string_ + ", " + if check_var: + tt_updated_var = ttnn.to_torch(var_tensor) + passed_, output_string_ = check_with_pcc(tt_updated_var, var_data.view(1, channels, 1, 1), 0.99) + passed.append(passed_) + output_string += output_string_ + ", " + + if all(passed): + passed = True + else: + passed = False + + output_string = output_string[:-2] + e2e_perf = stop_measuring_time(start_time) + + return [(passed, output_string), e2e_perf] + + +def run( + input_shape, + input_dtype, + input_layout, + input_memory_config, + training, + check_mean, + check_var, + weight, + bias, + eps, + momentum, + *, + device, +) -> list: + return run_batch_norm( + input_shape, + input_dtype, + input_layout, + input_memory_config, + training, + check_mean, + check_var, + weight, + bias, + eps, + momentum, + device, + ) + + +param_keys = parameters["BN_Testing"].keys() +param_values = itertools.product(*parameters["BN_Testing"].values()) + + +@pytest.mark.parametrize(",".join(param_keys), list(param_values)) +def test_batch_norm( + input_shape, + input_dtype, + input_layout, + input_memory_config, + training, + check_mean, + check_var, + weight, + bias, + eps, + momentum, + device, +): + run_batch_norm( + input_shape, + input_dtype, + input_layout, + input_memory_config, + training, + check_mean, + check_var, + weight, + bias, + eps, + momentum, + device, + )