Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sagemaker env settings #3368

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/book/component-guide/orchestrators/sagemaker.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ Additional configuration for the Sagemaker orchestrator can be passed via `Sagem
* `sagemaker_session`
* `entrypoint`
* `base_job_name`
* `env`
* `environment`

For example, settings can be provided and applied in the following way:

Expand All @@ -180,6 +180,7 @@ from zenml.integrations.aws.flavors.sagemaker_orchestrator_flavor import (
sagemaker_orchestrator_settings = SagemakerOrchestratorSettings(
instance_type="ml.m5.large",
volume_size_in_gb=30,
environment={"MY_ENV_VAR": "my_value"}
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class SagemakerOrchestratorSettings(BaseSettings):
For processor_args.instance_type, check
https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html
for a list of available instance types.
environment: Environment variables to pass to the container.
estimator_args: Arguments that are directly passed to the SageMaker
Estimator for a specific step, allowing for overriding the default
settings provided when configuring the component. See
Expand Down Expand Up @@ -116,6 +117,7 @@ class SagemakerOrchestratorSettings(BaseSettings):

processor_args: Dict[str, Any] = {}
estimator_args: Dict[str, Any] = {}
environment: Dict[str, str] = {}

input_data_s3_mode: str = "File"
input_data_s3_uri: Optional[Union[str, Dict[str, str]]] = Field(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class SagemakerStepOperatorSettings(BaseSettings):
For estimator_args.instance_type, check
https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-instance-types.html
for a list of available instance types.
environment: Environment variables to pass to the container.

"""

Expand All @@ -64,6 +65,7 @@ class SagemakerStepOperatorSettings(BaseSettings):
default=None, union_mode="left_to_right"
)
estimator_args: Dict[str, Any] = {}
environment: Dict[str, str] = {}

_deprecation_validator = deprecation_utils.deprecate_pydantic_attributes(
"instance_type"
Expand Down
18 changes: 18 additions & 0 deletions src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,19 @@ def prepare_or_run_pipeline(
ExecutionVariables.PIPELINE_EXECUTION_ARN
)

if step_settings.environment:
step_environment = step_settings.environment.copy()
# Sagemaker does not allow environment variables longer than 256
# characters to be passed to Processor steps. If an environment variable
# is longer than 256 characters, we split it into multiple environment
# variables (chunks) and re-construct it on the other side using the
# custom entrypoint configuration.
split_environment_variables(
size_limit=SAGEMAKER_PROCESSOR_STEP_ENV_VAR_SIZE_LIMIT,
env=step_environment,
)
environment.update(step_environment)

use_training_step = (
step_settings.use_training_step
if step_settings.use_training_step is not None
Expand Down Expand Up @@ -457,6 +470,11 @@ def prepare_or_run_pipeline(
)
)

# Convert environment to a dict of strings
environment = {
key: str(value) for key, value in environment.items()
}

if use_training_step:
# Create Estimator and TrainingStep
estimator = sagemaker.estimator.Estimator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,11 @@ def launch(
self.name,
)

settings = cast(SagemakerStepOperatorSettings, self.get_settings(info))

if settings.environment:
environment.update(settings.environment)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would allow users to potentially overwrite some crucial environment variables that we need to set in order to run steps

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should let them 😄

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this rather as a feature than a bug. It allows users to hack ZenML in all sorts of twisted ways while waiting for official releases to deliver their bug fixes / features.


# Sagemaker does not allow environment variables longer than 512
# characters to be passed to Estimator steps. If an environment variable
# is longer than 512 characters, we split it into multiple environment
Expand All @@ -194,8 +199,6 @@ def launch(
image_name = info.get_image(key=SAGEMAKER_DOCKER_IMAGE_KEY)
environment[_ENTRYPOINT_ENV_VARIABLE] = " ".join(entrypoint_command)

settings = cast(SagemakerStepOperatorSettings, self.get_settings(info))

# Get and default fill SageMaker estimator arguments for full ZenML support
estimator_args = settings.estimator_args

Expand All @@ -221,6 +224,9 @@ def launch(
"instance_type", settings.instance_type or "ml.m5.large"
)

# Convert environment to a dict of strings
environment = {key: str(value) for key, value in environment.items()}

estimator_args["environment"] = environment
estimator_args["instance_count"] = 1
estimator_args["sagemaker_session"] = session
Expand Down
Loading