Skip to content

Commit

Permalink
#17863: Add test function to run locally
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Mar 9, 2025
1 parent de5691a commit 025fdaa
Showing 1 changed file with 98 additions and 12 deletions.
110 changes: 98 additions & 12 deletions tests/sweep_framework/sweeps/normalization/batch_norm/batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

import torch
import random
import itertools
import ttnn
import pytest
from tests.sweep_framework.sweep_utils.utils import gen_shapes
from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt

Expand Down Expand Up @@ -37,7 +39,7 @@
}


def run(
def run_batch_norm(
input_shape,
input_dtype,
input_layout,
Expand All @@ -49,9 +51,8 @@ def run(
bias,
eps,
momentum,
*,
device,
) -> list:
):
data_seed = random.randint(0, 20000000)
torch.manual_seed(data_seed)

Expand Down Expand Up @@ -110,14 +111,6 @@ def run(
output_tensor = ttnn.to_torch(result)
e2e_perf = stop_measuring_time(start_time)

tt_updated_mean = None
tt_updated_var = None
if training:
if check_mean:
tt_updated_mean = ttnn.to_torch(mean_tensor)
if check_var:
tt_updated_var = ttnn.to_torch(var_tensor)

torch_result = torch.nn.functional.batch_norm(
input=in_data,
running_mean=mean_data,
Expand All @@ -129,4 +122,97 @@ def run(
momentum=momentum,
)

return [check_with_pcc(torch_result, output_tensor, 0.99), e2e_perf]
passed = []
output_string = ""
passed_, output_string_ = check_with_pcc(torch_result, output_tensor, 0.99)
passed.append(passed_)
output_string += output_string_ + ", "

if training:
channels = input_shape[1]
if check_mean:
tt_updated_mean = ttnn.to_torch(mean_tensor)
passed_, output_string_ = check_with_pcc(tt_updated_mean, mean_data.view(1, channels, 1, 1), 0.99)
passed.append(passed_)
output_string += output_string_ + ", "
if check_var:
tt_updated_var = ttnn.to_torch(var_tensor)
passed_, output_string_ = check_with_pcc(tt_updated_var, var_data.view(1, channels, 1, 1), 0.99)
passed.append(passed_)
output_string += output_string_ + ", "

if all(passed):
passed = True
else:
passed = False

output_string = output_string[:-2]
e2e_perf = stop_measuring_time(start_time)

return [(passed, output_string), e2e_perf]


def run(
input_shape,
input_dtype,
input_layout,
input_memory_config,
training,
check_mean,
check_var,
weight,
bias,
eps,
momentum,
*,
device,
) -> list:
return run_batch_norm(
input_shape,
input_dtype,
input_layout,
input_memory_config,
training,
check_mean,
check_var,
weight,
bias,
eps,
momentum,
device,
)


param_keys = parameters["BN_Testing"].keys()
param_values = itertools.product(*parameters["BN_Testing"].values())


@pytest.mark.parametrize(",".join(param_keys), list(param_values))
def test_batch_norm(
input_shape,
input_dtype,
input_layout,
input_memory_config,
training,
check_mean,
check_var,
weight,
bias,
eps,
momentum,
device,
):
run_batch_norm(
input_shape,
input_dtype,
input_layout,
input_memory_config,
training,
check_mean,
check_var,
weight,
bias,
eps,
momentum,
device,
)

0 comments on commit 025fdaa

Please sign in to comment.