Skip to content

Commit 81545be

Browse files
authored
Add flux transformer Dev test with IREE (nod-ai#843)
Remove the single layer with random weights test in favor for the pretrained Dev model variant. We test IREE f32 and bf16 against eager f32. The particular initialization parameters for the random values caused large intermediate values during execution, which deteriorated the model output. The pretrained variant does not suffer from this problem and the numerical error looks reasonable.
1 parent 3d36fe8 commit 81545be

File tree

1 file changed

+30
-38
lines changed

1 file changed

+30
-38
lines changed

sharktank/tests/models/flux/flux_test.py

+30-38
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
import pytest
1212
import iree.compiler
13+
import iree.runtime
1314
from collections import OrderedDict
1415
from diffusers import FluxTransformer2DModel
1516
from sharktank.models.flux.export import (
@@ -84,7 +85,10 @@ def testExportDevRandomSingleLayerBf16(self):
8485
)
8586

8687
def runCompareIreeAgainstTorchEager(
87-
self, reference_model: FluxModelV1, target_dtype: torch.dtype
88+
self,
89+
reference_model: FluxModelV1,
90+
target_dtype: torch.dtype,
91+
atol: float,
8892
):
8993
target_theta = reference_model.theta.transform(
9094
functools.partial(set_float_dtype, dtype=target_dtype)
@@ -164,22 +168,30 @@ def runCompareIreeAgainstTorchEager(
164168
ops.to(iree_result[i], dtype=expected_outputs[i].dtype)
165169
for i in range(len(expected_outputs))
166170
]
167-
# TODO: figure out a good metric. Probably per pixel comparison would be good
168-
# enough.
169-
torch.testing.assert_close(actual_outputs, expected_outputs)
171+
torch.testing.assert_close(actual_outputs, expected_outputs, atol=atol, rtol=0)
170172

171-
def runCompareDevRandomSingleLayerIreeAgainstTorchEager(
172-
self, reference_dtype: torch.dtype, target_dtype: torch.dtype
173+
def runTestCompareDevIreeAgainstHuggingFace(
174+
self, reference_dtype: torch.dtype, target_dtype: torch.dtype, atol: float
173175
):
174-
config = make_dev_single_layer_config()
176+
parameters_output_path = self._temp_dir / "parameters.irpa"
175177

176-
reference_theta = make_random_theta(config, reference_dtype)
177-
reference_theta.rename_tensors_to_paths()
178+
import_flux_transformer_dataset_from_hugging_face(
179+
repo_id="black-forest-labs/FLUX.1-dev/black-forest-labs-transformer",
180+
parameters_output_path=parameters_output_path,
181+
)
182+
refrence_dataset = Dataset.load(parameters_output_path)
183+
refrence_dataset.root_theta = Theta(
184+
{
185+
k: set_float_dtype(t, reference_dtype)
186+
for k, t in refrence_dataset.root_theta.flatten().items()
187+
}
188+
)
178189
reference_model = FluxModelV1(
179-
theta=reference_theta,
180-
params=config,
190+
theta=refrence_dataset.root_theta,
191+
params=FluxParams.from_hugging_face_properties(refrence_dataset.properties),
181192
)
182-
self.runCompareIreeAgainstTorchEager(reference_model, target_dtype)
193+
194+
self.runCompareIreeAgainstTorchEager(reference_model, target_dtype, atol=atol)
183195

184196
def runTestCompareTorchEagerAgainstHuggingFace(
185197
self,
@@ -217,36 +229,16 @@ def runTestCompareTorchEagerAgainstHuggingFace(
217229

218230
torch.testing.assert_close(target_output, reference_output, atol=atol, rtol=0)
219231

220-
@pytest.mark.xfail(
221-
raises=AssertionError,
222-
reason="Accuracy is not good enough. The observed absolute error is 8976.53.",
223-
)
224-
@pytest.mark.skip(
225-
reason=(
226-
"Waiting on merging of fix for https://github.com/iree-org/iree/issues/19539. "
227-
"Without it IREE compilation enters an infinite loop."
228-
)
229-
)
230232
@with_flux_data
231-
def testCompareDevRandomSingleLayerIreeBf16AgainstTorchEagerF32(self):
232-
self.runCompareDevRandomSingleLayerIreeAgainstTorchEager(
233-
reference_dtype=torch.float32, target_dtype=torch.bfloat16
233+
def testCompareDevIreeF32AgainstHuggingFaceF32(self):
234+
self.runTestCompareDevIreeAgainstHuggingFace(
235+
reference_dtype=torch.float32, target_dtype=torch.float32, atol=1e-2
234236
)
235237

236-
@pytest.mark.xfail(
237-
raises=AssertionError,
238-
reason="Accuracy is probably not good enough. The observed absolute error is 73.25.",
239-
)
240-
@pytest.mark.skip(
241-
reason=(
242-
"Waiting on merging of fix for https://github.com/iree-org/iree/issues/19539. "
243-
"Without it IREE compilation enters an infinite loop."
244-
)
245-
)
246238
@with_flux_data
247-
def testCompareDevRandomSingleLayerIreeF32AgainstTorchEagerF32(self):
248-
self.runCompareDevRandomSingleLayerIreeAgainstTorchEager(
249-
reference_dtype=torch.float32, target_dtype=torch.float32
239+
def testCompareDevIreeBf16AgainstHuggingFaceF32(self):
240+
self.runTestCompareDevIreeAgainstHuggingFace(
241+
reference_dtype=torch.float32, target_dtype=torch.bfloat16, atol=1
250242
)
251243

252244
@with_flux_data

0 commit comments

Comments
 (0)