diff --git a/pvnet/models/multimodal/attention_blocks.py b/pvnet/models/multimodal/attention_blocks.py index c7a0a997..a1ac3e83 100644 --- a/pvnet/models/multimodal/attention_blocks.py +++ b/pvnet/models/multimodal/attention_blocks.py @@ -20,6 +20,7 @@ class AbstractAttentionBlock(nn.Module, ABC): """ Abstract attention base class definition - for all derived attention mechanisms """ # Forward pass + # f: X → Y - abstract attention space @abstractmethod def forward( self, @@ -31,18 +32,21 @@ def forward( pass -# Splits input into multiple heads - scales attention scores for stability +# Partitions input into h parallel heads with scaled dot-product scoring +# s(x) = /√d_k class MultiheadAttention(AbstractAttentionBlock): """ Multihead attention implementation / definition - Scaled dot-product attention: softmax(QKᵀ/√d_k)V + Scaled dot-product attention Parallel attention heads permit model to jointly 'attend' information from different representation subspaces """ - # Initialisation of multihead attention - # Total embedding dimension and quantity of parallel attention heads + # Initialisation of h parallel attention mechanisms + # A_i: ℝ^d → ℝ^{d/h} + # Definition of embedding dimension d ∈ ℕ + # Definition of attention heads h | d mod h = 0 def __init__( self, embed_dim: int, @@ -59,7 +63,8 @@ def __init__( self.head_dim = embed_dim // num_heads self.scale = self.head_dim ** -0.5 - # Linear projections + # Linear transformations for query-key-value projections + # W_Q, W_K, W_V ∈ ℝ^{d×d} self.q_proj = nn.Linear(embed_dim, embed_dim) self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) @@ -78,8 +83,8 @@ def forward( batch_size = query.shape[0] - # Projection and reshape - define attention - # [batch_size, seq_len, embed_dim] → [batch_size, num_heads, seq_len, head_dim] + # Transform and partition input tensor + # ℝ^{B×L×d} → ℝ^{B×h×L×(d/h)} q = self.q_proj(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) v = self.v_proj(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) @@ -88,23 +93,25 @@ def forward( if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf')) - # Attention weights and subsequent output / weighted aggregation + # Compute attention distribution + # α = softmax(QK^T/√d_k) + # Compute weighted context + # Σ_i α_i v_i attn_weights = F.softmax(scores, dim=-1) attn_weights = self.dropout(attn_weights) attn_output = torch.matmul(attn_weights, v) - # Reshape: [batch_size, num_heads, seq_len, head_dim] → [batch_size, seq_len, embed_dim] + # Restore tensor dimensionality + # ℝ^{B×h×L×(d/h)} → ℝ^{B×L×d} attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim) return self.out_proj(attn_output) -# Enables singular modality to 'attend' to others utilising specific attention block +# Enables singular modality to 'attend' to others context utilising specific attention block class CrossModalAttention(AbstractAttentionBlock): """ CrossModal attention - interaction between multiple modalities """ - # Initialisation of CrossModal attention - # Total embedding dimension and quantity of parallel attention heads def __init__( self, embed_dim: int, @@ -112,11 +119,11 @@ def __init__( dropout: float = 0.1, num_modalities: int = 2 ): - super().__init__() self.num_modalities = num_modalities - # Parallel attention blocks for each modality + # Parallel attention mechanisms for M modalities + # {A_i}_{i=1}^M self.attention_blocks = nn.ModuleList([ MultiheadAttention(embed_dim, num_heads, dropout=dropout) for _ in range(num_modalities) @@ -124,40 +131,42 @@ def __init__( self.dropout = nn.Dropout(dropout) self.layer_norms = nn.ModuleList([nn.LayerNorm(embed_dim) for _ in range(num_modalities)]) - # Forward pass - CrossModal attention def forward( self, modalities: Dict[str, torch.Tensor], mask: Optional[torch.Tensor] = None ) -> Dict[str, torch.Tensor]: - updated_modalities = {} modality_keys = list(modalities.keys()) for i, key in enumerate(modality_keys): query = modalities[key] - - # Combine other modalities as key-value pairs - concatenate other_modalities = [modalities[k] for k in modality_keys if k != key] + if other_modalities: + # Concatenate context modalities + # C = [m_1; ...; m_{i-1}; m_{i+1}; ...; m_M] key_value = torch.cat(other_modalities, dim=1) - - # Apply attention block for this modality - cross-modal attn_output = self.attention_blocks[i](query, key_value, key_value, mask) - attn_output = self.dropout(attn_output) - updated_modalities[key] = self.layer_norms[i](query + attn_output) else: - # If no other modalities - pass through - updated_modalities[key] = query + # Apply self-attention + # A(x,x,x) when |M| = 1 + attn_output = self.attention_blocks[i](query, query, query, mask) + + attn_output = self.dropout(attn_output) + updated_modalities[key] = self.layer_norms[i](query + attn_output) return updated_modalities # Permits each element in input sequence to attend all other elements +# I.e. all pair interaction via self attention +# A(x_i, {x_j}_{j=1}^L) class SelfAttention(AbstractAttentionBlock): """ SelfAttention block for singular modality """ - # Initialisation of self attention + # Initialisation of h parallel self-attention mechanisms + # S_i: ℝ^d → ℝ^{d/h} # Total embedding dimension and quantity of parallel attention heads def __init__( self, @@ -178,7 +187,8 @@ def forward( mask: Optional[torch.Tensor] = None ) -> torch.Tensor: - # Self attention + # Self-attention operation + # SA(x) = LayerNorm(x + A(x,x,x)) attn_output = self.attention(x, x, x, mask) attn_output = self.dropout(attn_output) return self.layer_norm(x + attn_output)