Skip to content

Commit 0da9f25

Browse files
authored
Some cleanup on flux integration (#829)
Merges the diffusers reference models for sdxl and flux vae models. Also renames the exported function to decode from forward to avoid confusion with vae encode to be added in the future
1 parent d508b48 commit 0da9f25

File tree

4 files changed

+65
-52
lines changed

4 files changed

+65
-52
lines changed

sharktank/sharktank/models/vae/README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,14 @@ python -m sharktank.models.punet.tools.import_hf_dataset \
1616
```
1717

1818
# Run Vae decoder model eager mode
19+
# Sample SDXL command
1920
```
20-
python -m sharktank.models.vae.tools.run_vae --irpa-file ~/models/vae.irpa --device cpu
21+
python -m sharktank.models.vae.tools.run_vae --irpa-file ~/models/vae.irpa --device cpu --dtype=float32
22+
```
23+
# Sample Flux command to run through iree and compare vs huggingface diffusers torch model
24+
```
25+
python -m sharktank.models.vae.tools.run_vae --irpa-file ~/models/vae.irpa --device cpu --compare_vs_torch --dtype=float32 --sharktank_config=flux --torch_model=black-forest-labs/FLUX.1-dev
2126
```
22-
2327
## License
2428

2529
Significant portions of this implementation were derived from diffusers,

sharktank/sharktank/models/vae/tools/diffuser_ref.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,15 @@ def __init__(
1616
self,
1717
hf_model_name,
1818
custom_vae="",
19+
height=1024,
20+
width=1024,
21+
flux=False,
1922
):
2023
super().__init__()
2124
self.vae = None
25+
self.height = height
26+
self.width = width
27+
self.flux = flux
2228
if custom_vae in ["", None]:
2329
self.vae = AutoencoderKL.from_pretrained(
2430
hf_model_name,
@@ -44,43 +50,37 @@ def __init__(
4450
custom_vae,
4551
subfolder="vae",
4652
)
53+
self.shift_factor = (
54+
0.0
55+
if self.vae.config.shift_factor is None
56+
else self.vae.config.shift_factor
57+
)
4758

4859
def decode(self, inp):
49-
# The reference vae decode does not do scaling and leaves it for the sdxl pipeline. We integrate it into vae for pipeline performance so using the hardcoded values from the config.json here
50-
img = 1 / 0.13025 * inp
60+
if self.flux:
61+
inp = rearrange(
62+
inp,
63+
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
64+
h=math.ceil(self.height / 16),
65+
w=math.ceil(self.width / 16),
66+
ph=2,
67+
pw=2,
68+
)
69+
img = inp / self.vae.config.scaling_factor + self.shift_factor
5170
x = self.vae.decode(img, return_dict=False)[0]
52-
return (x / 2 + 0.5).clamp(0, 1)
71+
if self.flux:
72+
return x.clamp(-1, 1)
73+
else:
74+
return (x / 2 + 0.5).clamp(0, 1)
5375

5476

55-
def run_torch_vae(hf_model_name, example_input):
56-
vae_model = VaeModel(hf_model_name)
77+
def run_torch_vae(
78+
hf_model_name,
79+
example_input,
80+
height=1024,
81+
width=1024,
82+
flux=False,
83+
dtype=torch.float32,
84+
):
85+
vae_model = VaeModel(hf_model_name, height=height, width=width, flux=flux).to(dtype)
5786
return vae_model.decode(example_input)
58-
59-
60-
# TODO Remove and integrate with VaeModel
61-
class FluxAEWrapper(torch.nn.Module):
62-
def __init__(self, height=1024, width=1024):
63-
super().__init__()
64-
self.ae = AutoencoderKL.from_pretrained(
65-
"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16
66-
)
67-
self.height = height
68-
self.width = width
69-
70-
def forward(self, z):
71-
d_in = rearrange(
72-
z,
73-
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
74-
h=math.ceil(self.height / 16),
75-
w=math.ceil(self.width / 16),
76-
ph=2,
77-
pw=2,
78-
)
79-
d_in = d_in / self.ae.config.scaling_factor + self.ae.config.shift_factor
80-
return self.ae.decode(d_in, return_dict=False)[0].clamp(-1, 1)
81-
82-
83-
def run_flux_vae(example_input, dtype):
84-
# TODO add support for other height/width sizes
85-
vae_model = FluxAEWrapper(1024, 1024).to(dtype)
86-
return vae_model.forward(example_input)

sharktank/sharktank/models/vae/tools/run_vae.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def export_vae(model, sample_inputs, decomp_attn):
3939
fxb = FxProgramsBuilder(model)
4040

4141
@fxb.export_program(
42-
name=f"forward",
42+
name=f"decode",
4343
args=tuple(torch.unsqueeze(sample_inputs, 0)),
4444
strict=False,
4545
)
@@ -86,7 +86,7 @@ def main(argv):
8686
parser.add_argument(
8787
"--torch_model",
8888
default="stabilityai/stable-diffusion-xl-base-1.0",
89-
help="HF reference model id",
89+
help="HF reference model id, currently tested with stabilityai/stable-diffusion-xl-base-1.0 and black-forest-labs/FLUX.1-dev",
9090
)
9191

9292
parser.add_argument(
@@ -141,12 +141,14 @@ def main(argv):
141141
intermediates_saver.save_file(args.save_intermediates_path)
142142

143143
if args.compare_vs_torch:
144-
from .diffuser_ref import run_torch_vae, run_flux_vae
144+
from .diffuser_ref import run_torch_vae
145145

146146
if args.sharktank_config == "flux":
147-
diffusers_results = run_flux_vae(inputs, torch.bfloat16)
147+
diffusers_results = run_torch_vae(
148+
args.torch_model, inputs, flux=True, dtype=dtype
149+
)
148150
elif args.sharktank_config == "sdxl":
149-
run_torch_vae(args.torch_model, inputs)
151+
run_torch_vae(args.torch_model, inputs, flux=False, dtype=dtype)
150152
print("diffusers results:", diffusers_results)
151153
torch.testing.assert_close(diffusers_results, results)
152154

sharktank/tests/models/vae/vae_test.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from sharktank.types import Dataset
1515
from sharktank.models.vae.model import VaeDecoderModel
16-
from sharktank.models.vae.tools.diffuser_ref import run_torch_vae, run_flux_vae
16+
from sharktank.models.vae.tools.diffuser_ref import run_torch_vae
1717
from sharktank.models.vae.tools.run_vae import export_vae
1818
from sharktank.models.vae.tools.sample_data import get_random_inputs
1919

@@ -166,7 +166,7 @@ def testVaeIreeVsHuggingFace(self):
166166
vm_context=iree_vm_context,
167167
args=iree_args,
168168
driver="hip",
169-
function_name="forward",
169+
function_name="decode",
170170
)[0].to_host()
171171
# TODO: Verify these numerics are good or if tolerances are too loose
172172
# TODO: Upload IR on passing tests to keep https://github.com/iree-org/iree/blob/main/experimental/regression_suite/shark-test-suite-models/sdxl/test_vae.py at latest
@@ -194,7 +194,7 @@ def testVaeIreeVsHuggingFace(self):
194194
vm_context=iree_vm_context,
195195
args=iree_args,
196196
driver="hip",
197-
function_name="forward",
197+
function_name="decode",
198198
)[0].to_host()
199199
# TODO: Upload IR on passing tests
200200
torch.testing.assert_close(
@@ -237,7 +237,9 @@ def setUp(self):
237237
def testCompareBF16EagerVsHuggingface(self):
238238
dtype = torch.bfloat16
239239
inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1, config="flux")
240-
ref_results = run_flux_vae(inputs, dtype)
240+
ref_results = run_torch_vae(
241+
"black-forest-labs/FLUX.1-dev", inputs, 1024, 1024, True, dtype
242+
)
241243

242244
ds = Dataset.load("{self._temp_dir}/flux_vae_bf16.irpa", file_type="irpa")
243245
model = VaeDecoderModel.from_dataset(ds).to(device="cpu")
@@ -249,7 +251,9 @@ def testCompareBF16EagerVsHuggingface(self):
249251
def testCompareF32EagerVsHuggingface(self):
250252
dtype = torch.float32
251253
inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1, config="flux")
252-
ref_results = run_flux_vae(inputs, dtype)
254+
ref_results = run_torch_vae(
255+
"black-forest-labs/FLUX.1-dev", inputs, 1024, 1024, True, dtype
256+
)
253257

254258
ds = Dataset.load("{self._temp_dir}/flux_vae_f32.irpa", file_type="irpa")
255259
model = VaeDecoderModel.from_dataset(ds).to(device="cpu", dtype=dtype)
@@ -262,8 +266,9 @@ def testVaeIreeVsHuggingFace(self):
262266
inputs = get_random_inputs(
263267
dtype=torch.float32, device="cpu", bs=1, config="flux"
264268
)
265-
ref_results = run_flux_vae(inputs.to(dtype), dtype)
266-
ref_results_f32 = run_flux_vae(inputs, torch.float32)
269+
ref_results = run_torch_vae(
270+
"black-forest-labs/FLUX.1-dev", inputs, 1024, 1024, True, torch.float32
271+
)
267272

268273
ds = Dataset.load("{self._temp_dir}/flux_vae_bf16.irpa", file_type="irpa")
269274
ds_f32 = Dataset.load("{self._temp_dir}/flux_vae_f32.irpa", file_type="irpa")
@@ -324,12 +329,14 @@ def testVaeIreeVsHuggingFace(self):
324329
vm_context=iree_vm_context,
325330
args=iree_args,
326331
driver="hip",
327-
function_name="forward",
332+
function_name="decode",
328333
)[0]
329334
)
330335

331336
# TODO verify these numerics
332-
torch.testing.assert_close(ref_results, iree_result, atol=3.3e-2, rtol=4e5)
337+
torch.testing.assert_close(
338+
ref_results.to(torch.bfloat16), iree_result, atol=3.3e-2, rtol=4e5
339+
)
333340

334341
iree_module, iree_vm_context, iree_vm_instance = load_iree_module(
335342
module_path="{self._temp_dir}/flux_vae_f32.vmfb",
@@ -349,11 +356,11 @@ def testVaeIreeVsHuggingFace(self):
349356
vm_context=iree_vm_context,
350357
args=iree_args,
351358
driver="hip",
352-
function_name="forward",
359+
function_name="decode",
353360
)[0]
354361
)
355362

356-
torch.testing.assert_close(ref_results_f32, iree_result_f32)
363+
torch.testing.assert_close(ref_results, iree_result_f32)
357364

358365

359366
if __name__ == "__main__":

0 commit comments

Comments
 (0)