Skip to content

Commit 63286ff

Browse files
author
Michael Denkowski
authored
Option to not reload the best training checkpoint when reducing the learning rate (#1045)
1 parent 23ffd29 commit 63286ff

File tree

6 files changed

+18
-3
lines changed

6 files changed

+18
-3
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ Note that Sockeye has checks in place to not translate with an old model that wa
1111

1212
Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.
1313

14+
## [3.1.13]
15+
16+
### Added
17+
18+
- Added `sockeye-train` argument `--no-reload-on-learning-rate-reduce` that disables reloading the best training checkpoint when reducing the learning rate. This currently only applies to the `plateau-reduce` learning rate scheduler since other schedulers do not reload checkpoints.
19+
1420
## [3.1.12]
1521

1622
### Fixed

sockeye/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@
1111
# express or implied. See the License for the specific language governing
1212
# permissions and limitations under the License.
1313

14-
__version__ = '3.1.12'
14+
__version__ = '3.1.13'

sockeye/arguments.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,12 @@ def add_training_args(params):
958958
default=0,
959959
help="Number of warmup steps. If set to x, linearly increases learning rate from 10%% "
960960
"to 100%% of the initial learning rate. Default: %(default)s.")
961+
train_params.add_argument('--no-reload-on-learning-rate-reduce',
962+
action='store_true',
963+
default=False,
964+
help='Do not reload the best training checkpoint when reducing the learning rate. '
965+
'Default: %(default)s.')
966+
961967

962968
train_params.add_argument('--fixed-param-strategy',
963969
default=None,

sockeye/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -967,7 +967,8 @@ def train(args: argparse.Namespace, custom_metrics_logger: Optional[Callable] =
967967
max_epochs=args.max_num_epochs,
968968
max_seconds=args.max_seconds,
969969
update_interval=args.update_interval,
970-
stop_training_on_decoder_failure=args.stop_training_on_decoder_failure)
970+
stop_training_on_decoder_failure=args.stop_training_on_decoder_failure,
971+
no_reload_on_learning_rate_reduce=args.no_reload_on_learning_rate_reduce)
971972
if trainer_config.min_epochs is not None and trainer_config.max_epochs is not None:
972973
check_condition(trainer_config.min_epochs <= trainer_config.max_epochs,
973974
"Minimum number of epochs must be smaller than maximum number of epochs")

sockeye/training.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class TrainerConfig(Config):
7373
max_seconds: Optional[int] = None
7474
update_interval: int = 1
7575
stop_training_on_decoder_failure: bool = False
76+
no_reload_on_learning_rate_reduce: bool = False
7677

7778

7879
class TrainState:
@@ -549,7 +550,7 @@ def _adjust_learning_rate(self, has_improved: bool):
549550
lr_adjusted = scheduler.new_evaluation_result(has_improved) # type: ignore
550551
else:
551552
lr_adjusted = False
552-
if lr_adjusted and not has_improved:
553+
if lr_adjusted and not has_improved and not self.config.no_reload_on_learning_rate_reduce:
553554
logger.info("Loading model parameters and optimizer states from best checkpoint: %d",
554555
self.state.best_checkpoint)
555556
if os.path.exists(self.best_params_fname):

test/unit/test_arguments.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def test_inference_args(test_params, expected_params):
207207
learning_rate_reduce_factor=0.9,
208208
learning_rate_reduce_num_not_improved=8,
209209
learning_rate_warmup=0,
210+
no_reload_on_learning_rate_reduce=False,
210211
fixed_param_names=[],
211212
fixed_param_strategy=None,
212213
decode_and_evaluate=500,

0 commit comments

Comments
 (0)