From fab744a954418052d2dec6ad18e8797bc0340509 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Mon, 14 Apr 2025 18:49:53 +0000 Subject: [PATCH 1/4] fix: tgi image uri unit tests --- .../image_uris/test_huggingface_llm.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py index 0d96417e9f..205eb25239 100644 --- a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py +++ b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py @@ -13,6 +13,7 @@ from __future__ import absolute_import import pytest +from packaging import version from sagemaker.huggingface import get_huggingface_llm_image_uri from tests.unit.sagemaker.image_uris import expected_uris, conftest @@ -72,10 +73,29 @@ def test_huggingface_uris(load_config): VERSIONS = load_config["inference"]["versions"] device = load_config["inference"]["processors"][0] backend = "huggingface-neuronx" if device == "inf2" else "huggingface" + + # Fail if device is not in mapping + if device not in HF_VERSIONS_MAPPING: + raise ValueError(f"Device {device} not found in HF_VERSIONS_MAPPING") + + # Get highest version for the device + highest_version = max(HF_VERSIONS_MAPPING[device].keys(), key=lambda x: version.parse(x)) + for version in VERSIONS: ACCOUNTS = load_config["inference"]["versions"][version]["registries"] for region in ACCOUNTS.keys(): uri = get_huggingface_llm_image_uri(backend, region=region, version=version) + + # Skip only if test version is higher than highest known version. + # There's now automation to add new TGI releases to image_uri_config directory + # that doesn't involve a human raising a PR. + if version.parse(version) > version.parse(highest_version): + print( + f"Skipping test for version {test_version} as it is higher than the highest known version {highest_version}. " + "There is automation that now updates the image_uri_config without a human raising a PR." + ) + continue + expected = expected_uris.huggingface_llm_framework_uri( "huggingface-pytorch-tgi-inference", ACCOUNTS[region], From 6056bae3aae6061a78b8918883d45ee79f7880c9 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Mon, 14 Apr 2025 19:00:04 +0000 Subject: [PATCH 2/4] fix: black-format and flake8 failures --- .../unit/sagemaker/image_uris/test_huggingface_llm.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py index 205eb25239..0e085d2ca9 100644 --- a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py +++ b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py @@ -13,7 +13,7 @@ from __future__ import absolute_import import pytest -from packaging import version +from packaging.version import parse from sagemaker.huggingface import get_huggingface_llm_image_uri from tests.unit.sagemaker.image_uris import expected_uris, conftest @@ -89,10 +89,12 @@ def test_huggingface_uris(load_config): # Skip only if test version is higher than highest known version. # There's now automation to add new TGI releases to image_uri_config directory # that doesn't involve a human raising a PR. - if version.parse(version) > version.parse(highest_version): + if parse(version) > parse(highest_version): print( - f"Skipping test for version {test_version} as it is higher than the highest known version {highest_version}. " - "There is automation that now updates the image_uri_config without a human raising a PR." + f"Skipping test for version {version} as it is higher than " + "the highest known version {highest_version}. There is " + "automation that now updates the image_uri_config " + "without a human raising a PR." ) continue From 968eb3f5d1b96c018899fe3e0d2c4f9a1254ab12 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Mon, 14 Apr 2025 19:01:39 +0000 Subject: [PATCH 3/4] fix: parse --- tests/unit/sagemaker/image_uris/test_huggingface_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py index 0e085d2ca9..d12a8313e3 100644 --- a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py +++ b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py @@ -79,7 +79,7 @@ def test_huggingface_uris(load_config): raise ValueError(f"Device {device} not found in HF_VERSIONS_MAPPING") # Get highest version for the device - highest_version = max(HF_VERSIONS_MAPPING[device].keys(), key=lambda x: version.parse(x)) + highest_version = max(HF_VERSIONS_MAPPING[device].keys(), key=lambda x: parse(x)) for version in VERSIONS: ACCOUNTS = load_config["inference"]["versions"][version]["registries"] From 56b74cb31594a14a1d32b35caf00819da6ef5308 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Mon, 14 Apr 2025 22:46:31 +0000 Subject: [PATCH 4/4] fix: print statement --- tests/unit/sagemaker/image_uris/test_huggingface_llm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py index d12a8313e3..084c2d1438 100644 --- a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py +++ b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py @@ -91,10 +91,10 @@ def test_huggingface_uris(load_config): # that doesn't involve a human raising a PR. if parse(version) > parse(highest_version): print( - f"Skipping test for version {version} as it is higher than " - "the highest known version {highest_version}. There is " + f"Skipping version check for {version} as there is " "automation that now updates the image_uri_config " - "without a human raising a PR." + "without a human raising a PR. Tests will pass for " + f"versions higher than {highest_version} that are not in HF_VERSIONS_MAPPING." ) continue