diff --git a/multimodal_transformers/data/tabular_torch_dataset.py b/multimodal_transformers/data/tabular_torch_dataset.py index 55fef1f..e5bb058 100644 --- a/multimodal_transformers/data/tabular_torch_dataset.py +++ b/multimodal_transformers/data/tabular_torch_dataset.py @@ -42,7 +42,7 @@ def __init__( ): self.df = df self.encodings = encodings - self.cat_feats = categorical_feats.values + self.cat_feats = categorical_feats.values if categorical_feats is not None else None self.numerical_feats = numerical_feats self.labels = labels self.label_list = (