|
10 | 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
11 | 11 | # ANY KIND, either express or implied. See the License for the specific
|
12 | 12 | # language governing permissions and limitations under the License.
|
| 13 | +import base64 |
13 | 14 | import errno
|
14 | 15 | import json
|
15 | 16 | import logging
|
16 | 17 | import os
|
17 | 18 | import platform
|
| 19 | +import random |
18 | 20 | import shlex
|
19 | 21 | import shutil
|
| 22 | +import string |
20 | 23 | import subprocess
|
21 | 24 | import sys
|
22 | 25 | import tempfile
|
@@ -59,7 +62,10 @@ def __init__(self, instance_type, instance_count, image, sagemaker_session=None)
|
59 | 62 | self.instance_type = instance_type
|
60 | 63 | self.instance_count = instance_count
|
61 | 64 | self.image = image
|
62 |
| - self.hosts = ['{}-{}'.format(CONTAINER_PREFIX, i) for i in range(1, self.instance_count + 1)] |
| 65 | + # Since we are using a single docker network, Generate a random suffix to attach to the container names. |
| 66 | + # This way multiple jobs can run in parallel. |
| 67 | + suffix = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(5)) |
| 68 | + self.hosts = ['{}-{}-{}'.format(CONTAINER_PREFIX, i, suffix) for i in range(1, self.instance_count + 1)] |
63 | 69 | self.container_root = None
|
64 | 70 | self.container = None
|
65 | 71 | # set the local config. This is optional and will use reasonable defaults
|
@@ -110,6 +116,8 @@ def train(self, input_data_config, hyperparameters):
|
110 | 116 |
|
111 | 117 | compose_data = self._generate_compose_file('train', additional_volumes=volumes)
|
112 | 118 | compose_command = self._compose()
|
| 119 | + |
| 120 | + _ecr_login_if_needed(self.sagemaker_session.boto_session, self.image) |
113 | 121 | _execute_and_stream_output(compose_command)
|
114 | 122 |
|
115 | 123 | s3_model_artifacts = self.retrieve_model_artifacts(compose_data)
|
@@ -152,6 +160,8 @@ def serve(self, primary_container):
|
152 | 160 |
|
153 | 161 | env_vars = ['{}={}'.format(k, v) for k, v in primary_container['Environment'].items()]
|
154 | 162 |
|
| 163 | + _ecr_login_if_needed(self.sagemaker_session.boto_session, self.image) |
| 164 | + |
155 | 165 | self._generate_compose_file('serve', additional_env_vars=env_vars)
|
156 | 166 | compose_command = self._compose()
|
157 | 167 | self.container = _HostingContainer(compose_command)
|
@@ -296,7 +306,11 @@ def _generate_compose_file(self, command, additional_volumes=None, additional_en
|
296 | 306 | content = {
|
297 | 307 | # Some legacy hosts only support the 2.1 format.
|
298 | 308 | 'version': '2.1',
|
299 |
| - 'services': services |
| 309 | + 'services': services, |
| 310 | + 'networks': { |
| 311 | + 'sagemaker-local': {'name': 'sagemaker-local'} |
| 312 | + } |
| 313 | + |
300 | 314 | }
|
301 | 315 |
|
302 | 316 | docker_compose_path = os.path.join(self.container_root, DOCKER_COMPOSE_FILENAME)
|
@@ -335,7 +349,12 @@ def _create_docker_host(self, host, environment, optml_subdirs, command, volumes
|
335 | 349 | 'tty': True,
|
336 | 350 | 'volumes': [v.map for v in optml_volumes],
|
337 | 351 | 'environment': environment,
|
338 |
| - 'command': command |
| 352 | + 'command': command, |
| 353 | + 'networks': { |
| 354 | + 'sagemaker-local': { |
| 355 | + 'aliases': [host] |
| 356 | + } |
| 357 | + } |
339 | 358 | }
|
340 | 359 |
|
341 | 360 | serving_port = 8080 if self.local_config is None else self.local_config.get('serving_port', 8080)
|
@@ -390,7 +409,8 @@ def _build_optml_volumes(self, host, subdirs):
|
390 | 409 | return volumes
|
391 | 410 |
|
392 | 411 | def _cleanup(self):
|
393 |
| - _check_output('docker network prune -f') |
| 412 | + # we don't need to cleanup anything at the moment |
| 413 | + pass |
394 | 414 |
|
395 | 415 |
|
396 | 416 | class _HostingContainer(object):
|
@@ -525,3 +545,24 @@ def _aws_credentials(session):
|
525 | 545 | def _write_json_file(filename, content):
|
526 | 546 | with open(filename, 'w') as f:
|
527 | 547 | json.dump(content, f)
|
| 548 | + |
| 549 | + |
| 550 | +def _ecr_login_if_needed(boto_session, image): |
| 551 | + # Only ECR images need login |
| 552 | + if not ('dkr.ecr' in image and 'amazonaws.com' in image): |
| 553 | + return |
| 554 | + |
| 555 | + # do we have the image? |
| 556 | + if _check_output('docker images -q %s' % image).strip(): |
| 557 | + return |
| 558 | + |
| 559 | + ecr = boto_session.client('ecr') |
| 560 | + auth = ecr.get_authorization_token(registryIds=[image.split('.')[0]]) |
| 561 | + authorization_data = auth['authorizationData'][0] |
| 562 | + |
| 563 | + raw_token = base64.b64decode(authorization_data['authorizationToken']) |
| 564 | + token = raw_token.decode('utf-8').strip('AWS:') |
| 565 | + ecr_url = auth['authorizationData'][0]['proxyEndpoint'] |
| 566 | + |
| 567 | + cmd = "docker login -u AWS -p %s %s" % (token, ecr_url) |
| 568 | + subprocess.check_output(cmd, shell=True) |
0 commit comments