Skip to content

Commit 30c3913

Browse files
author
Michael Denkowski
authored
Ignore false-positive missing keys for traced modules (#1042)
1 parent b2d5e85 commit 30c3913

File tree

4 files changed

+18
-2
lines changed

4 files changed

+18
-2
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ Note that Sockeye has checks in place to not translate with an old model that wa
1111

1212
Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.
1313

14+
## [3.1.10]
15+
16+
### Fixed
17+
18+
- When loading parameters, SockeyeModel now ignores false positive missing parameters for traced modules. These modules use the same parameters as their original non-traced versions.
19+
1420
## [3.1.9]
1521

1622
### Changed

sockeye/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@
1111
# express or implied. See the License for the specific language governing
1212
# permissions and limitations under the License.
1313

14-
__version__ = '3.1.9'
14+
__version__ = '3.1.10'

sockeye/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,11 @@ def load_parameters(self,
362362
# Earlier versions of Sockeye may have saved parameters for traced
363363
# modules. These parameters can be safely ignored.
364364
unexpected = [key for key in unexpected if 'traced' not in key]
365+
# We also ignore cases where traced modules exist and appear to be
366+
# missing parameters. These modules actually use the same parameters as
367+
# their original non-traced versions so there are no separate parameters
368+
# to load.
369+
missing = [key for key in missing if 'traced' not in key]
365370
if not allow_missing:
366371
utils.check_condition(not missing, f"missing keys: {missing}")
367372
if not ignore_extra:

test/integration/test_seq_copy_int.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,10 @@ def _test_parameter_averaging(model_path: str):
231231

232232
def _test_checkpoint_decoder(dev_source_path: str, dev_target_path: str, model_path: str):
233233
"""
234-
Runs checkpoint decoder on 10% of the dev data and checks whether metric keys are present in the result dict.
234+
Runs checkpoint decoder on 10% of the dev data and checks whether metric
235+
keys are present in the result dict. Also checks that we can reload model
236+
parameters after running the checkpoint decoder (case when using the
237+
plateau-reduce scheduler).
235238
"""
236239
with open(dev_source_path) as dev_fd:
237240
num_dev_sent = sum(1 for _ in dev_fd)
@@ -254,3 +257,5 @@ def _test_checkpoint_decoder(dev_source_path: str, dev_target_path: str, model_p
254257
assert 'bleu' in cp_metrics
255258
assert 'chrf' in cp_metrics
256259
assert 'decode-walltime' in cp_metrics
260+
261+
model.load_parameters(os.path.join(model_path, C.PARAMS_BEST_NAME), device=pt.device('cpu'))

0 commit comments

Comments
 (0)