@@ -885,11 +885,11 @@ def test_model_parallelism(self):
885
885
886
886
@require_torch_gpu
887
887
def test_sharded_checkpoints (self ):
888
+ torch .manual_seed (0 )
888
889
config , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
889
890
model = self .model_class (** config ).eval ()
890
891
model = model .to (torch_device )
891
892
892
- torch .manual_seed (0 )
893
893
base_output = model (** inputs_dict )
894
894
895
895
model_size = compute_module_sizes (model )["" ]
@@ -909,7 +909,8 @@ def test_sharded_checkpoints(self):
909
909
new_model = new_model .to (torch_device )
910
910
911
911
torch .manual_seed (0 )
912
- _ , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
912
+ if "generator" in inputs_dict :
913
+ _ , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
913
914
new_output = new_model (** inputs_dict )
914
915
915
916
self .assertTrue (torch .allclose (base_output [0 ], new_output [0 ], atol = 1e-5 ))
@@ -942,7 +943,8 @@ def test_sharded_checkpoints_device_map(self):
942
943
new_model = new_model .to (torch_device )
943
944
944
945
torch .manual_seed (0 )
945
- _ , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
946
+ if "generator" in inputs_dict :
947
+ _ , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
946
948
new_output = new_model (** inputs_dict )
947
949
self .assertTrue (torch .allclose (base_output [0 ], new_output [0 ], atol = 1e-5 ))
948
950
0 commit comments