Skip to content

Commit

Permalink
Fixed op tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kmitrovicTT committed Mar 10, 2025
1 parent cd3e64e commit fa7d14a
Show file tree
Hide file tree
Showing 106 changed files with 693 additions and 556 deletions.
9 changes: 8 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
from utils import TestCategory


def pytest_configure(config: pytest.Config):
Expand Down Expand Up @@ -69,7 +70,13 @@ def validate_keys(keys):
# Validate that only allowed keys are used.
validate_keys(properties.keys())

is_model_test = properties.get("test_category", None) == "model_test"
# Turn all properties to strings.
for k, v in properties:
properties[k] = str(v)

is_model_test = (
properties.get("test_category", None) == TestCategory.MODEL_TEST.value
)
if is_model_test:
model_group = properties.get("model_group", None)

Expand Down
3 changes: 3 additions & 0 deletions tests/infra/model_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class RunMode(Enum):
INFERENCE = "inference"
TRAINING = "training"

def __str__(self) -> str:
return self.value


class ModelTester(BaseTester, ABC):
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/jax/graphs/test_MLP_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def comparison_config() -> ComparisonConfig:

@pytest.mark.push
@pytest.mark.nightly
@pytest.mark.record_test_properties(test_category=TestCategory.GRAPH_TEST.value)
@pytest.mark.record_test_properties(test_category=TestCategory.GRAPH_TEST)
@pytest.mark.parametrize(
["W1", "b1", "W2", "b2", "X", "y"],
[
Expand Down
2 changes: 1 addition & 1 deletion tests/jax/graphs/test_activation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

@pytest.mark.push
@pytest.mark.nightly
@pytest.mark.record_test_properties(test_category=TestCategory.GRAPH_TEST.value)
@pytest.mark.record_test_properties(test_category=TestCategory.GRAPH_TEST)
@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)])
def test_relu(x_shape: tuple):
"""Test ReLU activation function."""
Expand Down
1 change: 0 additions & 1 deletion tests/jax/graphs/test_example_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def example_graph(x: jax.Array, y: jax.Array) -> jax.Array:


@pytest.mark.push
@pytest.mark.nightly
@pytest.mark.parametrize(
["x_shape", "y_shape"],
[
Expand Down
2 changes: 1 addition & 1 deletion tests/jax/graphs/test_linear_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

@pytest.mark.push
@pytest.mark.nightly
@pytest.mark.record_test_properties(test_category=TestCategory.GRAPH_TEST.value)
@pytest.mark.record_test_properties(test_category=TestCategory.GRAPH_TEST)
@pytest.mark.parametrize(
["x_shape", "y_shape", "bias_shape"],
[
Expand Down
2 changes: 1 addition & 1 deletion tests/jax/graphs/test_simple_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

@pytest.mark.push
@pytest.mark.nightly
@pytest.mark.record_test_properties(test_category=TestCategory.GRAPH_TEST.value)
@pytest.mark.record_test_properties(test_category=TestCategory.GRAPH_TEST)
@pytest.mark.parametrize("x_shape", [(32, 32), (64, 64)])
def test_simple_gradient(x_shape: tuple):
def simple_gradient(x: jax.Array):
Expand Down
2 changes: 1 addition & 1 deletion tests/jax/graphs/test_simple_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

@pytest.mark.push
@pytest.mark.nightly
@pytest.mark.record_test_properties(test_category=TestCategory.GRAPH_TEST.value)
@pytest.mark.record_test_properties(test_category=TestCategory.GRAPH_TEST)
@pytest.mark.parametrize(
["weights", "bias", "X", "y"], [[(1, 2), (1, 1), (2, 1), (1, 1)]]
)
Expand Down
2 changes: 1 addition & 1 deletion tests/jax/graphs/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

@pytest.mark.push
@pytest.mark.nightly
@pytest.mark.record_test_properties(test_category=TestCategory.GRAPH_TEST.value)
@pytest.mark.record_test_properties(test_category=TestCategory.GRAPH_TEST)
@pytest.mark.parametrize(
["x_shape", "axis"],
[
Expand Down
23 changes: 17 additions & 6 deletions tests/jax/models/albert/v2/base/test_albert_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,21 @@

import pytest
from infra import RunMode
from utils import runtime_fail
from utils import (
ModelGroup,
ModelSource,
ModelTask,
TestCategory,
build_model_name,
runtime_fail,
)

from ..tester import AlbertV2Tester

MODEL_PATH = "albert/albert-base-v2"
MODEL_NAME = "albert-v2-base"
MODEL_NAME = build_model_name(
"jax", "albert_base", ModelTask.MASKED_LM, ModelSource.HUGGING_FACE, "v2"
)


# ----- Fixtures -----
Expand All @@ -30,9 +39,10 @@ def training_tester() -> AlbertV2Tester:

@pytest.mark.model_test
@pytest.mark.record_test_properties(
test_category="model_test",
test_category=TestCategory.MODEL_TEST,
model_name=MODEL_NAME,
run_mode=RunMode.INFERENCE.value,
model_group=ModelGroup.GENERALITY,
run_mode=RunMode.INFERENCE,
)
@pytest.mark.xfail(
reason=(
Expand All @@ -47,9 +57,10 @@ def test_flax_albert_v2_base_inference(inference_tester: AlbertV2Tester):

@pytest.mark.model_test
@pytest.mark.record_test_properties(
test_category="model_test",
test_category=TestCategory.MODEL_TEST,
model_name=MODEL_NAME,
run_mode=RunMode.TRAINING.value,
model_group=ModelGroup.GENERALITY,
run_mode=RunMode.TRAINING,
)
@pytest.mark.skip(reason="Support for training not implemented")
def test_flax_albert_v2_base_training(training_tester: AlbertV2Tester):
Expand Down
23 changes: 17 additions & 6 deletions tests/jax/models/albert/v2/large/test_albert_large.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,21 @@

import pytest
from infra import RunMode
from utils import runtime_fail
from utils import (
ModelGroup,
ModelSource,
ModelTask,
TestCategory,
build_model_name,
runtime_fail,
)

from ..tester import AlbertV2Tester

MODEL_PATH = "albert/albert-large-v2"
MODEL_NAME = "albert-v2-large"
MODEL_NAME = build_model_name(
"jax", "albert_large", ModelTask.MASKED_LM, ModelSource.HUGGING_FACE, "v2"
)


# ----- Fixtures -----
Expand All @@ -30,9 +39,10 @@ def training_tester() -> AlbertV2Tester:

@pytest.mark.model_test
@pytest.mark.record_test_properties(
test_category="model_test",
test_category=TestCategory.MODEL_TEST,
model_name=MODEL_NAME,
run_mode=RunMode.INFERENCE.value,
model_group=ModelGroup.GENERALITY,
run_mode=RunMode.INFERENCE,
)
@pytest.mark.xfail(
reason=(
Expand All @@ -47,9 +57,10 @@ def test_flax_albert_v2_large_inference(inference_tester: AlbertV2Tester):

@pytest.mark.model_test
@pytest.mark.record_test_properties(
test_category="model_test",
test_category=TestCategory.MODEL_TEST,
model_name=MODEL_NAME,
run_mode=RunMode.TRAINING.value,
model_group=ModelGroup.GENERALITY,
run_mode=RunMode.TRAINING,
)
@pytest.mark.skip(reason="Support for training not implemented")
def test_flax_albert_v2_large_training(training_tester: AlbertV2Tester):
Expand Down
24 changes: 17 additions & 7 deletions tests/jax/models/albert/v2/xlarge/test_albert_xlarge.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,21 @@

import pytest
from infra import RunMode
from utils import runtime_fail
from utils import (
ModelGroup,
ModelSource,
ModelTask,
TestCategory,
build_model_name,
runtime_fail,
)

from ..tester import AlbertV2Tester

MODEL_PATH = "albert/albert-xlarge-v2"
MODEL_NAME = "albert-v2-xlarge"

MODEL_NAME = build_model_name(
"jax", "albert_xlarge", ModelTask.MASKED_LM, ModelSource.HUGGING_FACE, "v2"
)

# ----- Fixtures -----

Expand All @@ -30,9 +38,10 @@ def training_tester() -> AlbertV2Tester:

@pytest.mark.model_test
@pytest.mark.record_test_properties(
test_category="model_test",
test_category=TestCategory.MODEL_TEST,
model_name=MODEL_NAME,
run_mode=RunMode.INFERENCE.value,
model_group=ModelGroup.GENERALITY,
run_mode=RunMode.INFERENCE,
)
@pytest.mark.xfail(
reason=(
Expand All @@ -47,9 +56,10 @@ def test_flax_albert_v2_xlarge_inference(inference_tester: AlbertV2Tester):

@pytest.mark.model_test
@pytest.mark.record_test_properties(
test_category="model_test",
test_category=TestCategory.MODEL_TEST,
model_name=MODEL_NAME,
run_mode=RunMode.TRAINING.value,
model_group=ModelGroup.GENERALITY,
run_mode=RunMode.TRAINING,
)
@pytest.mark.skip(reason="Support for training not implemented")
def test_flax_albert_v2_xlarge_training(training_tester: AlbertV2Tester):
Expand Down
24 changes: 17 additions & 7 deletions tests/jax/models/albert/v2/xxlarge/test_albert_xxlarge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,23 @@
#
# SPDX-License-Identifier: Apache-2.0


import pytest
from infra import RunMode
from utils import runtime_fail
from utils import (
ModelGroup,
ModelSource,
ModelTask,
TestCategory,
build_model_name,
runtime_fail,
)

from ..tester import AlbertV2Tester

MODEL_PATH = "albert/albert-xxlarge-v2"
MODEL_NAME = "albert-v2-xxlarge"
MODEL_NAME = build_model_name(
"jax", "albert_xxlarge", ModelTask.MASKED_LM, ModelSource.HUGGING_FACE, "v2"
)


# ----- Fixtures -----
Expand All @@ -31,9 +39,10 @@ def training_tester() -> AlbertV2Tester:

@pytest.mark.model_test
@pytest.mark.record_test_properties(
test_category="model_test",
test_category=TestCategory.MODEL_TEST,
model_name=MODEL_NAME,
run_mode=RunMode.INFERENCE.value,
model_group=ModelGroup.GENERALITY,
run_mode=RunMode.INFERENCE,
)
@pytest.mark.xfail(
reason=(
Expand All @@ -48,9 +57,10 @@ def test_flax_albert_v2_xxlarge_inference(inference_tester: AlbertV2Tester):

@pytest.mark.model_test
@pytest.mark.record_test_properties(
test_category="model_test",
test_category=TestCategory.MODEL_TEST,
model_name=MODEL_NAME,
run_mode=RunMode.TRAINING.value,
model_group=ModelGroup.GENERALITY,
run_mode=RunMode.TRAINING,
)
@pytest.mark.skip(reason="Support for training not implemented")
def test_flax_albert_v2_xxlarge_training(training_tester: AlbertV2Tester):
Expand Down
16 changes: 11 additions & 5 deletions tests/jax/models/bart/base/test_bart_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

import pytest
from infra import RunMode
from utils import runtime_fail
from utils import (
ModelGroup,
TestCategory,
runtime_fail,
)

from ..tester import FlaxBartForCausalLMTester

Expand All @@ -30,9 +34,10 @@ def training_tester() -> FlaxBartForCausalLMTester:

@pytest.mark.model_test
@pytest.mark.record_test_properties(
test_category="model_test",
test_category=TestCategory.MODEL_TEST,
model_name=MODEL_NAME,
run_mode=RunMode.INFERENCE.value,
model_group=ModelGroup.GENERALITY,
run_mode=RunMode.INFERENCE,
)
@pytest.mark.xfail(
reason=(
Expand All @@ -48,9 +53,10 @@ def test_flax_bart_base_inference(inference_tester: FlaxBartForCausalLMTester):

@pytest.mark.model_test
@pytest.mark.record_test_properties(
test_category="model_test",
test_category=TestCategory.MODEL_TEST,
model_name=MODEL_NAME,
run_mode=RunMode.TRAINING.value,
model_group=ModelGroup.GENERALITY,
run_mode=RunMode.TRAINING,
)
@pytest.mark.skip(reason="Support for training not implemented")
def test_flax_bart_base_training(training_tester: FlaxBartForCausalLMTester):
Expand Down
12 changes: 7 additions & 5 deletions tests/jax/models/bart/large/test_bart_large.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest
from infra import RunMode
from utils import runtime_fail
from utils import ModelGroup, TestCategory, runtime_fail

from ..tester import FlaxBartForCausalLMTester

Expand All @@ -30,9 +30,10 @@ def training_tester() -> FlaxBartForCausalLMTester:

@pytest.mark.model_test
@pytest.mark.record_test_properties(
test_category="model_test",
test_category=TestCategory.MODEL_TEST,
model_name=MODEL_NAME,
run_mode=RunMode.INFERENCE.value,
model_group=ModelGroup.GENERALITY,
run_mode=RunMode.INFERENCE,
)
@pytest.mark.xfail(
reason=(
Expand All @@ -48,9 +49,10 @@ def test_flax_bart_large_inference(inference_tester: FlaxBartForCausalLMTester):

@pytest.mark.model_test
@pytest.mark.record_test_properties(
test_category="model_test",
test_category=TestCategory.MODEL_TEST,
model_name=MODEL_NAME,
run_mode=RunMode.TRAINING.value,
model_group=ModelGroup.GENERALITY,
run_mode=RunMode.TRAINING,
)
@pytest.mark.skip(reason="Support for training not implemented")
def test_flax_bart_large_training(training_tester: FlaxBartForCausalLMTester):
Expand Down
Loading

0 comments on commit fa7d14a

Please sign in to comment.