Skip to content

Commit 820d773

Browse files
committed
Update attention blocks
1 parent a304738 commit 820d773

File tree

1 file changed

+36
-26
lines changed

1 file changed

+36
-26
lines changed

pvnet/models/multimodal/attention_blocks.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class AbstractAttentionBlock(nn.Module, ABC):
2020
""" Abstract attention base class definition - for all derived attention mechanisms """
2121

2222
# Forward pass
23+
# f: X → Y - abstract attention space
2324
@abstractmethod
2425
def forward(
2526
self,
@@ -31,18 +32,21 @@ def forward(
3132
pass
3233

3334

34-
# Splits input into multiple heads - scales attention scores for stability
35+
# Partitions input into h parallel heads with scaled dot-product scoring
36+
# s(x) = <q,k>/√d_k
3537
class MultiheadAttention(AbstractAttentionBlock):
3638
"""
3739
Multihead attention implementation / definition
3840
39-
Scaled dot-product attention: softmax(QKᵀ/√d_k)V
41+
Scaled dot-product attention
4042
4143
Parallel attention heads permit model to jointly 'attend' information from different representation subspaces
4244
"""
4345

44-
# Initialisation of multihead attention
45-
# Total embedding dimension and quantity of parallel attention heads
46+
# Initialisation of h parallel attention mechanisms
47+
# A_i: ℝ^d → ℝ^{d/h}
48+
# Definition of embedding dimension d ∈ ℕ
49+
# Definition of attention heads h | d mod h = 0
4650
def __init__(
4751
self,
4852
embed_dim: int,
@@ -59,7 +63,8 @@ def __init__(
5963
self.head_dim = embed_dim // num_heads
6064
self.scale = self.head_dim ** -0.5
6165

62-
# Linear projections
66+
# Linear transformations for query-key-value projections
67+
# W_Q, W_K, W_V ∈ ℝ^{d×d}
6368
self.q_proj = nn.Linear(embed_dim, embed_dim)
6469
self.k_proj = nn.Linear(embed_dim, embed_dim)
6570
self.v_proj = nn.Linear(embed_dim, embed_dim)
@@ -78,8 +83,8 @@ def forward(
7883

7984
batch_size = query.shape[0]
8085

81-
# Projection and reshape - define attention
82-
# [batch_size, seq_len, embed_dim] → [batch_size, num_heads, seq_len, head_dim]
86+
# Transform and partition input tensor
87+
# ℝ^{B×L×d} → ℝ^{B×h×L×(d/h)}
8388
q = self.q_proj(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
8489
k = self.k_proj(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
8590
v = self.v_proj(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
@@ -88,76 +93,80 @@ def forward(
8893
if mask is not None:
8994
scores = scores.masked_fill(mask == 0, float('-inf'))
9095

91-
# Attention weights and subsequent output / weighted aggregation
96+
# Compute attention distribution
97+
# α = softmax(QK^T/√d_k)
98+
# Compute weighted context
99+
# Σ_i α_i v_i
92100
attn_weights = F.softmax(scores, dim=-1)
93101
attn_weights = self.dropout(attn_weights)
94102
attn_output = torch.matmul(attn_weights, v)
95103

96-
# Reshape: [batch_size, num_heads, seq_len, head_dim] → [batch_size, seq_len, embed_dim]
104+
# Restore tensor dimensionality
105+
# ℝ^{B×h×L×(d/h)} → ℝ^{B×L×d}
97106
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)
98107

99108
return self.out_proj(attn_output)
100109

101110

102-
# Enables singular modality to 'attend' to others utilising specific attention block
111+
# Enables singular modality to 'attend' to others context utilising specific attention block
103112
class CrossModalAttention(AbstractAttentionBlock):
104113
""" CrossModal attention - interaction between multiple modalities """
105114

106-
# Initialisation of CrossModal attention
107-
# Total embedding dimension and quantity of parallel attention heads
108115
def __init__(
109116
self,
110117
embed_dim: int,
111118
num_heads: int,
112119
dropout: float = 0.1,
113120
num_modalities: int = 2
114121
):
115-
116122
super().__init__()
117123
self.num_modalities = num_modalities
118124

119-
# Parallel attention blocks for each modality
125+
# Parallel attention mechanisms for M modalities
126+
# {A_i}_{i=1}^M
120127
self.attention_blocks = nn.ModuleList([
121128
MultiheadAttention(embed_dim, num_heads, dropout=dropout)
122129
for _ in range(num_modalities)
123130
])
124131
self.dropout = nn.Dropout(dropout)
125132
self.layer_norms = nn.ModuleList([nn.LayerNorm(embed_dim) for _ in range(num_modalities)])
126133

127-
# Forward pass - CrossModal attention
128134
def forward(
129135
self,
130136
modalities: Dict[str, torch.Tensor],
131137
mask: Optional[torch.Tensor] = None
132138
) -> Dict[str, torch.Tensor]:
133-
134139
updated_modalities = {}
135140
modality_keys = list(modalities.keys())
136141

137142
for i, key in enumerate(modality_keys):
138143
query = modalities[key]
139-
140-
# Combine other modalities as key-value pairs - concatenate
141144
other_modalities = [modalities[k] for k in modality_keys if k != key]
145+
142146
if other_modalities:
147+
# Concatenate context modalities
148+
# C = [m_1; ...; m_{i-1}; m_{i+1}; ...; m_M]
143149
key_value = torch.cat(other_modalities, dim=1)
144-
145-
# Apply attention block for this modality - cross-modal
146150
attn_output = self.attention_blocks[i](query, key_value, key_value, mask)
147-
attn_output = self.dropout(attn_output)
148-
updated_modalities[key] = self.layer_norms[i](query + attn_output)
149151
else:
150-
# If no other modalities - pass through
151-
updated_modalities[key] = query
152+
# Apply self-attention
153+
# A(x,x,x) when |M| = 1
154+
attn_output = self.attention_blocks[i](query, query, query, mask)
155+
156+
attn_output = self.dropout(attn_output)
157+
updated_modalities[key] = self.layer_norms[i](query + attn_output)
152158

153159
return updated_modalities
154160

155161

156162
# Permits each element in input sequence to attend all other elements
163+
# I.e. all pair interaction via self attention
164+
# A(x_i, {x_j}_{j=1}^L)
157165
class SelfAttention(AbstractAttentionBlock):
158166
""" SelfAttention block for singular modality """
159167

160-
# Initialisation of self attention
168+
# Initialisation of h parallel self-attention mechanisms
169+
# S_i: ℝ^d → ℝ^{d/h}
161170
# Total embedding dimension and quantity of parallel attention heads
162171
def __init__(
163172
self,
@@ -178,7 +187,8 @@ def forward(
178187
mask: Optional[torch.Tensor] = None
179188
) -> torch.Tensor:
180189

181-
# Self attention
190+
# Self-attention operation
191+
# SA(x) = LayerNorm(x + A(x,x,x))
182192
attn_output = self.attention(x, x, x, mask)
183193
attn_output = self.dropout(attn_output)
184194
return self.layer_norm(x + attn_output)

0 commit comments

Comments
 (0)