forked from aws/sagemaker-python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_local_mode_gpu_jobs.py
233 lines (202 loc) · 7.88 KB
/
test_local_mode_gpu_jobs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import time
from typing import Union
import os
import re
import pytest
import subprocess
import logging
import sagemaker
import boto3
import urllib3
from pathlib import Path
from sagemaker.huggingface import (
HuggingFaceModel,
get_huggingface_llm_image_uri
)
from sagemaker.deserializers import JSONDeserializer
from sagemaker.local import LocalSession
from sagemaker.serializers import JSONSerializer
# Replace this role ARN with an appropriate role for your environment
ROLE = "arn:aws:iam::111111111111:role/service-role/AmazonSageMaker-ExecutionRole-20200101T000001"
def ensure_docker_compose_installed():
"""
Downloads the Docker Compose plugin if not present, and verifies installation
by checking the output of 'docker compose version' matches the pattern:
'Docker Compose version vX.Y.Z'
"""
cli_plugins_path = Path.home() / ".docker" / "cli-plugins"
cli_plugins_path.mkdir(parents=True, exist_ok=True)
compose_binary_path = cli_plugins_path / "docker-compose"
if not compose_binary_path.exists():
subprocess.run(
[
"curl",
"-SL",
"https://github.com/docker/compose/releases/download/v2.3.3/docker-compose-linux-x86_64",
"-o",
str(compose_binary_path),
],
check=True,
)
subprocess.run(["chmod", "+x", str(compose_binary_path)], check=True)
# Verify Docker Compose version
try:
output = subprocess.check_output(["docker", "compose", "version"], stderr=subprocess.STDOUT)
output_decoded = output.decode("utf-8").strip()
logging.info(f"'docker compose version' output: {output_decoded}")
# Example expected format: "Docker Compose version vxxx"
pattern = r"Docker Compose version+"
match = re.search(pattern, output_decoded)
assert (
match is not None
), f"Could not find a Docker Compose version string matching '{pattern}' in: {output_decoded}"
except subprocess.CalledProcessError as e:
raise AssertionError(f"Failed to verify Docker Compose: {e}")
"""
Local Model: HuggingFace LLM Inference
"""
@pytest.mark.local
def test_huggingfacellm_local_model_inference():
"""
Test local mode inference with DJL-LMI inference containers
without a model_data path provided at runtime. This test should
be run on a GPU only machine with instance set to local_gpu.
"""
ensure_docker_compose_installed()
# 1. Create a local session for inference
sagemaker_session = LocalSession()
sagemaker_session.config = {"local": {"local_code": True}}
djllmi_model = sagemaker.Model(
image_uri="763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124",
env={
"HF_MODEL_ID": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
"OPTION_MAX_MODEL_LEN": "10000",
"OPTION_GPU_MEMORY_UTILIZATION": "0.95",
"OPTION_ENABLE_STREAMING": "false",
"OPTION_ROLLING_BATCH": "auto",
"OPTION_MODEL_LOADING_TIMEOUT": "3600",
"OPTION_PAGED_ATTENTION": "false",
"OPTION_DTYPE": "fp16",
},
role=ROLE,
sagemaker_session=sagemaker_session
)
logging.warning('Deploying endpoint in local mode')
logging.warning(
'Note: if launching for the first time in local mode, container image download might take a few minutes to complete.'
)
endpoint_name = "test-djl"
djllmi_model.deploy(
endpoint_name=endpoint_name,
initial_instance_count=1,
instance_type="local_gpu",
container_startup_health_check_timeout=600,
)
predictor = sagemaker.Predictor(
endpoint_name=endpoint_name,
sagemaker_session=sagemaker_session,
serializer=JSONSerializer(),
deserializer=JSONDeserializer(),
)
test_response = predictor.predict(
{
"inputs": """<|begin_of_text|>
<|start_header_id|>system<|end_header_id|>
You are a helpful assistant that thinks and reasons before answering.
<|eot_id|>
<|start_header_id|>user<|end_header_id|>
What's 2x2?
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
"""
}
)
logging.warning(test_response)
gen_text = test_response['generated_text']
logging.warning(f"\n=======\nmodel response: {gen_text}\n=======\n")
assert type(test_response) == dict, f"invalid model response format: {gen_text}"
assert type(gen_text) == str, f"assistant response format: {gen_text}"
logging.warning('About to delete the endpoint')
predictor.delete_endpoint()
"""
Local Model: HuggingFace TGI Inference
"""
@pytest.mark.local
def test_huggingfacetgi_local_model_inference():
"""
Test local mode inference with HuggingFace TGI inference containers
without a model_data path provided at runtime. This test should
be run on a GPU only machine with instance set to local_gpu.
"""
ensure_docker_compose_installed()
# 1. Create a local session for inference
sagemaker_session = LocalSession()
sagemaker_session.config = {"local": {"local_code": True}}
huggingface_model = HuggingFaceModel(
image_uri=get_huggingface_llm_image_uri(
"huggingface",
version="2.3.1"
),
env={
"HF_MODEL_ID": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
"ENDPOINT_SERVER_TIMEOUT": "3600",
"MESSAGES_API_ENABLED": "true",
"OPTION_ENTRYPOINT": "inference.py",
"SAGEMAKER_ENV": "1",
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
"SAGEMAKER_PROGRAM": "inference.py",
"SM_NUM_GPUS": "1",
"MAX_TOTAL_TOKENS": "1024",
"MAX_INPUT_TOKENS": "800",
"MAX_BATCH_PREFILL_TOKENS": "900",
"DTYPE": "bfloat16",
"PORT": "8080"
},
role=ROLE,
sagemaker_session=sagemaker_session
)
logging.warning('Deploying endpoint in local mode')
logging.warning(
'Note: if launching for the first time in local mode, container image download might take a few minutes to complete.'
)
endpoint_name = "test-hf"
huggingface_model.deploy(
endpoint_name=endpoint_name,
initial_instance_count=1,
instance_type="local_gpu",
container_startup_health_check_timeout=600,
)
predictor = sagemaker.Predictor(
endpoint_name=endpoint_name,
sagemaker_session=sagemaker_session,
serializer=JSONSerializer(),
deserializer=JSONDeserializer(),
)
test_response = predictor.predict(
{
"messages": [
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "What is 2x2?"}
]
}
)
logging.warning(test_response)
gen_text = test_response['choices'][0]['message']
logging.warning(f"\n=======\nmodel response: {gen_text}\n=======\n")
assert type(gen_text) == dict, f"invalid model response: {gen_text}"
assert gen_text['role'] == 'assistant', f"assistant response missing: {gen_text}"
logging.warning('About to delete the endpoint')
predictor.delete_endpoint()