@@ -20,6 +20,7 @@ class AbstractAttentionBlock(nn.Module, ABC):
20
20
""" Abstract attention base class definition - for all derived attention mechanisms """
21
21
22
22
# Forward pass
23
+ # f: X → Y - abstract attention space
23
24
@abstractmethod
24
25
def forward (
25
26
self ,
@@ -31,18 +32,21 @@ def forward(
31
32
pass
32
33
33
34
34
- # Splits input into multiple heads - scales attention scores for stability
35
+ # Partitions input into h parallel heads with scaled dot-product scoring
36
+ # s(x) = <q,k>/√d_k
35
37
class MultiheadAttention (AbstractAttentionBlock ):
36
38
"""
37
39
Multihead attention implementation / definition
38
40
39
- Scaled dot-product attention: softmax(QKᵀ/√d_k)V
41
+ Scaled dot-product attention
40
42
41
43
Parallel attention heads permit model to jointly 'attend' information from different representation subspaces
42
44
"""
43
45
44
- # Initialisation of multihead attention
45
- # Total embedding dimension and quantity of parallel attention heads
46
+ # Initialisation of h parallel attention mechanisms
47
+ # A_i: ℝ^d → ℝ^{d/h}
48
+ # Definition of embedding dimension d ∈ ℕ
49
+ # Definition of attention heads h | d mod h = 0
46
50
def __init__ (
47
51
self ,
48
52
embed_dim : int ,
@@ -59,7 +63,8 @@ def __init__(
59
63
self .head_dim = embed_dim // num_heads
60
64
self .scale = self .head_dim ** - 0.5
61
65
62
- # Linear projections
66
+ # Linear transformations for query-key-value projections
67
+ # W_Q, W_K, W_V ∈ ℝ^{d×d}
63
68
self .q_proj = nn .Linear (embed_dim , embed_dim )
64
69
self .k_proj = nn .Linear (embed_dim , embed_dim )
65
70
self .v_proj = nn .Linear (embed_dim , embed_dim )
@@ -78,8 +83,8 @@ def forward(
78
83
79
84
batch_size = query .shape [0 ]
80
85
81
- # Projection and reshape - define attention
82
- # [batch_size, seq_len, embed_dim] → [batch_size, num_heads, seq_len, head_dim]
86
+ # Transform and partition input tensor
87
+ # ℝ^{B×L×d} → ℝ^{B×h×L×(d/h)}
83
88
q = self .q_proj (query ).view (batch_size , - 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
84
89
k = self .k_proj (key ).view (batch_size , - 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
85
90
v = self .v_proj (value ).view (batch_size , - 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
@@ -88,76 +93,80 @@ def forward(
88
93
if mask is not None :
89
94
scores = scores .masked_fill (mask == 0 , float ('-inf' ))
90
95
91
- # Attention weights and subsequent output / weighted aggregation
96
+ # Compute attention distribution
97
+ # α = softmax(QK^T/√d_k)
98
+ # Compute weighted context
99
+ # Σ_i α_i v_i
92
100
attn_weights = F .softmax (scores , dim = - 1 )
93
101
attn_weights = self .dropout (attn_weights )
94
102
attn_output = torch .matmul (attn_weights , v )
95
103
96
- # Reshape: [batch_size, num_heads, seq_len, head_dim] → [batch_size, seq_len, embed_dim]
104
+ # Restore tensor dimensionality
105
+ # ℝ^{B×h×L×(d/h)} → ℝ^{B×L×d}
97
106
attn_output = attn_output .transpose (1 , 2 ).contiguous ().view (batch_size , - 1 , self .embed_dim )
98
107
99
108
return self .out_proj (attn_output )
100
109
101
110
102
- # Enables singular modality to 'attend' to others utilising specific attention block
111
+ # Enables singular modality to 'attend' to others context utilising specific attention block
103
112
class CrossModalAttention (AbstractAttentionBlock ):
104
113
""" CrossModal attention - interaction between multiple modalities """
105
114
106
- # Initialisation of CrossModal attention
107
- # Total embedding dimension and quantity of parallel attention heads
108
115
def __init__ (
109
116
self ,
110
117
embed_dim : int ,
111
118
num_heads : int ,
112
119
dropout : float = 0.1 ,
113
120
num_modalities : int = 2
114
121
):
115
-
116
122
super ().__init__ ()
117
123
self .num_modalities = num_modalities
118
124
119
- # Parallel attention blocks for each modality
125
+ # Parallel attention mechanisms for M modalities
126
+ # {A_i}_{i=1}^M
120
127
self .attention_blocks = nn .ModuleList ([
121
128
MultiheadAttention (embed_dim , num_heads , dropout = dropout )
122
129
for _ in range (num_modalities )
123
130
])
124
131
self .dropout = nn .Dropout (dropout )
125
132
self .layer_norms = nn .ModuleList ([nn .LayerNorm (embed_dim ) for _ in range (num_modalities )])
126
133
127
- # Forward pass - CrossModal attention
128
134
def forward (
129
135
self ,
130
136
modalities : Dict [str , torch .Tensor ],
131
137
mask : Optional [torch .Tensor ] = None
132
138
) -> Dict [str , torch .Tensor ]:
133
-
134
139
updated_modalities = {}
135
140
modality_keys = list (modalities .keys ())
136
141
137
142
for i , key in enumerate (modality_keys ):
138
143
query = modalities [key ]
139
-
140
- # Combine other modalities as key-value pairs - concatenate
141
144
other_modalities = [modalities [k ] for k in modality_keys if k != key ]
145
+
142
146
if other_modalities :
147
+ # Concatenate context modalities
148
+ # C = [m_1; ...; m_{i-1}; m_{i+1}; ...; m_M]
143
149
key_value = torch .cat (other_modalities , dim = 1 )
144
-
145
- # Apply attention block for this modality - cross-modal
146
150
attn_output = self .attention_blocks [i ](query , key_value , key_value , mask )
147
- attn_output = self .dropout (attn_output )
148
- updated_modalities [key ] = self .layer_norms [i ](query + attn_output )
149
151
else :
150
- # If no other modalities - pass through
151
- updated_modalities [key ] = query
152
+ # Apply self-attention
153
+ # A(x,x,x) when |M| = 1
154
+ attn_output = self .attention_blocks [i ](query , query , query , mask )
155
+
156
+ attn_output = self .dropout (attn_output )
157
+ updated_modalities [key ] = self .layer_norms [i ](query + attn_output )
152
158
153
159
return updated_modalities
154
160
155
161
156
162
# Permits each element in input sequence to attend all other elements
163
+ # I.e. all pair interaction via self attention
164
+ # A(x_i, {x_j}_{j=1}^L)
157
165
class SelfAttention (AbstractAttentionBlock ):
158
166
""" SelfAttention block for singular modality """
159
167
160
- # Initialisation of self attention
168
+ # Initialisation of h parallel self-attention mechanisms
169
+ # S_i: ℝ^d → ℝ^{d/h}
161
170
# Total embedding dimension and quantity of parallel attention heads
162
171
def __init__ (
163
172
self ,
@@ -178,7 +187,8 @@ def forward(
178
187
mask : Optional [torch .Tensor ] = None
179
188
) -> torch .Tensor :
180
189
181
- # Self attention
190
+ # Self-attention operation
191
+ # SA(x) = LayerNorm(x + A(x,x,x))
182
192
attn_output = self .attention (x , x , x , mask )
183
193
attn_output = self .dropout (attn_output )
184
194
return self .layer_norm (x + attn_output )
0 commit comments