Skip to content

Commit

Permalink
Support for superset integration (#174)
Browse files Browse the repository at this point in the history
- Recorded relevant properties of op and model tests
- Added compile/runtime fail to xfail/skip messages of model inference
tests
- Split test_compare into multiple functions
- Added lambda function in each parametrization to generate more
informative test descriptions (instead of `some_test[xshape0-yshape0]`
it will report `some_test[(32, 32)-(32, 32)]`).
  • Loading branch information
kmitrovicTT authored Feb 3, 2025
1 parent 4fd0d29 commit 5175d7f
Show file tree
Hide file tree
Showing 53 changed files with 918 additions and 112 deletions.
Empty file added tests/__init__.py
Empty file.
102 changes: 102 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from datetime import datetime
from enum import Enum
from typing import Callable

import pytest


class RecordProperties(Enum):
"""Properties we can record."""

# Timestamp of test start.
START_TIMESTAMP = "start_timestamp"
# Timestamp of test end.
END_TIMESTAMP = "end_timestamp"
# Frontend or framework used to run the test.
FRONTEND = "frontend"
# Kind of operation. e.g. eltwise.
OP_KIND = "op_kind"
# Name of the operation in the framework. e.g. torch.conv2d.
FRAMEWORK_OP_NAME = "framework_op_name"
# Name of the operation. e.g. ttir.conv2d.
OP_NAME = "op_name"
# Name of the model in which this op appears.
MODEL_NAME = "model_name"


@pytest.fixture(scope="function", autouse=True)
def record_test_timestamp(record_property: Callable):
"""
Autouse fixture used to capture execution time of a test.
Parameters:
----------
record_property: Callable
A pytest built-in function used to record test metadata, such as custom
properties or additional information about the test execution.
Yields:
-------
Callable
The `record_property` callable, allowing tests to add additional properties if
needed.
Example:
--------
```
def test_model(fixture1, fixture2, ..., record_tt_xla_property):
record_tt_xla_property("key", value)
# Test logic...
```
"""
start_timestamp = datetime.strftime(datetime.now(), "%Y-%m-%dT%H:%M:%S%z")
record_property(RecordProperties.START_TIMESTAMP.value, start_timestamp)

# Run the test.
yield

end_timestamp = datetime.strftime(datetime.now(), "%Y-%m-%dT%H:%M:%S%z")
record_property(RecordProperties.END_TIMESTAMP.value, end_timestamp)


@pytest.fixture(scope="function", autouse=True)
def record_tt_xla_property(record_property: Callable):
"""
Autouse fixture that automatically records some test properties for each test
function.
It also yields back callable which can be explicitly used in tests to record
additional properties.
Example:
```
def test_model(fixture1, fixture2, ..., record_tt_xla_property):
record_tt_xla_property("key", value)
# Test logic...
```
Parameters:
----------
record_property: Callable
A pytest built-in function used to record test metadata, such as custom
properties or additional information about the test execution.
Yields:
-------
Callable
The `record_property` callable, allowing tests to add additional properties if
needed.
"""
# Record default properties for tt-xla.
record_property(RecordProperties.FRONTEND.value, "tt-xla")

# Run the test.
yield record_property
3 changes: 0 additions & 3 deletions tests/jax/graphs/test_simple_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
@pytest.mark.parametrize(
["weights", "bias", "X", "y"], [[(1, 2), (1, 1), (2, 1), (1, 1)]]
)
@pytest.mark.xfail(
reason="Atol comparison failed. Calculated: atol=0.850662112236023. Required: atol=0.16"
)
def test_simple_regression(weights, bias, X, y):
def simple_regression(weights, bias, X, y):
def loss(weights, bias, X, y):
Expand Down
17 changes: 16 additions & 1 deletion tests/jax/models/albert/v2/base/test_albert_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Callable

import pytest
from infra import RunMode
from utils import record_model_test_properties, runtime_fail

from ..tester import AlbertV2Tester

MODEL_PATH = "albert/albert-base-v2"
MODEL_NAME = "albert-v2-base"


# ----- Fixtures -----
Expand All @@ -27,16 +31,27 @@ def training_tester() -> AlbertV2Tester:


@pytest.mark.xfail(
reason="Cannot get the device from a tensor with host storage (https://github.com/tenstorrent/tt-xla/issues/171)"
reason=(
runtime_fail(
"Cannot get the device from a tensor with host storage "
"(https://github.com/tenstorrent/tt-xla/issues/171)"
)
)
)
def test_flax_albert_v2_base_inference(
inference_tester: AlbertV2Tester,
record_tt_xla_property: Callable,
):
record_model_test_properties(record_tt_xla_property, MODEL_NAME)

inference_tester.test()


@pytest.mark.skip(reason="Support for training not implemented")
def test_flax_albert_v2_base_training(
training_tester: AlbertV2Tester,
record_tt_xla_property: Callable,
):
record_model_test_properties(record_tt_xla_property, MODEL_NAME)

training_tester.test()
17 changes: 16 additions & 1 deletion tests/jax/models/albert/v2/large/test_albert_large.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Callable

import pytest
from infra import RunMode
from utils import record_model_test_properties, runtime_fail

from ..tester import AlbertV2Tester

MODEL_PATH = "albert/albert-large-v2"
MODEL_NAME = "albert-v2-large"


# ----- Fixtures -----
Expand All @@ -27,16 +31,27 @@ def training_tester() -> AlbertV2Tester:


@pytest.mark.xfail(
reason="Cannot get the device from a tensor with host storage (https://github.com/tenstorrent/tt-xla/issues/171)"
reason=(
runtime_fail(
"Cannot get the device from a tensor with host storage "
"(https://github.com/tenstorrent/tt-xla/issues/171)"
)
)
)
def test_flax_albert_v2_large_inference(
inference_tester: AlbertV2Tester,
record_tt_xla_property: Callable,
):
record_model_test_properties(record_tt_xla_property, MODEL_NAME)

inference_tester.test()


@pytest.mark.skip(reason="Support for training not implemented")
def test_flax_albert_v2_large_training(
training_tester: AlbertV2Tester,
record_tt_xla_property: Callable,
):
record_model_test_properties(record_tt_xla_property, MODEL_NAME)

training_tester.test()
17 changes: 16 additions & 1 deletion tests/jax/models/albert/v2/xlarge/test_albert_xlarge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Callable

import pytest
from infra import RunMode
from utils import record_model_test_properties, runtime_fail

from ..tester import AlbertV2Tester

MODEL_PATH = "albert/albert-xlarge-v2"
MODEL_NAME = "albert-v2-xlarge"


# ----- Fixtures -----
Expand All @@ -27,16 +31,27 @@ def training_tester() -> AlbertV2Tester:


@pytest.mark.xfail(
reason="Cannot get the device from a tensor with host storage (https://github.com/tenstorrent/tt-xla/issues/171)"
reason=(
runtime_fail(
"Cannot get the device from a tensor with host storage "
"(https://github.com/tenstorrent/tt-xla/issues/171)"
)
)
)
def test_flax_albert_v2_xlarge_inference(
inference_tester: AlbertV2Tester,
record_tt_xla_property: Callable,
):
record_model_test_properties(record_tt_xla_property, MODEL_NAME)

inference_tester.test()


@pytest.mark.skip(reason="Support for training not implemented")
def test_flax_albert_v2_xlarge_training(
training_tester: AlbertV2Tester,
record_tt_xla_property: Callable,
):
record_model_test_properties(record_tt_xla_property, MODEL_NAME)

training_tester.test()
17 changes: 16 additions & 1 deletion tests/jax/models/albert/v2/xxlarge/test_albert_xxlarge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Callable

import pytest
from infra import RunMode
from utils import record_model_test_properties, runtime_fail

from ..tester import AlbertV2Tester

MODEL_PATH = "albert/albert-xxlarge-v2"
MODEL_NAME = "albert-v2-xxlarge"


# ----- Fixtures -----
Expand All @@ -27,16 +31,27 @@ def training_tester() -> AlbertV2Tester:


@pytest.mark.xfail(
reason="Cannot get the device from a tensor with host storage (https://github.com/tenstorrent/tt-xla/issues/171)"
reason=(
runtime_fail(
"Cannot get the device from a tensor with host storage "
"(https://github.com/tenstorrent/tt-xla/issues/171)"
)
)
)
def test_flax_albert_v2_xxlarge_inference(
inference_tester: AlbertV2Tester,
record_tt_xla_property: Callable,
):
record_model_test_properties(record_tt_xla_property, MODEL_NAME)

inference_tester.test()


@pytest.mark.skip(reason="Support for training not implemented")
def test_flax_albert_v2_xxlarge_training(
training_tester: AlbertV2Tester,
record_tt_xla_property: Callable,
):
record_model_test_properties(record_tt_xla_property, MODEL_NAME)

training_tester.test()
17 changes: 16 additions & 1 deletion tests/jax/models/bart/base/test_bart_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Callable

import pytest
from infra import RunMode
from utils import record_model_test_properties, runtime_fail

from ..tester import FlaxBartForCausalLMTester

MODEL_PATH = "facebook/bart-base"
MODEL_NAME = "bart-base"


# ----- Fixtures -----
Expand All @@ -27,16 +31,27 @@ def training_tester() -> FlaxBartForCausalLMTester:


@pytest.mark.xfail(
reason="Cannot get the device from a tensor with host storage (https://github.com/tenstorrent/tt-xla/issues/171)"
reason=(
runtime_fail(
"Cannot get the device from a tensor with host storage "
"(https://github.com/tenstorrent/tt-xla/issues/171)"
)
)
)
def test_flax_bart_base_inference(
inference_tester: FlaxBartForCausalLMTester,
record_tt_xla_property: Callable,
):
record_model_test_properties(record_tt_xla_property, MODEL_NAME)

inference_tester.test()


@pytest.mark.skip(reason="Support for training not implemented")
def test_flax_bart_base_training(
training_tester: FlaxBartForCausalLMTester,
record_tt_xla_property: Callable,
):
record_model_test_properties(record_tt_xla_property, MODEL_NAME)

training_tester.test()
14 changes: 13 additions & 1 deletion tests/jax/models/bart/large/test_bart_large.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Callable

import pytest
from infra import RunMode
from utils import compile_fail, record_model_test_properties

from ..tester import FlaxBartForCausalLMTester

MODEL_PATH = "facebook/bart-large"
MODEL_NAME = "bart-large"


# ----- Fixtures -----
Expand All @@ -27,16 +31,24 @@ def training_tester() -> FlaxBartForCausalLMTester:


@pytest.mark.xfail(
reason="Unsupported data type (https://github.com/tenstorrent/tt-xla/issues/214)"
reason=compile_fail(
"Unsupported data type (https://github.com/tenstorrent/tt-xla/issues/214)"
)
)
def test_flax_bart_large_inference(
inference_tester: FlaxBartForCausalLMTester,
record_tt_xla_property: Callable,
):
record_model_test_properties(record_tt_xla_property, MODEL_NAME)

inference_tester.test()


@pytest.mark.skip(reason="Support for training not implemented")
def test_flax_bart_large_training(
training_tester: FlaxBartForCausalLMTester,
record_tt_xla_property: Callable,
):
record_model_test_properties(record_tt_xla_property, MODEL_NAME)

training_tester.test()
Loading

0 comments on commit 5175d7f

Please sign in to comment.