Skip to content

Commit 133f2da

Browse files
committed
Updates
1 parent bb013bd commit 133f2da

File tree

3 files changed

+160
-17
lines changed

3 files changed

+160
-17
lines changed

pvnet/models/multimodal/attention_blocks.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
from abc import ABC, abstractmethod
1515
from typing import Dict, Optional
1616
from torch import nn
17+
import logging
18+
19+
20+
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
21+
logger = logging.getLogger('attention_blocks')
1722

1823

1924
class AbstractAttentionBlock(nn.Module, ABC):
@@ -53,16 +58,22 @@ def __init__(
5358
num_heads: int,
5459
dropout: float = 0.1
5560
):
56-
5761
super().__init__()
62+
63+
logger.info(f"Initialising MultiheadAttention with embed_dim={embed_dim}, num_heads={num_heads}")
64+
5865
if embed_dim % num_heads != 0:
59-
raise ValueError("embed_dim not divisible by num_heads")
66+
error_msg = f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})"
67+
logger.error(error_msg)
68+
raise ValueError(error_msg)
6069

6170
self.embed_dim = embed_dim
6271
self.num_heads = num_heads
6372
self.head_dim = embed_dim // num_heads
6473
self.scale = self.head_dim ** -0.5
6574

75+
logger.debug(f"Head dimension: {self.head_dim}, Scale factor: {self.scale}")
76+
6677
# Linear transformations for query-key-value projections
6778
# W_Q, W_K, W_V ∈ ℝ^{d×d}
6879
self.q_proj = nn.Linear(embed_dim, embed_dim)
@@ -82,29 +93,39 @@ def forward(
8293
) -> torch.Tensor:
8394

8495
batch_size = query.shape[0]
85-
96+
logger.debug(f"Input shapes - Query: {query.shape}, Key: {key.shape}, Value: {value.shape}")
97+
8698
# Transform and partition input tensor
8799
# ℝ^{B×L×d} → ℝ^{B×h×L×(d/h)}
88100
q = self.q_proj(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
89101
k = self.k_proj(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
90102
v = self.v_proj(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
103+
104+
logger.debug(f"After projection shapes - Q: {q.shape}, K: {k.shape}, V: {v.shape}")
105+
91106
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
92-
107+
logger.debug(f"Attention scores shape: {scores.shape}")
108+
93109
if mask is not None:
110+
logger.debug(f"Applying mask with shape: {mask.shape}")
94111
scores = scores.masked_fill(mask == 0, float('-inf'))
95112

96113
# Compute attention distribution
97114
# α = softmax(QK^T/√d_k)
98115
# Compute weighted context
99116
# Σ_i α_i v_i
100117
attn_weights = F.softmax(scores, dim=-1)
118+
logger.debug(f"Attention weights shape: {attn_weights.shape}")
119+
101120
attn_weights = self.dropout(attn_weights)
102121
attn_output = torch.matmul(attn_weights, v)
122+
logger.debug(f"Attention output shape (before reshape): {attn_output.shape}")
103123

104124
# Restore tensor dimensionality
105125
# ℝ^{B×h×L×(d/h)} → ℝ^{B×L×d}
106126
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)
107-
127+
logger.debug(f"Final output shape: {attn_output.shape}")
128+
108129
return self.out_proj(attn_output)
109130

110131

@@ -120,6 +141,8 @@ def __init__(
120141
num_modalities: int = 2
121142
):
122143
super().__init__()
144+
logger.info(f"Initialising CrossModalAttention with {num_modalities} modalities")
145+
123146
self.num_modalities = num_modalities
124147

125148
# Parallel attention mechanisms for M modalities
@@ -136,25 +159,33 @@ def forward(
136159
modalities: Dict[str, torch.Tensor],
137160
mask: Optional[torch.Tensor] = None
138161
) -> Dict[str, torch.Tensor]:
162+
logger.info("Processing CrossModalAttention forward pass")
163+
logger.debug(f"Input modalities: {[f'{k}: {v.shape}' for k, v in modalities.items()]}")
164+
139165
updated_modalities = {}
140166
modality_keys = list(modalities.keys())
141167

142168
for i, key in enumerate(modality_keys):
169+
logger.debug(f"Processing modality: {key}")
143170
query = modalities[key]
144171
other_modalities = [modalities[k] for k in modality_keys if k != key]
145172

146173
if other_modalities:
147174
# Concatenate context modalities
148175
# C = [m_1; ...; m_{i-1}; m_{i+1}; ...; m_M]
176+
logger.debug(f"Concatenating {len(other_modalities)} other modalities")
149177
key_value = torch.cat(other_modalities, dim=1)
178+
logger.debug(f"Concatenated key_value shape: {key_value.shape}")
150179
attn_output = self.attention_blocks[i](query, key_value, key_value, mask)
151180
else:
152181
# Apply self-attention
153182
# A(x,x,x) when |M| = 1
183+
logger.debug("No other modalities found, applying self-attention")
154184
attn_output = self.attention_blocks[i](query, query, query, mask)
155185

156186
attn_output = self.dropout(attn_output)
157187
updated_modalities[key] = self.layer_norms[i](query + attn_output)
188+
logger.debug(f"Updated modality {key} shape: {updated_modalities[key].shape}")
158189

159190
return updated_modalities
160191

@@ -174,8 +205,9 @@ def __init__(
174205
num_heads: int,
175206
dropout: float = 0.1
176207
):
177-
178208
super().__init__()
209+
logger.info(f"Initialising SelfAttention with embed_dim={embed_dim}, num_heads={num_heads}")
210+
179211
self.attention = MultiheadAttention(embed_dim, num_heads, dropout)
180212
self.layer_norm = nn.LayerNorm(embed_dim)
181213
self.dropout = nn.Dropout(dropout)
@@ -186,10 +218,12 @@ def forward(
186218
x: torch.Tensor,
187219
mask: Optional[torch.Tensor] = None
188220
) -> torch.Tensor:
221+
logger.debug(f"SelfAttention input shape: {x.shape}")
189222

190223
# Self-attention operation
191224
# SA(x) = LayerNorm(x + A(x,x,x))
192225
attn_output = self.attention(x, x, x, mask)
226+
logger.debug(f"SelfAttention output shape (pre-dropout): {attn_output.shape}")
193227
attn_output = self.dropout(attn_output)
194228
return self.layer_norm(x + attn_output)
195229

0 commit comments

Comments
 (0)