1414from abc import ABC , abstractmethod
1515from typing import Dict , Optional
1616from torch import nn
17+ import logging
18+
19+
20+ logging .basicConfig (level = logging .DEBUG , format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' )
21+ logger = logging .getLogger ('attention_blocks' )
1722
1823
1924class AbstractAttentionBlock (nn .Module , ABC ):
@@ -53,16 +58,22 @@ def __init__(
5358 num_heads : int ,
5459 dropout : float = 0.1
5560 ):
56-
5761 super ().__init__ ()
62+
63+ logger .info (f"Initialising MultiheadAttention with embed_dim={ embed_dim } , num_heads={ num_heads } " )
64+
5865 if embed_dim % num_heads != 0 :
59- raise ValueError ("embed_dim not divisible by num_heads" )
66+ error_msg = f"embed_dim ({ embed_dim } ) must be divisible by num_heads ({ num_heads } )"
67+ logger .error (error_msg )
68+ raise ValueError (error_msg )
6069
6170 self .embed_dim = embed_dim
6271 self .num_heads = num_heads
6372 self .head_dim = embed_dim // num_heads
6473 self .scale = self .head_dim ** - 0.5
6574
75+ logger .debug (f"Head dimension: { self .head_dim } , Scale factor: { self .scale } " )
76+
6677 # Linear transformations for query-key-value projections
6778 # W_Q, W_K, W_V ∈ ℝ^{d×d}
6879 self .q_proj = nn .Linear (embed_dim , embed_dim )
@@ -82,29 +93,39 @@ def forward(
8293 ) -> torch .Tensor :
8394
8495 batch_size = query .shape [0 ]
85-
96+ logger .debug (f"Input shapes - Query: { query .shape } , Key: { key .shape } , Value: { value .shape } " )
97+
8698 # Transform and partition input tensor
8799 # ℝ^{B×L×d} → ℝ^{B×h×L×(d/h)}
88100 q = self .q_proj (query ).view (batch_size , - 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
89101 k = self .k_proj (key ).view (batch_size , - 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
90102 v = self .v_proj (value ).view (batch_size , - 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
103+
104+ logger .debug (f"After projection shapes - Q: { q .shape } , K: { k .shape } , V: { v .shape } " )
105+
91106 scores = torch .matmul (q , k .transpose (- 2 , - 1 )) * self .scale
92-
107+ logger .debug (f"Attention scores shape: { scores .shape } " )
108+
93109 if mask is not None :
110+ logger .debug (f"Applying mask with shape: { mask .shape } " )
94111 scores = scores .masked_fill (mask == 0 , float ('-inf' ))
95112
96113 # Compute attention distribution
97114 # α = softmax(QK^T/√d_k)
98115 # Compute weighted context
99116 # Σ_i α_i v_i
100117 attn_weights = F .softmax (scores , dim = - 1 )
118+ logger .debug (f"Attention weights shape: { attn_weights .shape } " )
119+
101120 attn_weights = self .dropout (attn_weights )
102121 attn_output = torch .matmul (attn_weights , v )
122+ logger .debug (f"Attention output shape (before reshape): { attn_output .shape } " )
103123
104124 # Restore tensor dimensionality
105125 # ℝ^{B×h×L×(d/h)} → ℝ^{B×L×d}
106126 attn_output = attn_output .transpose (1 , 2 ).contiguous ().view (batch_size , - 1 , self .embed_dim )
107-
127+ logger .debug (f"Final output shape: { attn_output .shape } " )
128+
108129 return self .out_proj (attn_output )
109130
110131
@@ -120,6 +141,8 @@ def __init__(
120141 num_modalities : int = 2
121142 ):
122143 super ().__init__ ()
144+ logger .info (f"Initialising CrossModalAttention with { num_modalities } modalities" )
145+
123146 self .num_modalities = num_modalities
124147
125148 # Parallel attention mechanisms for M modalities
@@ -136,25 +159,33 @@ def forward(
136159 modalities : Dict [str , torch .Tensor ],
137160 mask : Optional [torch .Tensor ] = None
138161 ) -> Dict [str , torch .Tensor ]:
162+ logger .info ("Processing CrossModalAttention forward pass" )
163+ logger .debug (f"Input modalities: { [f'{ k } : { v .shape } ' for k , v in modalities .items ()]} " )
164+
139165 updated_modalities = {}
140166 modality_keys = list (modalities .keys ())
141167
142168 for i , key in enumerate (modality_keys ):
169+ logger .debug (f"Processing modality: { key } " )
143170 query = modalities [key ]
144171 other_modalities = [modalities [k ] for k in modality_keys if k != key ]
145172
146173 if other_modalities :
147174 # Concatenate context modalities
148175 # C = [m_1; ...; m_{i-1}; m_{i+1}; ...; m_M]
176+ logger .debug (f"Concatenating { len (other_modalities )} other modalities" )
149177 key_value = torch .cat (other_modalities , dim = 1 )
178+ logger .debug (f"Concatenated key_value shape: { key_value .shape } " )
150179 attn_output = self .attention_blocks [i ](query , key_value , key_value , mask )
151180 else :
152181 # Apply self-attention
153182 # A(x,x,x) when |M| = 1
183+ logger .debug ("No other modalities found, applying self-attention" )
154184 attn_output = self .attention_blocks [i ](query , query , query , mask )
155185
156186 attn_output = self .dropout (attn_output )
157187 updated_modalities [key ] = self .layer_norms [i ](query + attn_output )
188+ logger .debug (f"Updated modality { key } shape: { updated_modalities [key ].shape } " )
158189
159190 return updated_modalities
160191
@@ -174,8 +205,9 @@ def __init__(
174205 num_heads : int ,
175206 dropout : float = 0.1
176207 ):
177-
178208 super ().__init__ ()
209+ logger .info (f"Initialising SelfAttention with embed_dim={ embed_dim } , num_heads={ num_heads } " )
210+
179211 self .attention = MultiheadAttention (embed_dim , num_heads , dropout )
180212 self .layer_norm = nn .LayerNorm (embed_dim )
181213 self .dropout = nn .Dropout (dropout )
@@ -186,10 +218,12 @@ def forward(
186218 x : torch .Tensor ,
187219 mask : Optional [torch .Tensor ] = None
188220 ) -> torch .Tensor :
221+ logger .debug (f"SelfAttention input shape: { x .shape } " )
189222
190223 # Self-attention operation
191224 # SA(x) = LayerNorm(x + A(x,x,x))
192225 attn_output = self .attention (x , x , x , mask )
226+ logger .debug (f"SelfAttention output shape (pre-dropout): { attn_output .shape } " )
193227 attn_output = self .dropout (attn_output )
194228 return self .layer_norm (x + attn_output )
195229
0 commit comments