From 6ae65c30db06360f0eac5f3332cea10f5257a6d1 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Fri, 22 Nov 2024 00:54:31 -0800 Subject: [PATCH] Add a test for mmdit --- sharktank/sharktank/layers/__init__.py | 1 - sharktank/sharktank/layers/mmdit_double.py | 55 +++++++---- sharktank/sharktank/layers/modulation.py | 3 +- sharktank/sharktank/layers/testing.py | 109 ++++++--------------- sharktank/sharktank/ops/default_impls.py | 3 +- sharktank/tests/layers/mmdit_test.py | 106 ++++++++++++++++++++ 6 files changed, 178 insertions(+), 99 deletions(-) create mode 100644 sharktank/tests/layers/mmdit_test.py diff --git a/sharktank/sharktank/layers/__init__.py b/sharktank/sharktank/layers/__init__.py index f5def99e3..ae44f2d7a 100644 --- a/sharktank/sharktank/layers/__init__.py +++ b/sharktank/sharktank/layers/__init__.py @@ -18,6 +18,5 @@ from .ffn_moe_block import FFNMOE from .mixture_of_experts_block import MoeBlock from .mmdit_double import MMDITDoubleBlock -from .mmdit_single import MMDITSingleBlock from .configs import * diff --git a/sharktank/sharktank/layers/mmdit_double.py b/sharktank/sharktank/layers/mmdit_double.py index b0cef175f..230b1ea95 100644 --- a/sharktank/sharktank/layers/mmdit_double.py +++ b/sharktank/sharktank/layers/mmdit_double.py @@ -11,7 +11,6 @@ from .paged_llama_attention_block import PagedLlamaAttentionBlock - def qk_norm(q, k, v, rms_q, rms_k): return rms_q(q).to(v), rms_k(k).to(v) @@ -25,9 +24,11 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tenso def attention(q, k, v, pe): - q, k = apply_rope(q, k, pe) #todo + q, k = apply_rope(q, k, pe) # todo - x = ops.scaled_dot_product_attention(q=q, k=k, v=v, a=None, is_causal=True, scale=None) + x = ops.scaled_dot_product_attention( + q=q, k=k, v=v, a=None, is_causal=True, scale=None + ) x = ops.permute(x, (0, 2, 1, 3)) x = x.view(x.shape[0], x.shape[1], -1) @@ -41,8 +42,12 @@ def __init__(self, theta, num_heads: int): self.num_heads = num_heads self.img_mod = ModulationLayer(theta("img_mod"), double=True) self.img_attn_qkv = LinearLayer(theta("img_attn.qkv")) - self.img_attn_norm_q = RMSNormLayer(theta("img_attn.norm.query_norm"), epsilon=1e-6) - self.img_attn_norm_k = RMSNormLayer(theta("img_attn.norm.key_norm"), epsilon=1e-6) + self.img_attn_norm_q = RMSNormLayer( + theta("img_attn.norm.query_norm"), epsilon=1e-6 + ) + self.img_attn_norm_k = RMSNormLayer( + theta("img_attn.norm.key_norm"), epsilon=1e-6 + ) self.img_attn_proj = LinearLayer(theta("img_attn.proj")) self.img_mlp1 = LinearLayer(theta("img_mlp.0")) @@ -50,14 +55,20 @@ def __init__(self, theta, num_heads: int): self.txt_mod = ModulationLayer(theta("txt_mod"), double=True) self.txt_attn_qkv = LinearLayer(theta("txt_attn.qkv")) - self.txt_attn_norm_q = RMSNormLayer(theta("txt_attn.norm.query_norm"), epsilon=1e-6) - self.txt_attn_norm_k = RMSNormLayer(theta("txt_attn.norm.key_norm"), epsilon=1e-6) + self.txt_attn_norm_q = RMSNormLayer( + theta("txt_attn.norm.query_norm"), epsilon=1e-6 + ) + self.txt_attn_norm_k = RMSNormLayer( + theta("txt_attn.norm.key_norm"), epsilon=1e-6 + ) self.txt_attn_proj = LinearLayer(theta("txt_attn.proj")) self.txt_mlp1 = LinearLayer(theta("txt_mlp.0")) self.txt_mlp2 = LinearLayer(theta("txt_mlp.2")) - def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: + def forward( + self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor + ) -> tuple[Tensor, Tensor]: img_mod1, img_mod2 = self.img_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec) @@ -65,20 +76,27 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Te img_modulated = ops.layer_norm(img, None, None, eps=1e-6) img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift img_qkv = self.img_attn_qkv(img_modulated) - img_qkv_2 = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1) # + img_qkv_2 = img_qkv.view( + img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1 + ) # img_qkv_3 = ops.permute(img_qkv_2, (2, 0, 3, 1, 4)) img_q, img_k, img_v = img_qkv_3 - img_q, img_k = qk_norm(img_q, img_k, img_v, self.img_attn_norm_q, self.img_attn_norm_k) - + img_q, img_k = qk_norm( + img_q, img_k, img_v, self.img_attn_norm_q, self.img_attn_norm_k + ) # prepare txt for attention txt_modulated = ops.layer_norm(txt, None, None, eps=1e-6) txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift txt_qkv = self.txt_attn_qkv(txt_modulated) - txt_qkv_2 = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1) # + txt_qkv_2 = txt_qkv.view( + txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1 + ) # txt_qkv_3 = ops.permute(txt_qkv_2, (2, 0, 3, 1, 4)) txt_q, txt_k, txt_v = txt_qkv_3 - txt_q, txt_k = qk_norm(txt_q, txt_k, txt_v, self.txt_attn_norm_q, self.txt_attn_norm_k) + txt_q, txt_k = qk_norm( + txt_q, txt_k, txt_v, self.txt_attn_norm_q, self.txt_attn_norm_k + ) # run actual attention q = torch.cat((txt_q, img_q), dim=2) @@ -90,7 +108,9 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Te # calculate the img bloks img = img + img_mod1.gate * self.img_attn_proj(img_attn) - img_mlp_in = (1 + img_mod2.scale) * ops.layer_norm(img, None, None, eps=1e-6) + img_mod2.shift + img_mlp_in = (1 + img_mod2.scale) * ops.layer_norm( + img, None, None, eps=1e-6 + ) + img_mod2.shift img_mlp_out1 = self.img_mlp1(img_mlp_in) img_mlp_out2 = ops.elementwise(F.gelu, img_mlp_out1) img_mlp_out3 = self.img_mlp2(img_mlp_out2) @@ -98,11 +118,12 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Te # calculate the txt bloks txt = txt + txt_mod1.gate * self.txt_attn_proj(txt_attn) - txt_mlp_in = (1 + txt_mod2.scale) * ops.layer_norm(txt, None, None, eps=1e-6) + txt_mod2.shift + txt_mlp_in = (1 + txt_mod2.scale) * ops.layer_norm( + txt, None, None, eps=1e-6 + ) + txt_mod2.shift txt_mlp_out1 = self.txt_mlp1(txt_mlp_in) txt_mlp_out2 = ops.elementwise(F.gelu, txt_mlp_out1) txt_mlp_out3 = self.txt_mlp2(txt_mlp_out2) txt = txt + txt_mod2.gate * txt_mlp_out3 - - return img, txt + return img, txt diff --git a/sharktank/sharktank/layers/modulation.py b/sharktank/sharktank/layers/modulation.py index dcde34606..e45cc6ed1 100644 --- a/sharktank/sharktank/layers/modulation.py +++ b/sharktank/sharktank/layers/modulation.py @@ -6,7 +6,8 @@ from .base import Theta, ThetaLayer from .linear import LinearLayer -class ModulationOut(): + +class ModulationOut: def __init__(self, shift, scale, gate): self.shift = shift self.scale = scale diff --git a/sharktank/sharktank/layers/testing.py b/sharktank/sharktank/layers/testing.py index 5d4cbcbc1..86b81a126 100644 --- a/sharktank/sharktank/layers/testing.py +++ b/sharktank/sharktank/layers/testing.py @@ -44,130 +44,81 @@ def make_llama_attention_block_theta( } ) -def make_mmdit_block_theta( - dtype: torch.dtype | None = None -) -> Theta: + +def make_mmdit_block_theta(dtype: torch.dtype | None = None) -> Theta: return Theta( { - "img_attn.norm.key_norm.weight": DefaultPrimitiveTensor( # - data=make_rand_torch( - (128,), dtype=dtype - ) + "img_attn.norm.key_norm.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((128,), dtype=dtype) ), - "img_attn.norm.query_norm.weight": DefaultPrimitiveTensor( # - data=make_rand_torch( - (128,), dtype=dtype - ) + "img_attn.norm.query_norm.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((128,), dtype=dtype) ), "img_attn.proj.bias": DefaultPrimitiveTensor( - data=make_rand_torch( - (3072,), dtype=dtype - ) + data=make_rand_torch((3072,), dtype=dtype) ), "img_attn.proj.weight": DefaultPrimitiveTensor( - data=make_rand_torch( - (3072, 3072), dtype=dtype - ) + data=make_rand_torch((3072, 3072), dtype=dtype) ), "img_attn.qkv.bias": DefaultPrimitiveTensor( - data=make_rand_torch( - (9216,), dtype=dtype - ) + data=make_rand_torch((9216,), dtype=dtype) ), "img_attn.qkv.weight": DefaultPrimitiveTensor( - data=make_rand_torch( - (9216, 3072), dtype=dtype - ) + data=make_rand_torch((9216, 3072), dtype=dtype) ), "img_mlp.0.bias": DefaultPrimitiveTensor( - data=make_rand_torch( - (12288), dtype=dtype - ) + data=make_rand_torch((12288), dtype=dtype) ), "img_mlp.0.weight": DefaultPrimitiveTensor( - data=make_rand_torch( - (12288, 3072), dtype=dtype - ) + data=make_rand_torch((12288, 3072), dtype=dtype) ), "img_mlp.2.bias": DefaultPrimitiveTensor( - data=make_rand_torch( - (3072), dtype=dtype - ) + data=make_rand_torch((3072), dtype=dtype) ), "img_mlp.2.weight": DefaultPrimitiveTensor( - data=make_rand_torch( - (3072, 12288), dtype=dtype - ) + data=make_rand_torch((3072, 12288), dtype=dtype) ), "img_mod.lin.bias": DefaultPrimitiveTensor( - data=make_rand_torch( - (18432,), dtype=dtype - ) + data=make_rand_torch((18432,), dtype=dtype) ), "img_mod.lin.weight": DefaultPrimitiveTensor( - data=make_rand_torch( - (18432, 3072), dtype=dtype - ) + data=make_rand_torch((18432, 3072), dtype=dtype) ), - "txt_attn.norm.key_norm.weight": DefaultPrimitiveTensor( # - data=make_rand_torch( - (128,), dtype=dtype - ) + "txt_attn.norm.key_norm.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((128,), dtype=dtype) ), - "txt_attn.norm.query_norm.weight": DefaultPrimitiveTensor( # - data=make_rand_torch( - (128,), dtype=dtype - ) + "txt_attn.norm.query_norm.weight": DefaultPrimitiveTensor( # + data=make_rand_torch((128,), dtype=dtype) ), "txt_attn.proj.bias": DefaultPrimitiveTensor( - data=make_rand_torch( - (3072,), dtype=dtype - ) + data=make_rand_torch((3072,), dtype=dtype) ), "txt_attn.proj.weight": DefaultPrimitiveTensor( - data=make_rand_torch( - (3072, 3072), dtype=dtype - ) + data=make_rand_torch((3072, 3072), dtype=dtype) ), "txt_attn.qkv.bias": DefaultPrimitiveTensor( - data=make_rand_torch( - (9216,), dtype=dtype - ) + data=make_rand_torch((9216,), dtype=dtype) ), "txt_attn.qkv.weight": DefaultPrimitiveTensor( - data=make_rand_torch( - (9216, 3072), dtype=dtype - ) + data=make_rand_torch((9216, 3072), dtype=dtype) ), "txt_mlp.0.bias": DefaultPrimitiveTensor( - data=make_rand_torch( - (12288), dtype=dtype - ) + data=make_rand_torch((12288), dtype=dtype) ), "txt_mlp.0.weight": DefaultPrimitiveTensor( - data=make_rand_torch( - (12288, 3072), dtype=dtype - ) + data=make_rand_torch((12288, 3072), dtype=dtype) ), "txt_mlp.2.bias": DefaultPrimitiveTensor( - data=make_rand_torch( - (3072), dtype=dtype - ) + data=make_rand_torch((3072), dtype=dtype) ), "txt_mlp.2.weight": DefaultPrimitiveTensor( - data=make_rand_torch( - (3072, 12288), dtype=dtype - ) + data=make_rand_torch((3072, 12288), dtype=dtype) ), "txt_mod.lin.bias": DefaultPrimitiveTensor( - data=make_rand_torch( - (18432,), dtype=dtype - ) + data=make_rand_torch((18432,), dtype=dtype) ), "txt_mod.lin.weight": DefaultPrimitiveTensor( - data=make_rand_torch( - (18432, 3072), dtype=dtype - ) + data=make_rand_torch((18432, 3072), dtype=dtype) ), } ) diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index 1f1798caa..1a79533d6 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -320,6 +320,7 @@ def interpolate_default( antialias=antialias, ) + def layer_norm_default(input, weight, bias, *, eps): input = unbox_tensor(input) if weight is not None: @@ -334,6 +335,7 @@ def layer_norm_default(input, weight, bias, *, eps): input, normalized_shape=weight.shape, weight=weight, bias=bias, eps=eps ) + layer_norm.override(Tensor)(layer_norm_default) layer_norm.override(Tensor, Tensor)(layer_norm_default) layer_norm.override(Tensor, Tensor, Tensor)(layer_norm_default) @@ -419,7 +421,6 @@ def rms_norm_default(x, weight, *, epsilon: float) -> Tensor: variance = x.pow(2).mean(-1, keepdim=True) output = x * elementwise(torch.rsqrt, variance + epsilon) # The cast here is to match the hf implementation, affects numerics - print(x.shape, weight.shape) output = elementwise(torch.mul, weight, to(output, weight.dtype)) return output diff --git a/sharktank/tests/layers/mmdit_test.py b/sharktank/tests/layers/mmdit_test.py new file mode 100644 index 000000000..33547e1c5 --- /dev/null +++ b/sharktank/tests/layers/mmdit_test.py @@ -0,0 +1,106 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging + +logging.basicConfig(level=logging.DEBUG) + +import unittest + +import torch + +from iree.turbine import aot +from sharktank.layers import ( + MMDITDoubleBlock, + PagedLlamaAttentionBlock, + PagedKVCache, + RotaryEmbeddingLayer, +) +import sharktank.ops as ops +from sharktank.layers.testing import ( + make_llama_attention_block_theta, + make_mmdit_block_theta, +) +from sharktank.types.tensors import DefaultPrimitiveTensor + + +class MMDITTest(unittest.TestCase): + def setUp(self): + torch.manual_seed(12345) + self.hidden_size = 3072 + self.num_heads = 24 + + self.transformer_block_count = 13 + self.block_index = 1 + self.shard_count = 3 + self.head_count_kv = 2 * self.shard_count + self.attention_head_count = 5 * self.head_count_kv + self.attention_head_dim = 24 + self.rms_epsilon = 0.01 + self.cache_partition_count = 2 + self.page_count = 23 + self.embedding_length = self.attention_head_count * self.attention_head_dim + self.rope_dimension_count = self.attention_head_dim + self.block_seqlen = 7 + self.block_seq_stride = 17 + self.max_seqlen = self.block_seq_stride * self.block_seqlen + self.rope_freq_base = None + self.batch_size = 3 + self.start_index = 0 + + def testExport(self): + dtype = torch.float32 + + txt_ids = torch.rand([self.batch_size, 3, self.max_seqlen, 3]) + img_ids = torch.rand([self.batch_size, 3, self.max_seqlen, 3]) + pe_dim = self.hidden_size // self.num_heads + axes_dim = [16, 56, 56] + theta = 10000 + + theta = make_mmdit_block_theta() + mmdit = MMDITDoubleBlock( + theta=theta, + num_heads=self.num_heads, + ) + + seq_block_ids = torch.arange(self.batch_size * self.block_seqlen).view( + self.batch_size, -1 + ) + + embedding_module = RotaryEmbeddingLayer( + rope_dimension_count=self.rope_dimension_count, + max_seqlen=self.max_seqlen, + rope_freq_base=self.rope_freq_base, + ) + + class MyModule(torch.nn.Module): + def forward(self, img, txt, vec, pe): + return mmdit.forward( + img, + txt, + vec, + pe, + ) + + mod = MyModule() + img = torch.rand([self.batch_size, 1024, self.hidden_size]) + txt = torch.rand([self.batch_size, 512, self.hidden_size]) + vec = torch.rand([self.batch_size, self.hidden_size]) + rot = torch.rand([self.batch_size, 1, 1536, 64, 2, 2]) + mod.forward(img, txt, vec, rot) + fxb = aot.FxProgramsBuilder(mod) + + @fxb.export_program(name="mmdit", args=(img, txt, vec, rot), strict=False) + def _(model, img, txt, vec, rot) -> torch.Tensor: + return mod(img, txt, vec, rot) + + output = aot.export(fxb) + output.verify() + asm = str(output.mlir_module) + + +if __name__ == "__main__": + unittest.main()