Skip to content

Commit 83437b5

Browse files
authored
Add Flux transformer export for easier use outside of tests (#700)
Adapt the model to accept parameters as structured in the HF repo. Make Punet parameters importation from HF more general to serve other models as well. When downloading a dataset from Hugging Face make it return the local location of all downloaded files including extras, not just the "leading" file. Add sample_inputs method to the BaseLayer interface to help standardize exportation. Introduce a standard export function for static-sized models.
1 parent b151ffa commit 83437b5

File tree

14 files changed

+417
-138
lines changed

14 files changed

+417
-138
lines changed

.github/workflows/ci-sharktank.yml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,15 +136,19 @@ jobs:
136136
pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/
137137
138138
- name: Run tests
139-
# TODO: unify with-t5-data and with-clip-data flags into a single flag
140-
# and make it possible to run only tests that require data.
139+
# TODO: unify with-*-data flags into a single flag and make it possible to run
140+
# only tests that require data.
141+
# We would still want the separate flags as we may endup with data being
142+
# scattered on different CI machines.
141143
run: |
142144
source ${VENV_DIR}/bin/activate
143145
pytest \
144-
--with-clip-data \
146+
--with-clip-data \
147+
--with-flux-data \
145148
--with-t5-data \
146149
sharktank/tests/models/clip/clip_test.py \
147150
sharktank/tests/models/t5/t5_test.py \
151+
sharktank/tests/models/flux/flux_test.py \
148152
--durations=0
149153
150154

sharktank/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,15 @@ def pytest_addoption(parser):
9797
"code. The user is expected to provide the data"
9898
),
9999
)
100+
parser.addoption(
101+
"--with-flux-data",
102+
action="store_true",
103+
default=False,
104+
help=(
105+
"Enable tests that use Flux data like models that is not a part of the source "
106+
"code. The user is expected to provide the data"
107+
),
108+
)
100109
parser.addoption(
101110
"--with-t5-data",
102111
action="store_true",

sharktank/integration/models/punet/integration_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def download(filename):
6767

6868
@pytest.fixture(scope="module")
6969
def sdxl_fp16_dataset(sdxl_fp16_base_files, temp_dir):
70-
from sharktank.models.punet.tools import import_hf_dataset
70+
from sharktank.tools import import_hf_dataset
7171

7272
dataset = temp_dir / "sdxl_fp16_dataset.irpa"
7373
import_hf_dataset.main(

sharktank/sharktank/export.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7-
from typing import Callable, Any
7+
from typing import Callable, Optional, Any
88
import torch
9+
from os import PathLike
10+
import iree.turbine.aot as aot
911
from iree.turbine.aot import DeviceAffinity, FxProgramsBuilder
1012
from torch.utils._pytree import tree_structure, tree_unflatten, tree_flatten
1113
from .types.tensors import ShardedTensor
14+
from .layers import BaseLayer
1215
from torch.utils._pytree import PyTree, _is_leaf
1316
import functools
1417

@@ -172,3 +175,51 @@ def flat_fn(*args, **kwargs):
172175
)
173176

174177
assert False, "TODO: implement the case when not using an FxProgramsBuilder"
178+
179+
180+
def export_static_model_mlir(
181+
model: BaseLayer,
182+
output_path: PathLike,
183+
function_batch_size_pairs: Optional[dict[Optional[str], list[int]]] = None,
184+
batch_sizes: Optional[list[int]] = None,
185+
):
186+
"""Export a model with no dynamic dimensions.
187+
188+
For the set of provided function name batch sizes pair, the resulting MLIR will
189+
have function names with the below format.
190+
```
191+
<function_name>_bs<batch_size>
192+
```
193+
194+
If `batch_sizes` is given then it defaults to a single function with named
195+
"forward".
196+
197+
The model is required to implement method `sample_inputs`.
198+
"""
199+
200+
assert not (function_batch_size_pairs is not None and batch_sizes is not None)
201+
202+
if batch_sizes is not None:
203+
function_batch_size_pairs = {None: batch_sizes}
204+
205+
if function_batch_size_pairs is None and batch_sizes is None:
206+
function_batch_size_pairs = {None: batch_sizes}
207+
208+
fxb = FxProgramsBuilder(model)
209+
210+
for function, batch_sizes in function_batch_size_pairs.items():
211+
for batch_size in batch_sizes:
212+
args, kwargs = model.sample_inputs(batch_size, function)
213+
214+
@fxb.export_program(
215+
name=f"{function or 'forward'}_bs{batch_size}",
216+
args=args,
217+
kwargs=kwargs,
218+
dynamic_shapes=None,
219+
strict=False,
220+
)
221+
def _(model, **kwargs):
222+
return model(**kwargs)
223+
224+
output = aot.export(fxb)
225+
output.save_mlir(output_path)

sharktank/sharktank/layers/base.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,12 @@
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7-
from typing import Dict
8-
7+
from typing import Dict, Optional
8+
from collections import OrderedDict
99
import torch
1010
import torch.nn as nn
1111

12-
from ..types import (
13-
InferenceTensor,
14-
Theta,
15-
)
12+
from ..types import InferenceTensor, Theta, AnyTensor
1613
from ..utils import debugging
1714

1815
__all__ = [
@@ -56,6 +53,21 @@ def assert_not_nan(self, *ts: torch.Tensor):
5653
if torch.isnan(t).any():
5754
raise AssertionError(f"Tensor contains nans! {t}")
5855

56+
def sample_inputs(
57+
self, batch_size: int = 1, function: Optional[str] = None
58+
) -> tuple[tuple[AnyTensor], OrderedDict[str, AnyTensor]]:
59+
"""Return sample inputs that can be used to run the function from the model.
60+
If function is None then layer is treated as the callable.
61+
E.g.
62+
```
63+
args, kwargs = model.sample_inputs()
64+
model(*args, **kwargs)
65+
```
66+
67+
One purpose of this method is to standardize exportation of models to MLIR.
68+
"""
69+
raise NotImplementedError()
70+
5971

6072
class ThetaLayer(BaseLayer):
6173
"Base class for layers that derive parameters from a Theta object."

sharktank/sharktank/layers/mmdit.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,15 @@ def __init__(self, theta, num_heads: int):
5555
self.add_module("img_attn_qkv", LinearLayer(theta("img_attn.qkv")))
5656
self.add_module(
5757
"img_attn_norm_q",
58-
RMSNormLayer(theta("img_attn.norm.query_norm"), epsilon=1e-6),
58+
RMSNormLayer(
59+
theta("img_attn.norm.query_norm"), weight_name="scale", epsilon=1e-6
60+
),
5961
)
6062
self.add_module(
6163
"img_attn_norm_k",
62-
RMSNormLayer(theta("img_attn.norm.key_norm"), epsilon=1e-6),
64+
RMSNormLayer(
65+
theta("img_attn.norm.key_norm"), weight_name="scale", epsilon=1e-6
66+
),
6367
)
6468
self.add_module("img_attn_proj", LinearLayer(theta("img_attn.proj")))
6569

@@ -70,11 +74,15 @@ def __init__(self, theta, num_heads: int):
7074
self.add_module("txt_attn_qkv", LinearLayer(theta("txt_attn.qkv")))
7175
self.add_module(
7276
"txt_attn_norm_q",
73-
RMSNormLayer(theta("txt_attn.norm.query_norm"), epsilon=1e-6),
77+
RMSNormLayer(
78+
theta("txt_attn.norm.query_norm"), weight_name="scale", epsilon=1e-6
79+
),
7480
)
7581
self.add_module(
7682
"txt_attn_norm_k",
77-
RMSNormLayer(theta("txt_attn.norm.key_norm"), epsilon=1e-6),
83+
RMSNormLayer(
84+
theta("txt_attn.norm.key_norm"), weight_name="scale", epsilon=1e-6
85+
),
7886
)
7987
self.add_module("txt_attn_proj", LinearLayer(theta("txt_attn.proj")))
8088

@@ -151,14 +159,15 @@ def __init__(self, theta, num_heads: int):
151159
super().__init__(theta)
152160

153161
self.num_heads = num_heads
154-
self.add_module("mod", ModulationLayer(theta("mod"), double=False))
162+
self.add_module("mod", ModulationLayer(theta("modulation"), double=False))
155163
self.add_module(
156-
"attn_norm_q", RMSNormLayer(theta("attn.norm.query_norm"), epsilon=1e-6)
164+
"attn_norm_q",
165+
RMSNormLayer(theta("norm.query_norm"), weight_name="scale", epsilon=1e-6),
157166
)
158167
self.add_module(
159-
"attn_norm_k", RMSNormLayer(theta("attn.norm.key_norm"), epsilon=1e-6)
168+
"attn_norm_k",
169+
RMSNormLayer(theta("norm.key_norm"), weight_name="scale", epsilon=1e-6),
160170
)
161-
self.add_module("attn_proj", LinearLayer(theta("attn.proj")))
162171

163172
self.add_module("linear1", LinearLayer(theta("linear1")))
164173
self.add_module("linear2", LinearLayer(theta("linear2")))

sharktank/sharktank/layers/testing.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,10 @@ def make_mmdit_double_block_random_theta(
6565
mlp_hidden_size3 = int(2 * (mlp_ratio - 1) * hidden_size)
6666
return Theta(
6767
{
68-
"img_attn.norm.key_norm.weight": DefaultPrimitiveTensor( #
68+
"img_attn.norm.key_norm.scale": DefaultPrimitiveTensor( #
6969
data=make_rand_torch((in_channels,), dtype=dtype)
7070
),
71-
"img_attn.norm.query_norm.weight": DefaultPrimitiveTensor( #
71+
"img_attn.norm.query_norm.scale": DefaultPrimitiveTensor( #
7272
data=make_rand_torch((in_channels,), dtype=dtype)
7373
),
7474
"img_attn.proj.bias": DefaultPrimitiveTensor(
@@ -101,10 +101,10 @@ def make_mmdit_double_block_random_theta(
101101
"img_mod.lin.weight": DefaultPrimitiveTensor(
102102
data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype)
103103
),
104-
"txt_attn.norm.key_norm.weight": DefaultPrimitiveTensor( #
104+
"txt_attn.norm.key_norm.scale": DefaultPrimitiveTensor( #
105105
data=make_rand_torch((in_channels,), dtype=dtype)
106106
),
107-
"txt_attn.norm.query_norm.weight": DefaultPrimitiveTensor( #
107+
"txt_attn.norm.query_norm.scale": DefaultPrimitiveTensor( #
108108
data=make_rand_torch((in_channels,), dtype=dtype)
109109
),
110110
"txt_attn.proj.bias": DefaultPrimitiveTensor(
@@ -155,10 +155,10 @@ def make_mmdit_single_block_random_theta(
155155
mlp_hidden_size3 = int((2 * mlp_ratio - 1) * hidden_size)
156156
return Theta(
157157
{
158-
"attn.norm.key_norm.weight": DefaultPrimitiveTensor( #
158+
"norm.key_norm.scale": DefaultPrimitiveTensor( #
159159
data=make_rand_torch((in_channels,), dtype=dtype)
160160
),
161-
"attn.norm.query_norm.weight": DefaultPrimitiveTensor( #
161+
"norm.query_norm.scale": DefaultPrimitiveTensor( #
162162
data=make_rand_torch((in_channels,), dtype=dtype)
163163
),
164164
"attn.proj.bias": DefaultPrimitiveTensor(
@@ -179,10 +179,10 @@ def make_mmdit_single_block_random_theta(
179179
"linear2.weight": DefaultPrimitiveTensor(
180180
data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype)
181181
),
182-
"mod.lin.bias": DefaultPrimitiveTensor(
182+
"modulation.lin.bias": DefaultPrimitiveTensor(
183183
data=make_rand_torch((mlp_hidden_size,), dtype=dtype)
184184
),
185-
"mod.lin.weight": DefaultPrimitiveTensor(
185+
"modulation.lin.weight": DefaultPrimitiveTensor(
186186
data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype)
187187
),
188188
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
from os import PathLike
8+
9+
from ...export import export_static_model_mlir
10+
from ...tools.import_hf_dataset import import_hf_dataset
11+
from .flux import FluxModelV1, FluxParams
12+
from ...types import Dataset
13+
from ...utils.hf_datasets import get_dataset
14+
15+
flux_transformer_default_batch_sizes = [4]
16+
17+
18+
def export_flux_transformer_model_mlir(
19+
model: FluxModelV1,
20+
output_path: PathLike,
21+
batch_sizes: list[int] = flux_transformer_default_batch_sizes,
22+
):
23+
export_static_model_mlir(model, output_path=output_path, batch_sizes=batch_sizes)
24+
25+
26+
def export_flux_transformer_from_hugging_face(
27+
repo_id: str,
28+
mlir_output_path: PathLike,
29+
parameters_output_path: PathLike,
30+
batch_sizes: list[int] = flux_transformer_default_batch_sizes,
31+
):
32+
hf_dataset = get_dataset(
33+
repo_id,
34+
).download()
35+
36+
import_hf_dataset(
37+
config_json_path=hf_dataset["config"][0],
38+
param_paths=hf_dataset["parameters"],
39+
output_irpa_file=parameters_output_path,
40+
)
41+
42+
dataset = Dataset.load(parameters_output_path)
43+
model = FluxModelV1(
44+
theta=dataset.root_theta,
45+
params=FluxParams.from_hugging_face_properties(dataset.properties),
46+
)
47+
export_flux_transformer_model_mlir(
48+
model, output_path=mlir_output_path, batch_sizes=batch_sizes
49+
)

0 commit comments

Comments
 (0)