14
14
from abc import ABC , abstractmethod
15
15
from typing import Dict , Optional
16
16
from torch import nn
17
+ import logging
18
+
19
+
20
+ logging .basicConfig (level = logging .DEBUG , format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' )
21
+ logger = logging .getLogger ('attention_blocks' )
17
22
18
23
19
24
class AbstractAttentionBlock (nn .Module , ABC ):
@@ -53,16 +58,22 @@ def __init__(
53
58
num_heads : int ,
54
59
dropout : float = 0.1
55
60
):
56
-
57
61
super ().__init__ ()
62
+
63
+ logger .info (f"Initialising MultiheadAttention with embed_dim={ embed_dim } , num_heads={ num_heads } " )
64
+
58
65
if embed_dim % num_heads != 0 :
59
- raise ValueError ("embed_dim not divisible by num_heads" )
66
+ error_msg = f"embed_dim ({ embed_dim } ) must be divisible by num_heads ({ num_heads } )"
67
+ logger .error (error_msg )
68
+ raise ValueError (error_msg )
60
69
61
70
self .embed_dim = embed_dim
62
71
self .num_heads = num_heads
63
72
self .head_dim = embed_dim // num_heads
64
73
self .scale = self .head_dim ** - 0.5
65
74
75
+ logger .debug (f"Head dimension: { self .head_dim } , Scale factor: { self .scale } " )
76
+
66
77
# Linear transformations for query-key-value projections
67
78
# W_Q, W_K, W_V ∈ ℝ^{d×d}
68
79
self .q_proj = nn .Linear (embed_dim , embed_dim )
@@ -82,29 +93,39 @@ def forward(
82
93
) -> torch .Tensor :
83
94
84
95
batch_size = query .shape [0 ]
85
-
96
+ logger .debug (f"Input shapes - Query: { query .shape } , Key: { key .shape } , Value: { value .shape } " )
97
+
86
98
# Transform and partition input tensor
87
99
# ℝ^{B×L×d} → ℝ^{B×h×L×(d/h)}
88
100
q = self .q_proj (query ).view (batch_size , - 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
89
101
k = self .k_proj (key ).view (batch_size , - 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
90
102
v = self .v_proj (value ).view (batch_size , - 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
103
+
104
+ logger .debug (f"After projection shapes - Q: { q .shape } , K: { k .shape } , V: { v .shape } " )
105
+
91
106
scores = torch .matmul (q , k .transpose (- 2 , - 1 )) * self .scale
92
-
107
+ logger .debug (f"Attention scores shape: { scores .shape } " )
108
+
93
109
if mask is not None :
110
+ logger .debug (f"Applying mask with shape: { mask .shape } " )
94
111
scores = scores .masked_fill (mask == 0 , float ('-inf' ))
95
112
96
113
# Compute attention distribution
97
114
# α = softmax(QK^T/√d_k)
98
115
# Compute weighted context
99
116
# Σ_i α_i v_i
100
117
attn_weights = F .softmax (scores , dim = - 1 )
118
+ logger .debug (f"Attention weights shape: { attn_weights .shape } " )
119
+
101
120
attn_weights = self .dropout (attn_weights )
102
121
attn_output = torch .matmul (attn_weights , v )
122
+ logger .debug (f"Attention output shape (before reshape): { attn_output .shape } " )
103
123
104
124
# Restore tensor dimensionality
105
125
# ℝ^{B×h×L×(d/h)} → ℝ^{B×L×d}
106
126
attn_output = attn_output .transpose (1 , 2 ).contiguous ().view (batch_size , - 1 , self .embed_dim )
107
-
127
+ logger .debug (f"Final output shape: { attn_output .shape } " )
128
+
108
129
return self .out_proj (attn_output )
109
130
110
131
@@ -120,6 +141,8 @@ def __init__(
120
141
num_modalities : int = 2
121
142
):
122
143
super ().__init__ ()
144
+ logger .info (f"Initialising CrossModalAttention with { num_modalities } modalities" )
145
+
123
146
self .num_modalities = num_modalities
124
147
125
148
# Parallel attention mechanisms for M modalities
@@ -136,25 +159,33 @@ def forward(
136
159
modalities : Dict [str , torch .Tensor ],
137
160
mask : Optional [torch .Tensor ] = None
138
161
) -> Dict [str , torch .Tensor ]:
162
+ logger .info ("Processing CrossModalAttention forward pass" )
163
+ logger .debug (f"Input modalities: { [f'{ k } : { v .shape } ' for k , v in modalities .items ()]} " )
164
+
139
165
updated_modalities = {}
140
166
modality_keys = list (modalities .keys ())
141
167
142
168
for i , key in enumerate (modality_keys ):
169
+ logger .debug (f"Processing modality: { key } " )
143
170
query = modalities [key ]
144
171
other_modalities = [modalities [k ] for k in modality_keys if k != key ]
145
172
146
173
if other_modalities :
147
174
# Concatenate context modalities
148
175
# C = [m_1; ...; m_{i-1}; m_{i+1}; ...; m_M]
176
+ logger .debug (f"Concatenating { len (other_modalities )} other modalities" )
149
177
key_value = torch .cat (other_modalities , dim = 1 )
178
+ logger .debug (f"Concatenated key_value shape: { key_value .shape } " )
150
179
attn_output = self .attention_blocks [i ](query , key_value , key_value , mask )
151
180
else :
152
181
# Apply self-attention
153
182
# A(x,x,x) when |M| = 1
183
+ logger .debug ("No other modalities found, applying self-attention" )
154
184
attn_output = self .attention_blocks [i ](query , query , query , mask )
155
185
156
186
attn_output = self .dropout (attn_output )
157
187
updated_modalities [key ] = self .layer_norms [i ](query + attn_output )
188
+ logger .debug (f"Updated modality { key } shape: { updated_modalities [key ].shape } " )
158
189
159
190
return updated_modalities
160
191
@@ -174,8 +205,9 @@ def __init__(
174
205
num_heads : int ,
175
206
dropout : float = 0.1
176
207
):
177
-
178
208
super ().__init__ ()
209
+ logger .info (f"Initialising SelfAttention with embed_dim={ embed_dim } , num_heads={ num_heads } " )
210
+
179
211
self .attention = MultiheadAttention (embed_dim , num_heads , dropout )
180
212
self .layer_norm = nn .LayerNorm (embed_dim )
181
213
self .dropout = nn .Dropout (dropout )
@@ -186,10 +218,12 @@ def forward(
186
218
x : torch .Tensor ,
187
219
mask : Optional [torch .Tensor ] = None
188
220
) -> torch .Tensor :
221
+ logger .debug (f"SelfAttention input shape: { x .shape } " )
189
222
190
223
# Self-attention operation
191
224
# SA(x) = LayerNorm(x + A(x,x,x))
192
225
attn_output = self .attention (x , x , x , mask )
226
+ logger .debug (f"SelfAttention output shape (pre-dropout): { attn_output .shape } " )
193
227
attn_output = self .dropout (attn_output )
194
228
return self .layer_norm (x + attn_output )
195
229
0 commit comments