Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement MMDIT block that is necessary for flux #592

Merged
merged 10 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sharktank/sharktank/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@
from .ffn_block import FFN
from .ffn_moe_block import FFNMOE
from .mixture_of_experts_block import MoeBlock
from .mmdit_double import MMDITDoubleBlock

from .configs import *
133 changes: 133 additions & 0 deletions sharktank/sharktank/layers/mmdit_double.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import torch.nn.functional as F
import torch
from torch import Tensor

from .. import ops

from .base import Theta, ThetaLayer
from .linear import LinearLayer
from .modulation import ModulationLayer
from .norm import RMSNormLayer
from .paged_llama_attention_block import PagedLlamaAttentionBlock


def qk_norm(q, k, v, rms_q, rms_k):
return rms_q(q).to(v), rms_k(k).to(v)


def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)


def attention(q, k, v, pe):
q, k = apply_rope(q, k, pe) # todo

x = ops.scaled_dot_product_attention(
q=q, k=k, v=v, a=None, is_causal=True, scale=None
)
x = ops.permute(x, (0, 2, 1, 3))
x = x.view(x.shape[0], x.shape[1], -1)

return x


class MMDITDoubleBlock(ThetaLayer):
def __init__(self, theta, num_heads: int):
super().__init__(theta)

self.num_heads = num_heads
self.add_module("img_mod", ModulationLayer(theta("img_mod"), double=True))
self.add_module("img_attn_qkv", LinearLayer(theta("img_attn.qkv")))
self.add_module(
"img_attn_norm_q",
RMSNormLayer(theta("img_attn.norm.query_norm"), epsilon=1e-6),
)
self.add_module(
"img_attn_norm_k",
RMSNormLayer(theta("img_attn.norm.key_norm"), epsilon=1e-6),
)
self.add_module("img_attn_proj", LinearLayer(theta("img_attn.proj")))

self.add_module("img_mlp1", LinearLayer(theta("img_mlp.0")))
self.add_module("img_mlp2", LinearLayer(theta("img_mlp.2")))

self.add_module("txt_mod", ModulationLayer(theta("txt_mod"), double=True))
self.add_module("txt_attn_qkv", LinearLayer(theta("txt_attn.qkv")))
self.add_module(
"txt_attn_norm_q",
RMSNormLayer(theta("txt_attn.norm.query_norm"), epsilon=1e-6),
)
self.add_module(
"txt_attn_norm_k",
RMSNormLayer(theta("txt_attn.norm.key_norm"), epsilon=1e-6),
)
self.add_module("txt_attn_proj", LinearLayer(theta("txt_attn.proj")))

self.add_module("txt_mlp1", LinearLayer(theta("txt_mlp.0")))
self.add_module("txt_mlp2", LinearLayer(theta("txt_mlp.2")))

def forward(
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
) -> tuple[Tensor, Tensor]:
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)

# prepare image for attention
img_modulated = ops.layer_norm(img, None, None, eps=1e-6)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn_qkv(img_modulated)
img_qkv_2 = img_qkv.view(
img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1
) #
img_qkv_3 = ops.permute(img_qkv_2, (2, 0, 3, 1, 4))
img_q, img_k, img_v = img_qkv_3
img_q, img_k = qk_norm(
img_q, img_k, img_v, self.img_attn_norm_q, self.img_attn_norm_k
)

# prepare txt for attention
txt_modulated = ops.layer_norm(txt, None, None, eps=1e-6)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_qkv_2 = txt_qkv.view(
txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1
) #
txt_qkv_3 = ops.permute(txt_qkv_2, (2, 0, 3, 1, 4))
txt_q, txt_k, txt_v = txt_qkv_3
txt_q, txt_k = qk_norm(
txt_q, txt_k, txt_v, self.txt_attn_norm_q, self.txt_attn_norm_k
)

# run actual attention
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)

attn = attention(q, k, v, pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]

# calculate the img bloks
img = img + img_mod1.gate * self.img_attn_proj(img_attn)
img_mlp_in = (1 + img_mod2.scale) * ops.layer_norm(
img, None, None, eps=1e-6
) + img_mod2.shift
img_mlp_out1 = self.img_mlp1(img_mlp_in)
img_mlp_out2 = ops.elementwise(F.gelu, img_mlp_out1)
img_mlp_out3 = self.img_mlp2(img_mlp_out2)
img = img + img_mod2.gate * img_mlp_out3

# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn_proj(txt_attn)
txt_mlp_in = (1 + txt_mod2.scale) * ops.layer_norm(
txt, None, None, eps=1e-6
) + txt_mod2.shift
txt_mlp_out1 = self.txt_mlp1(txt_mlp_in)
txt_mlp_out2 = ops.elementwise(F.gelu, txt_mlp_out1)
txt_mlp_out3 = self.txt_mlp2(txt_mlp_out2)
txt = txt + txt_mod2.gate * txt_mlp_out3

return img, txt
32 changes: 32 additions & 0 deletions sharktank/sharktank/layers/modulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
import torch.nn.functional as F

from .. import ops

from .base import Theta, ThetaLayer
from .linear import LinearLayer


class ModulationOut:
def __init__(self, shift, scale, gate):
self.shift = shift
self.scale = scale
self.gate = gate


class ModulationLayer(ThetaLayer):
def __init__(self, theta: Theta, double: bool):
super().__init__(theta)

self.is_double = double
self.multiplier = 6 if double else 3
self.add_module("lin", LinearLayer(theta("lin")))

def forward(self, vec: torch.Tensor) -> tuple[ModulationOut, ModulationOut | None]:
silu_result = ops.elementwise(F.silu, vec)
out = self.lin(silu_result)[:, None, :].chunk(self.multiplier, dim=-1)

return (
ModulationOut(*out[:3]),
ModulationOut(*out[3:]) if self.is_double else None,
)
79 changes: 79 additions & 0 deletions sharktank/sharktank/layers/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,82 @@ def make_llama_attention_block_theta(
),
}
)


def make_mmdit_block_theta(dtype: torch.dtype | None = None) -> Theta:
return Theta(
{
"img_attn.norm.key_norm.weight": DefaultPrimitiveTensor( #
data=make_rand_torch((128,), dtype=dtype)
),
"img_attn.norm.query_norm.weight": DefaultPrimitiveTensor( #
data=make_rand_torch((128,), dtype=dtype)
),
"img_attn.proj.bias": DefaultPrimitiveTensor(
data=make_rand_torch((3072,), dtype=dtype)
),
"img_attn.proj.weight": DefaultPrimitiveTensor(
data=make_rand_torch((3072, 3072), dtype=dtype)
),
"img_attn.qkv.bias": DefaultPrimitiveTensor(
data=make_rand_torch((9216,), dtype=dtype)
),
"img_attn.qkv.weight": DefaultPrimitiveTensor(
data=make_rand_torch((9216, 3072), dtype=dtype)
),
"img_mlp.0.bias": DefaultPrimitiveTensor(
data=make_rand_torch((12288), dtype=dtype)
),
"img_mlp.0.weight": DefaultPrimitiveTensor(
data=make_rand_torch((12288, 3072), dtype=dtype)
),
"img_mlp.2.bias": DefaultPrimitiveTensor(
data=make_rand_torch((3072), dtype=dtype)
),
"img_mlp.2.weight": DefaultPrimitiveTensor(
data=make_rand_torch((3072, 12288), dtype=dtype)
),
"img_mod.lin.bias": DefaultPrimitiveTensor(
data=make_rand_torch((18432,), dtype=dtype)
),
"img_mod.lin.weight": DefaultPrimitiveTensor(
data=make_rand_torch((18432, 3072), dtype=dtype)
),
"txt_attn.norm.key_norm.weight": DefaultPrimitiveTensor( #
data=make_rand_torch((128,), dtype=dtype)
),
"txt_attn.norm.query_norm.weight": DefaultPrimitiveTensor( #
data=make_rand_torch((128,), dtype=dtype)
),
"txt_attn.proj.bias": DefaultPrimitiveTensor(
data=make_rand_torch((3072,), dtype=dtype)
),
"txt_attn.proj.weight": DefaultPrimitiveTensor(
data=make_rand_torch((3072, 3072), dtype=dtype)
),
"txt_attn.qkv.bias": DefaultPrimitiveTensor(
data=make_rand_torch((9216,), dtype=dtype)
),
"txt_attn.qkv.weight": DefaultPrimitiveTensor(
data=make_rand_torch((9216, 3072), dtype=dtype)
),
"txt_mlp.0.bias": DefaultPrimitiveTensor(
data=make_rand_torch((12288), dtype=dtype)
),
"txt_mlp.0.weight": DefaultPrimitiveTensor(
data=make_rand_torch((12288, 3072), dtype=dtype)
),
"txt_mlp.2.bias": DefaultPrimitiveTensor(
data=make_rand_torch((3072), dtype=dtype)
),
"txt_mlp.2.weight": DefaultPrimitiveTensor(
data=make_rand_torch((3072, 12288), dtype=dtype)
),
"txt_mod.lin.bias": DefaultPrimitiveTensor(
data=make_rand_torch((18432,), dtype=dtype)
),
"txt_mod.lin.weight": DefaultPrimitiveTensor(
data=make_rand_torch((18432, 3072), dtype=dtype)
),
}
)
16 changes: 13 additions & 3 deletions sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,16 +304,26 @@ def interpolate_default(
)


@layer_norm.override(Tensor, Tensor, Tensor)
def layer_norm_default(input, weight, bias, *, eps):
input = unbox_tensor(input)
weight = unbox_tensor(weight)
bias = unbox_tensor(bias)
if weight is not None:
weight = unbox_tensor(weight)
else:
weight = torch.ones(input.shape)
if bias is not None:
bias = unbox_tensor(bias)
else:
bias = torch.zeros(input.shape)
return F.layer_norm(
input, normalized_shape=weight.shape, weight=weight, bias=bias, eps=eps
)


layer_norm.override(Tensor)(layer_norm_default)
layer_norm.override(Tensor, Tensor)(layer_norm_default)
layer_norm.override(Tensor, Tensor, Tensor)(layer_norm_default)


# Linear
def linear_default(input, weight, bias, *, accum_dtype) -> Tensor:
input = unbox_tensor(input)
Expand Down
6 changes: 4 additions & 2 deletions sharktank/sharktank/ops/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,12 +582,14 @@ def layer_norm(
def _layer_norm_trampoline(
d: SignatureDispatcher,
input: AnyTensor,
weight: AnyTensor,
weight: Optional[AnyTensor],
bias: Optional[AnyTensor],
*,
eps: float,
):
tensors = [input, weight]
tensors = [input]
if weight is not None:
tensors.append(bias)
if bias is not None:
tensors.append(bias)
for override in d.find_overrides(tensors):
Expand Down
58 changes: 58 additions & 0 deletions sharktank/tests/layers/mmdit_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import logging

logging.basicConfig(level=logging.DEBUG)

import unittest

import torch

from iree.turbine import aot
from sharktank.layers import (
MMDITDoubleBlock,
)
import sharktank.ops as ops
from sharktank.layers.testing import (
make_mmdit_block_theta,
)
from sharktank.types.tensors import DefaultPrimitiveTensor


class MMDITTest(unittest.TestCase):
def setUp(self):
torch.manual_seed(12345)
self.hidden_size = 3072
self.num_heads = 24
self.batch_size = 3

def testExport(self):

theta = make_mmdit_block_theta()
mmdit = MMDITDoubleBlock(
theta=theta,
num_heads=self.num_heads,
)

img = torch.rand([self.batch_size, 1024, self.hidden_size])
txt = torch.rand([self.batch_size, 512, self.hidden_size])
vec = torch.rand([self.batch_size, self.hidden_size])
rot = torch.rand([self.batch_size, 1, 1536, 64, 2, 2])
mmdit.forward(img, txt, vec, rot)
fxb = aot.FxProgramsBuilder(mmdit)

@fxb.export_program(name="mmdit", args=(img, txt, vec, rot), strict=False)
def _(model, img, txt, vec, rot) -> torch.Tensor:
return model.forward(img, txt, vec, rot)

output = aot.export(fxb)
output.verify()
asm = str(output.mlir_module)


if __name__ == "__main__":
unittest.main()
Loading