Skip to content

Commit

Permalink
Update attention blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
felix-e-h-p committed Jan 18, 2025
1 parent a304738 commit 820d773
Showing 1 changed file with 36 additions and 26 deletions.
62 changes: 36 additions & 26 deletions pvnet/models/multimodal/attention_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class AbstractAttentionBlock(nn.Module, ABC):
""" Abstract attention base class definition - for all derived attention mechanisms """

# Forward pass
# f: X → Y - abstract attention space
@abstractmethod
def forward(
self,
Expand All @@ -31,18 +32,21 @@ def forward(
pass


# Splits input into multiple heads - scales attention scores for stability
# Partitions input into h parallel heads with scaled dot-product scoring
# s(x) = <q,k>/√d_k
class MultiheadAttention(AbstractAttentionBlock):
"""
Multihead attention implementation / definition
Scaled dot-product attention: softmax(QKᵀ/√d_k)V
Scaled dot-product attention
Parallel attention heads permit model to jointly 'attend' information from different representation subspaces
"""

# Initialisation of multihead attention
# Total embedding dimension and quantity of parallel attention heads
# Initialisation of h parallel attention mechanisms
# A_i: ℝ^d → ℝ^{d/h}
# Definition of embedding dimension d ∈ ℕ
# Definition of attention heads h | d mod h = 0
def __init__(
self,
embed_dim: int,
Expand All @@ -59,7 +63,8 @@ def __init__(
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5

# Linear projections
# Linear transformations for query-key-value projections
# W_Q, W_K, W_V ∈ ℝ^{d×d}
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
Expand All @@ -78,8 +83,8 @@ def forward(

batch_size = query.shape[0]

# Projection and reshape - define attention
# [batch_size, seq_len, embed_dim] → [batch_size, num_heads, seq_len, head_dim]
# Transform and partition input tensor
# ℝ^{B×L×d} → ℝ^{B×h×L×(d/h)}
q = self.q_proj(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
Expand All @@ -88,76 +93,80 @@ def forward(
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))

# Attention weights and subsequent output / weighted aggregation
# Compute attention distribution
# α = softmax(QK^T/√d_k)
# Compute weighted context
# Σ_i α_i v_i
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
attn_output = torch.matmul(attn_weights, v)

# Reshape: [batch_size, num_heads, seq_len, head_dim] → [batch_size, seq_len, embed_dim]
# Restore tensor dimensionality
# ℝ^{B×h×L×(d/h)} → ℝ^{B×L×d}
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)

return self.out_proj(attn_output)


# Enables singular modality to 'attend' to others utilising specific attention block
# Enables singular modality to 'attend' to others context utilising specific attention block
class CrossModalAttention(AbstractAttentionBlock):
""" CrossModal attention - interaction between multiple modalities """

# Initialisation of CrossModal attention
# Total embedding dimension and quantity of parallel attention heads
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.1,
num_modalities: int = 2
):

super().__init__()
self.num_modalities = num_modalities

# Parallel attention blocks for each modality
# Parallel attention mechanisms for M modalities
# {A_i}_{i=1}^M
self.attention_blocks = nn.ModuleList([
MultiheadAttention(embed_dim, num_heads, dropout=dropout)
for _ in range(num_modalities)
])
self.dropout = nn.Dropout(dropout)
self.layer_norms = nn.ModuleList([nn.LayerNorm(embed_dim) for _ in range(num_modalities)])

# Forward pass - CrossModal attention
def forward(
self,
modalities: Dict[str, torch.Tensor],
mask: Optional[torch.Tensor] = None
) -> Dict[str, torch.Tensor]:

updated_modalities = {}
modality_keys = list(modalities.keys())

for i, key in enumerate(modality_keys):
query = modalities[key]

# Combine other modalities as key-value pairs - concatenate
other_modalities = [modalities[k] for k in modality_keys if k != key]

if other_modalities:
# Concatenate context modalities
# C = [m_1; ...; m_{i-1}; m_{i+1}; ...; m_M]
key_value = torch.cat(other_modalities, dim=1)

# Apply attention block for this modality - cross-modal
attn_output = self.attention_blocks[i](query, key_value, key_value, mask)
attn_output = self.dropout(attn_output)
updated_modalities[key] = self.layer_norms[i](query + attn_output)
else:
# If no other modalities - pass through
updated_modalities[key] = query
# Apply self-attention
# A(x,x,x) when |M| = 1
attn_output = self.attention_blocks[i](query, query, query, mask)

attn_output = self.dropout(attn_output)
updated_modalities[key] = self.layer_norms[i](query + attn_output)

return updated_modalities


# Permits each element in input sequence to attend all other elements
# I.e. all pair interaction via self attention
# A(x_i, {x_j}_{j=1}^L)
class SelfAttention(AbstractAttentionBlock):
""" SelfAttention block for singular modality """

# Initialisation of self attention
# Initialisation of h parallel self-attention mechanisms
# S_i: ℝ^d → ℝ^{d/h}
# Total embedding dimension and quantity of parallel attention heads
def __init__(
self,
Expand All @@ -178,7 +187,8 @@ def forward(
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:

# Self attention
# Self-attention operation
# SA(x) = LayerNorm(x + A(x,x,x))
attn_output = self.attention(x, x, x, mask)
attn_output = self.dropout(attn_output)
return self.layer_norm(x + attn_output)
Expand Down

0 comments on commit 820d773

Please sign in to comment.