Skip to content

Commit ceb62be

Browse files
committed
Added xgboost_script_mode_local_training_and_serving sample
1 parent 5ece1d2 commit ceb62be

File tree

5 files changed

+285
-1
lines changed

5 files changed

+285
-1
lines changed

tensorflow_script_mode_california_housing_local_training_and_serving/tensorflow_script_mode_california_housing_local_training_and_serving.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# This is a sample Python program that trains a simple TensorFlow CIFAR-10 model.
1+
# This is a sample Python program that trains a simple TensorFlow California Housing model.
22
# This implementation will work on your *local computer* or in the *AWS Cloud*.
33
# To run training and inference *locally* set: `config = get_config(LOCAL_MODE)`
44
# To run training and inference on the *cloud* set: `config = get_config(CLOUD_MODE)` and set a valid IAM role value in get_config()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
from __future__ import print_function
14+
15+
import argparse
16+
import json
17+
import logging
18+
import os
19+
import pickle as pkl
20+
21+
import pandas as pd
22+
import xgboost as xgb
23+
from sagemaker_containers import entry_point
24+
from sagemaker_xgboost_container import distributed
25+
from sagemaker_xgboost_container.data_utils import get_dmatrix
26+
27+
28+
def _xgb_train(params, dtrain, evals, num_boost_round, model_dir, is_master):
29+
"""Run xgb train on arguments given with rabit initialized.
30+
31+
This is our rabit execution function.
32+
33+
:param args_dict: Argument dictionary used to run xgb.train().
34+
:param is_master: True if current node is master host in distributed training,
35+
or is running single node training job.
36+
Note that rabit_run will include this argument.
37+
"""
38+
booster = xgb.train(params=params, dtrain=dtrain, evals=evals, num_boost_round=num_boost_round)
39+
40+
if is_master:
41+
model_location = model_dir + "/xgboost-model"
42+
pkl.dump(booster, open(model_location, "wb"))
43+
logging.info("Stored trained model at {}".format(model_location))
44+
45+
46+
if __name__ == "__main__":
47+
parser = argparse.ArgumentParser()
48+
49+
# Hyperparameters are described here.
50+
parser.add_argument(
51+
"--max_depth",
52+
type=int,
53+
)
54+
parser.add_argument("--eta", type=float)
55+
parser.add_argument("--gamma", type=int)
56+
parser.add_argument("--min_child_weight", type=int)
57+
parser.add_argument("--subsample", type=float)
58+
parser.add_argument("--verbosity", type=int)
59+
parser.add_argument("--objective", type=str)
60+
parser.add_argument("--num_round", type=int)
61+
parser.add_argument("--tree_method", type=str, default="auto")
62+
parser.add_argument("--predictor", type=str, default="auto")
63+
64+
# Sagemaker specific arguments. Defaults are set in the environment variables.
65+
parser.add_argument("--output_data_dir", type=str, default=os.environ.get("SM_OUTPUT_DATA_DIR"))
66+
parser.add_argument("--model_dir", type=str, default=os.environ.get("SM_MODEL_DIR"))
67+
parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN"))
68+
parser.add_argument("--validation", type=str, default=os.environ.get("SM_CHANNEL_VALIDATION"))
69+
parser.add_argument("--sm_hosts", type=str, default=os.environ.get("SM_HOSTS"))
70+
parser.add_argument("--sm_current_host", type=str, default=os.environ.get("SM_CURRENT_HOST"))
71+
72+
args, _ = parser.parse_known_args()
73+
74+
# Get SageMaker host information from runtime environment variables
75+
sm_hosts = json.loads(args.sm_hosts)
76+
sm_current_host = args.sm_current_host
77+
78+
dtrain = get_dmatrix(args.train, "libsvm")
79+
dval = get_dmatrix(args.validation, "libsvm")
80+
watchlist = (
81+
[(dtrain, "train"), (dval, "validation")] if dval is not None else [(dtrain, "train")]
82+
)
83+
84+
train_hp = {
85+
"max_depth": args.max_depth,
86+
"eta": args.eta,
87+
"gamma": args.gamma,
88+
"min_child_weight": args.min_child_weight,
89+
"subsample": args.subsample,
90+
"verbosity": args.verbosity,
91+
"objective": args.objective,
92+
"tree_method": args.tree_method,
93+
"predictor": args.predictor,
94+
}
95+
96+
xgb_train_args = dict(
97+
params=train_hp,
98+
dtrain=dtrain,
99+
evals=watchlist,
100+
num_boost_round=args.num_round,
101+
model_dir=args.model_dir,
102+
)
103+
104+
if len(sm_hosts) > 1:
105+
# Wait until all hosts are able to find each other
106+
entry_point._wait_hostname_resolution()
107+
108+
# Execute training function after initializing rabit.
109+
distributed.rabit_run(
110+
exec_fun=_xgb_train,
111+
args=xgb_train_args,
112+
include_in_training=(dtrain is not None),
113+
hosts=sm_hosts,
114+
current_host=sm_current_host,
115+
update_rabit_args=True,
116+
)
117+
else:
118+
# If single node training, call training method directly.
119+
if dtrain:
120+
xgb_train_args["is_master"] = True
121+
_xgb_train(**xgb_train_args)
122+
else:
123+
raise ValueError("Training channel must have data to train model.")
124+
125+
126+
def model_fn(model_dir):
127+
"""Deserialize and return fitted model.
128+
129+
Note that this should have the same name as the serialized model in the _xgb_train method
130+
"""
131+
model_file = "xgboost-model"
132+
booster = pkl.load(open(os.path.join(model_dir, model_file), "rb"))
133+
return booster
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
import json
14+
import os
15+
import pickle as pkl
16+
17+
import numpy as np
18+
import sagemaker_xgboost_container.encoder as xgb_encoders
19+
20+
21+
def model_fn(model_dir):
22+
"""
23+
Deserialize and return fitted model.
24+
"""
25+
model_file = "xgboost-model"
26+
booster = pkl.load(open(os.path.join(model_dir, model_file), "rb"))
27+
return booster
28+
29+
30+
def input_fn(request_body, request_content_type):
31+
"""
32+
The SageMaker XGBoost model server receives the request data body and the content type,
33+
and invokes the `input_fn`.
34+
35+
Return a DMatrix (an object that can be passed to predict_fn).
36+
"""
37+
if request_content_type == "text/libsvm":
38+
return xgb_encoders.libsvm_to_dmatrix(request_body)
39+
else:
40+
raise ValueError("Content type {} is not supported.".format(request_content_type))
41+
42+
43+
def predict_fn(input_data, model):
44+
"""
45+
SageMaker XGBoost model server invokes `predict_fn` on the return value of `input_fn`.
46+
47+
Return a two-dimensional NumPy array where the first columns are predictions
48+
and the remaining columns are the feature contributions (SHAP values) for that prediction.
49+
"""
50+
prediction = model.predict(input_data)
51+
feature_contribs = model.predict(input_data, pred_contribs=True, validate_features=False)
52+
output = np.hstack((prediction[:, np.newaxis], feature_contribs))
53+
return output
54+
55+
56+
def output_fn(predictions, content_type):
57+
"""
58+
After invoking predict_fn, the model server invokes `output_fn`.
59+
"""
60+
if content_type == "text/csv":
61+
return ",".join(str(x) for x in predictions[0])
62+
else:
63+
raise ValueError("Content type {} is not supported.".format(content_type))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
numpy
2+
pandas
3+
sagemaker>=2.0.0<3.0.0
4+
sagemaker[local]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# This is a sample Python program that trains a simple XGBoost model on Abalone dataset.
2+
# This implementation will work on your *local computer* or in the *AWS Cloud*.
3+
# To run training and inference *locally* set: `config = get_config(LOCAL_MODE)`
4+
# To run training and inference on the *cloud* set: `config = get_config(CLOUD_MODE)` and set a valid IAM role value in get_config()
5+
#
6+
# Prerequisites:
7+
# 1. Install required Python packages:
8+
# `pip install -r requirements.txt`
9+
# 2. Docker Desktop installed and running on your computer:
10+
# `docker ps`
11+
# 3. You should have AWS credentials configured on your local machine
12+
# in order to be able to pull the docker image from ECR.
13+
###############################################################################################
14+
15+
from sagemaker import TrainingInput
16+
from sagemaker.xgboost import XGBoost, XGBoostModel
17+
18+
DUMMY_IAM_ROLE = 'arn:aws:iam::111111111111:role/service-role/AmazonSageMaker-ExecutionRole-20200101T000001'
19+
20+
21+
def do_inference_on_local_endpoint(predictor, libsvm_str):
22+
label, *features = libsvm_str.strip().split()
23+
predictions = predictor.predict(" ".join(["-99"] + features)) # use dummy label -99
24+
print("Prediction: {}".format(predictions))
25+
26+
27+
def main():
28+
print('Starting model training.')
29+
print('Note: if launching for the first time in local mode, container image download might take a few minutes to complete.')
30+
31+
hyperparameters = {
32+
"max_depth": "5",
33+
"eta": "0.2",
34+
"gamma": "4",
35+
"min_child_weight": "6",
36+
"subsample": "0.7",
37+
"objective": "reg:squarederror",
38+
"num_round": "50",
39+
"verbosity": "2",
40+
}
41+
42+
xgb_script_mode_estimator = XGBoost(
43+
entry_point="./code/abalone.py",
44+
hyperparameters=hyperparameters,
45+
role=DUMMY_IAM_ROLE,
46+
instance_count=1,
47+
instance_type='local',
48+
framework_version="1.2-1"
49+
)
50+
51+
train_input = TrainingInput("s3://xgboost-script-mode-local-training-and-serving/train/abalone", content_type="text/libsvm")
52+
53+
xgb_script_mode_estimator.fit({"train": train_input, "validation": train_input})
54+
55+
print('Completed model training')
56+
57+
model_data = xgb_script_mode_estimator.model_data
58+
print(model_data)
59+
60+
xgb_inference_model = XGBoostModel(
61+
model_data=model_data,
62+
role=DUMMY_IAM_ROLE,
63+
entry_point="./code/inference.py",
64+
framework_version="1.2-1",
65+
)
66+
67+
print('Deploying endpoint in local mode')
68+
predictor = xgb_inference_model.deploy(
69+
initial_instance_count=1,
70+
instance_type="local",
71+
)
72+
73+
a_young_abalone = "6 1:3 2:0.37 3:0.29 4:0.095 5:0.249 6:0.1045 7:0.058 8:0.067"
74+
do_inference_on_local_endpoint(predictor, a_young_abalone)
75+
76+
an_old_abalone = "15 1:1 2:0.655 3:0.53 4:0.175 5:1.2635 6:0.486 7:0.2635 8:0.415"
77+
do_inference_on_local_endpoint(predictor, an_old_abalone)
78+
79+
print('About to delete the endpoint to stop paying (if in cloud mode).')
80+
predictor.delete_endpoint(predictor.endpoint_name)
81+
82+
83+
if __name__ == "__main__":
84+
main()

0 commit comments

Comments
 (0)