Skip to content

Commit

Permalink
Add the initial batch of JAX model tests (#129)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgligorijevicTT authored Jan 3, 2025
1 parent 02f0440 commit 51d5799
Show file tree
Hide file tree
Showing 26 changed files with 627 additions and 70 deletions.
7 changes: 4 additions & 3 deletions tests/infra/base_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ def __init__(
) -> None:
self._comparison_config = comparison_config

@staticmethod
def _compile(executable: Callable) -> Callable:
def _compile(
self, executable: Callable, static_argnames: Sequence[str] = None
) -> Callable:
"""Sets up `executable` for just-in-time compile."""
return jax.jit(executable)
return jax.jit(executable, static_argnames=static_argnames)

def _compare(
self,
Expand Down
34 changes: 19 additions & 15 deletions tests/infra/device_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from .types import Tensor
from .workload import Workload

import inspect


class DeviceRunner:
"""
Expand Down Expand Up @@ -100,25 +102,27 @@ def _safely_put_workload_on_device(
To avoid that, we try to `jax.device_put` arg or kwarg, and if it doesn't
succeed, we leave it as is.
"""
args_on_device = []

for arg in workload.args:
try:
arg_on_device = jax.device_put(arg, device)
except:
arg_on_device = arg
fn_params = list(inspect.signature(workload.executable).parameters.keys())

args_on_device.append(arg_on_device)
args_on_device = []
for i, arg in enumerate(workload.args):
if fn_params[i] not in workload.static_argnames:
try:
args_on_device.append(jax.device_put(arg, device))
except:
args_on_device.append(arg)
else:
args_on_device.append(arg)

kwargs_on_device = {}

for key, value in workload.kwargs.items():
try:
value_on_device = jax.device_put(value, device)
except:
value_on_device = value

kwargs_on_device[key] = value_on_device
if key not in workload.static_argnames:
try:
kwargs_on_device[key] = jax.device_put(value, device)
except:
kwargs_on_device[key] = value
else:
kwargs_on_device[key] = value

return Workload(workload.executable, args_on_device, kwargs_on_device)

Expand Down
42 changes: 30 additions & 12 deletions tests/infra/model_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ class ModelTester(BaseTester, ABC):
Derived classes must provide implementations of:
```
_get_model() -> Model
_get_input_activations() -> Sequence[Any]
_get_forward_method_name() -> str # Optional, has default behaviour.
_get_model(self) -> Model
_get_input_activations(self) -> Sequence[Any]
_get_forward_method_name(self) -> str # Optional, has default behaviour.
# One of or both:
_get_forward_method_args(self) -> Sequence[Any] # Optional, has default behaviour.
_get_forward_method_kwargs(self) -> Mapping[str, Any] # Optional, has default behaviour.
Expand Down Expand Up @@ -72,23 +72,24 @@ def _init_model_hooks(self) -> None:

forward_pass_method = getattr(self._model, forward_method_name)

forward_static_args = self._get_static_argnames()

# Store model's forward pass method and its arguments as a workload.
self._workload = Workload(forward_pass_method, args, kwargs)
self._workload = Workload(
forward_pass_method, args, kwargs, forward_static_args
)

@staticmethod
@abstractmethod
def _get_model() -> Model:
def _get_model(self) -> Model:
"""Returns model instance."""
raise NotImplementedError("Subclasses must implement this method.")

@staticmethod
@abstractmethod
def _get_input_activations() -> Sequence[Any]:
def _get_input_activations(self) -> Sequence[Any]:
"""Returns input activations."""
raise NotImplementedError("Subclasses must implement this method.")

@staticmethod
def _get_forward_method_name() -> str:
def _get_forward_method_name(self) -> str:
"""
Returns string name of model's forward pass method.
Expand Down Expand Up @@ -119,6 +120,18 @@ def _get_forward_method_kwargs(self) -> Mapping[str, Any]:
"""
return {}

def _get_static_argnames(self) -> Sequence[str]:
"""
Return the names of arguments which should be treated as static by JIT compiler.
Static arguments are those which are not replaced with Tracer objects by the JIT
but rather are used as is, which is needed if control flow or shapes depend on them.
https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables
By default no arguments are static.
"""
return []

def test(self) -> None:
"""Tests the model depending on test type with which tester was configured."""
if self._run_mode == RunMode.INFERENCE:
Expand All @@ -136,7 +149,10 @@ def _test_inference(self) -> None:
compiled_forward_method = self._compile_model()

compiled_workload = Workload(
compiled_forward_method, self._workload.args, self._workload.kwargs
compiled_forward_method,
self._workload.args,
self._workload.kwargs,
self._workload.static_argnames,
)

tt_res = DeviceRunner.run_on_tt_device(compiled_workload)
Expand Down Expand Up @@ -175,4 +191,6 @@ def _configure_model_for_training(model: Model) -> None:

def _compile_model(self) -> Callable:
"""JIT-compiles model's forward pass into optimized kernels."""
return super()._compile(self._workload.executable)
return super()._compile(
self._workload.executable, self._workload.static_argnames
)
3 changes: 3 additions & 0 deletions tests/infra/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@ class Workload:
executable: Callable
args: Sequence[Any]
kwargs: Optional[Mapping[str, Any]] = None
static_argnames: Optional[Sequence[str]] = None

def __post_init__(self):
# If kwargs is None, initialize it to an empty dictionary.
if self.kwargs is None:
self.kwargs = {}
if self.static_argnames is None:
self.static_argnames = []

def execute(self) -> Any:
"""Calls callable passing stored args and kwargs directly."""
Expand Down
Empty file.
Empty file.
41 changes: 41 additions & 0 deletions tests/jax/models/albert/v2/base/test_albert_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

import pytest
from infra import RunMode

from ..tester import AlbertV2Tester


MODEL_PATH = "albert/albert-base-v2"


# ----- Fixtures -----


@pytest.fixture
def inference_tester() -> AlbertV2Tester:
return AlbertV2Tester(MODEL_PATH)


@pytest.fixture
def training_tester() -> AlbertV2Tester:
return AlbertV2Tester(MODEL_PATH, RunMode.TRAINING)


# ----- Tests -----


@pytest.mark.skip(reason="failed to legalize operation 'stablehlo.dot_general'")
def test_flax_albert_v2_base_inference(
inference_tester: AlbertV2Tester,
):
inference_tester.test()


@pytest.mark.skip(reason="Support for training not implemented")
def test_flax_albert_v2_base_training(
training_tester: AlbertV2Tester,
):
training_tester.test()
Empty file.
40 changes: 40 additions & 0 deletions tests/jax/models/albert/v2/large/test_albert_large.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

import pytest
from infra import RunMode

from ..tester import AlbertV2Tester

MODEL_PATH = "albert/albert-large-v2"


# ----- Fixtures -----


@pytest.fixture
def inference_tester() -> AlbertV2Tester:
return AlbertV2Tester(MODEL_PATH)


@pytest.fixture
def training_tester() -> AlbertV2Tester:
return AlbertV2Tester(MODEL_PATH, RunMode.TRAINING)


# ----- Tests -----


@pytest.mark.skip(reason="failed to legalize operation 'stablehlo.dot_general'")
def test_flax_albert_v2_large_inference(
inference_tester: AlbertV2Tester,
):
inference_tester.test()


@pytest.mark.skip(reason="Support for training not implemented")
def test_flax_albert_v2_large_training(
training_tester: AlbertV2Tester,
):
training_tester.test()
48 changes: 48 additions & 0 deletions tests/jax/models/albert/v2/tester.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, Sequence

import jax
from flax import linen as nn
from infra import ModelTester, RunMode, ComparisonConfig
from transformers import AutoTokenizer, FlaxAlbertForMaskedLM


class AlbertV2Tester(ModelTester):
"""Tester for Albert model on a masked language modeling task."""

def __init__(
self,
model_name: str,
comparison_config: ComparisonConfig = ComparisonConfig(),
run_mode: RunMode = RunMode.INFERENCE,
) -> None:
self._model_name = model_name
super().__init__(comparison_config, run_mode)

# @override
def _get_model(self) -> nn.Module:
return FlaxAlbertForMaskedLM.from_pretrained(self._model_name)

# @override
def _get_input_activations(self) -> Sequence[jax.Array]:
tokenizer = AutoTokenizer.from_pretrained(self._model_name)
inputs = tokenizer("Hello [MASK].", return_tensors="np")
return inputs["input_ids"]

# @override
def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]:
assert hasattr(self._model, "params")
return {
"params": self._model.params,
"input_ids": self._get_input_activations(),
}

# @ override
def _get_static_argnames(self):
return ["train"]


# TODO(stefan): Add testers for Albert when used as a question answering or sentiment analysis model.
Empty file.
40 changes: 40 additions & 0 deletions tests/jax/models/albert/v2/xlarge/test_albert_xlarge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

import pytest
from infra import RunMode

from ..tester import AlbertV2Tester

MODEL_PATH = "albert/albert-xlarge-v2"


# ----- Fixtures -----


@pytest.fixture
def inference_tester() -> AlbertV2Tester:
return AlbertV2Tester(MODEL_PATH)


@pytest.fixture
def training_tester() -> AlbertV2Tester:
return AlbertV2Tester(MODEL_PATH, RunMode.TRAINING)


# ----- Tests -----


@pytest.mark.skip(reason="failed to legalize operation 'stablehlo.dot_general'")
def test_flax_albert_v2_xlarge_inference(
inference_tester: AlbertV2Tester,
):
inference_tester.test()


@pytest.mark.skip(reason="Support for training not implemented")
def test_flax_albert_v2_xlarge_training(
training_tester: AlbertV2Tester,
):
training_tester.test()
Empty file.
40 changes: 40 additions & 0 deletions tests/jax/models/albert/v2/xxlarge/test_albert_xxlarge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

import pytest
from infra import RunMode

from ..tester import AlbertV2Tester

MODEL_PATH = "albert/albert-xxlarge-v2"


# ----- Fixtures -----


@pytest.fixture
def inference_tester() -> AlbertV2Tester:
return AlbertV2Tester(MODEL_PATH)


@pytest.fixture
def training_tester() -> AlbertV2Tester:
return AlbertV2Tester(MODEL_PATH, RunMode.TRAINING)


# ----- Tests -----


@pytest.mark.skip(reason="failed to legalize operation 'stablehlo.dot_general'")
def test_flax_albert_v2_xxlarge_inference(
inference_tester: AlbertV2Tester,
):
inference_tester.test()


@pytest.mark.skip(reason="Support for training not implemented")
def test_flax_albert_v2_xxlarge_training(
training_tester: AlbertV2Tester,
):
training_tester.test()
Loading

0 comments on commit 51d5799

Please sign in to comment.