Skip to content

Commit 17d24a1

Browse files
authored
Login to ECR if needed for Local Mode (#121)
Instead of having docker-compose generate a random network each time use a constant network. When running local mode the SDK will perform an ECR login if required. Bump version to 1.2.1
1 parent 2d1d5cf commit 17d24a1

File tree

4 files changed

+102
-5
lines changed

4 files changed

+102
-5
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
CHANGELOG
33
=========
44

5+
1.2.1
6+
========
7+
* bug-fix: Change Local Mode to use a sagemaker-local docker network
8+
59
1.2.0
610
========
711

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def read(fname):
1111

1212

1313
setup(name="sagemaker",
14-
version="1.2.0",
14+
version="1.2.1",
1515
description="Open source library for training and deploying models on Amazon SageMaker.",
1616
packages=find_packages('src'),
1717
package_dir={'': 'src'},

src/sagemaker/local/image.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,16 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13+
import base64
1314
import errno
1415
import json
1516
import logging
1617
import os
1718
import platform
19+
import random
1820
import shlex
1921
import shutil
22+
import string
2023
import subprocess
2124
import sys
2225
import tempfile
@@ -59,7 +62,10 @@ def __init__(self, instance_type, instance_count, image, sagemaker_session=None)
5962
self.instance_type = instance_type
6063
self.instance_count = instance_count
6164
self.image = image
62-
self.hosts = ['{}-{}'.format(CONTAINER_PREFIX, i) for i in range(1, self.instance_count + 1)]
65+
# Since we are using a single docker network, Generate a random suffix to attach to the container names.
66+
# This way multiple jobs can run in parallel.
67+
suffix = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(5))
68+
self.hosts = ['{}-{}-{}'.format(CONTAINER_PREFIX, i, suffix) for i in range(1, self.instance_count + 1)]
6369
self.container_root = None
6470
self.container = None
6571
# set the local config. This is optional and will use reasonable defaults
@@ -110,6 +116,8 @@ def train(self, input_data_config, hyperparameters):
110116

111117
compose_data = self._generate_compose_file('train', additional_volumes=volumes)
112118
compose_command = self._compose()
119+
120+
_ecr_login_if_needed(self.sagemaker_session.boto_session, self.image)
113121
_execute_and_stream_output(compose_command)
114122

115123
s3_model_artifacts = self.retrieve_model_artifacts(compose_data)
@@ -152,6 +160,8 @@ def serve(self, primary_container):
152160

153161
env_vars = ['{}={}'.format(k, v) for k, v in primary_container['Environment'].items()]
154162

163+
_ecr_login_if_needed(self.sagemaker_session.boto_session, self.image)
164+
155165
self._generate_compose_file('serve', additional_env_vars=env_vars)
156166
compose_command = self._compose()
157167
self.container = _HostingContainer(compose_command)
@@ -296,7 +306,11 @@ def _generate_compose_file(self, command, additional_volumes=None, additional_en
296306
content = {
297307
# Some legacy hosts only support the 2.1 format.
298308
'version': '2.1',
299-
'services': services
309+
'services': services,
310+
'networks': {
311+
'sagemaker-local': {'name': 'sagemaker-local'}
312+
}
313+
300314
}
301315

302316
docker_compose_path = os.path.join(self.container_root, DOCKER_COMPOSE_FILENAME)
@@ -335,7 +349,12 @@ def _create_docker_host(self, host, environment, optml_subdirs, command, volumes
335349
'tty': True,
336350
'volumes': [v.map for v in optml_volumes],
337351
'environment': environment,
338-
'command': command
352+
'command': command,
353+
'networks': {
354+
'sagemaker-local': {
355+
'aliases': [host]
356+
}
357+
}
339358
}
340359

341360
serving_port = 8080 if self.local_config is None else self.local_config.get('serving_port', 8080)
@@ -390,7 +409,8 @@ def _build_optml_volumes(self, host, subdirs):
390409
return volumes
391410

392411
def _cleanup(self):
393-
_check_output('docker network prune -f')
412+
# we don't need to cleanup anything at the moment
413+
pass
394414

395415

396416
class _HostingContainer(object):
@@ -525,3 +545,24 @@ def _aws_credentials(session):
525545
def _write_json_file(filename, content):
526546
with open(filename, 'w') as f:
527547
json.dump(content, f)
548+
549+
550+
def _ecr_login_if_needed(boto_session, image):
551+
# Only ECR images need login
552+
if not ('dkr.ecr' in image and 'amazonaws.com' in image):
553+
return
554+
555+
# do we have the image?
556+
if _check_output('docker images -q %s' % image).strip():
557+
return
558+
559+
ecr = boto_session.client('ecr')
560+
auth = ecr.get_authorization_token(registryIds=[image.split('.')[0]])
561+
authorization_data = auth['authorizationData'][0]
562+
563+
raw_token = base64.b64decode(authorization_data['authorizationToken'])
564+
token = raw_token.decode('utf-8').strip('AWS:')
565+
ecr_url = auth['authorizationData'][0]['proxyEndpoint']
566+
567+
cmd = "docker login -u AWS -p %s %s" % (token, ecr_url)
568+
subprocess.check_output(cmd, shell=True)

tests/unit/test_image.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13+
import base64
1314
import json
1415
import os
1516

@@ -105,6 +106,7 @@ def test_write_config_file(LocalSession, tmpdir):
105106
@patch('sagemaker.local.local_session.LocalSession')
106107
def test_retrieve_artifacts(LocalSession, tmpdir):
107108
sagemaker_container = _SageMakerContainer('local', 2, 'my-image')
109+
sagemaker_container.hosts = ['algo-1', 'algo-2'] # avoid any randomness
108110
sagemaker_container.container_root = str(tmpdir.mkdir('container-root'))
109111

110112
volume1 = os.path.join(sagemaker_container.container_root, 'algo-1/output/')
@@ -227,3 +229,53 @@ def test_serve(up, copy, copytree, tmpdir, sagemaker_session):
227229
for h in sagemaker_container.hosts:
228230
assert config['services'][h]['image'] == image
229231
assert config['services'][h]['command'] == 'serve'
232+
233+
234+
def test_ecr_login_non_ecr():
235+
session_mock = Mock()
236+
sagemaker.local.image._ecr_login_if_needed(session_mock, 'ubuntu')
237+
238+
session_mock.assert_not_called()
239+
240+
241+
@patch('sagemaker.local.image._check_output', return_value='123451324')
242+
def test_ecr_login_image_exists(_check_output):
243+
session_mock = Mock()
244+
245+
image = '520713654638.dkr.ecr.us-east-1.amazonaws.com/image-i-have:1.0'
246+
sagemaker.local.image._ecr_login_if_needed(session_mock, image)
247+
248+
session_mock.assert_not_called()
249+
_check_output.assert_called()
250+
251+
252+
@patch('subprocess.check_output', return_value=''.encode('utf-8'))
253+
def test_ecr_login_needed(check_output):
254+
session_mock = Mock()
255+
256+
token = 'very-secure-token'
257+
token_response = 'AWS:%s' % token
258+
b64_token = base64.b64encode(token_response.encode('utf-8'))
259+
response = {
260+
u'authorizationData':
261+
[
262+
{
263+
u'authorizationToken': b64_token,
264+
u'proxyEndpoint': u'https://520713654638.dkr.ecr.us-east-1.amazonaws.com'
265+
}
266+
],
267+
'ResponseMetadata':
268+
{
269+
'RetryAttempts': 0,
270+
'HTTPStatusCode': 200,
271+
'RequestId': '25b2ac63-36bf-11e8-ab6a-e5dc597d2ad9',
272+
}
273+
}
274+
session_mock.client('ecr').get_authorization_token.return_value = response
275+
image = '520713654638.dkr.ecr.us-east-1.amazonaws.com/image-i-need:1.1'
276+
sagemaker.local.image._ecr_login_if_needed(session_mock, image)
277+
278+
expected_command = 'docker login -u AWS -p %s https://520713654638.dkr.ecr.us-east-1.amazonaws.com' % token
279+
280+
check_output.assert_called_with(expected_command, shell=True)
281+
session_mock.client('ecr').get_authorization_token.assert_called_with(registryIds=['520713654638'])

0 commit comments

Comments
 (0)