File tree Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Original file line number Diff line number Diff line change @@ -187,7 +187,10 @@ def main():
187
187
def collate_fn (batch ):
188
188
input_ids = torch .tensor ([item ["input_ids" ] for item in batch ], dtype = torch .long )
189
189
labels = torch .tensor ([item ["labels" ] for item in batch ], dtype = torch .long )
190
- return {"input_ids" : input_ids , "labels" : labels }
190
+ # Transformers expect `labels` to not be shifted, though we already shifted them, so we pass them both
191
+ # We need to pass both `shift_labels` and `labels` to the model, as the loss is calculated inside `if labels is not None`
192
+ # `shift_labels` take precedence over `labels` in this case
193
+ return {"input_ids" : input_ids , "labels" : labels , "shift_labels" : labels }
191
194
192
195
# We keep batch size at 1, as it is basically the same as sequence length, which we use instead
193
196
dataloader = DataLoader (dataset , batch_size = 1 , collate_fn = collate_fn )
You can’t perform that action at this time.
0 commit comments