Skip to content

Commit

Permalink
Fixed formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
mmaksimovicTT committed Mar 10, 2025
1 parent a242670 commit 7f376df
Showing 1 changed file with 37 additions and 34 deletions.
71 changes: 37 additions & 34 deletions tests/jax/models/vit/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from utils import record_model_test_properties, runtime_fail



class FlaxViTForImageClassificationTester(ModelTester):
"""Tester for Vision Transformer model on an image classification task using Flax."""

Expand All @@ -32,7 +31,7 @@ def _get_model(self) -> FlaxViTForImageClassification:

# @override
def _get_input_activations(self) -> Sequence[jax.Array]:
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
processor = ViTImageProcessor.from_pretrained(self._model_name)
inputs = processor(images=image, return_tensors="jax")
Expand All @@ -53,6 +52,7 @@ def _get_static_argnames(self):

# ----- Fixtures -----


@pytest.fixture
def inference_tester(request) -> FlaxViTForImageClassificationTester:
return FlaxViTForImageClassificationTester(request.param)
Expand All @@ -65,24 +65,26 @@ def training_tester(request) -> FlaxViTForImageClassificationTester:

# ----- Tests -----


@pytest.mark.nightly
@pytest.mark.parametrize(
"inference_tester",
[
"google/vit-base-patch16-384",
"google/vit-base-patch32-384",
"google/vit-base-patch16-224",
"google/vit-base-patch32-224-in21k",
"google/vit-base-patch16-224-in21k",
"google/vit-large-patch16-224-in21k",
"google/vit-large-patch16-224",
"google/vit-large-patch16-384",
"google/vit-large-patch32-224-in21k",
"google/vit-large-patch32-384",
"google/vit-huge-patch14-224-in21k",
],
indirect=True,
ids=lambda val: val)
"inference_tester",
[
"google/vit-base-patch16-384",
"google/vit-base-patch32-384",
"google/vit-base-patch16-224",
"google/vit-base-patch32-224-in21k",
"google/vit-base-patch16-224-in21k",
"google/vit-large-patch16-224-in21k",
"google/vit-large-patch16-224",
"google/vit-large-patch16-384",
"google/vit-large-patch32-224-in21k",
"google/vit-large-patch32-384",
"google/vit-huge-patch14-224-in21k",
],
indirect=True,
ids=lambda val: val,
)
@pytest.mark.xfail(
reason=runtime_fail(
"Out of memory while performing convolution."
Expand All @@ -99,22 +101,23 @@ def test_vit_inference(

@pytest.mark.nightly
@pytest.mark.parametrize(
"training_tester",
[
"google/vit-base-patch16-384",
"google/vit-base-patch32-384",
"google/vit-base-patch16-224",
"google/vit-base-patch32-224-in21k",
"google/vit-base-patch16-224-in21k",
"google/vit-large-patch16-224-in21k",
"google/vit-large-patch16-224",
"google/vit-large-patch16-384",
"google/vit-large-patch32-224-in21k",
"google/vit-large-patch32-384",
"google/vit-huge-patch14-224-in21k",
],
indirect=True,
ids=lambda val: val)
"training_tester",
[
"google/vit-base-patch16-384",
"google/vit-base-patch32-384",
"google/vit-base-patch16-224",
"google/vit-base-patch32-224-in21k",
"google/vit-base-patch16-224-in21k",
"google/vit-large-patch16-224-in21k",
"google/vit-large-patch16-224",
"google/vit-large-patch16-384",
"google/vit-large-patch32-224-in21k",
"google/vit-large-patch32-384",
"google/vit-huge-patch14-224-in21k",
],
indirect=True,
ids=lambda val: val,
)
@pytest.mark.skip(reason="Support for training not implemented")
def test_vit_training(
training_tester: FlaxViTForImageClassificationTester,
Expand Down

0 comments on commit 7f376df

Please sign in to comment.