6
6
import threading
7
7
8
8
import xgboost as xgb
9
+ from typing import Optional
9
10
from xgboost import rabit
10
- from xgboost .callback import _fmt_metric as format_metric
11
- from xgboost .core import Booster , XGBoostError
11
+ from xgboost .callback import EvaluationMonitor
12
+ from xgboost .core import XGBoostError
12
13
13
14
TEMP_FILE_SUFFIX = ".sagemaker-ignore"
14
15
FILE_LOCK_SUFFIX = ".sagemaker-uploading"
@@ -42,29 +43,33 @@ def train(train_args, checkpoint_dir):
42
43
43
44
xgb_model , start_iteration = load_checkpoint (checkpoint_dir )
44
45
46
+ # xgboost's default value for num_boost_round is 10.
47
+ # https://xgboost.readthedocs.io/en/stable/python/python_api.html#module-xgboost.training
48
+ # If num_boost_round <= 0, xgb.train() doesn't actually train and
49
+ # immediately returns a Booster object.
50
+ train_args ["num_boost_round" ] = train_args .get ("num_boost_round" , 10 ) - start_iteration
51
+
45
52
if xgb_model is not None :
46
53
logging .info ("Checkpoint loaded from %s" , xgb_model )
47
54
logging .info ("Resuming from iteration %s" , start_iteration )
48
55
49
56
callbacks = train_args .get ("callbacks" , [])
50
- callbacks .append (print_checkpointed_evaluation (start_iteration = start_iteration ))
51
- callbacks .append (save_checkpoint (checkpoint_dir , start_iteration = start_iteration ))
57
+ callbacks .append (print_checkpointed_evaluation (start_iteration = start_iteration ,
58
+ end_iteration = train_args ["num_boost_round" ]))
59
+ callbacks .append (save_checkpoint (checkpoint_dir , start_iteration = start_iteration , iteration = start_iteration ,
60
+ end_iteration = train_args ["num_boost_round" ]))
52
61
53
62
train_args ["verbose_eval" ] = False # suppress xgboost's print_evaluation()
54
63
train_args ["xgb_model" ] = xgb_model
55
64
train_args ["callbacks" ] = callbacks
56
- # xgboost's default value for num_boost_round is 10.
57
- # If num_boost_round <= 0, xgb.train() doesn't actually train and
58
- # immediately returns a Booster object.
59
- train_args ["num_boost_round" ] = train_args .get ("num_boost_round" , 10 ) - start_iteration
60
65
61
66
booster = xgb .train (** train_args )
62
67
63
68
return booster
64
69
65
70
66
- def print_checkpointed_evaluation ( period = 1 , show_stdv = True , start_iteration = 0 ):
67
- """Create a callback that print evaluation result.
71
+ class PrintCheckpoint ( xgb . callback . TrainingCallback ):
72
+ """Create a callback that print evaluation result every period iteration .
68
73
69
74
This function was modified from https://github.com/dmlc/xgboost/blob/master/python-package/xgboost/callback.py
70
75
The only difference between the following function and the original function in xgboost.callback
@@ -73,41 +78,62 @@ def print_checkpointed_evaluation(period=1, show_stdv=True, start_iteration=0):
73
78
We print the evaluation results every **period** iterations
74
79
and on the first and the last iterations.
75
80
76
- Parameters
81
+ Attributes
77
82
----------
78
83
period : int
79
- The period to log the evaluation results
80
-
84
+ The period to log the evaluation results
81
85
show_stdv : bool, optional
82
- Whether show stdv if provided
83
-
86
+ Whether show stdv if provided
84
87
start_iteration: int, optioonal
85
- Used for offsetting the iteratoin number that appears at the beginning of each evaluation result in the logs.
86
-
87
- Returns
88
- -------
89
- callback : function
90
- A callback that print evaluation every period iterations.
88
+ Used for offsetting the iteratoin number that appears at the beginning of each evaluation result in the logs.
91
89
"""
92
90
93
- def callback (env ):
94
- """internal function"""
95
- if env .rank != 0 or (not env .evaluation_result_list ) or period is False or period == 0 :
96
- return
97
- i = env .iteration
98
- if i % period == 0 or i + 1 == env .begin_iteration or i + 1 == env .end_iteration :
99
- msg = "\t " .join ([format_metric (x , show_stdv ) for x in env .evaluation_result_list ])
100
- rabit .tracker_print ("[%d]\t %s\n " % (i + start_iteration , msg ))
91
+ def __init__ (self , end_iteration , iteration = 0 , rank = 0 , period = 1 , show_stdv = True , start_iteration = 0 ):
92
+ self .period = period
93
+ self .show_stdv = show_stdv
94
+ self .start_iteration = start_iteration
95
+ self .rank = rank
96
+ self .iteration = iteration
97
+ self .end_iteration = end_iteration
101
98
102
- return callback
99
+ def __call__ (self , model , epoch = 0 , evals_log = None ):
100
+ return self .after_iteration (model , epoch , evals_log )
101
+
102
+ def after_iteration (self , model , epoch = 0 , evals_log = None ):
103
+ if self .rank != 0 or (not evals_log ) or self .period is False or self .period == 0 :
104
+ return
105
+ i = self .iteration
106
+ if i % self .period == 0 or i + 1 == self .start_iteration or i + 1 == self .end_iteration :
107
+ evaluation_monitor = EvaluationMonitor (self .rank , self .period , self .show_stdv )
108
+ msg : str = ""
109
+ for data , metric in evals_log .items ():
110
+ for metric_name , log in metric .items ():
111
+ stdv : Optional [float ] = None
112
+ if isinstance (log [- 1 ], tuple ):
113
+ score = log [- 1 ][0 ]
114
+ stdv = log [- 1 ][1 ]
115
+ else :
116
+ score = log [- 1 ]
117
+ msg += evaluation_monitor ._fmt_metric (data , metric_name , score , stdv )
118
+ msg += "\n "
119
+ rabit .tracker_print ("[%d]\t %s\n " % (i + self .start_iteration , msg ))
120
+
121
+
122
+ def print_checkpointed_evaluation (end_iteration , iteration = 0 , rank = 0 , period = 1 , show_stdv = True , start_iteration = 0 ):
123
+ """A callback function that print evaluation result every period iteration.
124
+
125
+ This is a wrapper function around PrintCheckpoint.
126
+ For details, see PrintCheckpoint.
127
+ """
128
+ return PrintCheckpoint (end_iteration , iteration , rank , period , show_stdv , start_iteration )
103
129
104
130
105
131
def load_checkpoint (checkpoint_dir , max_try = 5 ):
106
132
"""
107
133
:param checkpoint_dir: e.g., /opt/ml/checkpoints
108
134
:param max_try: number of times to try loading checkpoint before giving up.
109
135
:return xgb_model: file path of stored xgb model. None if no checkpoint.
110
- :return iteration: iterations completed before last checkpoiint .
136
+ :return iteration: iterations completed before last checkpoint .
111
137
"""
112
138
if not checkpoint_dir or not os .path .exists (checkpoint_dir ):
113
139
return None , 0
@@ -124,9 +150,6 @@ def load_checkpoint(checkpoint_dir, max_try=5):
124
150
try :
125
151
latest_checkpoint = checkpoints .pop ()
126
152
xgb_model = os .path .join (checkpoint_dir , latest_checkpoint )
127
- booster = Booster ()
128
- booster .load_model (xgb_model )
129
-
130
153
filename , extension = latest_checkpoint .split ("." )
131
154
iteration = int (extension ) + 1
132
155
break
@@ -141,18 +164,20 @@ def _sort_checkpoints(checkpoint_files):
141
164
return checkpoint_files
142
165
143
166
144
- def save_checkpoint (checkpoint_dir , start_iteration = 0 , max_to_keep = 5 , num_round = None ):
167
+ def save_checkpoint (checkpoint_dir , start_iteration = 0 , max_to_keep = 5 , num_round = None , rank = 0 , iteration = 0 ,
168
+ end_iteration = None ):
145
169
"""A callback function that saves checkpoints to disk.
146
170
147
171
This is a wrapper function around SaveCheckpoint.
148
172
For details, see SaveCheckpoint.
149
173
"""
150
- return SaveCheckpoint (
151
- checkpoint_dir = checkpoint_dir , start_iteration = start_iteration , max_to_keep = max_to_keep , num_round = num_round
174
+ return SaveCheckpointCallBack (
175
+ checkpoint_dir = checkpoint_dir , start_iteration = start_iteration , max_to_keep = max_to_keep , num_round = num_round ,
176
+ iteration = iteration , end_iteration = end_iteration
152
177
)
153
178
154
179
155
- class SaveCheckpoint ( object ):
180
+ class SaveCheckpointCallBack ( xgb . callback . TrainingCallback ):
156
181
"""Create a callback that saves checkpoints to disk.
157
182
158
183
The main purpose of this class is to support checkpointing for managed spot
@@ -192,19 +217,23 @@ class SaveCheckpoint(object):
192
217
after round 19, start_iteration will be 20).
193
218
num_round: (optional) indicates the number of boosting rounds.
194
219
195
- Example:
196
- >>> save_checkpoint = SaveCheckpoint("/opt/ml/checkpoints")
197
- >>> xgboost.train(prams, dtrain, callbacks=[save_checkpoint])
198
- """
220
+ Example:
221
+ >>> save_checkpoint = SaveCheckpoint("/opt/ml/checkpoints")
222
+ >>> xgboost.train(prams, dtrain, callbacks=[save_checkpoint])
223
+ """
199
224
200
225
SENTINEL = None
201
226
202
- def __init__ (self , checkpoint_dir , start_iteration = 0 , max_to_keep = 5 , num_round = None ):
227
+ def __init__ (self , checkpoint_dir , start_iteration = 0 , max_to_keep = 5 , num_round = None , rank = 0 , iteration = 0 ,
228
+ end_iteration = None ):
203
229
"""Init SaveCheckpoint with checkpoint_dir"""
204
230
self .checkpoint_dir = checkpoint_dir
205
231
self .max_to_keep = max_to_keep
206
232
self .start_iteration = start_iteration
207
233
self .num_round = num_round
234
+ self .rank = rank
235
+ self .iteration = iteration
236
+ self .end_iteration = end_iteration
208
237
209
238
if not os .path .exists (self .checkpoint_dir ):
210
239
os .makedirs (self .checkpoint_dir )
@@ -215,16 +244,46 @@ def __init__(self, checkpoint_dir, start_iteration=0, max_to_keep=5, num_round=N
215
244
216
245
self .start ()
217
246
218
- def __call__ (self , env ):
247
+ def __call__ (self , model , epoch = 0 , evals_log = None ):
219
248
"""Make the class callable since it is meant be used as a callback"""
220
- return self .callback ( env )
249
+ return self .after_iteration ( model , epoch , evals_log )
221
250
222
251
def format_path (self , iteration ):
223
252
"""Return a file path to checkpoint given a iteration number"""
224
253
filename = "{}.{}" .format (CHECKPOINT_FILENAME , iteration )
225
254
checkpoint_path = os .path .join (self .checkpoint_dir , filename )
226
255
return checkpoint_path
227
256
257
+ def after_iteration (self , model , epoch = 0 , evals_log = None ) -> bool :
258
+ # rank: master node has rank 0.
259
+ # iteration: current boosting round
260
+ # end_iteration: round # when training will end. this is always num_round + 1.
261
+ # model: model object
262
+ if self .rank != 0 :
263
+ logger .debug ("Not master (rank = %d). Exiting checkpoint callback." , self .rank )
264
+ return
265
+
266
+ if len (os .listdir (self .checkpoint_dir )) != 0 :
267
+ xgb_model , self .iteration = load_checkpoint (self .checkpoint_dir )
268
+ current_iteration = self .iteration
269
+ else :
270
+ current_iteration = self .start_iteration + self .iteration
271
+ self ._save_checkpoint (model , current_iteration )
272
+
273
+ # For example, if we are at iteration 5 and max_to_keep is 5, we no
274
+ # longer need checkpoint from iteration 0 (i.e., xgboost-checkpoint.0),
275
+ # so we put iteration_to_delete = 0 on the queue.
276
+ iteration_to_delete = current_iteration - self .max_to_keep
277
+ self .delete_queue .put (iteration_to_delete )
278
+
279
+ offset_iteration = self .end_iteration if self .num_round is None else self .num_round
280
+
281
+ training_has_ended = current_iteration + 1 >= self .start_iteration + offset_iteration
282
+
283
+ if training_has_ended :
284
+ self .stop ()
285
+ return False
286
+
228
287
def start (self ):
229
288
"""Start a background thread that deletes old checkpoints
230
289
@@ -236,7 +295,6 @@ def start(self):
236
295
When training is complete, we put SENTINEL on the queue, and when we
237
296
see the SENTINEL, we clean up and exit the thread.
238
297
"""
239
-
240
298
def _is_uploading (path ):
241
299
uploading = os .path .isfile (path + FILE_LOCK_SUFFIX )
242
300
uploaded = os .path .isfile (path + FILE_SAFE_SUFFIX )
@@ -286,7 +344,9 @@ def _delete_uploaded_files_and_cleanup():
286
344
_delete_uploaded_files ()
287
345
_cleanup ()
288
346
289
- self .thread = threading .Thread (target = _delete_uploaded_files_and_cleanup , daemon = True )
347
+ self .thread = threading .Thread (
348
+ target = _delete_uploaded_files_and_cleanup ,
349
+ daemon = True )
290
350
self .thread .start ()
291
351
292
352
def stop (self ):
@@ -304,30 +364,6 @@ def _save_checkpoint(self, model, iteration):
304
364
save_file_path = self .format_path (iteration )
305
365
os .rename (tf .name , save_file_path )
306
366
307
- def callback (self , env ):
308
- # env.rank: rabit rank of the node/process. master node has rank 0.
309
- # env.iteration: current boosting round
310
- # env.begin_iteration: round # when training started. this is always 0.
311
- # env.end_iteration: round # when training will end. this is always num_round + 1.
312
- # env.model: model object
313
- if env .rank != 0 :
314
- logger .debug ("Not master (rank = %d). Exiting checkpoint callback." , env .rank )
315
- return
316
-
317
- current_iteration = self .start_iteration + env .iteration
318
- self ._save_checkpoint (env .model , current_iteration )
319
-
320
- # For example, if we are at iteration 5 and max_to_keep is 5, we no
321
- # longer need checkpoint from iteration 0 (i.e., xgboost-checkpoint.0),
322
- # so we put iteration_to_delete = 0 on the queue.
323
- iteration_to_delete = current_iteration - self .max_to_keep
324
- self .delete_queue .put (iteration_to_delete )
325
-
326
- offset_iteration = env .end_iteration if self .num_round is None else self .num_round
327
- training_has_ended = current_iteration + 1 >= self .start_iteration + offset_iteration
328
- if training_has_ended :
329
- self .stop ()
330
-
331
367
332
368
def save_intermediate_model (intermediate_model_dir , model_name ):
333
369
"""A callback function that saves intermediate models to disk.
0 commit comments