Skip to content

Commit 5d21d4a

Browse files
authored
Fix FreeU tests (#7540)
update
1 parent 73ba810 commit 5d21d4a

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

tests/pipelines/test_pipelines_common.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,15 @@ def test_freeu_enabled(self):
133133

134134
inputs = self.get_dummy_inputs(torch_device)
135135
inputs["return_dict"] = False
136+
inputs["output_type"] = "np"
137+
136138
output = pipe(**inputs)[0]
137139

138140
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
139141
inputs = self.get_dummy_inputs(torch_device)
140142
inputs["return_dict"] = False
143+
inputs["output_type"] = "np"
144+
141145
output_freeu = pipe(**inputs)[0]
142146

143147
assert not np.allclose(
@@ -152,6 +156,8 @@ def test_freeu_disabled(self):
152156

153157
inputs = self.get_dummy_inputs(torch_device)
154158
inputs["return_dict"] = False
159+
inputs["output_type"] = "np"
160+
155161
output = pipe(**inputs)[0]
156162

157163
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
@@ -164,6 +170,8 @@ def test_freeu_disabled(self):
164170

165171
inputs = self.get_dummy_inputs(torch_device)
166172
inputs["return_dict"] = False
173+
inputs["output_type"] = "np"
174+
167175
output_no_freeu = pipe(**inputs)[0]
168176
assert np.allclose(
169177
output, output_no_freeu, atol=1e-2

0 commit comments

Comments
 (0)