From 82bc7455fcbcdde45f2f6e055bc7eba5ad4a5040 Mon Sep 17 00:00:00 2001 From: Kristijan Mitrovic Date: Wed, 5 Mar 2025 12:05:46 +0000 Subject: [PATCH] Implemented a new way of recording properties that works even if test is skipped --- .github/workflows/build-and-test.yml | 2 +- tests/conftest.py | 142 +++++++----------- tests/jax/graphs/test_MLP_regression.py | 1 + tests/jax/graphs/test_activation_functions.py | 1 + tests/jax/graphs/test_example_graph.py | 1 + .../jax/graphs/test_linear_transformation.py | 1 + tests/jax/graphs/test_simple_gradient.py | 1 + tests/jax/graphs/test_simple_regression.py | 1 + tests/jax/graphs/test_softmax.py | 1 + .../models/albert/v2/base/test_albert_base.py | 28 ++-- .../albert/v2/large/test_albert_large.py | 28 ++-- .../albert/v2/xlarge/test_albert_xlarge.py | 28 ++-- .../albert/v2/xxlarge/test_albert_xxlarge.py | 27 ++-- tests/jax/models/bart/base/test_bart_base.py | 28 ++-- .../jax/models/bart/large/test_bart_large.py | 28 ++-- tests/jax/models/beit/base/test_beit_base.py | 25 ++- .../jax/models/beit/large/test_beit_large.py | 20 +-- tests/jax/models/bert/base/test_bert_base.py | 28 ++-- .../jax/models/bert/large/test_bert_large.py | 27 ++-- tests/jax/models/bloom/bloom_1b1/test_1b1.py | 29 ++-- tests/jax/models/bloom/bloom_1b7/test_1b7.py | 28 ++-- tests/jax/models/bloom/bloom_3b/test_3b.py | 28 ++-- .../jax/models/bloom/bloom_560m/test_560m.py | 28 ++-- tests/jax/models/bloom/bloom_7b/test_7b.py | 28 ++-- .../base_patch16/test_clip_base_patch16.py | 29 ++-- .../base_patch32/test_clip_base_patch32.py | 29 ++-- .../large_patch14/test_clip_large_patch14.py | 29 ++-- .../test_clip_large_patch14_336.py | 29 ++-- .../jax/models/distilbert/test_distilbert.py | 28 ++-- ...est_example_model_mixed_args_and_kwargs.py | 10 ++ .../only_args/test_example_model_only_args.py | 10 ++ .../test_example_model_only_kwargs.py | 10 ++ tests/jax/models/gpt2/base/test_gpt2_base.py | 28 ++-- .../jax/models/gpt2/large/test_gpt2_large.py | 28 ++-- .../models/gpt2/medium/test_gpt2_medium.py | 28 ++-- tests/jax/models/gpt2/xl/test_gpt2_xl.py | 27 ++-- .../gpt_neo/gpt_neo_125m/test_gpt_neo_125m.py | 27 ++-- .../gpt_neo/gpt_neo_1_3b/test_gpt_neo_1_3b.py | 28 ++-- .../gpt_neo/gpt_neo_2_7b/test_gpt_neo_2_7b.py | 27 ++-- .../openllama_3b_v2/test_openllama_3b_v2.py | 27 ++-- tests/jax/models/mlpmixer/test_mlpmixer.py | 30 ++-- .../cnn/dropout/test_mnist_cnn_dropout.py | 29 ++-- .../cnn/nodropout/test_mnist_cnn_nodropout.py | 29 ++-- tests/jax/models/mnist/mlp/test_mnist_mlp.py | 29 ++-- tests/jax/models/opt/opt_125m/test_125m.py | 28 ++-- tests/jax/models/opt/opt_1_3b/test_1_3b.py | 28 ++-- tests/jax/models/opt/opt_2_7b/test_2_7b.py | 28 ++-- tests/jax/models/opt/opt_350m/test_350m.py | 28 ++-- tests/jax/models/opt/opt_6_7b/test_6_7b.py | 28 ++-- .../models/roberta/base/test_roberta_base.py | 28 ++-- .../roberta/large/test_roberta_large.py | 28 ++-- .../test_efficient_mlm_m0_40.py | 20 ++- .../models/squeezebert/test_squeezebert.py | 28 ++-- tests/jax/multichip/manual/all_gather.py | 3 +- tests/jax/multichip/manual/data_paralelism.py | 5 +- tests/jax/multichip/manual/psum.py | 3 +- tests/jax/multichip/manual/psum_scatter.py | 3 +- tests/jax/multichip/manual/unary_eltwise.py | 3 +- tests/jax/ops/test_abs.py | 16 +- tests/jax/ops/test_add.py | 16 +- tests/jax/ops/test_broadcast_in_dim.py | 16 +- tests/jax/ops/test_cbrt.py | 16 +- tests/jax/ops/test_compare.py | 89 +++++------ tests/jax/ops/test_concatenate.py | 18 +-- tests/jax/ops/test_constant.py | 46 +++--- tests/jax/ops/test_convert.py | 19 +-- tests/jax/ops/test_convolution.py | 29 ++-- tests/jax/ops/test_divide.py | 14 +- tests/jax/ops/test_dot_general.py | 42 +++--- tests/jax/ops/test_exponential.py | 16 +- tests/jax/ops/test_exponential_minus_one.py | 20 +-- tests/jax/ops/test_log_plus_one.py | 16 +- tests/jax/ops/test_maximum.py | 16 +- tests/jax/ops/test_minimum.py | 16 +- tests/jax/ops/test_multiply.py | 16 +- tests/jax/ops/test_negate.py | 16 +- tests/jax/ops/test_reduce.py | 38 ++--- tests/jax/ops/test_reduce_window.py | 16 +- tests/jax/ops/test_remainder.py | 16 +- tests/jax/ops/test_reshape.py | 16 +- tests/jax/ops/test_rsqrt.py | 16 +- tests/jax/ops/test_sign.py | 16 +- tests/jax/ops/test_slice.py | 16 +- tests/jax/ops/test_sqrt.py | 16 +- tests/jax/ops/test_subtract.py | 16 +- tests/jax/ops/test_transpose.py | 16 +- tests/jax/test_data_types.py | 1 + tests/jax/test_device_initialization.py | 3 + tests/jax/test_ranks.py | 2 + tests/jax/test_scalar_types.py | 2 + tests/utils.py | 30 ---- 91 files changed, 880 insertions(+), 1108 deletions(-) diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index dba34f29..b6730273 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -250,7 +250,7 @@ jobs: run: | export LD_LIBRARY_PATH="/opt/ttmlir-toolchain/lib/:${{ steps.strings.outputs.install-output-dir }}/lib:${LD_LIBRARY_PATH}" source venv/activate - pytest ./tests \ + pytest ./tests/jax \ -m "${{ inputs.test_mark }}" \ --junitxml=${{ steps.strings.outputs.test_report_path }} \ 2>&1 | tee pytest.log diff --git a/tests/conftest.py b/tests/conftest.py index affb3e03..dabdd4fd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,101 +2,67 @@ # # SPDX-License-Identifier: Apache-2.0 -from datetime import datetime -from enum import Enum -from typing import Callable - import pytest -class RecordProperties(Enum): - """Properties we can record.""" - - # Timestamp of test start. - START_TIMESTAMP = "start_timestamp" - # Timestamp of test end. - END_TIMESTAMP = "end_timestamp" - # Frontend or framework used to run the test. - FRONTEND = "frontend" - # Kind of operation. e.g. eltwise. - OP_KIND = "op_kind" - # Name of the operation in the framework. e.g. torch.conv2d. - FRAMEWORK_OP_NAME = "framework_op_name" - # Name of the operation. e.g. ttir.conv2d. - OP_NAME = "op_name" - # Name of the model in which this op appears. - MODEL_NAME = "model_name" - - -@pytest.fixture(scope="function", autouse=True) -def record_test_timestamp(record_property: Callable): +def pytest_configure(config: pytest.Config): """ - Autouse fixture used to capture execution time of a test. - - Parameters: - ---------- - record_property: Callable - A pytest built-in function used to record test metadata, such as custom - properties or additional information about the test execution. - - Yields: - ------- - Callable - The `record_property` callable, allowing tests to add additional properties if - needed. - - - Example: - -------- - ``` - def test_model(fixture1, fixture2, ..., record_tt_xla_property): - record_tt_xla_property("key", value) - - # Test logic... - ``` + Registers custom pytest marker `record_properties(key1=val1, key2=val2, ...)`. + + Allowed keys are ["test_category", "jax_op_name", "op_name", "model_name"]. + - `test_category`: one of ["op_test", "graph_test", "model_test", "multichip_test", "other"] + - `jax_op_name`: name of the operation in jax, e.g. `jax.numpy.exp` + - `shlo_op_name`: name of the matching stablehlo operation + - `model_name`: name of the model under test (if recorded from a model test, or op + under test comes from some model and we want to note that in the report) + - `run_mode`: one of ["inference", "training"]. Only exists for model tests. + + These are used to tag the function under test with properties which will be dumped + to the final XML test report. These reports get picked up by other CI workflows and + are used to display state of tests on a dashboard. """ - start_timestamp = datetime.strftime(datetime.now(), "%Y-%m-%dT%H:%M:%S%z") - record_property(RecordProperties.START_TIMESTAMP.value, start_timestamp) + config.addinivalue_line( + "markers", + "record_properties(key_value_pairs): Record custom properties for the test", + ) - # Run the test. - yield - end_timestamp = datetime.strftime(datetime.now(), "%Y-%m-%dT%H:%M:%S%z") - record_property(RecordProperties.END_TIMESTAMP.value, end_timestamp) - - -@pytest.fixture(scope="function", autouse=True) -def record_tt_xla_property(record_property: Callable): +def pytest_collection_modifyitems(items): """ - Autouse fixture that automatically records some test properties for each test - function. - - It also yields back callable which can be explicitly used in tests to record - additional properties. - - Example: - - ``` - def test_model(fixture1, fixture2, ..., record_tt_xla_property): - record_tt_xla_property("key", value) - - # Test logic... - ``` - - Parameters: - ---------- - record_property: Callable - A pytest built-in function used to record test metadata, such as custom - properties or additional information about the test execution. - - Yields: - ------- - Callable - The `record_property` callable, allowing tests to add additional properties if - needed. + Pytest hook to process the custom marker and attach recorder properties to the test. """ - # Record default properties for tt-xla. - record_property(RecordProperties.FRONTEND.value, "tt-xla") - # Run the test. - yield record_property + def validate_keys(keys): + valid_keys = [ + "test_category", + "jax_op_name", + "shlo_op_name", + "model_name", + "run_mode", + ] + + if not all(key in valid_keys for key in keys): + raise KeyError( + f"Invalid keys found in 'record_properties' marker: {', '.join(keys)}. " + f"Allowed keys are: {', '.join(valid_keys)}" + ) + + for item in items: + # Add some test metadata in a 'tags' dictionary. + tags = {"test_name": item.originalname, "specific_test_case": item.name} + + # Look for the custom marker. + properties_marker = item.get_closest_marker(name="record_properties") + + if properties_marker: + # Extract the key-value pairs passed to the marker. + properties: dict = properties_marker.kwargs + # Validate that only allowed keys are used. + validate_keys(properties.keys()) + + # Tag them. + for key, value in properties.items(): + tags[key] = value + + # Attach the tags dictionary as a single property. + item.user_properties.append(("tags", tags)) diff --git a/tests/jax/graphs/test_MLP_regression.py b/tests/jax/graphs/test_MLP_regression.py index 0a23054c..9d40d4f2 100644 --- a/tests/jax/graphs/test_MLP_regression.py +++ b/tests/jax/graphs/test_MLP_regression.py @@ -17,6 +17,7 @@ def comparison_config() -> ComparisonConfig: @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties(test_category="graph_test") @pytest.mark.parametrize( ["W1", "b1", "W2", "b2", "X", "y"], [ diff --git a/tests/jax/graphs/test_activation_functions.py b/tests/jax/graphs/test_activation_functions.py index 5b537b69..bf31dc3e 100644 --- a/tests/jax/graphs/test_activation_functions.py +++ b/tests/jax/graphs/test_activation_functions.py @@ -10,6 +10,7 @@ @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties(test_category="graph_test") @pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)]) def test_relu(x_shape: tuple): """Test ReLU activation function.""" diff --git a/tests/jax/graphs/test_example_graph.py b/tests/jax/graphs/test_example_graph.py index 7af1e8bb..f8e00c40 100644 --- a/tests/jax/graphs/test_example_graph.py +++ b/tests/jax/graphs/test_example_graph.py @@ -17,6 +17,7 @@ def example_graph(x: jax.Array, y: jax.Array) -> jax.Array: @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties(test_category="graph_test") @pytest.mark.parametrize( ["x_shape", "y_shape"], [ diff --git a/tests/jax/graphs/test_linear_transformation.py b/tests/jax/graphs/test_linear_transformation.py index 8208eaab..28399809 100644 --- a/tests/jax/graphs/test_linear_transformation.py +++ b/tests/jax/graphs/test_linear_transformation.py @@ -10,6 +10,7 @@ @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties(test_category="graph_test") @pytest.mark.parametrize( ["x_shape", "y_shape", "bias_shape"], [ diff --git a/tests/jax/graphs/test_simple_gradient.py b/tests/jax/graphs/test_simple_gradient.py index c529df4a..c1dc1446 100644 --- a/tests/jax/graphs/test_simple_gradient.py +++ b/tests/jax/graphs/test_simple_gradient.py @@ -9,6 +9,7 @@ @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties(test_category="graph_test") @pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)]) def test_simple_gradient(x_shape: tuple): def simple_gradient(x: jax.Array): diff --git a/tests/jax/graphs/test_simple_regression.py b/tests/jax/graphs/test_simple_regression.py index 57512c97..46966462 100644 --- a/tests/jax/graphs/test_simple_regression.py +++ b/tests/jax/graphs/test_simple_regression.py @@ -9,6 +9,7 @@ @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties(test_category="graph_test") @pytest.mark.parametrize( ["weights", "bias", "X", "y"], [[(1, 2), (1, 1), (2, 1), (1, 1)]] ) diff --git a/tests/jax/graphs/test_softmax.py b/tests/jax/graphs/test_softmax.py index a83f30a9..169b720e 100644 --- a/tests/jax/graphs/test_softmax.py +++ b/tests/jax/graphs/test_softmax.py @@ -9,6 +9,7 @@ @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties(test_category="graph_test") @pytest.mark.parametrize( ["x_shape", "axis"], [ diff --git a/tests/jax/models/albert/v2/base/test_albert_base.py b/tests/jax/models/albert/v2/base/test_albert_base.py index 23c7be5b..67b5156b 100644 --- a/tests/jax/models/albert/v2/base/test_albert_base.py +++ b/tests/jax/models/albert/v2/base/test_albert_base.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import RunMode -from utils import record_model_test_properties, runtime_fail +from utils import runtime_fail from ..tester import AlbertV2Tester @@ -31,6 +29,11 @@ def training_tester() -> AlbertV2Tester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.xfail( reason=( runtime_fail( @@ -39,21 +42,16 @@ def training_tester() -> AlbertV2Tester: ) ) ) -def test_flax_albert_v2_base_inference( - inference_tester: AlbertV2Tester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_flax_albert_v2_base_inference(inference_tester: AlbertV2Tester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_flax_albert_v2_base_training( - training_tester: AlbertV2Tester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_flax_albert_v2_base_training(training_tester: AlbertV2Tester): training_tester.test() diff --git a/tests/jax/models/albert/v2/large/test_albert_large.py b/tests/jax/models/albert/v2/large/test_albert_large.py index 5c0b82f6..3c8c8a68 100644 --- a/tests/jax/models/albert/v2/large/test_albert_large.py +++ b/tests/jax/models/albert/v2/large/test_albert_large.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import RunMode -from utils import record_model_test_properties, runtime_fail +from utils import runtime_fail from ..tester import AlbertV2Tester @@ -31,6 +29,11 @@ def training_tester() -> AlbertV2Tester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.xfail( reason=( runtime_fail( @@ -39,21 +42,16 @@ def training_tester() -> AlbertV2Tester: ) ) ) -def test_flax_albert_v2_large_inference( - inference_tester: AlbertV2Tester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_flax_albert_v2_large_inference(inference_tester: AlbertV2Tester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_flax_albert_v2_large_training( - training_tester: AlbertV2Tester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_flax_albert_v2_large_training(training_tester: AlbertV2Tester): training_tester.test() diff --git a/tests/jax/models/albert/v2/xlarge/test_albert_xlarge.py b/tests/jax/models/albert/v2/xlarge/test_albert_xlarge.py index 14f3631e..94653de4 100644 --- a/tests/jax/models/albert/v2/xlarge/test_albert_xlarge.py +++ b/tests/jax/models/albert/v2/xlarge/test_albert_xlarge.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import RunMode -from utils import record_model_test_properties, runtime_fail +from utils import runtime_fail from ..tester import AlbertV2Tester @@ -31,6 +29,11 @@ def training_tester() -> AlbertV2Tester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.xfail( reason=( runtime_fail( @@ -39,21 +42,16 @@ def training_tester() -> AlbertV2Tester: ) ) ) -def test_flax_albert_v2_xlarge_inference( - inference_tester: AlbertV2Tester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_flax_albert_v2_xlarge_inference(inference_tester: AlbertV2Tester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_flax_albert_v2_xlarge_training( - training_tester: AlbertV2Tester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_flax_albert_v2_xlarge_training(training_tester: AlbertV2Tester): training_tester.test() diff --git a/tests/jax/models/albert/v2/xxlarge/test_albert_xxlarge.py b/tests/jax/models/albert/v2/xxlarge/test_albert_xxlarge.py index c2c4d381..abff14eb 100644 --- a/tests/jax/models/albert/v2/xxlarge/test_albert_xxlarge.py +++ b/tests/jax/models/albert/v2/xxlarge/test_albert_xxlarge.py @@ -2,11 +2,10 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable import pytest from infra import RunMode -from utils import record_model_test_properties, runtime_fail +from utils import runtime_fail from ..tester import AlbertV2Tester @@ -31,6 +30,11 @@ def training_tester() -> AlbertV2Tester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.xfail( reason=( runtime_fail( @@ -39,21 +43,16 @@ def training_tester() -> AlbertV2Tester: ) ) ) -def test_flax_albert_v2_xxlarge_inference( - inference_tester: AlbertV2Tester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_flax_albert_v2_xxlarge_inference(inference_tester: AlbertV2Tester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_flax_albert_v2_xxlarge_training( - training_tester: AlbertV2Tester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_flax_albert_v2_xxlarge_training(training_tester: AlbertV2Tester): training_tester.test() diff --git a/tests/jax/models/bart/base/test_bart_base.py b/tests/jax/models/bart/base/test_bart_base.py index f5af6989..81ed705a 100644 --- a/tests/jax/models/bart/base/test_bart_base.py +++ b/tests/jax/models/bart/base/test_bart_base.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import RunMode -from utils import record_model_test_properties, runtime_fail +from utils import runtime_fail from ..tester import FlaxBartForCausalLMTester @@ -31,6 +29,11 @@ def training_tester() -> FlaxBartForCausalLMTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.xfail( reason=( runtime_fail( @@ -39,21 +42,16 @@ def training_tester() -> FlaxBartForCausalLMTester: ) ) ) -def test_flax_bart_base_inference( - inference_tester: FlaxBartForCausalLMTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_flax_bart_base_inference(inference_tester: FlaxBartForCausalLMTester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_flax_bart_base_training( - training_tester: FlaxBartForCausalLMTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_flax_bart_base_training(training_tester: FlaxBartForCausalLMTester): training_tester.test() diff --git a/tests/jax/models/bart/large/test_bart_large.py b/tests/jax/models/bart/large/test_bart_large.py index b25a8bd0..739f253c 100644 --- a/tests/jax/models/bart/large/test_bart_large.py +++ b/tests/jax/models/bart/large/test_bart_large.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import RunMode -from utils import compile_fail, record_model_test_properties +from utils import compile_fail from ..tester import FlaxBartForCausalLMTester @@ -31,26 +29,26 @@ def training_tester() -> FlaxBartForCausalLMTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.xfail( reason=compile_fail( "Unsupported data type (https://github.com/tenstorrent/tt-xla/issues/214)" ) ) -def test_flax_bart_large_inference( - inference_tester: FlaxBartForCausalLMTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_flax_bart_large_inference(inference_tester: FlaxBartForCausalLMTester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_flax_bart_large_training( - training_tester: FlaxBartForCausalLMTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_flax_bart_large_training(training_tester: FlaxBartForCausalLMTester): training_tester.test() diff --git a/tests/jax/models/beit/base/test_beit_base.py b/tests/jax/models/beit/base/test_beit_base.py index 607a14c6..9063b018 100644 --- a/tests/jax/models/beit/base/test_beit_base.py +++ b/tests/jax/models/beit/base/test_beit_base.py @@ -2,12 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 - -from typing import Callable - import pytest from infra import RunMode -from utils import compile_fail, record_model_test_properties +from utils import compile_fail from ..tester import FlaxBeitForImageClassificationTester @@ -32,22 +29,24 @@ def training_tester() -> FlaxBeitForImageClassificationTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.xfail(reason=compile_fail("failed to legalize operation 'ttir.gather'")) def test_flax_beit_base_inference( inference_tester: FlaxBeitForImageClassificationTester, - record_tt_xla_property: Callable, ): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_flax_beit_base_training( - training_tester: FlaxBeitForImageClassificationTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_flax_beit_base_training(training_tester: FlaxBeitForImageClassificationTester): training_tester.test() diff --git a/tests/jax/models/beit/large/test_beit_large.py b/tests/jax/models/beit/large/test_beit_large.py index 562f2e49..31a0e482 100644 --- a/tests/jax/models/beit/large/test_beit_large.py +++ b/tests/jax/models/beit/large/test_beit_large.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import RunMode -from utils import compile_fail, record_model_test_properties +from utils import compile_fail from ..tester import FlaxBeitForImageClassificationTester @@ -31,22 +29,26 @@ def training_tester() -> FlaxBeitForImageClassificationTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.xfail(reason=compile_fail("failed to legalize operation 'ttir.gather'")) def test_flax_beit_large_inference( inference_tester: FlaxBeitForImageClassificationTester, - record_tt_xla_property: Callable, ): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") def test_flax_beit_large_training( training_tester: FlaxBeitForImageClassificationTester, - record_tt_xla_property: Callable, ): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - training_tester.test() diff --git a/tests/jax/models/bert/base/test_bert_base.py b/tests/jax/models/bert/base/test_bert_base.py index 57be6b9d..9dadaeec 100644 --- a/tests/jax/models/bert/base/test_bert_base.py +++ b/tests/jax/models/bert/base/test_bert_base.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import RunMode -from utils import record_model_test_properties, runtime_fail +from utils import runtime_fail from ..tester import FlaxBertForMaskedLMTester @@ -31,6 +29,11 @@ def training_tester() -> FlaxBertForMaskedLMTester: @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.xfail( reason=( runtime_fail( @@ -39,22 +42,17 @@ def training_tester() -> FlaxBertForMaskedLMTester: ) ) ) -def test_flax_bert_base_inference( - inference_tester: FlaxBertForMaskedLMTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_flax_bert_base_inference(inference_tester: FlaxBertForMaskedLMTester): inference_tester.test() @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_flax_bert_base_training( - training_tester: FlaxBertForMaskedLMTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_flax_bert_base_training(training_tester: FlaxBertForMaskedLMTester): training_tester.test() diff --git a/tests/jax/models/bert/large/test_bert_large.py b/tests/jax/models/bert/large/test_bert_large.py index 82d8e4cc..3e83907c 100644 --- a/tests/jax/models/bert/large/test_bert_large.py +++ b/tests/jax/models/bert/large/test_bert_large.py @@ -2,11 +2,10 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable import pytest from infra import RunMode -from utils import record_model_test_properties, runtime_fail +from utils import runtime_fail from ..tester import FlaxBertForMaskedLMTester @@ -30,27 +29,27 @@ def training_tester() -> FlaxBertForMaskedLMTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.xfail( reason=runtime_fail( "Host data with total size 16B does not match expected size 8B of device buffer! " "(https://github.com/tenstorrent/tt-xla/issues/182)" ) ) -def test_flax_bert_large_inference( - inference_tester: FlaxBertForMaskedLMTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_flax_bert_large_inference(inference_tester: FlaxBertForMaskedLMTester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_flax_bert_large_training( - training_tester: FlaxBertForMaskedLMTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_flax_bert_large_training(training_tester: FlaxBertForMaskedLMTester): training_tester.test() diff --git a/tests/jax/models/bloom/bloom_1b1/test_1b1.py b/tests/jax/models/bloom/bloom_1b1/test_1b1.py index b6cb3c82..c40adc1a 100644 --- a/tests/jax/models/bloom/bloom_1b1/test_1b1.py +++ b/tests/jax/models/bloom/bloom_1b1/test_1b1.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import ModelTester, RunMode -from utils import compile_fail, record_model_test_properties +from utils import compile_fail from ..tester import BloomTester @@ -29,27 +27,28 @@ def training_tester() -> BloomTester: # ----- Tests ----- + # This is an interesting one. # The error message seems to happen before the compile even begins # And then then compile segfaults with no useful information # It is highly likely that both are caused by the same root cause @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.skip(reason=compile_fail("Unsupported data type")) # segfault -def test_bloom_1b1_inference( - inference_tester: BloomTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_bloom_1b1_inference(inference_tester: BloomTester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_bloom_1b1_training( - training_tester: BloomTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_bloom_1b1_training(training_tester: BloomTester): training_tester.test() diff --git a/tests/jax/models/bloom/bloom_1b7/test_1b7.py b/tests/jax/models/bloom/bloom_1b7/test_1b7.py index bc62e346..2b3fefda 100644 --- a/tests/jax/models/bloom/bloom_1b7/test_1b7.py +++ b/tests/jax/models/bloom/bloom_1b7/test_1b7.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import ModelTester, RunMode -from utils import compile_fail, record_model_test_properties +from utils import compile_fail from ..tester import BloomTester @@ -31,22 +29,22 @@ def training_tester() -> BloomTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.skip(reason=compile_fail("Unsupported data type")) # segfault -def test_bloom_1b7_inference( - inference_tester: BloomTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_bloom_1b7_inference(inference_tester: BloomTester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_bloom_1b7_training( - training_tester: BloomTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_bloom_1b7_training(training_tester: BloomTester): training_tester.test() diff --git a/tests/jax/models/bloom/bloom_3b/test_3b.py b/tests/jax/models/bloom/bloom_3b/test_3b.py index 79602606..d683c437 100644 --- a/tests/jax/models/bloom/bloom_3b/test_3b.py +++ b/tests/jax/models/bloom/bloom_3b/test_3b.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import ModelTester, RunMode -from utils import compile_fail, record_model_test_properties +from utils import compile_fail from ..tester import BloomTester @@ -31,22 +29,22 @@ def training_tester() -> BloomTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.skip(reason=compile_fail("Unsupported data type")) # segfault -def test_bloom_3b_inference( - inference_tester: BloomTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_bloom_3b_inference(inference_tester: BloomTester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_bloom_3b_training( - training_tester: BloomTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_bloom_3b_training(training_tester: BloomTester): training_tester.test() diff --git a/tests/jax/models/bloom/bloom_560m/test_560m.py b/tests/jax/models/bloom/bloom_560m/test_560m.py index 72cd6e02..0f998716 100644 --- a/tests/jax/models/bloom/bloom_560m/test_560m.py +++ b/tests/jax/models/bloom/bloom_560m/test_560m.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import ModelTester, RunMode -from utils import compile_fail, record_model_test_properties +from utils import compile_fail from ..tester import BloomTester @@ -31,22 +29,22 @@ def training_tester() -> BloomTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.skip(reason=compile_fail("Unsupported data type")) # segfault -def test_bloom_560m_inference( - inference_tester: BloomTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_bloom_560m_inference(inference_tester: BloomTester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_bloom_560m_training( - training_tester: BloomTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_bloom_560m_training(training_tester: BloomTester): training_tester.test() diff --git a/tests/jax/models/bloom/bloom_7b/test_7b.py b/tests/jax/models/bloom/bloom_7b/test_7b.py index 07a4ffb7..7a13f743 100644 --- a/tests/jax/models/bloom/bloom_7b/test_7b.py +++ b/tests/jax/models/bloom/bloom_7b/test_7b.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import ModelTester, RunMode -from utils import compile_fail, record_model_test_properties +from utils import compile_fail from ..tester import BloomTester @@ -30,22 +28,22 @@ def training_tester() -> BloomTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.skip(reason=compile_fail("Unsupported data type")) # segfault -def test_bloom_7b_inference( - inference_tester: BloomTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_bloom_7b_inference(inference_tester: BloomTester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_bloom_7b_training( - training_tester: BloomTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_bloom_7b_training(training_tester: BloomTester): training_tester.test() diff --git a/tests/jax/models/clip/base_patch16/test_clip_base_patch16.py b/tests/jax/models/clip/base_patch16/test_clip_base_patch16.py index 64da0ede..efadd9b0 100644 --- a/tests/jax/models/clip/base_patch16/test_clip_base_patch16.py +++ b/tests/jax/models/clip/base_patch16/test_clip_base_patch16.py @@ -2,12 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 - -from typing import Callable - import pytest from infra import RunMode -from utils import compile_fail, record_model_test_properties +from utils import compile_fail from ..tester import FlaxCLIPTester @@ -33,26 +30,26 @@ def training_tester() -> FlaxCLIPTester: @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.skip( reason=compile_fail( 'Assertion `llvm::isUIntN(BitWidth, val) && "Value is not an N-bit unsigned value"\' failed.' ) ) -def test_clip_base_patch16_inference( - inference_tester: FlaxCLIPTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_clip_base_patch16_inference(inference_tester: FlaxCLIPTester): inference_tester.test() @pytest.mark.push +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_clip_base_patch16_training( - training_tester: FlaxCLIPTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_clip_base_patch16_training(training_tester: FlaxCLIPTester): training_tester.test() diff --git a/tests/jax/models/clip/base_patch32/test_clip_base_patch32.py b/tests/jax/models/clip/base_patch32/test_clip_base_patch32.py index 2270f18c..0ea88189 100644 --- a/tests/jax/models/clip/base_patch32/test_clip_base_patch32.py +++ b/tests/jax/models/clip/base_patch32/test_clip_base_patch32.py @@ -2,12 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 - -from typing import Callable - import pytest from infra import RunMode -from utils import compile_fail, record_model_test_properties +from utils import compile_fail from ..tester import FlaxCLIPTester @@ -32,25 +29,25 @@ def training_tester() -> FlaxCLIPTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.skip( reason=compile_fail( 'Assertion `llvm::isUIntN(BitWidth, val) && "Value is not an N-bit unsigned value"\' failed.' ) ) -def test_clip_base_patch32_inference( - inference_tester: FlaxCLIPTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_clip_base_patch32_inference(inference_tester: FlaxCLIPTester): inference_tester.test() +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_clip_base_patch32_training( - training_tester: FlaxCLIPTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_clip_base_patch32_training(training_tester: FlaxCLIPTester): training_tester.test() diff --git a/tests/jax/models/clip/large_patch14/test_clip_large_patch14.py b/tests/jax/models/clip/large_patch14/test_clip_large_patch14.py index 551e1319..020baf52 100644 --- a/tests/jax/models/clip/large_patch14/test_clip_large_patch14.py +++ b/tests/jax/models/clip/large_patch14/test_clip_large_patch14.py @@ -2,12 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 - -from typing import Callable - import pytest from infra import RunMode -from utils import compile_fail, record_model_test_properties +from utils import compile_fail from ..tester import FlaxCLIPTester @@ -32,25 +29,25 @@ def training_tester() -> FlaxCLIPTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.skip( reason=compile_fail( 'Assertion `llvm::isUIntN(BitWidth, val) && "Value is not an N-bit unsigned value"\' failed.' ) ) -def test_clip_large_patch14_inference( - inference_tester: FlaxCLIPTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_clip_large_patch14_inference(inference_tester: FlaxCLIPTester): inference_tester.test() +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_clip_large_patch14_training( - training_tester: FlaxCLIPTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_clip_large_patch14_training(training_tester: FlaxCLIPTester): training_tester.test() diff --git a/tests/jax/models/clip/large_patch14_336/test_clip_large_patch14_336.py b/tests/jax/models/clip/large_patch14_336/test_clip_large_patch14_336.py index 05297896..978bc95b 100644 --- a/tests/jax/models/clip/large_patch14_336/test_clip_large_patch14_336.py +++ b/tests/jax/models/clip/large_patch14_336/test_clip_large_patch14_336.py @@ -2,12 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 - -from typing import Callable - import pytest from infra import RunMode -from utils import compile_fail, record_model_test_properties +from utils import compile_fail from ..tester import FlaxCLIPTester @@ -32,25 +29,25 @@ def training_tester() -> FlaxCLIPTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.skip( reason=compile_fail( 'Assertion `llvm::isUIntN(BitWidth, val) && "Value is not an N-bit unsigned value"\' failed.' ) ) -def test_clip_large_patch14_336_inference( - inference_tester: FlaxCLIPTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_clip_large_patch14_336_inference(inference_tester: FlaxCLIPTester): inference_tester.test() +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_clip_large_patch14_336_training( - training_tester: FlaxCLIPTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_clip_large_patch14_336_training(training_tester: FlaxCLIPTester): training_tester.test() diff --git a/tests/jax/models/distilbert/test_distilbert.py b/tests/jax/models/distilbert/test_distilbert.py index 7d74317c..878d785b 100644 --- a/tests/jax/models/distilbert/test_distilbert.py +++ b/tests/jax/models/distilbert/test_distilbert.py @@ -2,13 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Dict, Sequence +from typing import Dict, Sequence import jax import pytest from infra import ModelTester, RunMode from transformers import AutoTokenizer, FlaxDistilBertForMaskedLM, FlaxPreTrainedModel -from utils import record_model_test_properties, runtime_fail +from utils import runtime_fail MODEL_PATH = "distilbert/distilbert-base-uncased" MODEL_NAME = "distilbert" @@ -55,27 +55,27 @@ def training_tester() -> FlaxDistilBertForMaskedLMTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.xfail( reason=runtime_fail( "Host data with total size 20B does not match expected size 10B of device buffer! " "(https://github.com/tenstorrent/tt-xla/issues/182)" ) ) -def test_flax_distilbert_inference( - inference_tester: FlaxDistilBertForMaskedLMTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_flax_distilbert_inference(inference_tester: FlaxDistilBertForMaskedLMTester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_flax_distilbert_training( - training_tester: FlaxDistilBertForMaskedLMTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_flax_distilbert_training(training_tester: FlaxDistilBertForMaskedLMTester): training_tester.test() diff --git a/tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_mixed_args_and_kwargs.py b/tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_mixed_args_and_kwargs.py index 784bc7d0..e97b167b 100644 --- a/tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_mixed_args_and_kwargs.py +++ b/tests/jax/models/example_model/mixed_args_and_kwargs/test_example_model_mixed_args_and_kwargs.py @@ -80,6 +80,11 @@ def training_tester() -> ExampleModelMixedArgsAndKwargsTester: @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name="Example", + run_mode=RunMode.INFERENCE.value, +) def test_example_model_inference( inference_tester: ExampleModelMixedArgsAndKwargsTester, ): @@ -88,6 +93,11 @@ def test_example_model_inference( @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name="Example", + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") def test_example_model_training(training_tester: ExampleModelMixedArgsAndKwargsTester): training_tester.test() diff --git a/tests/jax/models/example_model/only_args/test_example_model_only_args.py b/tests/jax/models/example_model/only_args/test_example_model_only_args.py index 6b1d631a..3b3a1144 100644 --- a/tests/jax/models/example_model/only_args/test_example_model_only_args.py +++ b/tests/jax/models/example_model/only_args/test_example_model_only_args.py @@ -75,12 +75,22 @@ def training_tester() -> ExampleModelOnlyArgsTester: @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name="Example", + run_mode=RunMode.INFERENCE.value, +) def test_example_model_inference(inference_tester: ExampleModelOnlyArgsTester): inference_tester.test() @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name="Example", + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") def test_example_model_training(training_tester: ExampleModelOnlyArgsTester): training_tester.test() diff --git a/tests/jax/models/example_model/only_kwargs/test_example_model_only_kwargs.py b/tests/jax/models/example_model/only_kwargs/test_example_model_only_kwargs.py index 08b9af00..fc308f23 100644 --- a/tests/jax/models/example_model/only_kwargs/test_example_model_only_kwargs.py +++ b/tests/jax/models/example_model/only_kwargs/test_example_model_only_kwargs.py @@ -75,12 +75,22 @@ def training_tester() -> ExampleModelOnlyKwargsTester: @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name="Example", + run_mode=RunMode.INFERENCE.value, +) def test_example_model_inference(inference_tester: ExampleModelOnlyKwargsTester): inference_tester.test() @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name="Example", + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") def test_example_model_training(training_tester: ExampleModelOnlyKwargsTester): training_tester.test() diff --git a/tests/jax/models/gpt2/base/test_gpt2_base.py b/tests/jax/models/gpt2/base/test_gpt2_base.py index 73686fe1..85195b79 100644 --- a/tests/jax/models/gpt2/base/test_gpt2_base.py +++ b/tests/jax/models/gpt2/base/test_gpt2_base.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import ModelTester, RunMode -from utils import record_model_test_properties, runtime_fail +from utils import runtime_fail from ..tester import GPT2Tester @@ -31,28 +29,28 @@ def training_tester() -> GPT2Tester: @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.xfail( reason=runtime_fail( "Host data with total size 4B does not match expected size 2B of device buffer! " "(https://github.com/tenstorrent/tt-xla/issues/182)" ) ) -def test_gpt2_base_inference( - inference_tester: GPT2Tester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_gpt2_base_inference(inference_tester: GPT2Tester): inference_tester.test() @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_gpt2_base_training( - training_tester: GPT2Tester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_gpt2_base_training(training_tester: GPT2Tester): training_tester.test() diff --git a/tests/jax/models/gpt2/large/test_gpt2_large.py b/tests/jax/models/gpt2/large/test_gpt2_large.py index 86cc2cfe..52950abe 100644 --- a/tests/jax/models/gpt2/large/test_gpt2_large.py +++ b/tests/jax/models/gpt2/large/test_gpt2_large.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import ModelTester, RunMode -from utils import record_model_test_properties, runtime_fail +from utils import runtime_fail from ..tester import GPT2Tester @@ -31,27 +29,27 @@ def training_tester() -> GPT2Tester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.xfail( reason=runtime_fail( "Host data with total size 4B does not match expected size 2B of device buffer! " "(https://github.com/tenstorrent/tt-xla/issues/182)" ) ) -def test_gpt2_large_inference( - inference_tester: GPT2Tester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_gpt2_large_inference(inference_tester: GPT2Tester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_gpt2_large_training( - training_tester: GPT2Tester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_gpt2_large_training(training_tester: GPT2Tester): training_tester.test() diff --git a/tests/jax/models/gpt2/medium/test_gpt2_medium.py b/tests/jax/models/gpt2/medium/test_gpt2_medium.py index 1ddb1049..a500625c 100644 --- a/tests/jax/models/gpt2/medium/test_gpt2_medium.py +++ b/tests/jax/models/gpt2/medium/test_gpt2_medium.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import ModelTester, RunMode -from utils import record_model_test_properties, runtime_fail +from utils import runtime_fail from ..tester import GPT2Tester @@ -31,27 +29,27 @@ def training_tester() -> GPT2Tester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.xfail( reason=runtime_fail( "Host data with total size 4B does not match expected size 2B of device buffer! " "(https://github.com/tenstorrent/tt-xla/issues/182)" ) ) -def test_gpt2_medium_inference( - inference_tester: GPT2Tester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_gpt2_medium_inference(inference_tester: GPT2Tester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_gpt2_medium_training( - training_tester: GPT2Tester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_gpt2_medium_training(training_tester: GPT2Tester): training_tester.test() diff --git a/tests/jax/models/gpt2/xl/test_gpt2_xl.py b/tests/jax/models/gpt2/xl/test_gpt2_xl.py index 0ee88081..bd589006 100644 --- a/tests/jax/models/gpt2/xl/test_gpt2_xl.py +++ b/tests/jax/models/gpt2/xl/test_gpt2_xl.py @@ -2,11 +2,8 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import ModelTester, RunMode -from utils import record_model_test_properties from ..tester import GPT2Tester @@ -31,24 +28,24 @@ def training_tester() -> GPT2Tester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.skip( reason="OOMs in CI (https://github.com/tenstorrent/tt-xla/issues/186)" ) -def test_gpt2_xl_inference( - inference_tester: GPT2Tester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_gpt2_xl_inference(inference_tester: GPT2Tester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_gpt2_xl_training( - training_tester: GPT2Tester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_gpt2_xl_training(training_tester: GPT2Tester): training_tester.test() diff --git a/tests/jax/models/gpt_neo/gpt_neo_125m/test_gpt_neo_125m.py b/tests/jax/models/gpt_neo/gpt_neo_125m/test_gpt_neo_125m.py index 1da5cb50..a65925fa 100644 --- a/tests/jax/models/gpt_neo/gpt_neo_125m/test_gpt_neo_125m.py +++ b/tests/jax/models/gpt_neo/gpt_neo_125m/test_gpt_neo_125m.py @@ -2,11 +2,10 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable import pytest from infra import RunMode -from utils import record_model_test_properties, runtime_fail +from utils import runtime_fail from ..tester import GPTNeoTester @@ -30,26 +29,26 @@ def training_tester() -> GPTNeoTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.skip( reason=runtime_fail( "Host data with total size 4B does not match expected size 2B of device buffer!" ) ) -def test_gpt_neo_125m_inference( - inference_tester: GPTNeoTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_gpt_neo_125m_inference(inference_tester: GPTNeoTester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_gpt_neo_125m_training( - training_tester: GPTNeoTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_gpt_neo_125m_training(training_tester: GPTNeoTester): training_tester.test() diff --git a/tests/jax/models/gpt_neo/gpt_neo_1_3b/test_gpt_neo_1_3b.py b/tests/jax/models/gpt_neo/gpt_neo_1_3b/test_gpt_neo_1_3b.py index e158bbe9..bdee8bf4 100644 --- a/tests/jax/models/gpt_neo/gpt_neo_1_3b/test_gpt_neo_1_3b.py +++ b/tests/jax/models/gpt_neo/gpt_neo_1_3b/test_gpt_neo_1_3b.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import RunMode -from utils import record_model_test_properties, runtime_fail +from utils import runtime_fail from ..tester import GPTNeoTester @@ -30,26 +28,26 @@ def training_tester() -> GPTNeoTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.skip( reason=runtime_fail( "Host data with total size 4B does not match expected size 2B of device buffer!" ) ) -def test_gpt_neo_1_3b_inference( - inference_tester: GPTNeoTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_gpt_neo_1_3b_inference(inference_tester: GPTNeoTester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_gpt_neo_1_3b_training( - training_tester: GPTNeoTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_gpt_neo_1_3b_training(training_tester: GPTNeoTester): training_tester.test() diff --git a/tests/jax/models/gpt_neo/gpt_neo_2_7b/test_gpt_neo_2_7b.py b/tests/jax/models/gpt_neo/gpt_neo_2_7b/test_gpt_neo_2_7b.py index 0e36ae25..7737d549 100644 --- a/tests/jax/models/gpt_neo/gpt_neo_2_7b/test_gpt_neo_2_7b.py +++ b/tests/jax/models/gpt_neo/gpt_neo_2_7b/test_gpt_neo_2_7b.py @@ -2,11 +2,8 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import RunMode -from utils import record_model_test_properties from ..tester import GPTNeoTester @@ -30,22 +27,22 @@ def training_tester() -> GPTNeoTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.skip(reason="OOMs on CI.") -def test_gpt_neo_2_7b_inference( - inference_tester: GPTNeoTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_gpt_neo_2_7b_inference(inference_tester: GPTNeoTester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_gpt_neo_2_7b_training( - training_tester: GPTNeoTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_gpt_neo_2_7b_training(training_tester: GPTNeoTester): training_tester.test() diff --git a/tests/jax/models/llama/openllama_3b_v2/test_openllama_3b_v2.py b/tests/jax/models/llama/openllama_3b_v2/test_openllama_3b_v2.py index e4c3f8f6..7df12d38 100644 --- a/tests/jax/models/llama/openllama_3b_v2/test_openllama_3b_v2.py +++ b/tests/jax/models/llama/openllama_3b_v2/test_openllama_3b_v2.py @@ -2,11 +2,8 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import RunMode -from utils import record_model_test_properties from ..tester import LLamaTester @@ -31,24 +28,24 @@ def training_tester() -> LLamaTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.skip( reason="OOMs in CI (https://github.com/tenstorrent/tt-xla/issues/186)" ) -def test_openllama3b_inference( - inference_tester: LLamaTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_openllama3b_inference(inference_tester: LLamaTester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_openllama3b_training( - training_tester: LLamaTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_openllama3b_training(training_tester: LLamaTester): training_tester.test() diff --git a/tests/jax/models/mlpmixer/test_mlpmixer.py b/tests/jax/models/mlpmixer/test_mlpmixer.py index 9b7fb491..3d8b154a 100644 --- a/tests/jax/models/mlpmixer/test_mlpmixer.py +++ b/tests/jax/models/mlpmixer/test_mlpmixer.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Dict, Sequence +from typing import Any, Dict, Sequence import flax.traverse_util import fsspec @@ -12,10 +12,12 @@ import pytest from flax import linen as nn from infra import ModelTester, RunMode -from utils import record_model_test_properties, runtime_fail +from utils import runtime_fail from .model_implementation import MlpMixer +MODEL_NAME = "mlpmixer" + # Hyperparameters for Mixer-B/16 patch_size = 16 num_classes = 21843 @@ -90,6 +92,11 @@ def training_tester() -> MlpMixerTester: @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.skip( reason=runtime_fail( "Statically allocated circular buffers in program 16 clash with L1 buffers " @@ -98,22 +105,17 @@ def training_tester() -> MlpMixerTester: "(https://github.com/tenstorrent/tt-xla/issues/187)" ) ) # segfault -def test_mlpmixer_inference( - inference_tester: MlpMixerTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, "mlpmixer") - +def test_mlpmixer_inference(inference_tester: MlpMixerTester): inference_tester.test() @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_mlpmixer_training( - training_tester: MlpMixerTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, "mlpmixer") - +def test_mlpmixer_training(training_tester: MlpMixerTester): training_tester.test() diff --git a/tests/jax/models/mnist/cnn/dropout/test_mnist_cnn_dropout.py b/tests/jax/models/mnist/cnn/dropout/test_mnist_cnn_dropout.py index 6e35bdda..14e38b03 100644 --- a/tests/jax/models/mnist/cnn/dropout/test_mnist_cnn_dropout.py +++ b/tests/jax/models/mnist/cnn/dropout/test_mnist_cnn_dropout.py @@ -2,17 +2,16 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import RunMode -from utils import record_model_test_properties from ..tester import MNISTCNNTester from .model_implementation import MNISTCNNDropoutModel # ----- Fixtures ----- +MODEL_NAME = "mnist-cnn-dropout" + @pytest.fixture def inference_tester() -> MNISTCNNTester: @@ -29,22 +28,22 @@ def training_tester() -> MNISTCNNTester: @pytest.mark.push @pytest.mark.nightly -def test_mnist_cnn_dropout_inference( - inference_tester: MNISTCNNTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, "mnist-cnn-dropout") - +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) +def test_mnist_cnn_dropout_inference(inference_tester: MNISTCNNTester): inference_tester.test() @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_mnist_cnn_nodropout_training( - training_tester: MNISTCNNTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, "mnist-cnn-dropout") - +def test_mnist_cnn_nodropout_training(training_tester: MNISTCNNTester): training_tester.test() diff --git a/tests/jax/models/mnist/cnn/nodropout/test_mnist_cnn_nodropout.py b/tests/jax/models/mnist/cnn/nodropout/test_mnist_cnn_nodropout.py index 5f765f48..55f87392 100644 --- a/tests/jax/models/mnist/cnn/nodropout/test_mnist_cnn_nodropout.py +++ b/tests/jax/models/mnist/cnn/nodropout/test_mnist_cnn_nodropout.py @@ -2,17 +2,16 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import RunMode -from utils import record_model_test_properties from ..tester import MNISTCNNTester from .model_implementation import MNISTCNNNoDropoutModel # ----- Fixtures ----- +MODEL_NAME = "mnist-cnn-nodropout" + @pytest.fixture def inference_tester() -> MNISTCNNTester: @@ -29,22 +28,22 @@ def training_tester() -> MNISTCNNTester: @pytest.mark.push @pytest.mark.nightly -def test_mnist_cnn_nodropout_inference( - inference_tester: MNISTCNNTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, "mnist-cnn-nodropout") - +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) +def test_mnist_cnn_nodropout_inference(inference_tester: MNISTCNNTester): inference_tester.test() @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_mnist_cnn_nodropout_training( - training_tester: MNISTCNNTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, "mnist-cnn-nodropout") - +def test_mnist_cnn_nodropout_training(training_tester: MNISTCNNTester): training_tester.test() diff --git a/tests/jax/models/mnist/mlp/test_mnist_mlp.py b/tests/jax/models/mnist/mlp/test_mnist_mlp.py index 11bf4f8a..6e7810a6 100644 --- a/tests/jax/models/mnist/mlp/test_mnist_mlp.py +++ b/tests/jax/models/mnist/mlp/test_mnist_mlp.py @@ -2,13 +2,12 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Sequence +from typing import Sequence import jax import pytest from flax import linen as nn from infra import ComparisonConfig, ModelTester, RunMode -from utils import record_model_test_properties from .model_implementation import MNISTMLPModel @@ -51,6 +50,8 @@ def _get_forward_method_args(self): # ----- Fixtures ----- +MODEL_NAME = "mnist-mlp" + @pytest.fixture def inference_tester(request) -> MNISTMLPTester: @@ -67,6 +68,11 @@ def training_tester(request) -> MNISTMLPTester: @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.parametrize( "inference_tester", [ @@ -80,22 +86,17 @@ def training_tester(request) -> MNISTMLPTester: indirect=True, ids=lambda val: f"{val}", ) -def test_mnist_mlp_inference( - inference_tester: MNISTMLPTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, "mnist-mlp") - +def test_mnist_mlp_inference(inference_tester: MNISTMLPTester): inference_tester.test() @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_mnist_mlp_training( - training_tester: MNISTMLPTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MNISTMLPModel.__qualname__) - +def test_mnist_mlp_training(training_tester: MNISTMLPTester): training_tester.test() diff --git a/tests/jax/models/opt/opt_125m/test_125m.py b/tests/jax/models/opt/opt_125m/test_125m.py index a1b47259..829fd903 100644 --- a/tests/jax/models/opt/opt_125m/test_125m.py +++ b/tests/jax/models/opt/opt_125m/test_125m.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import RunMode -from utils import compile_fail, record_model_test_properties +from utils import compile_fail from ..tester import OPTTester @@ -31,22 +29,22 @@ def training_tester() -> OPTTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.skip(reason=compile_fail("Unsupported data type")) # segfault -def test_opt_125m_inference( - inference_tester: OPTTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_opt_125m_inference(inference_tester: OPTTester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_opt_125m_training( - training_tester: OPTTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_opt_125m_training(training_tester: OPTTester): training_tester.test() diff --git a/tests/jax/models/opt/opt_1_3b/test_1_3b.py b/tests/jax/models/opt/opt_1_3b/test_1_3b.py index 6f7b364b..ab9ff01b 100644 --- a/tests/jax/models/opt/opt_1_3b/test_1_3b.py +++ b/tests/jax/models/opt/opt_1_3b/test_1_3b.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import RunMode -from utils import compile_fail, record_model_test_properties +from utils import compile_fail from ..tester import OPTTester @@ -31,22 +29,22 @@ def training_tester() -> OPTTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.skip(reason=compile_fail("Unsupported data type")) # segfault -def test_opt_1_3b_inference( - inference_tester: OPTTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_opt_1_3b_inference(inference_tester: OPTTester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_opt_1_3b_training( - training_tester: OPTTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_opt_1_3b_training(training_tester: OPTTester): training_tester.test() diff --git a/tests/jax/models/opt/opt_2_7b/test_2_7b.py b/tests/jax/models/opt/opt_2_7b/test_2_7b.py index cabc245c..881b1f0a 100644 --- a/tests/jax/models/opt/opt_2_7b/test_2_7b.py +++ b/tests/jax/models/opt/opt_2_7b/test_2_7b.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import RunMode -from utils import compile_fail, record_model_test_properties +from utils import compile_fail from ..tester import OPTTester @@ -31,22 +29,22 @@ def training_tester() -> OPTTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.skip(reason=compile_fail("Unsupported data type")) # segfault -def test_opt_2_7b_inference( - inference_tester: OPTTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_opt_2_7b_inference(inference_tester: OPTTester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_opt_2_7b_training( - training_tester: OPTTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_opt_2_7b_training(training_tester: OPTTester): training_tester.test() diff --git a/tests/jax/models/opt/opt_350m/test_350m.py b/tests/jax/models/opt/opt_350m/test_350m.py index f8244912..773f1ec5 100644 --- a/tests/jax/models/opt/opt_350m/test_350m.py +++ b/tests/jax/models/opt/opt_350m/test_350m.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import RunMode -from utils import compile_fail, record_model_test_properties +from utils import compile_fail from ..tester import OPTTester @@ -31,22 +29,22 @@ def training_tester() -> OPTTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.skip(reason=compile_fail("Unsupported data type")) # segfault -def test_opt_350m_inference( - inference_tester: OPTTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_opt_350m_inference(inference_tester: OPTTester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_opt_350m_training( - training_tester: OPTTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_opt_350m_training(training_tester: OPTTester): training_tester.test() diff --git a/tests/jax/models/opt/opt_6_7b/test_6_7b.py b/tests/jax/models/opt/opt_6_7b/test_6_7b.py index 45e7c0fd..8a67c1fe 100644 --- a/tests/jax/models/opt/opt_6_7b/test_6_7b.py +++ b/tests/jax/models/opt/opt_6_7b/test_6_7b.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import RunMode -from utils import compile_fail, record_model_test_properties +from utils import compile_fail from ..tester import OPTTester @@ -31,22 +29,22 @@ def training_tester() -> OPTTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.skip(reason=compile_fail("Unsupported data type")) # segfault -def test_opt_6_7b_inference( - inference_tester: OPTTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_opt_6_7b_inference(inference_tester: OPTTester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_opt_6_7b_training( - training_tester: OPTTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_opt_6_7b_training(training_tester: OPTTester): training_tester.test() diff --git a/tests/jax/models/roberta/base/test_roberta_base.py b/tests/jax/models/roberta/base/test_roberta_base.py index 9b7c77b0..7e4adbad 100644 --- a/tests/jax/models/roberta/base/test_roberta_base.py +++ b/tests/jax/models/roberta/base/test_roberta_base.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import RunMode -from utils import record_model_test_properties, runtime_fail +from utils import runtime_fail from ..tester import FlaxRobertaForMaskedLMTester @@ -31,27 +29,27 @@ def training_tester() -> FlaxRobertaForMaskedLMTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.xfail( reason=runtime_fail( "Host data with total size 20B does not match expected size 10B of device buffer! " "(https://github.com/tenstorrent/tt-xla/issues/182)" ) ) -def test_flax_roberta_base_inference( - inference_tester: FlaxRobertaForMaskedLMTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_flax_roberta_base_inference(inference_tester: FlaxRobertaForMaskedLMTester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_flax_roberta_base_training( - training_tester: FlaxRobertaForMaskedLMTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_flax_roberta_base_training(training_tester: FlaxRobertaForMaskedLMTester): training_tester.test() diff --git a/tests/jax/models/roberta/large/test_roberta_large.py b/tests/jax/models/roberta/large/test_roberta_large.py index 1ef0fa37..26c641dd 100644 --- a/tests/jax/models/roberta/large/test_roberta_large.py +++ b/tests/jax/models/roberta/large/test_roberta_large.py @@ -2,11 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import pytest from infra import RunMode -from utils import record_model_test_properties, runtime_fail +from utils import runtime_fail from ..tester import FlaxRobertaForMaskedLMTester @@ -30,27 +28,27 @@ def training_tester() -> FlaxRobertaForMaskedLMTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.xfail( reason=runtime_fail( "Host data with total size 20B does not match expected size 10B of device buffer! " "(https://github.com/tenstorrent/tt-xla/issues/182)" ) ) -def test_flax_roberta_large_inference( - inference_tester: FlaxRobertaForMaskedLMTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_flax_roberta_large_inference(inference_tester: FlaxRobertaForMaskedLMTester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_flax_roberta_large_training( - training_tester: FlaxRobertaForMaskedLMTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_flax_roberta_large_training(training_tester: FlaxRobertaForMaskedLMTester): training_tester.test() diff --git a/tests/jax/models/roberta_prelayernorm/efficient_mlm_m0_40/test_efficient_mlm_m0_40.py b/tests/jax/models/roberta_prelayernorm/efficient_mlm_m0_40/test_efficient_mlm_m0_40.py index 9c289d04..297ac4f7 100644 --- a/tests/jax/models/roberta_prelayernorm/efficient_mlm_m0_40/test_efficient_mlm_m0_40.py +++ b/tests/jax/models/roberta_prelayernorm/efficient_mlm_m0_40/test_efficient_mlm_m0_40.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Dict, Sequence +from typing import Dict, Sequence import jax import pytest @@ -12,7 +12,7 @@ FlaxPreTrainedModel, FlaxRobertaPreLayerNormForMaskedLM, ) -from utils import record_model_test_properties, runtime_fail +from utils import runtime_fail MODEL_PATH = "andreasmadsen/efficient_mlm_m0.40" MODEL_NAME = "roberta-prelayernorm" @@ -72,6 +72,11 @@ def training_tester() -> FlaxRobertaPreLayerNormForMaskedLMTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.xfail( reason=runtime_fail( "Host data with total size 20B does not match expected size 10B of device buffer! " @@ -80,19 +85,18 @@ def training_tester() -> FlaxRobertaPreLayerNormForMaskedLMTester: ) def test_flax_roberta_prelayernorm_inference( inference_tester: FlaxRobertaPreLayerNormForMaskedLMTester, - record_tt_xla_property: Callable, ): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") def test_flax_roberta_prelayernorm_training( training_tester: FlaxRobertaPreLayerNormForMaskedLMTester, - record_tt_xla_property: Callable, ): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - training_tester.test() diff --git a/tests/jax/models/squeezebert/test_squeezebert.py b/tests/jax/models/squeezebert/test_squeezebert.py index 68945d02..ae832042 100644 --- a/tests/jax/models/squeezebert/test_squeezebert.py +++ b/tests/jax/models/squeezebert/test_squeezebert.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Dict, Sequence +from typing import Dict, Sequence import jax import pytest @@ -11,7 +11,7 @@ from huggingface_hub import hf_hub_download from infra import ModelTester, RunMode from transformers import AutoTokenizer -from utils import compile_fail, record_model_test_properties +from utils import compile_fail from .model_implementation import SqueezeBertConfig, SqueezeBertForMaskedLM @@ -76,24 +76,24 @@ def training_tester() -> SqueezeBertTester: @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.INFERENCE.value, +) @pytest.mark.xfail( reason=compile_fail("Failed to legalize operation 'ttir.convolution'") ) -def test_squeezebert_inference( - inference_tester: SqueezeBertTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_squeezebert_inference(inference_tester: SqueezeBertTester): inference_tester.test() @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="model_test", + model_name=MODEL_NAME, + run_mode=RunMode.TRAINING.value, +) @pytest.mark.skip(reason="Support for training not implemented") -def test_squeezebert_training( - training_tester: SqueezeBertTester, - record_tt_xla_property: Callable, -): - record_model_test_properties(record_tt_xla_property, MODEL_NAME) - +def test_squeezebert_training(training_tester: SqueezeBertTester): training_tester.test() diff --git a/tests/jax/multichip/manual/all_gather.py b/tests/jax/multichip/manual/all_gather.py index 91626ac3..2dd6648b 100644 --- a/tests/jax/multichip/manual/all_gather.py +++ b/tests/jax/multichip/manual/all_gather.py @@ -4,11 +4,12 @@ import jax import jax.numpy as jnp -from infra import run_multichip_test_with_random_inputs, make_partition_spec import pytest +from infra import make_partition_spec, run_multichip_test_with_random_inputs from utils import compile_fail +@pytest.mark.record_properties(test_category="multichip_test") @pytest.mark.parametrize( ("x_shape", "mesh_shape", "axis_names"), [((8192, 784), (2,), ("batch",))] ) diff --git a/tests/jax/multichip/manual/data_paralelism.py b/tests/jax/multichip/manual/data_paralelism.py index f3606887..994932fb 100644 --- a/tests/jax/multichip/manual/data_paralelism.py +++ b/tests/jax/multichip/manual/data_paralelism.py @@ -2,13 +2,14 @@ # # SPDX-License-Identifier: Apache-2.0 -from infra import run_multichip_test_with_random_inputs, make_partition_spec import jax import jax.numpy as jnp -from utils import compile_fail import pytest +from infra import make_partition_spec, run_multichip_test_with_random_inputs +from utils import compile_fail +@pytest.mark.record_properties(test_category="multichip_test") @pytest.mark.parametrize( [ "batch_shape", diff --git a/tests/jax/multichip/manual/psum.py b/tests/jax/multichip/manual/psum.py index a3bd39b4..64696905 100644 --- a/tests/jax/multichip/manual/psum.py +++ b/tests/jax/multichip/manual/psum.py @@ -4,11 +4,12 @@ import jax import jax.numpy as jnp -from infra import run_multichip_test_with_random_inputs, make_partition_spec import pytest +from infra import make_partition_spec, run_multichip_test_with_random_inputs from utils import compile_fail +@pytest.mark.record_properties(test_category="multichip_test") @pytest.mark.parametrize( ["batch_shape", "W1_shape", "B1_shape", "mesh_shape", "axis_names"], [ diff --git a/tests/jax/multichip/manual/psum_scatter.py b/tests/jax/multichip/manual/psum_scatter.py index 7730d112..e169c91d 100644 --- a/tests/jax/multichip/manual/psum_scatter.py +++ b/tests/jax/multichip/manual/psum_scatter.py @@ -2,13 +2,14 @@ # # SPDX-License-Identifier: Apache-2.0 -from infra import run_multichip_test_with_random_inputs, make_partition_spec import jax import jax.numpy as jnp import pytest +from infra import make_partition_spec, run_multichip_test_with_random_inputs from utils import compile_fail +@pytest.mark.record_properties(test_category="multichip_test") @pytest.mark.parametrize( ["batch_shape", "W1_shape", "B1_shape", "mesh_shape", "axis_names"], [ diff --git a/tests/jax/multichip/manual/unary_eltwise.py b/tests/jax/multichip/manual/unary_eltwise.py index 872ae287..c4a7bb13 100644 --- a/tests/jax/multichip/manual/unary_eltwise.py +++ b/tests/jax/multichip/manual/unary_eltwise.py @@ -2,13 +2,14 @@ # # SPDX-License-Identifier: Apache-2.0 -from infra import run_multichip_test_with_random_inputs, make_partition_spec import jax import jax.numpy as jnp import pytest +from infra import make_partition_spec, run_multichip_test_with_random_inputs from utils import compile_fail +@pytest.mark.record_properties(test_category="multichip_test") @pytest.mark.parametrize( ("x_shape", "mesh_shape", "axis_names"), [((256, 256), (1, 2), ("x", "y"))] ) diff --git a/tests/jax/ops/test_abs.py b/tests/jax/ops/test_abs.py index d5774e14..9ae45453 100644 --- a/tests/jax/ops/test_abs.py +++ b/tests/jax/ops/test_abs.py @@ -2,27 +2,23 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs -from utils import record_unary_op_test_properties @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.abs", + shlo_op_name="stablehlo.abs", +) @pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") -def test_abs(x_shape: tuple, record_tt_xla_property: Callable): +def test_abs(x_shape: tuple): def abs(x: jax.Array) -> jax.Array: return jnp.abs(x) - record_unary_op_test_properties( - record_tt_xla_property, - "jax.numpy.abs", - "stablehlo.abs", - ) - # Test both negative and positive values. run_op_test_with_random_inputs(abs, [x_shape], minval=-5.0, maxval=5.0) diff --git a/tests/jax/ops/test_add.py b/tests/jax/ops/test_add.py index 080ddd8d..70e18f60 100644 --- a/tests/jax/ops/test_add.py +++ b/tests/jax/ops/test_add.py @@ -2,17 +2,19 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs -from utils import record_binary_op_test_properties @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.add", + shlo_op_name="stablehlo.add", +) @pytest.mark.parametrize( ["x_shape", "y_shape"], [ @@ -21,14 +23,8 @@ ], ids=lambda val: f"{val}", ) -def test_add(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable): +def test_add(x_shape: tuple, y_shape: tuple): def add(x: jax.Array, y: jax.Array) -> jax.Array: return jnp.add(x, y) - record_binary_op_test_properties( - record_tt_xla_property, - "jax.numpy.add", - "stablehlo.add", - ) - run_op_test_with_random_inputs(add, [x_shape, y_shape]) diff --git a/tests/jax/ops/test_broadcast_in_dim.py b/tests/jax/ops/test_broadcast_in_dim.py index 8adafc60..16f2aef6 100644 --- a/tests/jax/ops/test_broadcast_in_dim.py +++ b/tests/jax/ops/test_broadcast_in_dim.py @@ -2,26 +2,22 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs -from utils import record_unary_op_test_properties @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.broadcast_to", + shlo_op_name="stablehlo.broadcast_in_dim", +) @pytest.mark.parametrize("input_shapes", [[(2, 1)]], ids=lambda val: f"{val}") -def test_broadcast_in_dim(input_shapes: tuple, record_tt_xla_property: Callable): +def test_broadcast_in_dim(input_shapes: tuple): def broadcast(a: jax.Array): return jnp.broadcast_to(a, (2, 4)) - record_unary_op_test_properties( - record_tt_xla_property, - "jax.numpy.broadcast_to", - "stablehlo.broadcast_in_dim", - ) - run_op_test_with_random_inputs(broadcast, input_shapes) diff --git a/tests/jax/ops/test_cbrt.py b/tests/jax/ops/test_cbrt.py index 35ac43ec..6ac23638 100644 --- a/tests/jax/ops/test_cbrt.py +++ b/tests/jax/ops/test_cbrt.py @@ -2,26 +2,22 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs -from utils import record_unary_op_test_properties @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.cbrt", + shlo_op_name="stablehlo.cbrt", +) @pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") -def test_cbrt(x_shape: tuple, record_tt_xla_property: Callable): +def test_cbrt(x_shape: tuple): def cbrt(x: jax.Array) -> jax.Array: return jnp.cbrt(x) - record_unary_op_test_properties( - record_tt_xla_property, - "jax.numpy.cbrt", - "stablehlo.cbrt", - ) - run_op_test_with_random_inputs(cbrt, [x_shape]) diff --git a/tests/jax/ops/test_compare.py b/tests/jax/ops/test_compare.py index 4cb34615..45770d84 100644 --- a/tests/jax/ops/test_compare.py +++ b/tests/jax/ops/test_compare.py @@ -9,7 +9,6 @@ import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs -from utils import record_binary_op_test_properties # NOTE TTNN does not support boolean data type, so bfloat16 is used instead. # Hence the output of comparison operation is bfloat16. JAX can not perform any @@ -36,6 +35,11 @@ def wrapper(*args, **kwargs): @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.equal", + shlo_op_name="stablehlo.compare{EQ}", +) @pytest.mark.parametrize( ["x_shape", "y_shape"], [ @@ -44,24 +48,21 @@ def wrapper(*args, **kwargs): ], ids=lambda val: f"{val}", ) -def test_compare_equal( - x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable -): +def test_compare_equal(x_shape: tuple, y_shape: tuple): @convert_output_to_bfloat16 def equal(x: jax.Array, y: jax.Array) -> jax.Array: return x == y - record_binary_op_test_properties( - record_tt_xla_property, - "jax.numpy.equal", - "stablehlo.compare{EQ}", - ) - run_op_test_with_random_inputs(equal, [x_shape, y_shape]) @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.not_equal", + shlo_op_name="stablehlo.compare{NE}", +) @pytest.mark.parametrize( ["x_shape", "y_shape"], [ @@ -70,24 +71,21 @@ def equal(x: jax.Array, y: jax.Array) -> jax.Array: ], ids=lambda val: f"{val}", ) -def test_compare_not_equal( - x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable -): +def test_compare_not_equal(x_shape: tuple, y_shape: tuple): @convert_output_to_bfloat16 def not_equal(x: jax.Array, y: jax.Array) -> jax.Array: return x != y - record_binary_op_test_properties( - record_tt_xla_property, - "jax.numpy.not_equal", - "stablehlo.compare{NE}", - ) - run_op_test_with_random_inputs(not_equal, [x_shape, y_shape]) @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.greater", + shlo_op_name="stablehlo.compare{GT}", +) @pytest.mark.parametrize( ["x_shape", "y_shape"], [ @@ -96,24 +94,21 @@ def not_equal(x: jax.Array, y: jax.Array) -> jax.Array: ], ids=lambda val: f"{val}", ) -def test_compare_greater( - x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable -): +def test_compare_greater(x_shape: tuple, y_shape: tuple): @convert_output_to_bfloat16 def greater(x: jax.Array, y: jax.Array) -> jax.Array: return x > y - record_binary_op_test_properties( - record_tt_xla_property, - "jax.numpy.greater", - "stablehlo.compare{GT}", - ) - run_op_test_with_random_inputs(greater, [x_shape, y_shape]) @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.greater_equal", + shlo_op_name="stablehlo.compare{GE}", +) @pytest.mark.parametrize( ["x_shape", "y_shape"], [ @@ -122,24 +117,21 @@ def greater(x: jax.Array, y: jax.Array) -> jax.Array: ], ids=lambda val: f"{val}", ) -def test_compare_greater_equal( - x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable -): +def test_compare_greater_equal(x_shape: tuple, y_shape: tuple): @convert_output_to_bfloat16 def greater_equal(x: jax.Array, y: jax.Array) -> jax.Array: return x >= y - record_binary_op_test_properties( - record_tt_xla_property, - "jax.numpy.greater_equal", - "stablehlo.compare{GE}", - ) - run_op_test_with_random_inputs(greater_equal, [x_shape, y_shape]) @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.less", + shlo_op_name="stablehlo.compare{LT}", +) @pytest.mark.parametrize( ["x_shape", "y_shape"], [ @@ -148,22 +140,21 @@ def greater_equal(x: jax.Array, y: jax.Array) -> jax.Array: ], ids=lambda val: f"{val}", ) -def test_compare_less(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable): +def test_compare_less(x_shape: tuple, y_shape: tuple): @convert_output_to_bfloat16 def less(x: jax.Array, y: jax.Array) -> jax.Array: return x < y - record_binary_op_test_properties( - record_tt_xla_property, - "jax.numpy.less", - "stablehlo.compare{LT}", - ) - run_op_test_with_random_inputs(less, [x_shape, y_shape]) @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.less_equal", + shlo_op_name="stablehlo.compare{LE}", +) @pytest.mark.parametrize( ["x_shape", "y_shape"], [ @@ -172,17 +163,9 @@ def less(x: jax.Array, y: jax.Array) -> jax.Array: ], ids=lambda val: f"{val}", ) -def test_compare_less_equal( - x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable -): +def test_compare_less_equal(x_shape: tuple, y_shape: tuple): @convert_output_to_bfloat16 def less_equal(x: jax.Array, y: jax.Array) -> jax.Array: return x <= y - record_binary_op_test_properties( - record_tt_xla_property, - "jax.numpy.less_equal", - "stablehlo.compare{LE}", - ) - run_op_test_with_random_inputs(less_equal, [x_shape, y_shape]) diff --git a/tests/jax/ops/test_concatenate.py b/tests/jax/ops/test_concatenate.py index 4aa0d9f8..8711246a 100644 --- a/tests/jax/ops/test_concatenate.py +++ b/tests/jax/ops/test_concatenate.py @@ -2,17 +2,19 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs -from utils import record_binary_op_test_properties @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.concatenate", + shlo_op_name="stablehlo.concatenate", +) @pytest.mark.parametrize( ["x_shape", "y_shape", "axis"], [ @@ -23,16 +25,8 @@ ], ids=lambda val: f"{val}", ) -def test_concatenate( - x_shape: tuple, y_shape: tuple, axis: int, record_tt_xla_property: Callable -): +def test_concatenate(x_shape: tuple, y_shape: tuple, axis: int): def concat(x: jax.Array, y: jax.Array) -> jax.Array: return jnp.concatenate([x, y], axis=axis) - record_binary_op_test_properties( - record_tt_xla_property, - "jax.numpy.concatenate", - "stablehlo.concatenate", - ) - run_op_test_with_random_inputs(concat, [x_shape, y_shape]) diff --git a/tests/jax/ops/test_constant.py b/tests/jax/ops/test_constant.py index 1525691f..dbfdbded 100644 --- a/tests/jax/ops/test_constant.py +++ b/tests/jax/ops/test_constant.py @@ -2,60 +2,52 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import jax.numpy as jnp import pytest from infra import run_op_test -from utils import compile_fail, record_op_test_properties +from utils import compile_fail @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.zeros", + shlo_op_name="stablehlo.constant", +) @pytest.mark.parametrize("shape", [(32, 32), (1, 1)], ids=lambda val: f"{val}") -def test_constant_zeros(shape: tuple, record_tt_xla_property: Callable): +def test_constant_zeros(shape: tuple): def module_constant_zeros(): return jnp.zeros(shape) - record_op_test_properties( - record_tt_xla_property, - "Constant op", - "jax.numpy.zeros", - "stablehlo.constant", - ) - run_op_test(module_constant_zeros, []) @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.ones", + shlo_op_name="stablehlo.constant", +) @pytest.mark.parametrize("shape", [(32, 32), (1, 1)], ids=lambda val: f"{val}") -def test_constant_ones(shape: tuple, record_tt_xla_property: Callable): +def test_constant_ones(shape: tuple): def module_constant_ones(): return jnp.ones(shape) - record_op_test_properties( - record_tt_xla_property, - "Constant op", - "jax.numpy.ones", - "stablehlo.constant", - ) - run_op_test(module_constant_ones, []) @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.array", + shlo_op_name="stablehlo.constant", +) @pytest.mark.xfail(reason=compile_fail("failed to legalize operation 'ttir.constant'")) -def test_constant_multi_value(record_tt_xla_property: Callable): +def test_constant_multi_value(): def module_constant_multi(): return jnp.array([[1, 2], [3, 4]], dtype=jnp.float32) - record_op_test_properties( - record_tt_xla_property, - "Constant op", - "jax.numpy.array", - "stablehlo.constant", - ) - run_op_test(module_constant_multi, []) diff --git a/tests/jax/ops/test_convert.py b/tests/jax/ops/test_convert.py index f8a51702..1e9d53d2 100644 --- a/tests/jax/ops/test_convert.py +++ b/tests/jax/ops/test_convert.py @@ -2,15 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import jax import jax.lax as jlx import jax.numpy as jnp import pytest from infra import random_tensor, run_op_test from jax._src.typing import DTypeLike -from utils import compile_fail, record_unary_op_test_properties, runtime_fail +from utils import compile_fail, runtime_fail from tests.utils import enable_x64 @@ -165,6 +163,11 @@ def conditionally_skip(from_dtype: DTypeLike, to_dtype: DTypeLike): @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.lax.convert_element_type", + shlo_op_name="stablehlo.convert", +) @pytest.mark.parametrize( "from_dtype", [ @@ -239,18 +242,10 @@ def conditionally_skip(from_dtype: DTypeLike, to_dtype: DTypeLike): ), ], ) -def test_convert( - from_dtype: DTypeLike, to_dtype: DTypeLike, record_tt_xla_property: Callable -): +def test_convert(from_dtype: DTypeLike, to_dtype: DTypeLike): def convert(x: jax.Array) -> jax.Array: return jlx.convert_element_type(x, new_dtype=to_dtype) - record_unary_op_test_properties( - record_tt_xla_property, - "jax.lax.convert_element_type", - "stablehlo.convert", - ) - # Some dtype conversions are not supported. Check and decide whether to skip or # proceed. conditionally_skip(from_dtype, to_dtype) diff --git a/tests/jax/ops/test_convolution.py b/tests/jax/ops/test_convolution.py index 1de7b3b8..88938adc 100644 --- a/tests/jax/ops/test_convolution.py +++ b/tests/jax/ops/test_convolution.py @@ -2,12 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import jax import pytest from infra import ComparisonConfig, random_tensor, run_op_test -from utils import record_op_test_properties # TODO investigate why conv has such poor precision. @@ -22,6 +19,11 @@ def comparison_config() -> ComparisonConfig: @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.lax.conv_general_dilated", + shlo_op_name="stablehlo.convolution", +) @pytest.mark.parametrize( ["img_shape", "kernel_shape"], [ @@ -36,7 +38,6 @@ def test_conv1d( img_shape: tuple, kernel_shape: tuple, comparison_config: ComparisonConfig, - record_tt_xla_property: Callable, ): def conv1d(img, weights): return jax.lax.conv_general_dilated( @@ -51,13 +52,6 @@ def conv1d(img, weights): batch_group_count=1, ) - record_op_test_properties( - record_tt_xla_property, - "Convolution op", - "jax.lax.conv_general_dilated", - "stablehlo.convolution", - ) - img = random_tensor(img_shape, dtype="bfloat16") kernel = random_tensor(kernel_shape, dtype="bfloat16") @@ -66,6 +60,11 @@ def conv1d(img, weights): @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.lax.conv_general_dilated", + shlo_op_name="stablehlo.convolution", +) @pytest.mark.parametrize( [ "batch_size", @@ -122,7 +121,6 @@ def test_conv2d( stride_w: int, padding: int, comparison_config: ComparisonConfig, - record_tt_xla_property: Callable, ): def conv2d(img: jax.Array, kernel: jax.Array): return jax.lax.conv_general_dilated( @@ -133,13 +131,6 @@ def conv2d(img: jax.Array, kernel: jax.Array): dimension_numbers=("NHWC", "OIHW", "NHWC"), ) - record_op_test_properties( - record_tt_xla_property, - "Convolution op", - "jax.lax.conv_general_dilated", - "stablehlo.convolution", - ) - img_shape = (batch_size, input_height, input_width, input_channels) kernel_shape = (output_channels, input_channels, filter_height, filter_width) diff --git a/tests/jax/ops/test_divide.py b/tests/jax/ops/test_divide.py index ec0a1a7f..638346bd 100644 --- a/tests/jax/ops/test_divide.py +++ b/tests/jax/ops/test_divide.py @@ -2,17 +2,19 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs -from utils import record_binary_op_test_properties @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.divide", + shlo_op_name="stablehlo.divide", +) @pytest.mark.parametrize( ["x_shape", "y_shape"], [ @@ -21,12 +23,8 @@ ], ids=lambda val: f"{val}", ) -def test_divide(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable): +def test_divide(x_shape: tuple, y_shape: tuple): def divide(x: jax.Array, y: jax.Array) -> jax.Array: return jnp.divide(x, y) - record_binary_op_test_properties( - record_tt_xla_property, "jax.numpy.divide", "stablehlo.divide" - ) - run_op_test_with_random_inputs(divide, [x_shape, y_shape]) diff --git a/tests/jax/ops/test_dot_general.py b/tests/jax/ops/test_dot_general.py index e1880a1d..6c7083ac 100644 --- a/tests/jax/ops/test_dot_general.py +++ b/tests/jax/ops/test_dot_general.py @@ -2,12 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import jax import pytest from infra import run_op_test_with_random_inputs -from utils import record_binary_op_test_properties # Tests for dot_general op where vectors containing indices of contracting dimensions @@ -15,6 +12,11 @@ # this is the most common one we have. @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.lax.dot_general", + shlo_op_name="stablehlo.dot_general", +) @pytest.mark.parametrize( ["x_shape", "y_shape"], [ @@ -25,22 +27,21 @@ ], ids=lambda val: f"{val}", ) -def test_dot_general_common( - x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable -): +def test_dot_general_common(x_shape: tuple, y_shape: tuple): def dot_general(x: jax.Array, y: jax.Array) -> jax.Array: return jax.lax.dot_general(x, y, dimension_numbers=((1, 1), (0, 0))) - record_binary_op_test_properties( - record_tt_xla_property, "jax.lax.dot_general", "stablehlo.dot_general" - ) - run_op_test_with_random_inputs(dot_general, [x_shape, y_shape]) # Tests for dot_general op where this operation corresponds to regular matmul. @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.lax.dot_general", + shlo_op_name="stablehlo.dot_general", +) @pytest.mark.parametrize( ["x_shape", "y_shape"], [ @@ -48,16 +49,10 @@ def dot_general(x: jax.Array, y: jax.Array) -> jax.Array: [(2, 32, 64), (2, 64, 64)], ], ) -def test_dot_general_matmul( - x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable -): +def test_dot_general_matmul(x_shape: tuple, y_shape: tuple): def dot_general(x: jax.Array, y: jax.Array) -> jax.Array: return jax.lax.dot_general(x, y, dimension_numbers=((2, 1), (0, 0))) - record_binary_op_test_properties( - record_tt_xla_property, "jax.lax.dot_general", "stablehlo.dot_general" - ) - run_op_test_with_random_inputs(dot_general, [x_shape, y_shape]) @@ -65,6 +60,11 @@ def dot_general(x: jax.Array, y: jax.Array) -> jax.Array: # contracting dimensions are of size greater than 1. @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.lax.dot_general", + shlo_op_name="stablehlo.dot_general", +) @pytest.mark.parametrize( ["x_shape", "y_shape"], [ @@ -72,14 +72,8 @@ def dot_general(x: jax.Array, y: jax.Array) -> jax.Array: [(2, 8, 8, 16), (2, 8, 16, 8)], ], ) -def test_dot_general_multiple_contract( - x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable -): +def test_dot_general_multiple_contract(x_shape: tuple, y_shape: tuple): def dot_general(x: jax.Array, y: jax.Array) -> jax.Array: return jax.lax.dot_general(x, y, dimension_numbers=(((1, 3), (1, 2)), (0, 0))) - record_binary_op_test_properties( - record_tt_xla_property, "jax.lax.dot_general", "stablehlo.dot_general" - ) - run_op_test_with_random_inputs(dot_general, [x_shape, y_shape]) diff --git a/tests/jax/ops/test_exponential.py b/tests/jax/ops/test_exponential.py index 38b2a27d..2f07296f 100644 --- a/tests/jax/ops/test_exponential.py +++ b/tests/jax/ops/test_exponential.py @@ -2,26 +2,22 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs -from utils import record_unary_op_test_properties @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.exp", + shlo_op_name="stablehlo.exponential", +) @pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") -def test_exponential(x_shape: tuple, record_tt_xla_property: Callable): +def test_exponential(x_shape: tuple): def exponential(x: jax.Array) -> jax.Array: return jnp.exp(x) - record_unary_op_test_properties( - record_tt_xla_property, - "jax.numpy.exp", - "stablehlo.exponential", - ) - run_op_test_with_random_inputs(exponential, [x_shape]) diff --git a/tests/jax/ops/test_exponential_minus_one.py b/tests/jax/ops/test_exponential_minus_one.py index fb6f9fcd..1a68c80e 100644 --- a/tests/jax/ops/test_exponential_minus_one.py +++ b/tests/jax/ops/test_exponential_minus_one.py @@ -2,13 +2,10 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import jax import jax.numpy as jnp import pytest from infra import ComparisonConfig, run_op_test_with_random_inputs -from utils import record_unary_op_test_properties @pytest.fixture @@ -25,21 +22,16 @@ def comparison_config() -> ComparisonConfig: @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.expm1", + shlo_op_name="stablehlo.exponential_minus_one", +) @pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") -def test_exponential_minus_one( - x_shape: tuple, - comparison_config: ComparisonConfig, - record_tt_xla_property: Callable, -): +def test_exponential_minus_one(x_shape: tuple, comparison_config: ComparisonConfig): def expm1(x: jax.Array) -> jax.Array: return jnp.expm1(x) - record_unary_op_test_properties( - record_tt_xla_property, - "jax.numpy.expm1", - "stablehlo.exponential_minus_one", - ) - run_op_test_with_random_inputs( expm1, [x_shape], comparison_config=comparison_config ) diff --git a/tests/jax/ops/test_log_plus_one.py b/tests/jax/ops/test_log_plus_one.py index d0ed20cc..566fa1ee 100644 --- a/tests/jax/ops/test_log_plus_one.py +++ b/tests/jax/ops/test_log_plus_one.py @@ -2,26 +2,22 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs -from utils import record_unary_op_test_properties @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.log1p", + shlo_op_name="stablehlo.log_plus_one", +) @pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") -def test_log1p(x_shape: tuple, record_tt_xla_property: Callable): +def test_log1p(x_shape: tuple): def log1p(x: jax.Array) -> jax.Array: return jnp.log1p(x) - record_unary_op_test_properties( - record_tt_xla_property, - "jax.numpy.log1p", - "stablehlo.log_plus_one", - ) - run_op_test_with_random_inputs(log1p, [x_shape]) diff --git a/tests/jax/ops/test_maximum.py b/tests/jax/ops/test_maximum.py index 868fc820..99224844 100644 --- a/tests/jax/ops/test_maximum.py +++ b/tests/jax/ops/test_maximum.py @@ -2,17 +2,19 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs -from utils import record_binary_op_test_properties @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.maximum", + shlo_op_name="stablehlo.maximum", +) @pytest.mark.parametrize( ["x_shape", "y_shape"], [ @@ -21,14 +23,8 @@ ], ids=lambda val: f"{val}", ) -def test_maximum(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable): +def test_maximum(x_shape: tuple, y_shape: tuple): def maximum(x: jax.Array, y: jax.Array) -> jax.Array: return jnp.maximum(x, y) - record_binary_op_test_properties( - record_tt_xla_property, - "jax.numpy.maximum", - "stablehlo.maximum", - ) - run_op_test_with_random_inputs(maximum, [x_shape, y_shape]) diff --git a/tests/jax/ops/test_minimum.py b/tests/jax/ops/test_minimum.py index 1b652e7f..dcb0b949 100644 --- a/tests/jax/ops/test_minimum.py +++ b/tests/jax/ops/test_minimum.py @@ -2,17 +2,19 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs -from utils import record_binary_op_test_properties @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.minimum", + shlo_op_name="stablehlo.minimum", +) @pytest.mark.parametrize( ["x_shape", "y_shape"], [ @@ -21,14 +23,8 @@ ], ids=lambda val: f"{val}", ) -def test_minimum(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable): +def test_minimum(x_shape: tuple, y_shape: tuple): def minimum(x: jax.Array, y: jax.Array) -> jax.Array: return jnp.minimum(x, y) - record_binary_op_test_properties( - record_tt_xla_property, - "jax.numpy.minimum", - "stablehlo.minimum", - ) - run_op_test_with_random_inputs(minimum, [x_shape, y_shape]) diff --git a/tests/jax/ops/test_multiply.py b/tests/jax/ops/test_multiply.py index 52992134..f2f97477 100644 --- a/tests/jax/ops/test_multiply.py +++ b/tests/jax/ops/test_multiply.py @@ -2,17 +2,19 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs -from utils import record_binary_op_test_properties @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.multiply", + shlo_op_name="stablehlo.multiply", +) @pytest.mark.parametrize( ["x_shape", "y_shape"], [ @@ -21,14 +23,8 @@ ], ids=lambda val: f"{val}", ) -def test_multiply(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable): +def test_multiply(x_shape: tuple, y_shape: tuple): def multiply(x: jax.Array, y: jax.Array) -> jax.Array: return jnp.multiply(x, y) - record_binary_op_test_properties( - record_tt_xla_property, - "jax.numpy.multiply", - "stablehlo.multiply", - ) - run_op_test_with_random_inputs(multiply, [x_shape, y_shape]) diff --git a/tests/jax/ops/test_negate.py b/tests/jax/ops/test_negate.py index 00ea3499..b13636a7 100644 --- a/tests/jax/ops/test_negate.py +++ b/tests/jax/ops/test_negate.py @@ -2,27 +2,23 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs -from utils import record_unary_op_test_properties @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.negative", + shlo_op_name="stablehlo.negative", +) @pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") -def test_negate(x_shape: tuple, record_tt_xla_property: Callable): +def test_negate(x_shape: tuple): def negate(x: jax.Array) -> jax.Array: return jnp.negative(x) - record_unary_op_test_properties( - record_tt_xla_property, - "jax.numpy.negative", - "stablehlo.negate", - ) - # Trying both negative and positive values. run_op_test_with_random_inputs(negate, [x_shape], minval=-5.0, maxval=5.0) diff --git a/tests/jax/ops/test_reduce.py b/tests/jax/ops/test_reduce.py index db1fa636..05f8c1bc 100644 --- a/tests/jax/ops/test_reduce.py +++ b/tests/jax/ops/test_reduce.py @@ -2,13 +2,11 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable import jax import jax.numpy as jnp import pytest from infra import ComparisonConfig, run_op_test_with_random_inputs -from utils import record_op_test_properties # TODO investigate why this doesn't pass with default comparison config. @@ -24,22 +22,16 @@ def comparison_config() -> ComparisonConfig: # TODO axis should be parametrized as well. @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.sum", + shlo_op_name="stablehlo.reduce{SUM}", +) @pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") -def test_reduce_sum( - x_shape: tuple, - comparison_config: ComparisonConfig, - record_tt_xla_property: Callable, -): +def test_reduce_sum(x_shape: tuple, comparison_config: ComparisonConfig): def reduce_sum(x: jax.Array) -> jax.Array: return jnp.sum(x) - record_op_test_properties( - record_tt_xla_property, - "Reduce op", - "jax.numpy.sum", - "stablehlo.reduce{SUM}", - ) - run_op_test_with_random_inputs( reduce_sum, [x_shape], comparison_config=comparison_config ) @@ -48,22 +40,16 @@ def reduce_sum(x: jax.Array) -> jax.Array: # TODO axis should be parametrized as well. @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.max", + shlo_op_name="stablehlo.reduce{MAX}", +) @pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") -def test_reduce_max( - x_shape: tuple, - comparison_config: ComparisonConfig, - record_tt_xla_property: Callable, -): +def test_reduce_max(x_shape: tuple, comparison_config: ComparisonConfig): def reduce_max(x: jax.Array) -> jax.Array: return jnp.max(x) - record_op_test_properties( - record_tt_xla_property, - "Reduce op", - "jax.numpy.max", - "stablehlo.reduce{MAX}", - ) - run_op_test_with_random_inputs( reduce_max, [x_shape], comparison_config=comparison_config ) diff --git a/tests/jax/ops/test_reduce_window.py b/tests/jax/ops/test_reduce_window.py index d6227008..2c30000c 100644 --- a/tests/jax/ops/test_reduce_window.py +++ b/tests/jax/ops/test_reduce_window.py @@ -2,13 +2,10 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import flax import jax import pytest from infra import ComparisonConfig, random_tensor, run_op_test -from utils import record_op_test_properties @pytest.fixture @@ -22,6 +19,11 @@ def comparison_config() -> ComparisonConfig: @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="flax.linen.max_pool", + shlo_op_name="stablehlo.reduce_window{MAX}", +) @pytest.mark.parametrize( "img_shape", ## NHWC [ @@ -69,20 +71,12 @@ def test_reduce_window_max( strides: tuple, padding: tuple, comparison_config: ComparisonConfig, - record_tt_xla_property: Callable, ): def maxpool2d(img: jax.Array): return flax.linen.max_pool( img, window_shape=window_shape, strides=strides, padding=padding ) - record_op_test_properties( - record_tt_xla_property, - "Maxpool op", - "flax.linen.max_pool", - "stablehlo.reduce_window{MAX}", - ) - # NOTE Some resnet convolutions seem to require bfloat16, ttnn throws in runtime # otherwise. On another note, MaxPool2d is also only supported for bfloat16 in ttnn, # so we have to run conv in bfloat16 for the time being. diff --git a/tests/jax/ops/test_remainder.py b/tests/jax/ops/test_remainder.py index 1f1acef2..37ed55af 100644 --- a/tests/jax/ops/test_remainder.py +++ b/tests/jax/ops/test_remainder.py @@ -2,17 +2,19 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import jax import jax.lax as jlx import pytest from infra import run_op_test_with_random_inputs -from utils import record_binary_op_test_properties @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.lax.rem", + shlo_op_name="stablehlo.remainder", +) @pytest.mark.parametrize( ["x_shape", "y_shape"], [ @@ -21,14 +23,8 @@ ], ids=lambda val: f"{val}", ) -def test_remainder(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable): +def test_remainder(x_shape: tuple, y_shape: tuple): def remainder(x: jax.Array, y: jax.Array) -> jax.Array: return jlx.rem(x, y) - record_binary_op_test_properties( - record_tt_xla_property, - "jax.lax.rem", - "stablehlo.remainder", - ) - run_op_test_with_random_inputs(remainder, [x_shape, y_shape]) diff --git a/tests/jax/ops/test_reshape.py b/tests/jax/ops/test_reshape.py index 6a857749..54402650 100644 --- a/tests/jax/ops/test_reshape.py +++ b/tests/jax/ops/test_reshape.py @@ -2,17 +2,19 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs -from utils import record_unary_op_test_properties @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.reshape", + shlo_op_name="stablehlo.reshape", +) @pytest.mark.parametrize( ["in_shape", "out_shape"], [ @@ -22,14 +24,8 @@ ], ids=lambda val: f"{val}", ) -def test_reshape(in_shape: tuple, out_shape: tuple, record_tt_xla_property: Callable): +def test_reshape(in_shape: tuple, out_shape: tuple): def reshape(x: jax.Array): return jnp.reshape(x, out_shape) - record_unary_op_test_properties( - record_tt_xla_property, - "jax.numpy.reshape", - "stablehlo.reshape", - ) - run_op_test_with_random_inputs(reshape, [in_shape]) diff --git a/tests/jax/ops/test_rsqrt.py b/tests/jax/ops/test_rsqrt.py index fe37e460..a60ce226 100644 --- a/tests/jax/ops/test_rsqrt.py +++ b/tests/jax/ops/test_rsqrt.py @@ -2,27 +2,23 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import jax import jax.lax as jlx import pytest from infra import run_op_test_with_random_inputs -from utils import record_unary_op_test_properties @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.lax.rsqrt", + shlo_op_name="stablehlo.rsqrt", +) @pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") -def test_rsqrt(x_shape: tuple, record_tt_xla_property: Callable): +def test_rsqrt(x_shape: tuple): def rsqrt(x: jax.Array) -> jax.Array: return jlx.rsqrt(x) - record_unary_op_test_properties( - record_tt_xla_property, - "jax.lax.rsqrt", - "stablehlo.rsqrt", - ) - # Input must be strictly positive because of sqrt(x). run_op_test_with_random_inputs(rsqrt, [x_shape], minval=0.1, maxval=10.0) diff --git a/tests/jax/ops/test_sign.py b/tests/jax/ops/test_sign.py index 83fd8e26..45e80ee1 100644 --- a/tests/jax/ops/test_sign.py +++ b/tests/jax/ops/test_sign.py @@ -2,27 +2,23 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs -from utils import record_unary_op_test_properties @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.sign", + shlo_op_name="stablehlo.sign", +) @pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") -def test_sign(x_shape: tuple, record_tt_xla_property: Callable): +def test_sign(x_shape: tuple): def sign(x: jax.Array) -> jax.Array: return jnp.sign(x) - record_unary_op_test_properties( - record_tt_xla_property, - "jax.numpy.sign", - "stablehlo.sign", - ) - # Trying both negative and positive values. run_op_test_with_random_inputs(sign, [x_shape], minval=-5.0, maxval=5.0) diff --git a/tests/jax/ops/test_slice.py b/tests/jax/ops/test_slice.py index 7aa81d8e..e499a703 100644 --- a/tests/jax/ops/test_slice.py +++ b/tests/jax/ops/test_slice.py @@ -2,12 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs -from utils import record_unary_op_test_properties dim0_cases = [] for begin in jnp.arange(10).tolist(): @@ -33,12 +30,17 @@ # TODO investigate if this test can be rewritten to make it easier for understanding. @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.lax.slice", + shlo_op_name="stablehlo.slice", +) @pytest.mark.parametrize( ["begin", "end", "dim"], [*dim2_cases, *dim3_cases, *dim0_cases, *dim1_cases], ids=lambda val: f"{val}", ) -def test_slice(begin: int, end: int, dim: int, record_tt_xla_property: Callable): +def test_slice(begin: int, end: int, dim: int): def module_slice(a): if dim == 0: return a[begin:end, :, :, :] @@ -49,12 +51,6 @@ def module_slice(a): else: return a[:, :, :, begin:end] - record_unary_op_test_properties( - record_tt_xla_property, - "jax.lax.slice", - "stablehlo.slice", - ) - shape = [10, 10, 10, 10] shape[dim] = 128 diff --git a/tests/jax/ops/test_sqrt.py b/tests/jax/ops/test_sqrt.py index 8f6a195f..86dcb050 100644 --- a/tests/jax/ops/test_sqrt.py +++ b/tests/jax/ops/test_sqrt.py @@ -2,27 +2,23 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs -from utils import record_unary_op_test_properties @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.sqrt", + shlo_op_name="stablehlo.sqrt", +) @pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") -def test_sqrt(x_shape: tuple, record_tt_xla_property: Callable): +def test_sqrt(x_shape: tuple): def sqrt(x: jax.Array) -> jax.Array: return jnp.sqrt(x) - record_unary_op_test_properties( - record_tt_xla_property, - "jax.numpy.sqrt", - "stablehlo.sqrt", - ) - # Input must be strictly positive because of sqrt(x). run_op_test_with_random_inputs(sqrt, [x_shape], minval=0.1, maxval=10.0) diff --git a/tests/jax/ops/test_subtract.py b/tests/jax/ops/test_subtract.py index 20f79075..65627d32 100644 --- a/tests/jax/ops/test_subtract.py +++ b/tests/jax/ops/test_subtract.py @@ -2,17 +2,19 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs -from utils import record_binary_op_test_properties @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.subtract", + shlo_op_name="stablehlo.subtract", +) @pytest.mark.parametrize( ["x_shape", "y_shape"], [ @@ -21,14 +23,8 @@ ], ids=lambda val: f"{val}", ) -def test_subtract(x_shape: tuple, y_shape: tuple, record_tt_xla_property: Callable): +def test_subtract(x_shape: tuple, y_shape: tuple): def subtract(x: jax.Array, y: jax.Array) -> jax.Array: return jnp.subtract(x, y) - record_binary_op_test_properties( - record_tt_xla_property, - "jax.numpy.subtract", - "stablehlo.subtract", - ) - run_op_test_with_random_inputs(subtract, [x_shape, y_shape]) diff --git a/tests/jax/ops/test_transpose.py b/tests/jax/ops/test_transpose.py index be99870a..48e721df 100644 --- a/tests/jax/ops/test_transpose.py +++ b/tests/jax/ops/test_transpose.py @@ -2,26 +2,22 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Callable - import jax import jax.numpy as jnp import pytest from infra import run_op_test_with_random_inputs -from utils import record_unary_op_test_properties @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties( + test_category="op_test", + jax_op_name="jax.numpy.transpose", + shlo_op_name="stablehlo.transpose", +) @pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)], ids=lambda val: f"{val}") -def test_transpose(x_shape: tuple, record_tt_xla_property: Callable): +def test_transpose(x_shape: tuple): def transpose(x: jax.Array) -> jax.Array: return jnp.transpose(x) - record_unary_op_test_properties( - record_tt_xla_property, - "jax.numpy.transpose", - "stablehlo.transpose", - ) - run_op_test_with_random_inputs(transpose, [x_shape]) diff --git a/tests/jax/test_data_types.py b/tests/jax/test_data_types.py index 3c340262..e5cb1fcb 100644 --- a/tests/jax/test_data_types.py +++ b/tests/jax/test_data_types.py @@ -17,6 +17,7 @@ @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties(test_category="other") @pytest.mark.parametrize( "dtype", [ diff --git a/tests/jax/test_device_initialization.py b/tests/jax/test_device_initialization.py index 5385ffff..72f9d5e5 100644 --- a/tests/jax/test_device_initialization.py +++ b/tests/jax/test_device_initialization.py @@ -50,6 +50,7 @@ def is_tt_device(device: jax.Device) -> bool: @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties(test_category="other") def test_devices_are_connected(): cpus = jax.devices("cpu") @@ -64,6 +65,7 @@ def test_devices_are_connected(): @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties(test_category="other") def test_put_tensor_on_device(cpu: jax.Device, tt_device: jax.Device): # `random_tensor` is executed on cpu due to `@run_on_cpu` decorator so we don't have # to put it explicitly on cpu, but we will just for demonstration purposes. @@ -78,6 +80,7 @@ def test_put_tensor_on_device(cpu: jax.Device, tt_device: jax.Device): @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties(test_category="other") def test_device_output_comparison(cpu: jax.Device, tt_device: jax.Device): @jax.jit # Apply jit to this function. def add(x: jax.Array, y: jax.Array): diff --git a/tests/jax/test_ranks.py b/tests/jax/test_ranks.py index ee8d83ae..f8a8a9b5 100644 --- a/tests/jax/test_ranks.py +++ b/tests/jax/test_ranks.py @@ -14,6 +14,7 @@ @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties(test_category="other") @pytest.mark.parametrize( "x_shape", [ @@ -43,6 +44,7 @@ def negate(x: jax.Array) -> jax.Array: @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties(test_category="other") @pytest.mark.parametrize( "shape", [ diff --git a/tests/jax/test_scalar_types.py b/tests/jax/test_scalar_types.py index f7b2e7f5..f56c35a7 100644 --- a/tests/jax/test_scalar_types.py +++ b/tests/jax/test_scalar_types.py @@ -14,6 +14,7 @@ @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties(test_category="other") def test_scalar_scalar_add(): """Tests adding two scalars.""" @@ -25,6 +26,7 @@ def add() -> jax.Array: @pytest.mark.push @pytest.mark.nightly +@pytest.mark.record_properties(test_category="other") @pytest.mark.skip("Fails due to https://github.com/tenstorrent/tt-metal/issues/16701") def test_scalar_array_add(): """ diff --git a/tests/utils.py b/tests/utils.py index 73ea98fb..3a8ef4e4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,10 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 from contextlib import contextmanager -from typing import Callable import jax -from conftest import RecordProperties def compile_fail(reason: str) -> str: @@ -17,34 +15,6 @@ def runtime_fail(reason: str) -> str: return f"Runtime failed: {reason}" -def record_unary_op_test_properties( - record_property: Callable, framework_op_name: str, op_name: str -): - record_property(RecordProperties.OP_KIND.value, "Unary op") - record_property(RecordProperties.FRAMEWORK_OP_NAME.value, framework_op_name) - record_property(RecordProperties.OP_NAME.value, op_name) - - -def record_binary_op_test_properties( - record_property: Callable, framework_op_name: str, op_name: str -): - record_property(RecordProperties.OP_KIND.value, "Binary op") - record_property(RecordProperties.FRAMEWORK_OP_NAME.value, framework_op_name) - record_property(RecordProperties.OP_NAME.value, op_name) - - -def record_op_test_properties( - record_property: Callable, op_kind: str, framework_op_name: str, op_name: str -): - record_property(RecordProperties.OP_KIND.value, op_kind) - record_property(RecordProperties.FRAMEWORK_OP_NAME.value, framework_op_name) - record_property(RecordProperties.OP_NAME.value, op_name) - - -def record_model_test_properties(record_property: Callable, model_name: str): - record_property(RecordProperties.MODEL_NAME.value, model_name) - - @contextmanager def enable_x64(): """