Skip to content

Commit

Permalink
Move autogen directory under tests/autogen_accuracy_tests and add pri…
Browse files Browse the repository at this point in the history
…ntout of generated items
  • Loading branch information
kevinwuTT committed Jan 6, 2025
1 parent a0941b6 commit f10cf5f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/before_merge.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,12 @@ jobs:
with:
pattern: model-accuracy-tests-group-*
merge-multiple: true
path: accuracy_tests/
path: tests/autogen_accuracy_tests/

- name: Run Accuracy Tests
run: |
set +e
cd accuracy_tests
cd tests/autogen_accuracy_tests
find . -type f -name "*.py" -exec python {} ";"
exit 0;
shell: bash
Expand Down
11 changes: 8 additions & 3 deletions torch_ttnn/generate_op_accuracy_tests.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import logging
import lzma
import pickle
import torch.utils._pytree as pytree
Expand Down Expand Up @@ -379,7 +380,7 @@ def test_accuracy(expected, actual):
"""

# main function definition
directory = Path("accuracy_tests")
directory = Path("tests/autogen_accuracy_tests")
input_pkl_file = Path(f"{model_name}_inputs.pickle")
full_input_pkl_path = directory / input_pkl_file
full_input_pkl_path.parent.mkdir(parents=True, exist_ok=True)
Expand All @@ -404,11 +405,15 @@ def test_accuracy(expected, actual):
)
full_text = "\n".join(full_code)

with open(directory / Path(f"{model_name}_code.py"), "w") as text_file:
code_full_path = directory / Path(f"{model_name}_code.py")
with open(code_full_path, "w") as text_file:
print(full_text, file=text_file)
logging.info(f"Accuracy test code saved to {code_full_path}.")

with lzma.open(directory / Path(f"{model_name}_inputs.pickle"), "wb") as f:
data_full_path = directory / Path(f"{model_name}_inputs.pickle")
with lzma.open(data_full_path, "wb") as f:
pickle.dump(all_inputs, f)
logging.info(f"Accuracy data object saved to {data_full_path}.")


def generate_op_accuracy_tests(model_name, aten_fx_graphs, ttnn_fx_graphs, all_inputs, *, verbose=False):
Expand Down

0 comments on commit f10cf5f

Please sign in to comment.