Skip to content

Commit be826a6

Browse files
authored
Fix: correct labels (#3637)
1 parent 5939640 commit be826a6

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

examples/fsdp2/fsdp2_fp8.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,10 @@ def main():
187187
def collate_fn(batch):
188188
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
189189
labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
190-
return {"input_ids": input_ids, "labels": labels}
190+
# Transformers expect `labels` to not be shifted, though we already shifted them, so we pass them both
191+
# We need to pass both `shift_labels` and `labels` to the model, as the loss is calculated inside `if labels is not None`
192+
# `shift_labels` take precedence over `labels` in this case
193+
return {"input_ids": input_ids, "labels": labels, "shift_labels": labels}
191194

192195
# We keep batch size at 1, as it is basically the same as sequence length, which we use instead
193196
dataloader = DataLoader(dataset, batch_size=1, collate_fn=collate_fn)

0 commit comments

Comments
 (0)