Skip to content

Commit

Permalink
🌌 Fix logits computation in trainer prediction step (huggingface#2969)
Browse files Browse the repository at this point in the history
* Fix logits computation in DPO trainer prediction step

* fix compute_metrics for bco and test

* same for cpo

* same from dpo

* for kto

* anf finally orpo

* Apply style fixes

---------

Co-authored-by: kyungdae-jo <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
5 people authored Feb 27, 2025
1 parent aa18ecf commit c0854c3
Show file tree
Hide file tree
Showing 10 changed files with 201 additions and 19 deletions.
38 changes: 38 additions & 0 deletions tests/test_bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,3 +407,41 @@ def test_bco_lora_save(self):
AutoModelForCausalLM.from_pretrained(tmp_dir)
except OSError:
self.fail("Loading the saved peft adapter failed")

@require_sklearn
def test_compute_metrics(self):
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer.pad_token = tokenizer.eos_token

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference")

def dummy_compute_metrics(*args, **kwargs):
return {"test": 0.0}

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = BCOConfig(
output_dir=tmp_dir,
remove_unused_columns=False,
per_device_train_batch_size=2,
do_eval=True,
eval_strategy="steps",
eval_steps=1,
per_device_eval_batch_size=2,
report_to="none",
)

trainer = BCOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
processing_class=tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
compute_metrics=dummy_compute_metrics,
)

trainer.train()

self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0)
35 changes: 35 additions & 0 deletions tests/test_cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,38 @@ def test_cpo_trainer_with_lora(self, config_name):
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
self.assertFalse(torch.equal(param, new_param))

def test_compute_metrics(self):
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer.pad_token = tokenizer.eos_token

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

def dummy_compute_metrics(*args, **kwargs):
return {"test": 0.0}

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = CPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
remove_unused_columns=False,
do_eval=True,
eval_strategy="steps",
eval_steps=1,
per_device_eval_batch_size=2,
report_to="none",
)

trainer = CPOTrainer(
model=model,
args=training_args,
processing_class=tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
compute_metrics=dummy_compute_metrics,
)

trainer.train()

self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0)
36 changes: 36 additions & 0 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,42 @@ def test_padding_free(self):
if param.sum() != 0: # ignore 0 biases
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))

def test_compute_metrics(self):
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer.pad_token = tokenizer.eos_token

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

def dummy_compute_metrics(*args, **kwargs):
return {"test": 0.0}

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
do_eval=True,
eval_strategy="steps",
eval_steps=1,
per_device_eval_batch_size=2,
report_to="none",
)

trainer = DPOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
processing_class=tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
compute_metrics=dummy_compute_metrics,
)

trainer.train()

self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0)


@require_vision
class DPOVisionTrainerTester(unittest.TestCase):
Expand Down
37 changes: 37 additions & 0 deletions tests/test_kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,40 @@ def test_kto_lora_save(self):
AutoModelForCausalLM.from_pretrained(tmp_dir)
except OSError:
self.fail("Loading the saved peft adapter failed")

def test_compute_metrics(self):
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer.pad_token = tokenizer.eos_token

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference")

def dummy_compute_metrics(*args, **kwargs):
return {"test": 0.0}

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
output_dir=tmp_dir,
remove_unused_columns=False,
per_device_train_batch_size=2,
do_eval=True,
eval_strategy="steps",
eval_steps=1,
per_device_eval_batch_size=2,
report_to="none",
)

trainer = KTOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
processing_class=tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
compute_metrics=dummy_compute_metrics,
)

trainer.train()

self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0)
35 changes: 35 additions & 0 deletions tests/test_orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,38 @@ def test_orpo_trainer_with_lora(self, config_name):
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
self.assertFalse(torch.equal(param, new_param))

def test_compute_metrics(self):
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer.pad_token = tokenizer.eos_token

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

def dummy_compute_metrics(*args, **kwargs):
return {"test": 0.0}

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = ORPOConfig(
output_dir=tmp_dir,
remove_unused_columns=False,
per_device_train_batch_size=2,
do_eval=True,
eval_strategy="steps",
eval_steps=1,
per_device_eval_batch_size=2,
report_to="none",
)

trainer = ORPOTrainer(
model=model,
args=training_args,
processing_class=tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
compute_metrics=dummy_compute_metrics,
)

trainer.train()

self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0)
13 changes: 7 additions & 6 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,12 +1385,13 @@ def prediction_step(
return (loss.detach(), None, None)

# logits for the chosen and rejected samples from model
logits_dict = {
"eval_logits/chosen": metrics["logits/chosen"],
"eval_logits/rejected": metrics["logits/rejected"],
}
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
logits_dict = {}
if "logits/chosen_sum" in metrics:
logits_dict["eval_logits/chosen"] = metrics["logits/chosen_sum"]
if "logits/rejected_sum" in metrics:
logits_dict["eval_logits/rejected"] = metrics["logits/rejected_sum"]
logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
logits = torch.tensor(logits, device=self.accelerator.device)
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)

return (loss.detach(), logits, labels)
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,8 +917,8 @@ def prediction_step(
"eval_logits/chosen": metrics["eval_logits/chosen"],
"eval_logits/rejected": metrics["eval_logits/rejected"],
}
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
logits = torch.tensor(logits, device=self.accelerator.device)
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)

return (loss.detach(), logits, labels)
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,8 +1438,8 @@ def prediction_step(
"eval_logits/chosen": metrics["eval_logits/chosen"],
"eval_logits/rejected": metrics["eval_logits/rejected"],
}
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
logits = torch.tensor(logits, device=self.accelerator.device)
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)

return (loss.detach(), logits, labels)
Expand Down
14 changes: 7 additions & 7 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,13 +1392,13 @@ def prediction_step(
return (loss.detach(), None, None)

# logits for the chosen and rejected samples from model
logits_dict = {
"eval_logits/chosen": metrics["logits/chosen"],
"eval_logits/rejected": metrics["logits/rejected"],
}
logits = torch.tensor(
[v for k, v in logits_dict.items() if k not in ignore_keys], device=self.accelerator.device
)
logits_dict = {}
if "logits/chosen_sum" in metrics:
logits_dict["eval_logits/chosen"] = metrics["logits/chosen_sum"]
if "logits/rejected_sum" in metrics:
logits_dict["eval_logits/rejected"] = metrics["logits/rejected_sum"]
logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
logits = torch.tensor(logits, device=self.accelerator.device)
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)

return (loss.detach(), logits, labels)
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,8 +937,8 @@ def prediction_step(
"eval_logits/chosen": metrics["eval_logits/chosen"],
"eval_logits/rejected": metrics["eval_logits/rejected"],
}
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
logits = [v for k, v in logits_dict.items() if k not in ignore_keys]
logits = torch.tensor(logits, device=self.accelerator.device)
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)

return (loss.detach(), logits, labels)
Expand Down

0 comments on commit c0854c3

Please sign in to comment.