@@ -79,10 +79,6 @@ def mock_training(accelerator, model):
79
79
80
80
def check_weights (operation , state_1 , state_2 ):
81
81
for weight_1 , weight_2 in zip (state_1 .values (), state_2 .values ()):
82
- if str (weight_1 .device ) != torch_device :
83
- weight_1 = weight_1 .to (torch_device )
84
- if str (weight_2 .device ) != torch_device :
85
- weight_2 = weight_2 .to (torch_device )
86
82
if operation == "same" :
87
83
assert torch .allclose (weight_1 , weight_2 )
88
84
else :
@@ -91,15 +87,15 @@ def check_weights(operation, state_1, state_2):
91
87
92
88
def check_safetensors_weights (path , model ):
93
89
safe_state_dict = load_file (path / "model.safetensors" )
94
- safe_loaded_model = TinyModel ()
90
+ safe_loaded_model = TinyModel (). to ( torch_device )
95
91
check_weights ("diff" , model .state_dict (), safe_loaded_model .state_dict ())
96
92
safe_loaded_model .load_state_dict (safe_state_dict )
97
93
check_weights ("same" , model .state_dict (), safe_loaded_model .state_dict ())
98
94
99
95
100
96
def check_pytorch_weights (path , model ):
101
97
nonsafe_state_dict = torch .load (path / "pytorch_model.bin" , weights_only = True )
102
- nonsafe_loaded_model = TinyModel ()
98
+ nonsafe_loaded_model = TinyModel (). to ( torch_device )
103
99
check_weights ("diff" , model .state_dict (), nonsafe_loaded_model .state_dict ())
104
100
nonsafe_loaded_model .load_state_dict (nonsafe_state_dict )
105
101
check_weights ("same" , model .state_dict (), nonsafe_loaded_model .state_dict ())
0 commit comments