1111from .paged_llama_attention_block import PagedLlamaAttentionBlock
1212
1313
14-
1514def qk_norm (q , k , v , rms_q , rms_k ):
1615 return rms_q (q ).to (v ), rms_k (k ).to (v )
1716
@@ -25,9 +24,11 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tenso
2524
2625
2726def attention (q , k , v , pe ):
28- q , k = apply_rope (q , k , pe ) # todo
27+ q , k = apply_rope (q , k , pe ) # todo
2928
30- x = ops .scaled_dot_product_attention (q = q , k = k , v = v , a = None , is_causal = True , scale = None )
29+ x = ops .scaled_dot_product_attention (
30+ q = q , k = k , v = v , a = None , is_causal = True , scale = None
31+ )
3132 x = ops .permute (x , (0 , 2 , 1 , 3 ))
3233 x = x .view (x .shape [0 ], x .shape [1 ], - 1 )
3334
@@ -41,44 +42,61 @@ def __init__(self, theta, num_heads: int):
4142 self .num_heads = num_heads
4243 self .img_mod = ModulationLayer (theta ("img_mod" ), double = True )
4344 self .img_attn_qkv = LinearLayer (theta ("img_attn.qkv" ))
44- self .img_attn_norm_q = RMSNormLayer (theta ("img_attn.norm.query_norm" ), epsilon = 1e-6 )
45- self .img_attn_norm_k = RMSNormLayer (theta ("img_attn.norm.key_norm" ), epsilon = 1e-6 )
45+ self .img_attn_norm_q = RMSNormLayer (
46+ theta ("img_attn.norm.query_norm" ), epsilon = 1e-6
47+ )
48+ self .img_attn_norm_k = RMSNormLayer (
49+ theta ("img_attn.norm.key_norm" ), epsilon = 1e-6
50+ )
4651 self .img_attn_proj = LinearLayer (theta ("img_attn.proj" ))
4752
4853 self .img_mlp1 = LinearLayer (theta ("img_mlp.0" ))
4954 self .img_mlp2 = LinearLayer (theta ("img_mlp.2" ))
5055
5156 self .txt_mod = ModulationLayer (theta ("txt_mod" ), double = True )
5257 self .txt_attn_qkv = LinearLayer (theta ("txt_attn.qkv" ))
53- self .txt_attn_norm_q = RMSNormLayer (theta ("txt_attn.norm.query_norm" ), epsilon = 1e-6 )
54- self .txt_attn_norm_k = RMSNormLayer (theta ("txt_attn.norm.key_norm" ), epsilon = 1e-6 )
58+ self .txt_attn_norm_q = RMSNormLayer (
59+ theta ("txt_attn.norm.query_norm" ), epsilon = 1e-6
60+ )
61+ self .txt_attn_norm_k = RMSNormLayer (
62+ theta ("txt_attn.norm.key_norm" ), epsilon = 1e-6
63+ )
5564 self .txt_attn_proj = LinearLayer (theta ("txt_attn.proj" ))
5665
5766 self .txt_mlp1 = LinearLayer (theta ("txt_mlp.0" ))
5867 self .txt_mlp2 = LinearLayer (theta ("txt_mlp.2" ))
5968
60- def forward (self , img : Tensor , txt : Tensor , vec : Tensor , pe : Tensor ) -> tuple [Tensor , Tensor ]:
69+ def forward (
70+ self , img : Tensor , txt : Tensor , vec : Tensor , pe : Tensor
71+ ) -> tuple [Tensor , Tensor ]:
6172 img_mod1 , img_mod2 = self .img_mod (vec )
6273 txt_mod1 , txt_mod2 = self .txt_mod (vec )
6374
6475 # prepare image for attention
6576 img_modulated = ops .layer_norm (img , None , None , eps = 1e-6 )
6677 img_modulated = (1 + img_mod1 .scale ) * img_modulated + img_mod1 .shift
6778 img_qkv = self .img_attn_qkv (img_modulated )
68- img_qkv_2 = img_qkv .view (img_qkv .shape [0 ], img_qkv .shape [1 ], 3 , self .num_heads , - 1 ) #
79+ img_qkv_2 = img_qkv .view (
80+ img_qkv .shape [0 ], img_qkv .shape [1 ], 3 , self .num_heads , - 1
81+ ) #
6982 img_qkv_3 = ops .permute (img_qkv_2 , (2 , 0 , 3 , 1 , 4 ))
7083 img_q , img_k , img_v = img_qkv_3
71- img_q , img_k = qk_norm (img_q , img_k , img_v , self .img_attn_norm_q , self .img_attn_norm_k )
72-
84+ img_q , img_k = qk_norm (
85+ img_q , img_k , img_v , self .img_attn_norm_q , self .img_attn_norm_k
86+ )
7387
7488 # prepare txt for attention
7589 txt_modulated = ops .layer_norm (txt , None , None , eps = 1e-6 )
7690 txt_modulated = (1 + txt_mod1 .scale ) * txt_modulated + txt_mod1 .shift
7791 txt_qkv = self .txt_attn_qkv (txt_modulated )
78- txt_qkv_2 = txt_qkv .view (txt_qkv .shape [0 ], txt_qkv .shape [1 ], 3 , self .num_heads , - 1 ) #
92+ txt_qkv_2 = txt_qkv .view (
93+ txt_qkv .shape [0 ], txt_qkv .shape [1 ], 3 , self .num_heads , - 1
94+ ) #
7995 txt_qkv_3 = ops .permute (txt_qkv_2 , (2 , 0 , 3 , 1 , 4 ))
8096 txt_q , txt_k , txt_v = txt_qkv_3
81- txt_q , txt_k = qk_norm (txt_q , txt_k , txt_v , self .txt_attn_norm_q , self .txt_attn_norm_k )
97+ txt_q , txt_k = qk_norm (
98+ txt_q , txt_k , txt_v , self .txt_attn_norm_q , self .txt_attn_norm_k
99+ )
82100
83101 # run actual attention
84102 q = torch .cat ((txt_q , img_q ), dim = 2 )
@@ -90,19 +108,22 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Te
90108
91109 # calculate the img bloks
92110 img = img + img_mod1 .gate * self .img_attn_proj (img_attn )
93- img_mlp_in = (1 + img_mod2 .scale ) * ops .layer_norm (img , None , None , eps = 1e-6 ) + img_mod2 .shift
111+ img_mlp_in = (1 + img_mod2 .scale ) * ops .layer_norm (
112+ img , None , None , eps = 1e-6
113+ ) + img_mod2 .shift
94114 img_mlp_out1 = self .img_mlp1 (img_mlp_in )
95115 img_mlp_out2 = ops .elementwise (F .gelu , img_mlp_out1 )
96116 img_mlp_out3 = self .img_mlp2 (img_mlp_out2 )
97117 img = img + img_mod2 .gate * img_mlp_out3
98118
99119 # calculate the txt bloks
100120 txt = txt + txt_mod1 .gate * self .txt_attn_proj (txt_attn )
101- txt_mlp_in = (1 + txt_mod2 .scale ) * ops .layer_norm (txt , None , None , eps = 1e-6 ) + txt_mod2 .shift
121+ txt_mlp_in = (1 + txt_mod2 .scale ) * ops .layer_norm (
122+ txt , None , None , eps = 1e-6
123+ ) + txt_mod2 .shift
102124 txt_mlp_out1 = self .txt_mlp1 (txt_mlp_in )
103125 txt_mlp_out2 = ops .elementwise (F .gelu , txt_mlp_out1 )
104126 txt_mlp_out3 = self .txt_mlp2 (txt_mlp_out2 )
105127 txt = txt + txt_mod2 .gate * txt_mlp_out3
106-
107- return img , txt
108128
129+ return img , txt
0 commit comments