11
11
from .paged_llama_attention_block import PagedLlamaAttentionBlock
12
12
13
13
14
-
15
14
def qk_norm (q , k , v , rms_q , rms_k ):
16
15
return rms_q (q ).to (v ), rms_k (k ).to (v )
17
16
@@ -25,9 +24,11 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tenso
25
24
26
25
27
26
def attention (q , k , v , pe ):
28
- q , k = apply_rope (q , k , pe ) # todo
27
+ q , k = apply_rope (q , k , pe ) # todo
29
28
30
- x = ops .scaled_dot_product_attention (q = q , k = k , v = v , a = None , is_causal = True , scale = None )
29
+ x = ops .scaled_dot_product_attention (
30
+ q = q , k = k , v = v , a = None , is_causal = True , scale = None
31
+ )
31
32
x = ops .permute (x , (0 , 2 , 1 , 3 ))
32
33
x = x .view (x .shape [0 ], x .shape [1 ], - 1 )
33
34
@@ -41,44 +42,61 @@ def __init__(self, theta, num_heads: int):
41
42
self .num_heads = num_heads
42
43
self .img_mod = ModulationLayer (theta ("img_mod" ), double = True )
43
44
self .img_attn_qkv = LinearLayer (theta ("img_attn.qkv" ))
44
- self .img_attn_norm_q = RMSNormLayer (theta ("img_attn.norm.query_norm" ), epsilon = 1e-6 )
45
- self .img_attn_norm_k = RMSNormLayer (theta ("img_attn.norm.key_norm" ), epsilon = 1e-6 )
45
+ self .img_attn_norm_q = RMSNormLayer (
46
+ theta ("img_attn.norm.query_norm" ), epsilon = 1e-6
47
+ )
48
+ self .img_attn_norm_k = RMSNormLayer (
49
+ theta ("img_attn.norm.key_norm" ), epsilon = 1e-6
50
+ )
46
51
self .img_attn_proj = LinearLayer (theta ("img_attn.proj" ))
47
52
48
53
self .img_mlp1 = LinearLayer (theta ("img_mlp.0" ))
49
54
self .img_mlp2 = LinearLayer (theta ("img_mlp.2" ))
50
55
51
56
self .txt_mod = ModulationLayer (theta ("txt_mod" ), double = True )
52
57
self .txt_attn_qkv = LinearLayer (theta ("txt_attn.qkv" ))
53
- self .txt_attn_norm_q = RMSNormLayer (theta ("txt_attn.norm.query_norm" ), epsilon = 1e-6 )
54
- self .txt_attn_norm_k = RMSNormLayer (theta ("txt_attn.norm.key_norm" ), epsilon = 1e-6 )
58
+ self .txt_attn_norm_q = RMSNormLayer (
59
+ theta ("txt_attn.norm.query_norm" ), epsilon = 1e-6
60
+ )
61
+ self .txt_attn_norm_k = RMSNormLayer (
62
+ theta ("txt_attn.norm.key_norm" ), epsilon = 1e-6
63
+ )
55
64
self .txt_attn_proj = LinearLayer (theta ("txt_attn.proj" ))
56
65
57
66
self .txt_mlp1 = LinearLayer (theta ("txt_mlp.0" ))
58
67
self .txt_mlp2 = LinearLayer (theta ("txt_mlp.2" ))
59
68
60
- def forward (self , img : Tensor , txt : Tensor , vec : Tensor , pe : Tensor ) -> tuple [Tensor , Tensor ]:
69
+ def forward (
70
+ self , img : Tensor , txt : Tensor , vec : Tensor , pe : Tensor
71
+ ) -> tuple [Tensor , Tensor ]:
61
72
img_mod1 , img_mod2 = self .img_mod (vec )
62
73
txt_mod1 , txt_mod2 = self .txt_mod (vec )
63
74
64
75
# prepare image for attention
65
76
img_modulated = ops .layer_norm (img , None , None , eps = 1e-6 )
66
77
img_modulated = (1 + img_mod1 .scale ) * img_modulated + img_mod1 .shift
67
78
img_qkv = self .img_attn_qkv (img_modulated )
68
- img_qkv_2 = img_qkv .view (img_qkv .shape [0 ], img_qkv .shape [1 ], 3 , self .num_heads , - 1 ) #
79
+ img_qkv_2 = img_qkv .view (
80
+ img_qkv .shape [0 ], img_qkv .shape [1 ], 3 , self .num_heads , - 1
81
+ ) #
69
82
img_qkv_3 = ops .permute (img_qkv_2 , (2 , 0 , 3 , 1 , 4 ))
70
83
img_q , img_k , img_v = img_qkv_3
71
- img_q , img_k = qk_norm (img_q , img_k , img_v , self .img_attn_norm_q , self .img_attn_norm_k )
72
-
84
+ img_q , img_k = qk_norm (
85
+ img_q , img_k , img_v , self .img_attn_norm_q , self .img_attn_norm_k
86
+ )
73
87
74
88
# prepare txt for attention
75
89
txt_modulated = ops .layer_norm (txt , None , None , eps = 1e-6 )
76
90
txt_modulated = (1 + txt_mod1 .scale ) * txt_modulated + txt_mod1 .shift
77
91
txt_qkv = self .txt_attn_qkv (txt_modulated )
78
- txt_qkv_2 = txt_qkv .view (txt_qkv .shape [0 ], txt_qkv .shape [1 ], 3 , self .num_heads , - 1 ) #
92
+ txt_qkv_2 = txt_qkv .view (
93
+ txt_qkv .shape [0 ], txt_qkv .shape [1 ], 3 , self .num_heads , - 1
94
+ ) #
79
95
txt_qkv_3 = ops .permute (txt_qkv_2 , (2 , 0 , 3 , 1 , 4 ))
80
96
txt_q , txt_k , txt_v = txt_qkv_3
81
- txt_q , txt_k = qk_norm (txt_q , txt_k , txt_v , self .txt_attn_norm_q , self .txt_attn_norm_k )
97
+ txt_q , txt_k = qk_norm (
98
+ txt_q , txt_k , txt_v , self .txt_attn_norm_q , self .txt_attn_norm_k
99
+ )
82
100
83
101
# run actual attention
84
102
q = torch .cat ((txt_q , img_q ), dim = 2 )
@@ -90,19 +108,22 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Te
90
108
91
109
# calculate the img bloks
92
110
img = img + img_mod1 .gate * self .img_attn_proj (img_attn )
93
- img_mlp_in = (1 + img_mod2 .scale ) * ops .layer_norm (img , None , None , eps = 1e-6 ) + img_mod2 .shift
111
+ img_mlp_in = (1 + img_mod2 .scale ) * ops .layer_norm (
112
+ img , None , None , eps = 1e-6
113
+ ) + img_mod2 .shift
94
114
img_mlp_out1 = self .img_mlp1 (img_mlp_in )
95
115
img_mlp_out2 = ops .elementwise (F .gelu , img_mlp_out1 )
96
116
img_mlp_out3 = self .img_mlp2 (img_mlp_out2 )
97
117
img = img + img_mod2 .gate * img_mlp_out3
98
118
99
119
# calculate the txt bloks
100
120
txt = txt + txt_mod1 .gate * self .txt_attn_proj (txt_attn )
101
- txt_mlp_in = (1 + txt_mod2 .scale ) * ops .layer_norm (txt , None , None , eps = 1e-6 ) + txt_mod2 .shift
121
+ txt_mlp_in = (1 + txt_mod2 .scale ) * ops .layer_norm (
122
+ txt , None , None , eps = 1e-6
123
+ ) + txt_mod2 .shift
102
124
txt_mlp_out1 = self .txt_mlp1 (txt_mlp_in )
103
125
txt_mlp_out2 = ops .elementwise (F .gelu , txt_mlp_out1 )
104
126
txt_mlp_out3 = self .txt_mlp2 (txt_mlp_out2 )
105
127
txt = txt + txt_mod2 .gate * txt_mlp_out3
106
-
107
- return img , txt
108
128
129
+ return img , txt
0 commit comments