Skip to content

Commit

Permalink
Add a test for mmdit
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleHerndon committed Nov 22, 2024
1 parent 4a18656 commit 7177d2d
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 97 deletions.
55 changes: 38 additions & 17 deletions sharktank/sharktank/layers/mmdit_double.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from .paged_llama_attention_block import PagedLlamaAttentionBlock



def qk_norm(q, k, v, rms_q, rms_k):
return rms_q(q).to(v), rms_k(k).to(v)

Expand All @@ -25,9 +24,11 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tenso


def attention(q, k, v, pe):
q, k = apply_rope(q, k, pe) #todo
q, k = apply_rope(q, k, pe) # todo

x = ops.scaled_dot_product_attention(q=q, k=k, v=v, a=None, is_causal=True, scale=None)
x = ops.scaled_dot_product_attention(
q=q, k=k, v=v, a=None, is_causal=True, scale=None
)
x = ops.permute(x, (0, 2, 1, 3))
x = x.view(x.shape[0], x.shape[1], -1)

Expand All @@ -41,44 +42,61 @@ def __init__(self, theta, num_heads: int):
self.num_heads = num_heads
self.img_mod = ModulationLayer(theta("img_mod"), double=True)
self.img_attn_qkv = LinearLayer(theta("img_attn.qkv"))
self.img_attn_norm_q = RMSNormLayer(theta("img_attn.norm.query_norm"), epsilon=1e-6)
self.img_attn_norm_k = RMSNormLayer(theta("img_attn.norm.key_norm"), epsilon=1e-6)
self.img_attn_norm_q = RMSNormLayer(
theta("img_attn.norm.query_norm"), epsilon=1e-6
)
self.img_attn_norm_k = RMSNormLayer(
theta("img_attn.norm.key_norm"), epsilon=1e-6
)
self.img_attn_proj = LinearLayer(theta("img_attn.proj"))

self.img_mlp1 = LinearLayer(theta("img_mlp.0"))
self.img_mlp2 = LinearLayer(theta("img_mlp.2"))

self.txt_mod = ModulationLayer(theta("txt_mod"), double=True)
self.txt_attn_qkv = LinearLayer(theta("txt_attn.qkv"))
self.txt_attn_norm_q = RMSNormLayer(theta("txt_attn.norm.query_norm"), epsilon=1e-6)
self.txt_attn_norm_k = RMSNormLayer(theta("txt_attn.norm.key_norm"), epsilon=1e-6)
self.txt_attn_norm_q = RMSNormLayer(
theta("txt_attn.norm.query_norm"), epsilon=1e-6
)
self.txt_attn_norm_k = RMSNormLayer(
theta("txt_attn.norm.key_norm"), epsilon=1e-6
)
self.txt_attn_proj = LinearLayer(theta("txt_attn.proj"))

self.txt_mlp1 = LinearLayer(theta("txt_mlp.0"))
self.txt_mlp2 = LinearLayer(theta("txt_mlp.2"))

def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
def forward(
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
) -> tuple[Tensor, Tensor]:
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)

# prepare image for attention
img_modulated = ops.layer_norm(img, None, None, eps=1e-6)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn_qkv(img_modulated)
img_qkv_2 = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1) #
img_qkv_2 = img_qkv.view(
img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1
) #
img_qkv_3 = ops.permute(img_qkv_2, (2, 0, 3, 1, 4))
img_q, img_k, img_v = img_qkv_3
img_q, img_k = qk_norm(img_q, img_k, img_v, self.img_attn_norm_q, self.img_attn_norm_k)

img_q, img_k = qk_norm(
img_q, img_k, img_v, self.img_attn_norm_q, self.img_attn_norm_k
)

# prepare txt for attention
txt_modulated = ops.layer_norm(txt, None, None, eps=1e-6)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_qkv_2 = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1) #
txt_qkv_2 = txt_qkv.view(
txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1
) #
txt_qkv_3 = ops.permute(txt_qkv_2, (2, 0, 3, 1, 4))
txt_q, txt_k, txt_v = txt_qkv_3
txt_q, txt_k = qk_norm(txt_q, txt_k, txt_v, self.txt_attn_norm_q, self.txt_attn_norm_k)
txt_q, txt_k = qk_norm(
txt_q, txt_k, txt_v, self.txt_attn_norm_q, self.txt_attn_norm_k
)

# run actual attention
q = torch.cat((txt_q, img_q), dim=2)
Expand All @@ -90,19 +108,22 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Te

# calculate the img bloks
img = img + img_mod1.gate * self.img_attn_proj(img_attn)
img_mlp_in = (1 + img_mod2.scale) * ops.layer_norm(img, None, None, eps=1e-6) + img_mod2.shift
img_mlp_in = (1 + img_mod2.scale) * ops.layer_norm(
img, None, None, eps=1e-6
) + img_mod2.shift
img_mlp_out1 = self.img_mlp1(img_mlp_in)
img_mlp_out2 = ops.elementwise(F.gelu, img_mlp_out1)
img_mlp_out3 = self.img_mlp2(img_mlp_out2)
img = img + img_mod2.gate * img_mlp_out3

# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn_proj(txt_attn)
txt_mlp_in = (1 + txt_mod2.scale) * ops.layer_norm(txt, None, None, eps=1e-6) + txt_mod2.shift
txt_mlp_in = (1 + txt_mod2.scale) * ops.layer_norm(
txt, None, None, eps=1e-6
) + txt_mod2.shift
txt_mlp_out1 = self.txt_mlp1(txt_mlp_in)
txt_mlp_out2 = ops.elementwise(F.gelu, txt_mlp_out1)
txt_mlp_out3 = self.txt_mlp2(txt_mlp_out2)
txt = txt + txt_mod2.gate * txt_mlp_out3

return img, txt

return img, txt
3 changes: 2 additions & 1 deletion sharktank/sharktank/layers/modulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from .base import Theta, ThetaLayer
from .linear import LinearLayer

class ModulationOut():

class ModulationOut:
def __init__(self, shift, scale, gate):
self.shift = shift
self.scale = scale
Expand Down
109 changes: 30 additions & 79 deletions sharktank/sharktank/layers/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,130 +44,81 @@ def make_llama_attention_block_theta(
}
)

def make_mmdit_block_theta(
dtype: torch.dtype | None = None
) -> Theta:

def make_mmdit_block_theta(dtype: torch.dtype | None = None) -> Theta:
return Theta(
{
"img_attn.norm.key_norm.weight": DefaultPrimitiveTensor( #
data=make_rand_torch(
(128,), dtype=dtype
)
"img_attn.norm.key_norm.weight": DefaultPrimitiveTensor( #
data=make_rand_torch((128,), dtype=dtype)
),
"img_attn.norm.query_norm.weight": DefaultPrimitiveTensor( #
data=make_rand_torch(
(128,), dtype=dtype
)
"img_attn.norm.query_norm.weight": DefaultPrimitiveTensor( #
data=make_rand_torch((128,), dtype=dtype)
),
"img_attn.proj.bias": DefaultPrimitiveTensor(
data=make_rand_torch(
(3072,), dtype=dtype
)
data=make_rand_torch((3072,), dtype=dtype)
),
"img_attn.proj.weight": DefaultPrimitiveTensor(
data=make_rand_torch(
(3072, 3072), dtype=dtype
)
data=make_rand_torch((3072, 3072), dtype=dtype)
),
"img_attn.qkv.bias": DefaultPrimitiveTensor(
data=make_rand_torch(
(9216,), dtype=dtype
)
data=make_rand_torch((9216,), dtype=dtype)
),
"img_attn.qkv.weight": DefaultPrimitiveTensor(
data=make_rand_torch(
(9216, 3072), dtype=dtype
)
data=make_rand_torch((9216, 3072), dtype=dtype)
),
"img_mlp.0.bias": DefaultPrimitiveTensor(
data=make_rand_torch(
(12288), dtype=dtype
)
data=make_rand_torch((12288), dtype=dtype)
),
"img_mlp.0.weight": DefaultPrimitiveTensor(
data=make_rand_torch(
(12288, 3072), dtype=dtype
)
data=make_rand_torch((12288, 3072), dtype=dtype)
),
"img_mlp.2.bias": DefaultPrimitiveTensor(
data=make_rand_torch(
(3072), dtype=dtype
)
data=make_rand_torch((3072), dtype=dtype)
),
"img_mlp.2.weight": DefaultPrimitiveTensor(
data=make_rand_torch(
(3072, 12288), dtype=dtype
)
data=make_rand_torch((3072, 12288), dtype=dtype)
),
"img_mod.lin.bias": DefaultPrimitiveTensor(
data=make_rand_torch(
(18432,), dtype=dtype
)
data=make_rand_torch((18432,), dtype=dtype)
),
"img_mod.lin.weight": DefaultPrimitiveTensor(
data=make_rand_torch(
(18432, 3072), dtype=dtype
)
data=make_rand_torch((18432, 3072), dtype=dtype)
),
"txt_attn.norm.key_norm.weight": DefaultPrimitiveTensor( #
data=make_rand_torch(
(128,), dtype=dtype
)
"txt_attn.norm.key_norm.weight": DefaultPrimitiveTensor( #
data=make_rand_torch((128,), dtype=dtype)
),
"txt_attn.norm.query_norm.weight": DefaultPrimitiveTensor( #
data=make_rand_torch(
(128,), dtype=dtype
)
"txt_attn.norm.query_norm.weight": DefaultPrimitiveTensor( #
data=make_rand_torch((128,), dtype=dtype)
),
"txt_attn.proj.bias": DefaultPrimitiveTensor(
data=make_rand_torch(
(3072,), dtype=dtype
)
data=make_rand_torch((3072,), dtype=dtype)
),
"txt_attn.proj.weight": DefaultPrimitiveTensor(
data=make_rand_torch(
(3072, 3072), dtype=dtype
)
data=make_rand_torch((3072, 3072), dtype=dtype)
),
"txt_attn.qkv.bias": DefaultPrimitiveTensor(
data=make_rand_torch(
(9216,), dtype=dtype
)
data=make_rand_torch((9216,), dtype=dtype)
),
"txt_attn.qkv.weight": DefaultPrimitiveTensor(
data=make_rand_torch(
(9216, 3072), dtype=dtype
)
data=make_rand_torch((9216, 3072), dtype=dtype)
),
"txt_mlp.0.bias": DefaultPrimitiveTensor(
data=make_rand_torch(
(12288), dtype=dtype
)
data=make_rand_torch((12288), dtype=dtype)
),
"txt_mlp.0.weight": DefaultPrimitiveTensor(
data=make_rand_torch(
(12288, 3072), dtype=dtype
)
data=make_rand_torch((12288, 3072), dtype=dtype)
),
"txt_mlp.2.bias": DefaultPrimitiveTensor(
data=make_rand_torch(
(3072), dtype=dtype
)
data=make_rand_torch((3072), dtype=dtype)
),
"txt_mlp.2.weight": DefaultPrimitiveTensor(
data=make_rand_torch(
(3072, 12288), dtype=dtype
)
data=make_rand_torch((3072, 12288), dtype=dtype)
),
"txt_mod.lin.bias": DefaultPrimitiveTensor(
data=make_rand_torch(
(18432,), dtype=dtype
)
data=make_rand_torch((18432,), dtype=dtype)
),
"txt_mod.lin.weight": DefaultPrimitiveTensor(
data=make_rand_torch(
(18432, 3072), dtype=dtype
)
data=make_rand_torch((18432, 3072), dtype=dtype)
),
}
)
2 changes: 2 additions & 0 deletions sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ def interpolate_default(
antialias=antialias,
)


def layer_norm_default(input, weight, bias, *, eps):
input = unbox_tensor(input)
if weight is not None:
Expand All @@ -334,6 +335,7 @@ def layer_norm_default(input, weight, bias, *, eps):
input, normalized_shape=weight.shape, weight=weight, bias=bias, eps=eps
)


layer_norm.override(Tensor)(layer_norm_default)
layer_norm.override(Tensor, Tensor)(layer_norm_default)
layer_norm.override(Tensor, Tensor, Tensor)(layer_norm_default)
Expand Down
Loading

0 comments on commit 7177d2d

Please sign in to comment.