4242 "BertModel, with no task heads.)" )
4343flags .DEFINE_string ("converted_checkpoint_path" , None ,
4444 "Name for the created object-based V2 checkpoint." )
45+ flags .DEFINE_string ("checkpoint_model_name" , "model" ,
46+ "The name of the model when saving the checkpoint, i.e., "
47+ "the checkpoint will be saved using: "
48+ "tf.train.Checkpoint(FLAGS.checkpoint_model_name=model)." )
4549
4650
4751def _create_bert_model (cfg ):
@@ -71,7 +75,8 @@ def _create_bert_model(cfg):
7175 return bert_encoder
7276
7377
74- def convert_checkpoint (bert_config , output_path , v1_checkpoint ):
78+ def convert_checkpoint (bert_config , output_path , v1_checkpoint ,
79+ checkpoint_model_name = "model" ):
7580 """Converts a V1 checkpoint into an OO V2 checkpoint."""
7681 output_dir , _ = os .path .split (output_path )
7782 tf .io .gfile .makedirs (output_dir )
@@ -90,7 +95,8 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint):
9095 # Create a V2 checkpoint from the temporary checkpoint.
9196 model = _create_bert_model (bert_config )
9297 tf1_checkpoint_converter_lib .create_v2_checkpoint (model , temporary_checkpoint ,
93- output_path )
98+ output_path ,
99+ checkpoint_model_name )
94100
95101 # Clean up the temporary checkpoint, if it exists.
96102 try :
@@ -103,8 +109,10 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint):
103109def main (_ ):
104110 output_path = FLAGS .converted_checkpoint_path
105111 v1_checkpoint = FLAGS .checkpoint_to_convert
112+ checkpoint_model_name = FLAGS .checkpoint_model_name
106113 bert_config = configs .BertConfig .from_json_file (FLAGS .bert_config_file )
107- convert_checkpoint (bert_config , output_path , v1_checkpoint )
114+ convert_checkpoint (bert_config , output_path , v1_checkpoint ,
115+ checkpoint_model_name )
108116
109117
110118if __name__ == "__main__" :
0 commit comments