@@ -102,7 +102,6 @@ def __init__(self,
102
102
from swift .trainers .rlhf_arguments import GRPOConfig
103
103
args : GRPOConfig = kwargs ['args' ]
104
104
self .args = args
105
- self .queue = None
106
105
self .train_queue = Queue ()
107
106
self .eval_queue = Queue ()
108
107
self .processing_class = kwargs .get ('template' ).tokenizer
@@ -598,7 +597,7 @@ def patch_merge(model):
598
597
unwrapped_model .unmerge_adapter ()
599
598
600
599
def _wait_queue (self ):
601
- while self .queue .empty ():
600
+ while self ._queue .empty ():
602
601
time .sleep (0.01 )
603
602
604
603
@staticmethod
@@ -621,9 +620,11 @@ def infer_task():
621
620
return result
622
621
623
622
future : Future = self .executor .submit (infer_task )
623
+ # pre-fetch the queue to avoid switching back to eval_queue at the end of training sample sampling
624
+ current_queue = self ._queue
624
625
625
626
def done (_self ):
626
- self . queue .put (DataCache (inputs , _self .result (), distributed_idx ))
627
+ current_queue .put (DataCache (inputs , _self .result (), distributed_idx ))
627
628
628
629
future .add_done_callback (done )
629
630
@@ -634,9 +635,9 @@ def _prefetch(self, dataloader):
634
635
if self .infer_rank >= 0 :
635
636
_input_slice = np .array (all_inputs )[distributed_idx [self .infer_rank ]]
636
637
outputs = self .engine .infer (_input_slice , self .request_config , use_tqdm = False )
637
- self .queue .put (DataCache (inputs , outputs , distributed_idx ))
638
+ self ._queue .put (DataCache (inputs , outputs , distributed_idx ))
638
639
else :
639
- self .queue .put (DataCache (inputs , [], distributed_idx ))
640
+ self ._queue .put (DataCache (inputs , [], distributed_idx ))
640
641
if self .accelerator .num_processes > 1 :
641
642
self .accelerator .wait_for_everyone ()
642
643
@@ -666,7 +667,7 @@ def _fast_infer(self, inputs):
666
667
_input_slice = np .array (all_inputs )[distributed_idx [self .infer_rank ]]
667
668
if self .args .async_generate :
668
669
self .async_infer (inputs , _input_slice , distributed_idx )
669
- data_cache = self .queue .get ()
670
+ data_cache = self ._queue .get ()
670
671
inputs = data_cache .inputs
671
672
outputs = data_cache .outputs
672
673
distributed_idx = data_cache .distributed_idx
@@ -690,8 +691,8 @@ def _fast_infer(self, inputs):
690
691
else :
691
692
if self .args .async_generate :
692
693
# using old model to generate, which will ignore the `clip` of advantages.
693
- self .queue .put (DataCache (inputs , [], distributed_idx ))
694
- data_cache = self .queue .get ()
694
+ self ._queue .put (DataCache (inputs , [], distributed_idx ))
695
+ data_cache = self ._queue .get ()
695
696
inputs = data_cache .inputs
696
697
distributed_idx = data_cache .distributed_idx
697
698
outputs = []
@@ -907,12 +908,17 @@ def _get_per_token_logps(self, model, inputs):
907
908
return selective_log_softmax (logits , input_ids ) # compute logprobs for the input tokens
908
909
909
910
def evaluation_loop (self , dataloader , * args , ** kwargs ):
910
- self .queue = self .eval_queue
911
- if self .queue .empty () and self .args .async_generate :
911
+ if self ._queue .empty () and self .args .async_generate :
912
912
self ._prefetch (dataloader )
913
913
metric_key_prefix = kwargs ['metric_key_prefix' ]
914
914
output = super ().evaluation_loop (dataloader , * args , ** kwargs )
915
915
metrics = {f'{ metric_key_prefix } _{ key } ' : sum (val ) / len (val ) for key , val in self ._metrics ['eval' ].items ()}
916
916
output .metrics .update (metrics )
917
- self .queue = self .train_queue
918
917
return output
918
+
919
+ @property
920
+ def _queue (self ):
921
+ if self .control .should_evaluate :
922
+ return self .eval_queue
923
+ else :
924
+ return self .train_queue
0 commit comments