Skip to content

Commit

Permalink
[sharktank] Test additional version on windows (#697)
Browse files Browse the repository at this point in the history
  • Loading branch information
marbre authored and eagarvey-amd committed Jan 8, 2025
1 parent 089f3cd commit 9a51938
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .github/workflows/ci-sharktank.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ jobs:
- os: windows-2022
python-version: "3.11"
torch-version: "2.3.0"
- os: windows-2022
python-version: "3.12"
torch-version: "2.4.1"
exclude:
- python-version: "3.12"
# `torch.compile` requires torch>=2.4.0 for Python 3.12+
Expand Down
5 changes: 5 additions & 0 deletions sharktank/tests/examples/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import pytest
import sys
import unittest

from sharktank.utils.testing import MainRunnerTestBase


@pytest.mark.skipif(
sys.platform == "win32", reason="https://github.com/nod-ai/shark-ai/issues/698"
)
class ShardingTests(MainRunnerTestBase):
def testExportFfnNet(self):
from sharktank.examples.sharding.export_ffn_net import main
Expand Down
4 changes: 4 additions & 0 deletions sharktank/tests/layers/sharded_conv2d_with_iree_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import pytest
import sys

from pathlib import Path
import tempfile
Expand Down Expand Up @@ -186,6 +187,9 @@ def run_test_sharded_conv2d_with_iree(
@pytest.mark.xfail(
torch.__version__ >= (2, 5), reason="https://github.com/nod-ai/shark-ai/issues/682"
)
@pytest.mark.skipif(
sys.platform == "win32", reason="https://github.com/nod-ai/shark-ai/issues/698"
)
def test_sharded_conv2d_with_iree(
mlir_path: Optional[Path],
module_path: Optional[Path],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import pytest
import sys

from pathlib import Path
import tempfile

Expand All @@ -19,7 +22,6 @@
import iree.runtime
from typing import List, Optional
import os
import pytest

vm_context: iree.runtime.VmContext = None

Expand Down Expand Up @@ -231,6 +233,9 @@ def run_test_sharded_resnet_block_with_iree(
@pytest.mark.xfail(
torch.__version__ >= (2, 5), reason="https://github.com/nod-ai/shark-ai/issues/683"
)
@pytest.mark.skipif(
sys.platform == "win32", reason="https://github.com/nod-ai/shark-ai/issues/698"
)
def test_sharded_resnet_block_with_iree(
mlir_path: Optional[Path],
module_path: Optional[Path],
Expand Down

0 comments on commit 9a51938

Please sign in to comment.