Skip to content

Commit 4a18656

Browse files
committed
Implement MMDIT block that is necessary for flux
1 parent ac17f86 commit 4a18656

File tree

6 files changed

+286
-6
lines changed

6 files changed

+286
-6
lines changed

sharktank/sharktank/layers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,7 @@
1717
from .ffn_block import FFN
1818
from .ffn_moe_block import FFNMOE
1919
from .mixture_of_experts_block import MoeBlock
20+
from .mmdit_double import MMDITDoubleBlock
21+
from .mmdit_single import MMDITSingleBlock
2022

2123
from .configs import *
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import torch.nn.functional as F
2+
import torch
3+
from torch import Tensor
4+
5+
from .. import ops
6+
7+
from .base import Theta, ThetaLayer
8+
from .linear import LinearLayer
9+
from .modulation import ModulationLayer
10+
from .norm import RMSNormLayer
11+
from .paged_llama_attention_block import PagedLlamaAttentionBlock
12+
13+
14+
15+
def qk_norm(q, k, v, rms_q, rms_k):
16+
return rms_q(q).to(v), rms_k(k).to(v)
17+
18+
19+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
20+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
21+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
22+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
23+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
24+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
25+
26+
27+
def attention(q, k, v, pe):
28+
q, k = apply_rope(q, k, pe) #todo
29+
30+
x = ops.scaled_dot_product_attention(q=q, k=k, v=v, a=None, is_causal=True, scale=None)
31+
x = ops.permute(x, (0, 2, 1, 3))
32+
x = x.view(x.shape[0], x.shape[1], -1)
33+
34+
return x
35+
36+
37+
class MMDITDoubleBlock(ThetaLayer):
38+
def __init__(self, theta, num_heads: int):
39+
super().__init__(theta)
40+
41+
self.num_heads = num_heads
42+
self.img_mod = ModulationLayer(theta("img_mod"), double=True)
43+
self.img_attn_qkv = LinearLayer(theta("img_attn.qkv"))
44+
self.img_attn_norm_q = RMSNormLayer(theta("img_attn.norm.query_norm"), epsilon=1e-6)
45+
self.img_attn_norm_k = RMSNormLayer(theta("img_attn.norm.key_norm"), epsilon=1e-6)
46+
self.img_attn_proj = LinearLayer(theta("img_attn.proj"))
47+
48+
self.img_mlp1 = LinearLayer(theta("img_mlp.0"))
49+
self.img_mlp2 = LinearLayer(theta("img_mlp.2"))
50+
51+
self.txt_mod = ModulationLayer(theta("txt_mod"), double=True)
52+
self.txt_attn_qkv = LinearLayer(theta("txt_attn.qkv"))
53+
self.txt_attn_norm_q = RMSNormLayer(theta("txt_attn.norm.query_norm"), epsilon=1e-6)
54+
self.txt_attn_norm_k = RMSNormLayer(theta("txt_attn.norm.key_norm"), epsilon=1e-6)
55+
self.txt_attn_proj = LinearLayer(theta("txt_attn.proj"))
56+
57+
self.txt_mlp1 = LinearLayer(theta("txt_mlp.0"))
58+
self.txt_mlp2 = LinearLayer(theta("txt_mlp.2"))
59+
60+
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
61+
img_mod1, img_mod2 = self.img_mod(vec)
62+
txt_mod1, txt_mod2 = self.txt_mod(vec)
63+
64+
# prepare image for attention
65+
img_modulated = ops.layer_norm(img, None, None, eps=1e-6)
66+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
67+
img_qkv = self.img_attn_qkv(img_modulated)
68+
img_qkv_2 = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1) #
69+
img_qkv_3 = ops.permute(img_qkv_2, (2, 0, 3, 1, 4))
70+
img_q, img_k, img_v = img_qkv_3
71+
img_q, img_k = qk_norm(img_q, img_k, img_v, self.img_attn_norm_q, self.img_attn_norm_k)
72+
73+
74+
# prepare txt for attention
75+
txt_modulated = ops.layer_norm(txt, None, None, eps=1e-6)
76+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
77+
txt_qkv = self.txt_attn_qkv(txt_modulated)
78+
txt_qkv_2 = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1) #
79+
txt_qkv_3 = ops.permute(txt_qkv_2, (2, 0, 3, 1, 4))
80+
txt_q, txt_k, txt_v = txt_qkv_3
81+
txt_q, txt_k = qk_norm(txt_q, txt_k, txt_v, self.txt_attn_norm_q, self.txt_attn_norm_k)
82+
83+
# run actual attention
84+
q = torch.cat((txt_q, img_q), dim=2)
85+
k = torch.cat((txt_k, img_k), dim=2)
86+
v = torch.cat((txt_v, img_v), dim=2)
87+
88+
attn = attention(q, k, v, pe)
89+
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
90+
91+
# calculate the img bloks
92+
img = img + img_mod1.gate * self.img_attn_proj(img_attn)
93+
img_mlp_in = (1 + img_mod2.scale) * ops.layer_norm(img, None, None, eps=1e-6) + img_mod2.shift
94+
img_mlp_out1 = self.img_mlp1(img_mlp_in)
95+
img_mlp_out2 = ops.elementwise(F.gelu, img_mlp_out1)
96+
img_mlp_out3 = self.img_mlp2(img_mlp_out2)
97+
img = img + img_mod2.gate * img_mlp_out3
98+
99+
# calculate the txt bloks
100+
txt = txt + txt_mod1.gate * self.txt_attn_proj(txt_attn)
101+
txt_mlp_in = (1 + txt_mod2.scale) * ops.layer_norm(txt, None, None, eps=1e-6) + txt_mod2.shift
102+
txt_mlp_out1 = self.txt_mlp1(txt_mlp_in)
103+
txt_mlp_out2 = ops.elementwise(F.gelu, txt_mlp_out1)
104+
txt_mlp_out3 = self.txt_mlp2(txt_mlp_out2)
105+
txt = txt + txt_mod2.gate * txt_mlp_out3
106+
107+
return img, txt
108+
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import torch
2+
import torch.nn.functional as F
3+
4+
from .. import ops
5+
6+
from .base import Theta, ThetaLayer
7+
from .linear import LinearLayer
8+
9+
class ModulationOut():
10+
def __init__(self, shift, scale, gate):
11+
self.shift = shift
12+
self.scale = scale
13+
self.gate = gate
14+
15+
16+
class ModulationLayer(ThetaLayer):
17+
def __init__(self, theta: Theta, double: bool):
18+
super().__init__(theta)
19+
20+
self.is_double = double
21+
self.multiplier = 6 if double else 3
22+
self.lin = LinearLayer(theta("lin"))
23+
24+
def forward(self, vec: torch.Tensor) -> tuple[ModulationOut, ModulationOut | None]:
25+
silu_result = ops.elementwise(F.silu, vec)
26+
out = self.lin(silu_result)[:, None, :].chunk(self.multiplier, dim=-1)
27+
28+
return (
29+
ModulationOut(*out[:3]),
30+
ModulationOut(*out[3:]) if self.is_double else None,
31+
)

sharktank/sharktank/layers/testing.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,131 @@ def make_llama_attention_block_theta(
4343
),
4444
}
4545
)
46+
47+
def make_mmdit_block_theta(
48+
dtype: torch.dtype | None = None
49+
) -> Theta:
50+
return Theta(
51+
{
52+
"img_attn.norm.key_norm.weight": DefaultPrimitiveTensor( #
53+
data=make_rand_torch(
54+
(128,), dtype=dtype
55+
)
56+
),
57+
"img_attn.norm.query_norm.weight": DefaultPrimitiveTensor( #
58+
data=make_rand_torch(
59+
(128,), dtype=dtype
60+
)
61+
),
62+
"img_attn.proj.bias": DefaultPrimitiveTensor(
63+
data=make_rand_torch(
64+
(3072,), dtype=dtype
65+
)
66+
),
67+
"img_attn.proj.weight": DefaultPrimitiveTensor(
68+
data=make_rand_torch(
69+
(3072, 3072), dtype=dtype
70+
)
71+
),
72+
"img_attn.qkv.bias": DefaultPrimitiveTensor(
73+
data=make_rand_torch(
74+
(9216,), dtype=dtype
75+
)
76+
),
77+
"img_attn.qkv.weight": DefaultPrimitiveTensor(
78+
data=make_rand_torch(
79+
(9216, 3072), dtype=dtype
80+
)
81+
),
82+
"img_mlp.0.bias": DefaultPrimitiveTensor(
83+
data=make_rand_torch(
84+
(12288), dtype=dtype
85+
)
86+
),
87+
"img_mlp.0.weight": DefaultPrimitiveTensor(
88+
data=make_rand_torch(
89+
(12288, 3072), dtype=dtype
90+
)
91+
),
92+
"img_mlp.2.bias": DefaultPrimitiveTensor(
93+
data=make_rand_torch(
94+
(3072), dtype=dtype
95+
)
96+
),
97+
"img_mlp.2.weight": DefaultPrimitiveTensor(
98+
data=make_rand_torch(
99+
(3072, 12288), dtype=dtype
100+
)
101+
),
102+
"img_mod.lin.bias": DefaultPrimitiveTensor(
103+
data=make_rand_torch(
104+
(18432,), dtype=dtype
105+
)
106+
),
107+
"img_mod.lin.weight": DefaultPrimitiveTensor(
108+
data=make_rand_torch(
109+
(18432, 3072), dtype=dtype
110+
)
111+
),
112+
"txt_attn.norm.key_norm.weight": DefaultPrimitiveTensor( #
113+
data=make_rand_torch(
114+
(128,), dtype=dtype
115+
)
116+
),
117+
"txt_attn.norm.query_norm.weight": DefaultPrimitiveTensor( #
118+
data=make_rand_torch(
119+
(128,), dtype=dtype
120+
)
121+
),
122+
"txt_attn.proj.bias": DefaultPrimitiveTensor(
123+
data=make_rand_torch(
124+
(3072,), dtype=dtype
125+
)
126+
),
127+
"txt_attn.proj.weight": DefaultPrimitiveTensor(
128+
data=make_rand_torch(
129+
(3072, 3072), dtype=dtype
130+
)
131+
),
132+
"txt_attn.qkv.bias": DefaultPrimitiveTensor(
133+
data=make_rand_torch(
134+
(9216,), dtype=dtype
135+
)
136+
),
137+
"txt_attn.qkv.weight": DefaultPrimitiveTensor(
138+
data=make_rand_torch(
139+
(9216, 3072), dtype=dtype
140+
)
141+
),
142+
"txt_mlp.0.bias": DefaultPrimitiveTensor(
143+
data=make_rand_torch(
144+
(12288), dtype=dtype
145+
)
146+
),
147+
"txt_mlp.0.weight": DefaultPrimitiveTensor(
148+
data=make_rand_torch(
149+
(12288, 3072), dtype=dtype
150+
)
151+
),
152+
"txt_mlp.2.bias": DefaultPrimitiveTensor(
153+
data=make_rand_torch(
154+
(3072), dtype=dtype
155+
)
156+
),
157+
"txt_mlp.2.weight": DefaultPrimitiveTensor(
158+
data=make_rand_torch(
159+
(3072, 12288), dtype=dtype
160+
)
161+
),
162+
"txt_mod.lin.bias": DefaultPrimitiveTensor(
163+
data=make_rand_torch(
164+
(18432,), dtype=dtype
165+
)
166+
),
167+
"txt_mod.lin.weight": DefaultPrimitiveTensor(
168+
data=make_rand_torch(
169+
(18432, 3072), dtype=dtype
170+
)
171+
),
172+
}
173+
)

sharktank/sharktank/ops/default_impls.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -320,16 +320,24 @@ def interpolate_default(
320320
antialias=antialias,
321321
)
322322

323-
324-
@layer_norm.override(Tensor, Tensor, Tensor)
325323
def layer_norm_default(input, weight, bias, *, eps):
326324
input = unbox_tensor(input)
327-
weight = unbox_tensor(weight)
328-
bias = unbox_tensor(bias)
325+
if weight is not None:
326+
weight = unbox_tensor(weight)
327+
else:
328+
weight = torch.ones(input.shape)
329+
if bias is not None:
330+
bias = unbox_tensor(bias)
331+
else:
332+
bias = torch.zeros(input.shape)
329333
return F.layer_norm(
330334
input, normalized_shape=weight.shape, weight=weight, bias=bias, eps=eps
331335
)
332336

337+
layer_norm.override(Tensor)(layer_norm_default)
338+
layer_norm.override(Tensor, Tensor)(layer_norm_default)
339+
layer_norm.override(Tensor, Tensor, Tensor)(layer_norm_default)
340+
333341

334342
# Linear
335343
def linear_default(input, weight, bias, *, accum_dtype) -> Tensor:
@@ -411,6 +419,7 @@ def rms_norm_default(x, weight, *, epsilon: float) -> Tensor:
411419
variance = x.pow(2).mean(-1, keepdim=True)
412420
output = x * elementwise(torch.rsqrt, variance + epsilon)
413421
# The cast here is to match the hf implementation, affects numerics
422+
print(x.shape, weight.shape)
414423
output = elementwise(torch.mul, weight, to(output, weight.dtype))
415424
return output
416425

sharktank/sharktank/ops/signatures.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -563,12 +563,14 @@ def layer_norm(
563563
def _layer_norm_trampoline(
564564
d: SignatureDispatcher,
565565
input: AnyTensor,
566-
weight: AnyTensor,
566+
weight: Optional[AnyTensor],
567567
bias: Optional[AnyTensor],
568568
*,
569569
eps: float,
570570
):
571-
tensors = [input, weight]
571+
tensors = [input]
572+
if weight is not None:
573+
tensors.append(bias)
572574
if bias is not None:
573575
tensors.append(bias)
574576
for override in d.find_overrides(tensors):

0 commit comments

Comments
 (0)