Skip to content

Commit

Permalink
Add modules
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleHerndon committed Nov 22, 2024
1 parent 6ae65c3 commit 4e2e1a4
Showing 1 changed file with 34 additions and 18 deletions.
52 changes: 34 additions & 18 deletions sharktank/sharktank/layers/mmdit_double.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,31 +40,47 @@ def __init__(self, theta, num_heads: int):
super().__init__(theta)

self.num_heads = num_heads
self.img_mod = ModulationLayer(theta("img_mod"), double=True)
self.img_attn_qkv = LinearLayer(theta("img_attn.qkv"))
self.img_attn_norm_q = RMSNormLayer(
theta("img_attn.norm.query_norm"), epsilon=1e-6
self.img_mod = self.add_module(
"img_mod", ModulationLayer(theta("img_mod"), double=True)
)
self.img_attn_norm_k = RMSNormLayer(
theta("img_attn.norm.key_norm"), epsilon=1e-6
self.img_attn_qkv = self.add_module(
"img_attn_qkv", LinearLayer(theta("img_attn.qkv"))
)
self.img_attn_norm_q = self.add_module(
"img_attn_query_norm",
RMSNormLayer(theta("img_attn.norm.query_norm"), epsilon=1e-6),
)
self.img_attn_norm_k = self.add_module(
"img_attn_key_norm",
RMSNormLayer(theta("img_attn.norm.key_norm"), epsilon=1e-6),
)
self.img_attn_proj = self.add_module(
"img_attn_proj", LinearLayer(theta("img_attn.proj"))
)
self.img_attn_proj = LinearLayer(theta("img_attn.proj"))

self.img_mlp1 = LinearLayer(theta("img_mlp.0"))
self.img_mlp2 = LinearLayer(theta("img_mlp.2"))
self.img_mlp1 = self.add_module("img_mlp_0", LinearLayer(theta("img_mlp.0")))
self.img_mlp2 = self.add_module("img_mlp_2", LinearLayer(theta("img_mlp.2")))

self.txt_mod = ModulationLayer(theta("txt_mod"), double=True)
self.txt_attn_qkv = LinearLayer(theta("txt_attn.qkv"))
self.txt_attn_norm_q = RMSNormLayer(
theta("txt_attn.norm.query_norm"), epsilon=1e-6
self.txt_mod = self.add_module(
"txt_mod", ModulationLayer(theta("txt_mod"), double=True)
)
self.txt_attn_qkv = self.add_module(
"txt_attn_qkv", LinearLayer(theta("txt_attn.qkv"))
)
self.txt_attn_norm_q = self.add_module(
"txt_attn_query_norm",
RMSNormLayer(theta("txt_attn.norm.query_norm"), epsilon=1e-6),
)
self.txt_attn_norm_k = self.add_module(
"txt_attn_key_norm",
RMSNormLayer(theta("txt_attn.norm.key_norm"), epsilon=1e-6),
)
self.txt_attn_norm_k = RMSNormLayer(
theta("txt_attn.norm.key_norm"), epsilon=1e-6
self.txt_attn_proj = self.add_module(
"txt_attn_proj", LinearLayer(theta("txt_attn.proj"))
)
self.txt_attn_proj = LinearLayer(theta("txt_attn.proj"))

self.txt_mlp1 = LinearLayer(theta("txt_mlp.0"))
self.txt_mlp2 = LinearLayer(theta("txt_mlp.2"))
self.txt_mlp1 = self.add_module("txt_mlp_0", LinearLayer(theta("txt_mlp.0")))
self.txt_mlp2 = self.add_module("txt_mlp_2", LinearLayer(theta("txt_mlp.2")))

def forward(
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
Expand Down

0 comments on commit 4e2e1a4

Please sign in to comment.