Skip to content

Commit

Permalink
tt-torch model test parser (#8)
Browse files Browse the repository at this point in the history
* tt-torch model test parser

* Formatting

* Add unit test

* continue-on-error: true
  • Loading branch information
nsmithtt authored Dec 4, 2024
1 parent f18e9a2 commit 1285bee
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

import tarfile
import os
import json
from loguru import logger
from datetime import datetime
from pydantic_models import Test
from .parser import Parser
from enum import IntEnum


class OpCompilationStatus(IntEnum):
NOT_STARTED = 0
CREATED_GRAPH = 1
CONVERTED_TO_TORCH_IR = 2
CONVERTED_TO_TORCH_BACKEND_IR = 3
CONVERTED_TO_STABLE_HLO = 4
CONVERTED_TO_TTIR = 5
CONVERTED_TO_TTNN = 6
EXECUTED = 7


class TTTorchModelTestsParser(Parser):
"""Parser for python unitest report files."""

def can_parse(self, filepath: str):
basename = os.path.basename(filepath)
return basename.startswith("run") and basename.endswith(".tar")

def parse(self, filepath: str):
return get_tests(filepath)


def untar(filepath):
basename = os.path.basename(filepath)
path = f"/tmp/{basename}"
with tarfile.open(filepath, "r") as fd:
fd.extractall(path=path)
return path


def all_json_files(filepath):
for root, dirs, files in os.walk(filepath):
for file in files:
if file.endswith(".json") and not file.startswith("."):
yield os.path.join(root, file)


def get_tests_from_json(filepath):
with open(filepath, "r") as fd:
data = json.load(fd)

for name, test in data.items():
yield get_pydantic_test(filepath, name, test)


def get_pydantic_test(filepath, name, test, default_timestamp=datetime.now()):
status = OpCompilationStatus(test["compilation_status"])

skipped = False
failed = status < OpCompilationStatus.EXECUTED
error = False
success = not (failed or error)
error_message = str(status).split(".")[1]

properties = {}

test_start_ts = default_timestamp
test_end_ts = default_timestamp

model_name = os.path.basename(filepath).split(".")[0]

# leaving empty for now
group = None
owner = None

full_test_name = f"{filepath}::{name}"

# to be populated with [] if available
config = None

tags = None

return Test(
test_start_ts=test_start_ts,
test_end_ts=test_end_ts,
test_case_name=name,
filepath=filepath,
category="models",
group=None,
owner=None,
frontend="tt-torch",
model_name=model_name,
op_name=None,
framework_op_name=test["torch_name"],
op_kind=None,
error_message=error_message,
success=success,
skipped=skipped,
full_test_name=full_test_name,
config=config,
tags=tags,
)


def flatten(list_of_lists):
return [item for sublist in list_of_lists for item in sublist]


def get_tests(filepath):
untar_path = untar(filepath)
tests = map(get_tests_from_json, all_json_files(untar_path))
return flatten(tests)
17 changes: 14 additions & 3 deletions .github/actions/collect_data/src/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@

from parsers.python_unittest_parser import PythonUnittestParser
from parsers.python_pytest_parser import PythonPytestParser
from parsers.tt_torch_model_tests_parser import TTTorchModelTestsParser

parsers = [
PythonPytestParser(),
PythonUnittestParser(),
TTTorchModelTestsParser(),
]


Expand All @@ -27,8 +29,17 @@ def parse_file(filepath: str) -> List[Test]:
try:
return parser.parse(filepath)
except Exception as e:
logger.error(
f"Error parsing file: {filepath} using parser: {type(parser).__name__}, trying next parser."
)
logger.error(f"Error parsing file: {filepath} using parser: {type(parser).__name__}")
logger.error(f"Exception: {e}")
logger.error("Trying next parser")
logger.error(f"No parser available for file: {filepath}")
return []


if __name__ == "__main__":
import sys

if len(sys.argv) != 2:
print("Usage: python test_parser.py <file>")
sys.exit(1)
print(parse_file(sys.argv[1]))
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

import pytest
from parsers.tt_torch_model_tests_parser import TTTorchModelTestsParser


@pytest.mark.parametrize(
"tar, expected",
[
("run2.tar", {"tests_cnt": 32}),
],
)
def test_tt_torch_model_tests_parser(tar, expected):
filepath = f"test/data/tt_torch_models/{tar}"
parser = TTTorchModelTestsParser()
assert parser.can_parse(filepath)
tests = parser.parse(filepath)
assert len(tests) == expected["tests_cnt"]
1 change: 1 addition & 0 deletions .github/workflows/test_collect_data_action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@ jobs:
with:
pytest-coverage-path: .github/actions/collect_data/pytest-coverage.txt
junitxml-path: .github/actions/collect_data/pytest.xml
continue-on-error: true

0 comments on commit 1285bee

Please sign in to comment.