Skip to content

Commit a785992

Browse files
authored
[Tests] fix more sharding tests (#8797)
* fix * fix * ugly * okay * fix more * fix oops
1 parent 35cc66d commit a785992

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

tests/models/test_modeling_common.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -885,11 +885,11 @@ def test_model_parallelism(self):
885885

886886
@require_torch_gpu
887887
def test_sharded_checkpoints(self):
888+
torch.manual_seed(0)
888889
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
889890
model = self.model_class(**config).eval()
890891
model = model.to(torch_device)
891892

892-
torch.manual_seed(0)
893893
base_output = model(**inputs_dict)
894894

895895
model_size = compute_module_sizes(model)[""]
@@ -909,7 +909,8 @@ def test_sharded_checkpoints(self):
909909
new_model = new_model.to(torch_device)
910910

911911
torch.manual_seed(0)
912-
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
912+
if "generator" in inputs_dict:
913+
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
913914
new_output = new_model(**inputs_dict)
914915

915916
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
@@ -942,7 +943,8 @@ def test_sharded_checkpoints_device_map(self):
942943
new_model = new_model.to(torch_device)
943944

944945
torch.manual_seed(0)
945-
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
946+
if "generator" in inputs_dict:
947+
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
946948
new_output = new_model(**inputs_dict)
947949
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
948950

0 commit comments

Comments
 (0)