|
10 | 10 | import torch
|
11 | 11 | import pytest
|
12 | 12 | import iree.compiler
|
| 13 | +import iree.runtime |
13 | 14 | from collections import OrderedDict
|
14 | 15 | from diffusers import FluxTransformer2DModel
|
15 | 16 | from sharktank.models.flux.export import (
|
@@ -84,7 +85,10 @@ def testExportDevRandomSingleLayerBf16(self):
|
84 | 85 | )
|
85 | 86 |
|
86 | 87 | def runCompareIreeAgainstTorchEager(
|
87 |
| - self, reference_model: FluxModelV1, target_dtype: torch.dtype |
| 88 | + self, |
| 89 | + reference_model: FluxModelV1, |
| 90 | + target_dtype: torch.dtype, |
| 91 | + atol: float, |
88 | 92 | ):
|
89 | 93 | target_theta = reference_model.theta.transform(
|
90 | 94 | functools.partial(set_float_dtype, dtype=target_dtype)
|
@@ -164,22 +168,30 @@ def runCompareIreeAgainstTorchEager(
|
164 | 168 | ops.to(iree_result[i], dtype=expected_outputs[i].dtype)
|
165 | 169 | for i in range(len(expected_outputs))
|
166 | 170 | ]
|
167 |
| - # TODO: figure out a good metric. Probably per pixel comparison would be good |
168 |
| - # enough. |
169 |
| - torch.testing.assert_close(actual_outputs, expected_outputs) |
| 171 | + torch.testing.assert_close(actual_outputs, expected_outputs, atol=atol, rtol=0) |
170 | 172 |
|
171 |
| - def runCompareDevRandomSingleLayerIreeAgainstTorchEager( |
172 |
| - self, reference_dtype: torch.dtype, target_dtype: torch.dtype |
| 173 | + def runTestCompareDevIreeAgainstHuggingFace( |
| 174 | + self, reference_dtype: torch.dtype, target_dtype: torch.dtype, atol: float |
173 | 175 | ):
|
174 |
| - config = make_dev_single_layer_config() |
| 176 | + parameters_output_path = self._temp_dir / "parameters.irpa" |
175 | 177 |
|
176 |
| - reference_theta = make_random_theta(config, reference_dtype) |
177 |
| - reference_theta.rename_tensors_to_paths() |
| 178 | + import_flux_transformer_dataset_from_hugging_face( |
| 179 | + repo_id="black-forest-labs/FLUX.1-dev/black-forest-labs-transformer", |
| 180 | + parameters_output_path=parameters_output_path, |
| 181 | + ) |
| 182 | + refrence_dataset = Dataset.load(parameters_output_path) |
| 183 | + refrence_dataset.root_theta = Theta( |
| 184 | + { |
| 185 | + k: set_float_dtype(t, reference_dtype) |
| 186 | + for k, t in refrence_dataset.root_theta.flatten().items() |
| 187 | + } |
| 188 | + ) |
178 | 189 | reference_model = FluxModelV1(
|
179 |
| - theta=reference_theta, |
180 |
| - params=config, |
| 190 | + theta=refrence_dataset.root_theta, |
| 191 | + params=FluxParams.from_hugging_face_properties(refrence_dataset.properties), |
181 | 192 | )
|
182 |
| - self.runCompareIreeAgainstTorchEager(reference_model, target_dtype) |
| 193 | + |
| 194 | + self.runCompareIreeAgainstTorchEager(reference_model, target_dtype, atol=atol) |
183 | 195 |
|
184 | 196 | def runTestCompareTorchEagerAgainstHuggingFace(
|
185 | 197 | self,
|
@@ -217,36 +229,16 @@ def runTestCompareTorchEagerAgainstHuggingFace(
|
217 | 229 |
|
218 | 230 | torch.testing.assert_close(target_output, reference_output, atol=atol, rtol=0)
|
219 | 231 |
|
220 |
| - @pytest.mark.xfail( |
221 |
| - raises=AssertionError, |
222 |
| - reason="Accuracy is not good enough. The observed absolute error is 8976.53.", |
223 |
| - ) |
224 |
| - @pytest.mark.skip( |
225 |
| - reason=( |
226 |
| - "Waiting on merging of fix for https://github.com/iree-org/iree/issues/19539. " |
227 |
| - "Without it IREE compilation enters an infinite loop." |
228 |
| - ) |
229 |
| - ) |
230 | 232 | @with_flux_data
|
231 |
| - def testCompareDevRandomSingleLayerIreeBf16AgainstTorchEagerF32(self): |
232 |
| - self.runCompareDevRandomSingleLayerIreeAgainstTorchEager( |
233 |
| - reference_dtype=torch.float32, target_dtype=torch.bfloat16 |
| 233 | + def testCompareDevIreeF32AgainstHuggingFaceF32(self): |
| 234 | + self.runTestCompareDevIreeAgainstHuggingFace( |
| 235 | + reference_dtype=torch.float32, target_dtype=torch.float32, atol=1e-2 |
234 | 236 | )
|
235 | 237 |
|
236 |
| - @pytest.mark.xfail( |
237 |
| - raises=AssertionError, |
238 |
| - reason="Accuracy is probably not good enough. The observed absolute error is 73.25.", |
239 |
| - ) |
240 |
| - @pytest.mark.skip( |
241 |
| - reason=( |
242 |
| - "Waiting on merging of fix for https://github.com/iree-org/iree/issues/19539. " |
243 |
| - "Without it IREE compilation enters an infinite loop." |
244 |
| - ) |
245 |
| - ) |
246 | 238 | @with_flux_data
|
247 |
| - def testCompareDevRandomSingleLayerIreeF32AgainstTorchEagerF32(self): |
248 |
| - self.runCompareDevRandomSingleLayerIreeAgainstTorchEager( |
249 |
| - reference_dtype=torch.float32, target_dtype=torch.float32 |
| 239 | + def testCompareDevIreeBf16AgainstHuggingFaceF32(self): |
| 240 | + self.runTestCompareDevIreeAgainstHuggingFace( |
| 241 | + reference_dtype=torch.float32, target_dtype=torch.bfloat16, atol=1 |
250 | 242 | )
|
251 | 243 |
|
252 | 244 | @with_flux_data
|
|
0 commit comments