Skip to content

Commit

Permalink
Fix missing initialization of contrastive loss
Browse files Browse the repository at this point in the history
  • Loading branch information
psaegert committed Jan 29, 2025
1 parent 154023d commit 9d133e5
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/flash_ansr/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,8 @@ def _train_batch(self, batch: dict[str, torch.Tensor], numeric_prediction_loss_w
# Calculate the loss
ce_loss: torch.Tensor = self.cross_entropy_loss(flat_logits, flat_labels)

contrastive_loss = torch.tensor(0, device=self.model.device, dtype=torch.float32)

if contrastive_loss_weight > 0:
# Use memory (embeddings of encoder) that the model cached during the forward pass
contrastive_loss = self.contrastive_loss_fn(self.model.memory.reshape(self.model.memory.shape[0], -1), skeleton_hashes)
Expand Down Expand Up @@ -531,6 +533,8 @@ def _validate(self, val_dataset: FlashANSRDataset, numeric_prediction_loss_weigh
# Calculate the loss
ce_loss: torch.Tensor = self.cross_entropy_loss(flat_logits, flat_labels)

contrastive_loss = torch.tensor(0, device=self.model.device, dtype=torch.float32)

if contrastive_loss_weight > 0:
# Use memory (embeddings of encoder) that the model cached during the forward pass
contrastive_loss = self.contrastive_loss_fn(self.model.memory.reshape(self.model.memory.shape[0], -1), skeleton_hashes)
Expand Down

0 comments on commit 9d133e5

Please sign in to comment.