Skip to content

Commit 7177d2d

Browse files
committed
Add a test for mmdit
1 parent 4a18656 commit 7177d2d

File tree

5 files changed

+178
-97
lines changed

5 files changed

+178
-97
lines changed

sharktank/sharktank/layers/mmdit_double.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from .paged_llama_attention_block import PagedLlamaAttentionBlock
1212

1313

14-
1514
def qk_norm(q, k, v, rms_q, rms_k):
1615
return rms_q(q).to(v), rms_k(k).to(v)
1716

@@ -25,9 +24,11 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tenso
2524

2625

2726
def attention(q, k, v, pe):
28-
q, k = apply_rope(q, k, pe) #todo
27+
q, k = apply_rope(q, k, pe) # todo
2928

30-
x = ops.scaled_dot_product_attention(q=q, k=k, v=v, a=None, is_causal=True, scale=None)
29+
x = ops.scaled_dot_product_attention(
30+
q=q, k=k, v=v, a=None, is_causal=True, scale=None
31+
)
3132
x = ops.permute(x, (0, 2, 1, 3))
3233
x = x.view(x.shape[0], x.shape[1], -1)
3334

@@ -41,44 +42,61 @@ def __init__(self, theta, num_heads: int):
4142
self.num_heads = num_heads
4243
self.img_mod = ModulationLayer(theta("img_mod"), double=True)
4344
self.img_attn_qkv = LinearLayer(theta("img_attn.qkv"))
44-
self.img_attn_norm_q = RMSNormLayer(theta("img_attn.norm.query_norm"), epsilon=1e-6)
45-
self.img_attn_norm_k = RMSNormLayer(theta("img_attn.norm.key_norm"), epsilon=1e-6)
45+
self.img_attn_norm_q = RMSNormLayer(
46+
theta("img_attn.norm.query_norm"), epsilon=1e-6
47+
)
48+
self.img_attn_norm_k = RMSNormLayer(
49+
theta("img_attn.norm.key_norm"), epsilon=1e-6
50+
)
4651
self.img_attn_proj = LinearLayer(theta("img_attn.proj"))
4752

4853
self.img_mlp1 = LinearLayer(theta("img_mlp.0"))
4954
self.img_mlp2 = LinearLayer(theta("img_mlp.2"))
5055

5156
self.txt_mod = ModulationLayer(theta("txt_mod"), double=True)
5257
self.txt_attn_qkv = LinearLayer(theta("txt_attn.qkv"))
53-
self.txt_attn_norm_q = RMSNormLayer(theta("txt_attn.norm.query_norm"), epsilon=1e-6)
54-
self.txt_attn_norm_k = RMSNormLayer(theta("txt_attn.norm.key_norm"), epsilon=1e-6)
58+
self.txt_attn_norm_q = RMSNormLayer(
59+
theta("txt_attn.norm.query_norm"), epsilon=1e-6
60+
)
61+
self.txt_attn_norm_k = RMSNormLayer(
62+
theta("txt_attn.norm.key_norm"), epsilon=1e-6
63+
)
5564
self.txt_attn_proj = LinearLayer(theta("txt_attn.proj"))
5665

5766
self.txt_mlp1 = LinearLayer(theta("txt_mlp.0"))
5867
self.txt_mlp2 = LinearLayer(theta("txt_mlp.2"))
5968

60-
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
69+
def forward(
70+
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
71+
) -> tuple[Tensor, Tensor]:
6172
img_mod1, img_mod2 = self.img_mod(vec)
6273
txt_mod1, txt_mod2 = self.txt_mod(vec)
6374

6475
# prepare image for attention
6576
img_modulated = ops.layer_norm(img, None, None, eps=1e-6)
6677
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
6778
img_qkv = self.img_attn_qkv(img_modulated)
68-
img_qkv_2 = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1) #
79+
img_qkv_2 = img_qkv.view(
80+
img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1
81+
) #
6982
img_qkv_3 = ops.permute(img_qkv_2, (2, 0, 3, 1, 4))
7083
img_q, img_k, img_v = img_qkv_3
71-
img_q, img_k = qk_norm(img_q, img_k, img_v, self.img_attn_norm_q, self.img_attn_norm_k)
72-
84+
img_q, img_k = qk_norm(
85+
img_q, img_k, img_v, self.img_attn_norm_q, self.img_attn_norm_k
86+
)
7387

7488
# prepare txt for attention
7589
txt_modulated = ops.layer_norm(txt, None, None, eps=1e-6)
7690
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
7791
txt_qkv = self.txt_attn_qkv(txt_modulated)
78-
txt_qkv_2 = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1) #
92+
txt_qkv_2 = txt_qkv.view(
93+
txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1
94+
) #
7995
txt_qkv_3 = ops.permute(txt_qkv_2, (2, 0, 3, 1, 4))
8096
txt_q, txt_k, txt_v = txt_qkv_3
81-
txt_q, txt_k = qk_norm(txt_q, txt_k, txt_v, self.txt_attn_norm_q, self.txt_attn_norm_k)
97+
txt_q, txt_k = qk_norm(
98+
txt_q, txt_k, txt_v, self.txt_attn_norm_q, self.txt_attn_norm_k
99+
)
82100

83101
# run actual attention
84102
q = torch.cat((txt_q, img_q), dim=2)
@@ -90,19 +108,22 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Te
90108

91109
# calculate the img bloks
92110
img = img + img_mod1.gate * self.img_attn_proj(img_attn)
93-
img_mlp_in = (1 + img_mod2.scale) * ops.layer_norm(img, None, None, eps=1e-6) + img_mod2.shift
111+
img_mlp_in = (1 + img_mod2.scale) * ops.layer_norm(
112+
img, None, None, eps=1e-6
113+
) + img_mod2.shift
94114
img_mlp_out1 = self.img_mlp1(img_mlp_in)
95115
img_mlp_out2 = ops.elementwise(F.gelu, img_mlp_out1)
96116
img_mlp_out3 = self.img_mlp2(img_mlp_out2)
97117
img = img + img_mod2.gate * img_mlp_out3
98118

99119
# calculate the txt bloks
100120
txt = txt + txt_mod1.gate * self.txt_attn_proj(txt_attn)
101-
txt_mlp_in = (1 + txt_mod2.scale) * ops.layer_norm(txt, None, None, eps=1e-6) + txt_mod2.shift
121+
txt_mlp_in = (1 + txt_mod2.scale) * ops.layer_norm(
122+
txt, None, None, eps=1e-6
123+
) + txt_mod2.shift
102124
txt_mlp_out1 = self.txt_mlp1(txt_mlp_in)
103125
txt_mlp_out2 = ops.elementwise(F.gelu, txt_mlp_out1)
104126
txt_mlp_out3 = self.txt_mlp2(txt_mlp_out2)
105127
txt = txt + txt_mod2.gate * txt_mlp_out3
106-
107-
return img, txt
108128

129+
return img, txt

sharktank/sharktank/layers/modulation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from .base import Theta, ThetaLayer
77
from .linear import LinearLayer
88

9-
class ModulationOut():
9+
10+
class ModulationOut:
1011
def __init__(self, shift, scale, gate):
1112
self.shift = shift
1213
self.scale = scale

sharktank/sharktank/layers/testing.py

Lines changed: 30 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -44,130 +44,81 @@ def make_llama_attention_block_theta(
4444
}
4545
)
4646

47-
def make_mmdit_block_theta(
48-
dtype: torch.dtype | None = None
49-
) -> Theta:
47+
48+
def make_mmdit_block_theta(dtype: torch.dtype | None = None) -> Theta:
5049
return Theta(
5150
{
52-
"img_attn.norm.key_norm.weight": DefaultPrimitiveTensor( #
53-
data=make_rand_torch(
54-
(128,), dtype=dtype
55-
)
51+
"img_attn.norm.key_norm.weight": DefaultPrimitiveTensor( #
52+
data=make_rand_torch((128,), dtype=dtype)
5653
),
57-
"img_attn.norm.query_norm.weight": DefaultPrimitiveTensor( #
58-
data=make_rand_torch(
59-
(128,), dtype=dtype
60-
)
54+
"img_attn.norm.query_norm.weight": DefaultPrimitiveTensor( #
55+
data=make_rand_torch((128,), dtype=dtype)
6156
),
6257
"img_attn.proj.bias": DefaultPrimitiveTensor(
63-
data=make_rand_torch(
64-
(3072,), dtype=dtype
65-
)
58+
data=make_rand_torch((3072,), dtype=dtype)
6659
),
6760
"img_attn.proj.weight": DefaultPrimitiveTensor(
68-
data=make_rand_torch(
69-
(3072, 3072), dtype=dtype
70-
)
61+
data=make_rand_torch((3072, 3072), dtype=dtype)
7162
),
7263
"img_attn.qkv.bias": DefaultPrimitiveTensor(
73-
data=make_rand_torch(
74-
(9216,), dtype=dtype
75-
)
64+
data=make_rand_torch((9216,), dtype=dtype)
7665
),
7766
"img_attn.qkv.weight": DefaultPrimitiveTensor(
78-
data=make_rand_torch(
79-
(9216, 3072), dtype=dtype
80-
)
67+
data=make_rand_torch((9216, 3072), dtype=dtype)
8168
),
8269
"img_mlp.0.bias": DefaultPrimitiveTensor(
83-
data=make_rand_torch(
84-
(12288), dtype=dtype
85-
)
70+
data=make_rand_torch((12288), dtype=dtype)
8671
),
8772
"img_mlp.0.weight": DefaultPrimitiveTensor(
88-
data=make_rand_torch(
89-
(12288, 3072), dtype=dtype
90-
)
73+
data=make_rand_torch((12288, 3072), dtype=dtype)
9174
),
9275
"img_mlp.2.bias": DefaultPrimitiveTensor(
93-
data=make_rand_torch(
94-
(3072), dtype=dtype
95-
)
76+
data=make_rand_torch((3072), dtype=dtype)
9677
),
9778
"img_mlp.2.weight": DefaultPrimitiveTensor(
98-
data=make_rand_torch(
99-
(3072, 12288), dtype=dtype
100-
)
79+
data=make_rand_torch((3072, 12288), dtype=dtype)
10180
),
10281
"img_mod.lin.bias": DefaultPrimitiveTensor(
103-
data=make_rand_torch(
104-
(18432,), dtype=dtype
105-
)
82+
data=make_rand_torch((18432,), dtype=dtype)
10683
),
10784
"img_mod.lin.weight": DefaultPrimitiveTensor(
108-
data=make_rand_torch(
109-
(18432, 3072), dtype=dtype
110-
)
85+
data=make_rand_torch((18432, 3072), dtype=dtype)
11186
),
112-
"txt_attn.norm.key_norm.weight": DefaultPrimitiveTensor( #
113-
data=make_rand_torch(
114-
(128,), dtype=dtype
115-
)
87+
"txt_attn.norm.key_norm.weight": DefaultPrimitiveTensor( #
88+
data=make_rand_torch((128,), dtype=dtype)
11689
),
117-
"txt_attn.norm.query_norm.weight": DefaultPrimitiveTensor( #
118-
data=make_rand_torch(
119-
(128,), dtype=dtype
120-
)
90+
"txt_attn.norm.query_norm.weight": DefaultPrimitiveTensor( #
91+
data=make_rand_torch((128,), dtype=dtype)
12192
),
12293
"txt_attn.proj.bias": DefaultPrimitiveTensor(
123-
data=make_rand_torch(
124-
(3072,), dtype=dtype
125-
)
94+
data=make_rand_torch((3072,), dtype=dtype)
12695
),
12796
"txt_attn.proj.weight": DefaultPrimitiveTensor(
128-
data=make_rand_torch(
129-
(3072, 3072), dtype=dtype
130-
)
97+
data=make_rand_torch((3072, 3072), dtype=dtype)
13198
),
13299
"txt_attn.qkv.bias": DefaultPrimitiveTensor(
133-
data=make_rand_torch(
134-
(9216,), dtype=dtype
135-
)
100+
data=make_rand_torch((9216,), dtype=dtype)
136101
),
137102
"txt_attn.qkv.weight": DefaultPrimitiveTensor(
138-
data=make_rand_torch(
139-
(9216, 3072), dtype=dtype
140-
)
103+
data=make_rand_torch((9216, 3072), dtype=dtype)
141104
),
142105
"txt_mlp.0.bias": DefaultPrimitiveTensor(
143-
data=make_rand_torch(
144-
(12288), dtype=dtype
145-
)
106+
data=make_rand_torch((12288), dtype=dtype)
146107
),
147108
"txt_mlp.0.weight": DefaultPrimitiveTensor(
148-
data=make_rand_torch(
149-
(12288, 3072), dtype=dtype
150-
)
109+
data=make_rand_torch((12288, 3072), dtype=dtype)
151110
),
152111
"txt_mlp.2.bias": DefaultPrimitiveTensor(
153-
data=make_rand_torch(
154-
(3072), dtype=dtype
155-
)
112+
data=make_rand_torch((3072), dtype=dtype)
156113
),
157114
"txt_mlp.2.weight": DefaultPrimitiveTensor(
158-
data=make_rand_torch(
159-
(3072, 12288), dtype=dtype
160-
)
115+
data=make_rand_torch((3072, 12288), dtype=dtype)
161116
),
162117
"txt_mod.lin.bias": DefaultPrimitiveTensor(
163-
data=make_rand_torch(
164-
(18432,), dtype=dtype
165-
)
118+
data=make_rand_torch((18432,), dtype=dtype)
166119
),
167120
"txt_mod.lin.weight": DefaultPrimitiveTensor(
168-
data=make_rand_torch(
169-
(18432, 3072), dtype=dtype
170-
)
121+
data=make_rand_torch((18432, 3072), dtype=dtype)
171122
),
172123
}
173124
)

sharktank/sharktank/ops/default_impls.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ def interpolate_default(
320320
antialias=antialias,
321321
)
322322

323+
323324
def layer_norm_default(input, weight, bias, *, eps):
324325
input = unbox_tensor(input)
325326
if weight is not None:
@@ -334,6 +335,7 @@ def layer_norm_default(input, weight, bias, *, eps):
334335
input, normalized_shape=weight.shape, weight=weight, bias=bias, eps=eps
335336
)
336337

338+
337339
layer_norm.override(Tensor)(layer_norm_default)
338340
layer_norm.override(Tensor, Tensor)(layer_norm_default)
339341
layer_norm.override(Tensor, Tensor, Tensor)(layer_norm_default)

0 commit comments

Comments
 (0)