Skip to content

Commit ccf7da9

Browse files
aichendoubletensorflower-gardener
authored andcommitted
Add a FLAG checkpoint_model_name to specify the object name when saving the checkpoint,
i.e., the checkpoint will be saved using tf.train.Checkpoint(FLAGS.checkpoint_model_name=model) PiperOrigin-RevId: 326672697
1 parent cf82a72 commit ccf7da9

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

official/nlp/bert/tf1_checkpoint_converter_lib.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,14 @@ def _get_new_shape(name, shape, num_heads):
111111
return None
112112

113113

114-
def create_v2_checkpoint(model, src_checkpoint, output_path):
114+
def create_v2_checkpoint(model,
115+
src_checkpoint,
116+
output_path,
117+
checkpoint_model_name="model"):
115118
"""Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
116119
# Uses streaming-restore in eager model to read V1 name-based checkpoints.
117120
model.load_weights(src_checkpoint).assert_existing_objects_matched()
118-
checkpoint = tf.train.Checkpoint(model=model)
121+
checkpoint = tf.train.Checkpoint(**{checkpoint_model_name: model})
119122
checkpoint.save(output_path)
120123

121124

official/nlp/bert/tf2_encoder_checkpoint_converter.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@
4242
"BertModel, with no task heads.)")
4343
flags.DEFINE_string("converted_checkpoint_path", None,
4444
"Name for the created object-based V2 checkpoint.")
45+
flags.DEFINE_string("checkpoint_model_name", "model",
46+
"The name of the model when saving the checkpoint, i.e., "
47+
"the checkpoint will be saved using: "
48+
"tf.train.Checkpoint(FLAGS.checkpoint_model_name=model).")
4549

4650

4751
def _create_bert_model(cfg):
@@ -71,7 +75,8 @@ def _create_bert_model(cfg):
7175
return bert_encoder
7276

7377

74-
def convert_checkpoint(bert_config, output_path, v1_checkpoint):
78+
def convert_checkpoint(bert_config, output_path, v1_checkpoint,
79+
checkpoint_model_name="model"):
7580
"""Converts a V1 checkpoint into an OO V2 checkpoint."""
7681
output_dir, _ = os.path.split(output_path)
7782
tf.io.gfile.makedirs(output_dir)
@@ -90,7 +95,8 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint):
9095
# Create a V2 checkpoint from the temporary checkpoint.
9196
model = _create_bert_model(bert_config)
9297
tf1_checkpoint_converter_lib.create_v2_checkpoint(model, temporary_checkpoint,
93-
output_path)
98+
output_path,
99+
checkpoint_model_name)
94100

95101
# Clean up the temporary checkpoint, if it exists.
96102
try:
@@ -103,8 +109,10 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint):
103109
def main(_):
104110
output_path = FLAGS.converted_checkpoint_path
105111
v1_checkpoint = FLAGS.checkpoint_to_convert
112+
checkpoint_model_name = FLAGS.checkpoint_model_name
106113
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
107-
convert_checkpoint(bert_config, output_path, v1_checkpoint)
114+
convert_checkpoint(bert_config, output_path, v1_checkpoint,
115+
checkpoint_model_name)
108116

109117

110118
if __name__ == "__main__":

0 commit comments

Comments
 (0)