Skip to content

Commit

Permalink
Updates
Browse files Browse the repository at this point in the history
  • Loading branch information
felix-e-h-p committed Feb 14, 2025
1 parent bb013bd commit 133f2da
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 17 deletions.
46 changes: 40 additions & 6 deletions pvnet/models/multimodal/attention_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
from abc import ABC, abstractmethod
from typing import Dict, Optional
from torch import nn
import logging


logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('attention_blocks')


class AbstractAttentionBlock(nn.Module, ABC):
Expand Down Expand Up @@ -53,16 +58,22 @@ def __init__(
num_heads: int,
dropout: float = 0.1
):

super().__init__()

logger.info(f"Initialising MultiheadAttention with embed_dim={embed_dim}, num_heads={num_heads}")

if embed_dim % num_heads != 0:
raise ValueError("embed_dim not divisible by num_heads")
error_msg = f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})"
logger.error(error_msg)
raise ValueError(error_msg)

self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5

logger.debug(f"Head dimension: {self.head_dim}, Scale factor: {self.scale}")

# Linear transformations for query-key-value projections
# W_Q, W_K, W_V ∈ ℝ^{d×d}
self.q_proj = nn.Linear(embed_dim, embed_dim)
Expand All @@ -82,29 +93,39 @@ def forward(
) -> torch.Tensor:

batch_size = query.shape[0]

logger.debug(f"Input shapes - Query: {query.shape}, Key: {key.shape}, Value: {value.shape}")

# Transform and partition input tensor
# ℝ^{B×L×d} → ℝ^{B×h×L×(d/h)}
q = self.q_proj(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

logger.debug(f"After projection shapes - Q: {q.shape}, K: {k.shape}, V: {v.shape}")

scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

logger.debug(f"Attention scores shape: {scores.shape}")

if mask is not None:
logger.debug(f"Applying mask with shape: {mask.shape}")
scores = scores.masked_fill(mask == 0, float('-inf'))

# Compute attention distribution
# α = softmax(QK^T/√d_k)
# Compute weighted context
# Σ_i α_i v_i
attn_weights = F.softmax(scores, dim=-1)
logger.debug(f"Attention weights shape: {attn_weights.shape}")

attn_weights = self.dropout(attn_weights)
attn_output = torch.matmul(attn_weights, v)
logger.debug(f"Attention output shape (before reshape): {attn_output.shape}")

# Restore tensor dimensionality
# ℝ^{B×h×L×(d/h)} → ℝ^{B×L×d}
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)

logger.debug(f"Final output shape: {attn_output.shape}")

return self.out_proj(attn_output)


Expand All @@ -120,6 +141,8 @@ def __init__(
num_modalities: int = 2
):
super().__init__()
logger.info(f"Initialising CrossModalAttention with {num_modalities} modalities")

self.num_modalities = num_modalities

# Parallel attention mechanisms for M modalities
Expand All @@ -136,25 +159,33 @@ def forward(
modalities: Dict[str, torch.Tensor],
mask: Optional[torch.Tensor] = None
) -> Dict[str, torch.Tensor]:
logger.info("Processing CrossModalAttention forward pass")
logger.debug(f"Input modalities: {[f'{k}: {v.shape}' for k, v in modalities.items()]}")

updated_modalities = {}
modality_keys = list(modalities.keys())

for i, key in enumerate(modality_keys):
logger.debug(f"Processing modality: {key}")
query = modalities[key]
other_modalities = [modalities[k] for k in modality_keys if k != key]

if other_modalities:
# Concatenate context modalities
# C = [m_1; ...; m_{i-1}; m_{i+1}; ...; m_M]
logger.debug(f"Concatenating {len(other_modalities)} other modalities")
key_value = torch.cat(other_modalities, dim=1)
logger.debug(f"Concatenated key_value shape: {key_value.shape}")
attn_output = self.attention_blocks[i](query, key_value, key_value, mask)
else:
# Apply self-attention
# A(x,x,x) when |M| = 1
logger.debug("No other modalities found, applying self-attention")
attn_output = self.attention_blocks[i](query, query, query, mask)

attn_output = self.dropout(attn_output)
updated_modalities[key] = self.layer_norms[i](query + attn_output)
logger.debug(f"Updated modality {key} shape: {updated_modalities[key].shape}")

return updated_modalities

Expand All @@ -174,8 +205,9 @@ def __init__(
num_heads: int,
dropout: float = 0.1
):

super().__init__()
logger.info(f"Initialising SelfAttention with embed_dim={embed_dim}, num_heads={num_heads}")

self.attention = MultiheadAttention(embed_dim, num_heads, dropout)
self.layer_norm = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
Expand All @@ -186,10 +218,12 @@ def forward(
x: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
logger.debug(f"SelfAttention input shape: {x.shape}")

# Self-attention operation
# SA(x) = LayerNorm(x + A(x,x,x))
attn_output = self.attention(x, x, x, mask)
logger.debug(f"SelfAttention output shape (pre-dropout): {attn_output.shape}")
attn_output = self.dropout(attn_output)
return self.layer_norm(x + attn_output)

Loading

0 comments on commit 133f2da

Please sign in to comment.