Skip to content

Commit 159a72d

Browse files
committed
Further updates
1 parent 227f988 commit 159a72d

File tree

2 files changed

+266
-238
lines changed

2 files changed

+266
-238
lines changed

pvnet/models/multimodal/fusion_blocks.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -100,21 +100,46 @@ def __init__(
100100
if use_residual:
101101
self.layer_norm = nn.LayerNorm(hidden_dim)
102102

103+
# def _validate_features(self, features: Dict[str, torch.Tensor]) -> None:
104+
# """ Validates input feature dimensions and sequence lengths """
105+
106+
# if not features:
107+
# raise ValueError("Empty features dict")
108+
109+
# seq_length = None
110+
# for name, feat in features.items():
111+
# if feat is None:
112+
# raise ValueError(f"None tensor for modality: {name}")
113+
114+
# if seq_length is None:
115+
# seq_length = feat.size(1)
116+
# elif feat.size(1) != seq_length:
117+
# raise ValueError("All modalities must have same sequence length")
118+
103119
def _validate_features(self, features: Dict[str, torch.Tensor]) -> None:
104120
""" Validates input feature dimensions and sequence lengths """
105-
106-
if not features:
121+
122+
# Handle case where features might be a single tensor or empty
123+
if not isinstance(features, dict) or not features:
124+
if isinstance(features, torch.Tensor):
125+
return # Skip validation for single tensor
107126
raise ValueError("Empty features dict")
108-
109-
seq_length = None
127+
128+
# Collect feature lengths for features with 2D+ tensors
129+
multi_dim_features = {}
110130
for name, feat in features.items():
111131
if feat is None:
112132
raise ValueError(f"None tensor for modality: {name}")
113-
114-
if seq_length is None:
115-
seq_length = feat.size(1)
116-
elif feat.size(1) != seq_length:
117-
raise ValueError("All modalities must have same sequence length")
133+
134+
# Only consider features with more than 1 dimension
135+
if feat.ndim > 1:
136+
multi_dim_features[name] = feat.size(1)
137+
138+
# If more than one unique length, raise an error
139+
feature_lengths = set(multi_dim_features.values())
140+
if len(feature_lengths) > 1:
141+
raise ValueError(f"All modalities must have same sequence length. Current lengths: {multi_dim_features}")
142+
118143

119144
def compute_modality_weights(
120145
self,
@@ -214,8 +239,6 @@ def forward(
214239
return fused
215240

216241

217-
218-
219242
class ModalityGating(AbstractFusionBlock):
220243
""" Implementation of modality specific gating mechanism """
221244

0 commit comments

Comments
 (0)