42
42
"BertModel, with no task heads.)" )
43
43
flags .DEFINE_string ("converted_checkpoint_path" , None ,
44
44
"Name for the created object-based V2 checkpoint." )
45
+ flags .DEFINE_string ("checkpoint_model_name" , "model" ,
46
+ "The name of the model when saving the checkpoint, i.e., "
47
+ "the checkpoint will be saved using: "
48
+ "tf.train.Checkpoint(FLAGS.checkpoint_model_name=model)." )
45
49
46
50
47
51
def _create_bert_model (cfg ):
@@ -71,7 +75,8 @@ def _create_bert_model(cfg):
71
75
return bert_encoder
72
76
73
77
74
- def convert_checkpoint (bert_config , output_path , v1_checkpoint ):
78
+ def convert_checkpoint (bert_config , output_path , v1_checkpoint ,
79
+ checkpoint_model_name = "model" ):
75
80
"""Converts a V1 checkpoint into an OO V2 checkpoint."""
76
81
output_dir , _ = os .path .split (output_path )
77
82
tf .io .gfile .makedirs (output_dir )
@@ -90,7 +95,8 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint):
90
95
# Create a V2 checkpoint from the temporary checkpoint.
91
96
model = _create_bert_model (bert_config )
92
97
tf1_checkpoint_converter_lib .create_v2_checkpoint (model , temporary_checkpoint ,
93
- output_path )
98
+ output_path ,
99
+ checkpoint_model_name )
94
100
95
101
# Clean up the temporary checkpoint, if it exists.
96
102
try :
@@ -103,8 +109,10 @@ def convert_checkpoint(bert_config, output_path, v1_checkpoint):
103
109
def main (_ ):
104
110
output_path = FLAGS .converted_checkpoint_path
105
111
v1_checkpoint = FLAGS .checkpoint_to_convert
112
+ checkpoint_model_name = FLAGS .checkpoint_model_name
106
113
bert_config = configs .BertConfig .from_json_file (FLAGS .bert_config_file )
107
- convert_checkpoint (bert_config , output_path , v1_checkpoint )
114
+ convert_checkpoint (bert_config , output_path , v1_checkpoint ,
115
+ checkpoint_model_name )
108
116
109
117
110
118
if __name__ == "__main__" :
0 commit comments