Skip to content

Commit 3b7903a

Browse files
committed
feature(nyz): add vllm collector interface definition
1 parent c45e429 commit 3b7903a

File tree

1 file changed

+162
-1
lines changed

1 file changed

+162
-1
lines changed

ding/worker/collector/vllm_collector.py

Lines changed: 162 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
1-
from typing import List, Tuple
1+
from typing import List, Tuple, Optional, Any
22
import os
33
import uuid
4+
import asyncio
5+
import numpy as np
46
from loguru import logger
7+
from easydict import EasyDict
58
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams, RequestOutput
9+
from transformers import AutoTokenizer
10+
11+
from ding.utils.data.rlhf_online_dataset import OnlineRLDataset
12+
from ding.utils import SERIAL_COLLECTOR_REGISTRY
13+
from .base_serial_collector import ISerialCollector
614

715

816
class VllmActor:
@@ -145,3 +153,156 @@ async def generate(
145153
# Use raw logprobs as confidence scores
146154
confidence_scores = [x.cumulative_logprob for x in response.outputs]
147155
return [(x.text.strip(), conf) for x, conf in zip(response.outputs, confidence_scores)]
156+
157+
158+
@SERIAL_COLLECTOR_REGISTRY.register('vllm')
159+
class VllmCollector(ISerialCollector):
160+
"""
161+
Overview:
162+
Collector implementation for vLLM-based language models (LLM/VLM).
163+
This collector manages the interaction with vLLM models for text generation tasks.
164+
"""
165+
config = dict(
166+
# (str) LLM/VLM model path
167+
model_path='',
168+
# (int) Maximum number of tokens to generate per request
169+
max_tokens=1024,
170+
# (float) Temperature for sampling, 0 means greedy decoding
171+
temperature=0.0,
172+
# (dict) Multimodal processor kwargs for vision-language models
173+
mm_processor_kwargs={
174+
"min_pixels": 28 * 28,
175+
"max_pixels": 1280 * 28 * 28,
176+
},
177+
# Dataset related configs
178+
# (str) Key to access the input data in the dataset
179+
input_key='input',
180+
# (bool) Whether to apply a chat template to the input
181+
apply_chat_template=False,
182+
# (str) Template for the input
183+
input_template=None,
184+
# (bool) Whether to shuffle the dataset
185+
shuffle=True,
186+
)
187+
188+
def __init__(self, cfg: EasyDict) -> None:
189+
"""
190+
Overview:
191+
Initialize the VllmCollector with configuration.
192+
Arguments:
193+
- cfg (:obj:`EasyDict`): Configuration for the collector including model path, generation parameters,
194+
and dataset configuration
195+
"""
196+
super().__init__()
197+
self._cfg = cfg
198+
self._envstep = 0
199+
200+
# Initialize the tokenizer and dataset
201+
self._tokenizer = AutoTokenizer.from_pretrained(cfg.model_path)
202+
self._dataset = OnlineRLDataset(
203+
dataset=cfg.dataset,
204+
tokenizer=self._tokenizer,
205+
input_key=cfg.input_key,
206+
apply_chat_template=cfg.apply_chat_template,
207+
input_template=cfg.input_template,
208+
)
209+
210+
self._model = VllmActor(model_path=cfg.model_path, mm_processor_kwargs=cfg.mm_processor_kwargs)
211+
self.reset()
212+
213+
def reset(self) -> None:
214+
"""
215+
Overview:
216+
Reset the collector, including the dataset index.
217+
"""
218+
self._index = np.arange(len(self._dataset))
219+
if self._cfg.shuffle:
220+
np.random.shuffle(self._index)
221+
222+
def reset_policy(self, _model: Optional[str] = None) -> None:
223+
"""
224+
Overview:
225+
Since LLM generation does not require a explicit policy and env, this function is empty.
226+
"""
227+
pass
228+
229+
def reset_env(self, _env: Optional[Any] = None) -> None:
230+
"""
231+
Overview:
232+
Since LLM generation does not require a explicit policy and env, this function is empty.
233+
"""
234+
pass
235+
236+
def collect(
237+
self,
238+
n_samples: int = 100,
239+
num_samples_per_prompt: int = 1,
240+
train_iter: int = 0,
241+
) -> List[Tuple[str, float]]:
242+
"""
243+
Overview:
244+
Collect generated responses from the vLLM model.
245+
Arguments:
246+
- n_samples (:obj:`int`): Number of prompts to generate.
247+
- num_samples_per_prompt (:obj:`int`): Number of samples to generate per prompt.
248+
- train_iter (:obj:`int`): Current training iteration, used for logging.
249+
Returns:
250+
- responses (:obj:`List[Tuple[str, float]]`): List of (generated_text, confidence_score) pairs
251+
"""
252+
if self._model is None:
253+
raise RuntimeError("Model not initialized. Call `reset` method first.")
254+
255+
prompt = self._dataset[self._index[:n_samples]]
256+
# recusively update the index
257+
self._index = self._index[n_samples:] + self._index[:n_samples]
258+
259+
self._envstep += n_samples
260+
261+
# Get the current event loop or create a new one
262+
try:
263+
loop = asyncio.get_event_loop()
264+
except RuntimeError:
265+
loop = asyncio.new_event_loop()
266+
asyncio.set_event_loop(loop)
267+
268+
# Run the async generate method in the event loop
269+
return loop.run_until_complete(
270+
self._model.generate(
271+
prompt=prompt,
272+
num_samples=num_samples_per_prompt,
273+
max_tokens=self._cfg.max_tokens,
274+
temperature=self._cfg.temperature
275+
)
276+
)
277+
278+
@property
279+
def envstep(self) -> int:
280+
"""
281+
Overview:
282+
Get the current environment step count.
283+
Returns:
284+
- count (:obj:`int`): Current environment step count
285+
"""
286+
return self._envstep
287+
288+
@envstep.setter
289+
def envstep(self, value: int) -> None:
290+
"""
291+
Overview:
292+
Set the current environment step count.
293+
"""
294+
self._envstep = value
295+
296+
def close(self) -> None:
297+
"""
298+
Overview:
299+
Close the collector.
300+
"""
301+
pass
302+
303+
def __del__(self) -> None:
304+
"""
305+
Overview:
306+
Destructor for the collector.
307+
"""
308+
self.close()

0 commit comments

Comments
 (0)