@@ -100,21 +100,46 @@ def __init__(
100
100
if use_residual :
101
101
self .layer_norm = nn .LayerNorm (hidden_dim )
102
102
103
+ # def _validate_features(self, features: Dict[str, torch.Tensor]) -> None:
104
+ # """ Validates input feature dimensions and sequence lengths """
105
+
106
+ # if not features:
107
+ # raise ValueError("Empty features dict")
108
+
109
+ # seq_length = None
110
+ # for name, feat in features.items():
111
+ # if feat is None:
112
+ # raise ValueError(f"None tensor for modality: {name}")
113
+
114
+ # if seq_length is None:
115
+ # seq_length = feat.size(1)
116
+ # elif feat.size(1) != seq_length:
117
+ # raise ValueError("All modalities must have same sequence length")
118
+
103
119
def _validate_features (self , features : Dict [str , torch .Tensor ]) -> None :
104
120
""" Validates input feature dimensions and sequence lengths """
105
-
106
- if not features :
121
+
122
+ # Handle case where features might be a single tensor or empty
123
+ if not isinstance (features , dict ) or not features :
124
+ if isinstance (features , torch .Tensor ):
125
+ return # Skip validation for single tensor
107
126
raise ValueError ("Empty features dict" )
108
-
109
- seq_length = None
127
+
128
+ # Collect feature lengths for features with 2D+ tensors
129
+ multi_dim_features = {}
110
130
for name , feat in features .items ():
111
131
if feat is None :
112
132
raise ValueError (f"None tensor for modality: { name } " )
113
-
114
- if seq_length is None :
115
- seq_length = feat .size (1 )
116
- elif feat .size (1 ) != seq_length :
117
- raise ValueError ("All modalities must have same sequence length" )
133
+
134
+ # Only consider features with more than 1 dimension
135
+ if feat .ndim > 1 :
136
+ multi_dim_features [name ] = feat .size (1 )
137
+
138
+ # If more than one unique length, raise an error
139
+ feature_lengths = set (multi_dim_features .values ())
140
+ if len (feature_lengths ) > 1 :
141
+ raise ValueError (f"All modalities must have same sequence length. Current lengths: { multi_dim_features } " )
142
+
118
143
119
144
def compute_modality_weights (
120
145
self ,
@@ -214,8 +239,6 @@ def forward(
214
239
return fused
215
240
216
241
217
-
218
-
219
242
class ModalityGating (AbstractFusionBlock ):
220
243
""" Implementation of modality specific gating mechanism """
221
244
0 commit comments