Skip to content

Commit 8c0d00b

Browse files
authored
[grpo] fix incorrect placement of data in eval_queue during async_generate (#3573)
* fix * move to cur_queue to property * fix --------- Co-authored-by: hjh <[email protected]>
1 parent 7d8b5b9 commit 8c0d00b

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ def __init__(self,
102102
from swift.trainers.rlhf_arguments import GRPOConfig
103103
args: GRPOConfig = kwargs['args']
104104
self.args = args
105-
self.queue = None
106105
self.train_queue = Queue()
107106
self.eval_queue = Queue()
108107
self.processing_class = kwargs.get('template').tokenizer
@@ -598,7 +597,7 @@ def patch_merge(model):
598597
unwrapped_model.unmerge_adapter()
599598

600599
def _wait_queue(self):
601-
while self.queue.empty():
600+
while self._queue.empty():
602601
time.sleep(0.01)
603602

604603
@staticmethod
@@ -621,9 +620,11 @@ def infer_task():
621620
return result
622621

623622
future: Future = self.executor.submit(infer_task)
623+
# pre-fetch the queue to avoid switching back to eval_queue at the end of training sample sampling
624+
current_queue = self._queue
624625

625626
def done(_self):
626-
self.queue.put(DataCache(inputs, _self.result(), distributed_idx))
627+
current_queue.put(DataCache(inputs, _self.result(), distributed_idx))
627628

628629
future.add_done_callback(done)
629630

@@ -634,9 +635,9 @@ def _prefetch(self, dataloader):
634635
if self.infer_rank >= 0:
635636
_input_slice = np.array(all_inputs)[distributed_idx[self.infer_rank]]
636637
outputs = self.engine.infer(_input_slice, self.request_config, use_tqdm=False)
637-
self.queue.put(DataCache(inputs, outputs, distributed_idx))
638+
self._queue.put(DataCache(inputs, outputs, distributed_idx))
638639
else:
639-
self.queue.put(DataCache(inputs, [], distributed_idx))
640+
self._queue.put(DataCache(inputs, [], distributed_idx))
640641
if self.accelerator.num_processes > 1:
641642
self.accelerator.wait_for_everyone()
642643

@@ -666,7 +667,7 @@ def _fast_infer(self, inputs):
666667
_input_slice = np.array(all_inputs)[distributed_idx[self.infer_rank]]
667668
if self.args.async_generate:
668669
self.async_infer(inputs, _input_slice, distributed_idx)
669-
data_cache = self.queue.get()
670+
data_cache = self._queue.get()
670671
inputs = data_cache.inputs
671672
outputs = data_cache.outputs
672673
distributed_idx = data_cache.distributed_idx
@@ -690,8 +691,8 @@ def _fast_infer(self, inputs):
690691
else:
691692
if self.args.async_generate:
692693
# using old model to generate, which will ignore the `clip` of advantages.
693-
self.queue.put(DataCache(inputs, [], distributed_idx))
694-
data_cache = self.queue.get()
694+
self._queue.put(DataCache(inputs, [], distributed_idx))
695+
data_cache = self._queue.get()
695696
inputs = data_cache.inputs
696697
distributed_idx = data_cache.distributed_idx
697698
outputs = []
@@ -907,12 +908,17 @@ def _get_per_token_logps(self, model, inputs):
907908
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
908909

909910
def evaluation_loop(self, dataloader, *args, **kwargs):
910-
self.queue = self.eval_queue
911-
if self.queue.empty() and self.args.async_generate:
911+
if self._queue.empty() and self.args.async_generate:
912912
self._prefetch(dataloader)
913913
metric_key_prefix = kwargs['metric_key_prefix']
914914
output = super().evaluation_loop(dataloader, *args, **kwargs)
915915
metrics = {f'{metric_key_prefix}_{key}': sum(val) / len(val) for key, val in self._metrics['eval'].items()}
916916
output.metrics.update(metrics)
917-
self.queue = self.train_queue
918917
return output
918+
919+
@property
920+
def _queue(self):
921+
if self.control.should_evaluate:
922+
return self.eval_queue
923+
else:
924+
return self.train_queue

0 commit comments

Comments
 (0)