@@ -133,11 +133,15 @@ def test_freeu_enabled(self):
133
133
134
134
inputs = self .get_dummy_inputs (torch_device )
135
135
inputs ["return_dict" ] = False
136
+ inputs ["output_type" ] = "np"
137
+
136
138
output = pipe (** inputs )[0 ]
137
139
138
140
pipe .enable_freeu (s1 = 0.9 , s2 = 0.2 , b1 = 1.2 , b2 = 1.4 )
139
141
inputs = self .get_dummy_inputs (torch_device )
140
142
inputs ["return_dict" ] = False
143
+ inputs ["output_type" ] = "np"
144
+
141
145
output_freeu = pipe (** inputs )[0 ]
142
146
143
147
assert not np .allclose (
@@ -152,6 +156,8 @@ def test_freeu_disabled(self):
152
156
153
157
inputs = self .get_dummy_inputs (torch_device )
154
158
inputs ["return_dict" ] = False
159
+ inputs ["output_type" ] = "np"
160
+
155
161
output = pipe (** inputs )[0 ]
156
162
157
163
pipe .enable_freeu (s1 = 0.9 , s2 = 0.2 , b1 = 1.2 , b2 = 1.4 )
@@ -164,6 +170,8 @@ def test_freeu_disabled(self):
164
170
165
171
inputs = self .get_dummy_inputs (torch_device )
166
172
inputs ["return_dict" ] = False
173
+ inputs ["output_type" ] = "np"
174
+
167
175
output_no_freeu = pipe (** inputs )[0 ]
168
176
assert np .allclose (
169
177
output , output_no_freeu , atol = 1e-2
0 commit comments