|
1 |
| -from typing import List, Tuple |
| 1 | +from typing import List, Tuple, Optional, Any |
2 | 2 | import os
|
3 | 3 | import uuid
|
| 4 | +import asyncio |
| 5 | +import numpy as np |
4 | 6 | from loguru import logger
|
| 7 | +from easydict import EasyDict |
5 | 8 | from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams, RequestOutput
|
| 9 | +from transformers import AutoTokenizer |
| 10 | + |
| 11 | +from ding.utils.data.rlhf_online_dataset import OnlineRLDataset |
| 12 | +from ding.utils import SERIAL_COLLECTOR_REGISTRY |
| 13 | +from .base_serial_collector import ISerialCollector |
6 | 14 |
|
7 | 15 |
|
8 | 16 | class VllmActor:
|
@@ -145,3 +153,156 @@ async def generate(
|
145 | 153 | # Use raw logprobs as confidence scores
|
146 | 154 | confidence_scores = [x.cumulative_logprob for x in response.outputs]
|
147 | 155 | return [(x.text.strip(), conf) for x, conf in zip(response.outputs, confidence_scores)]
|
| 156 | + |
| 157 | + |
| 158 | +@SERIAL_COLLECTOR_REGISTRY.register('vllm') |
| 159 | +class VllmCollector(ISerialCollector): |
| 160 | + """ |
| 161 | + Overview: |
| 162 | + Collector implementation for vLLM-based language models (LLM/VLM). |
| 163 | + This collector manages the interaction with vLLM models for text generation tasks. |
| 164 | + """ |
| 165 | + config = dict( |
| 166 | + # (str) LLM/VLM model path |
| 167 | + model_path='', |
| 168 | + # (int) Maximum number of tokens to generate per request |
| 169 | + max_tokens=1024, |
| 170 | + # (float) Temperature for sampling, 0 means greedy decoding |
| 171 | + temperature=0.0, |
| 172 | + # (dict) Multimodal processor kwargs for vision-language models |
| 173 | + mm_processor_kwargs={ |
| 174 | + "min_pixels": 28 * 28, |
| 175 | + "max_pixels": 1280 * 28 * 28, |
| 176 | + }, |
| 177 | + # Dataset related configs |
| 178 | + # (str) Key to access the input data in the dataset |
| 179 | + input_key='input', |
| 180 | + # (bool) Whether to apply a chat template to the input |
| 181 | + apply_chat_template=False, |
| 182 | + # (str) Template for the input |
| 183 | + input_template=None, |
| 184 | + # (bool) Whether to shuffle the dataset |
| 185 | + shuffle=True, |
| 186 | + ) |
| 187 | + |
| 188 | + def __init__(self, cfg: EasyDict) -> None: |
| 189 | + """ |
| 190 | + Overview: |
| 191 | + Initialize the VllmCollector with configuration. |
| 192 | + Arguments: |
| 193 | + - cfg (:obj:`EasyDict`): Configuration for the collector including model path, generation parameters, |
| 194 | + and dataset configuration |
| 195 | + """ |
| 196 | + super().__init__() |
| 197 | + self._cfg = cfg |
| 198 | + self._envstep = 0 |
| 199 | + |
| 200 | + # Initialize the tokenizer and dataset |
| 201 | + self._tokenizer = AutoTokenizer.from_pretrained(cfg.model_path) |
| 202 | + self._dataset = OnlineRLDataset( |
| 203 | + dataset=cfg.dataset, |
| 204 | + tokenizer=self._tokenizer, |
| 205 | + input_key=cfg.input_key, |
| 206 | + apply_chat_template=cfg.apply_chat_template, |
| 207 | + input_template=cfg.input_template, |
| 208 | + ) |
| 209 | + |
| 210 | + self._model = VllmActor(model_path=cfg.model_path, mm_processor_kwargs=cfg.mm_processor_kwargs) |
| 211 | + self.reset() |
| 212 | + |
| 213 | + def reset(self) -> None: |
| 214 | + """ |
| 215 | + Overview: |
| 216 | + Reset the collector, including the dataset index. |
| 217 | + """ |
| 218 | + self._index = np.arange(len(self._dataset)) |
| 219 | + if self._cfg.shuffle: |
| 220 | + np.random.shuffle(self._index) |
| 221 | + |
| 222 | + def reset_policy(self, _model: Optional[str] = None) -> None: |
| 223 | + """ |
| 224 | + Overview: |
| 225 | + Since LLM generation does not require a explicit policy and env, this function is empty. |
| 226 | + """ |
| 227 | + pass |
| 228 | + |
| 229 | + def reset_env(self, _env: Optional[Any] = None) -> None: |
| 230 | + """ |
| 231 | + Overview: |
| 232 | + Since LLM generation does not require a explicit policy and env, this function is empty. |
| 233 | + """ |
| 234 | + pass |
| 235 | + |
| 236 | + def collect( |
| 237 | + self, |
| 238 | + n_samples: int = 100, |
| 239 | + num_samples_per_prompt: int = 1, |
| 240 | + train_iter: int = 0, |
| 241 | + ) -> List[Tuple[str, float]]: |
| 242 | + """ |
| 243 | + Overview: |
| 244 | + Collect generated responses from the vLLM model. |
| 245 | + Arguments: |
| 246 | + - n_samples (:obj:`int`): Number of prompts to generate. |
| 247 | + - num_samples_per_prompt (:obj:`int`): Number of samples to generate per prompt. |
| 248 | + - train_iter (:obj:`int`): Current training iteration, used for logging. |
| 249 | + Returns: |
| 250 | + - responses (:obj:`List[Tuple[str, float]]`): List of (generated_text, confidence_score) pairs |
| 251 | + """ |
| 252 | + if self._model is None: |
| 253 | + raise RuntimeError("Model not initialized. Call `reset` method first.") |
| 254 | + |
| 255 | + prompt = self._dataset[self._index[:n_samples]] |
| 256 | + # recusively update the index |
| 257 | + self._index = self._index[n_samples:] + self._index[:n_samples] |
| 258 | + |
| 259 | + self._envstep += n_samples |
| 260 | + |
| 261 | + # Get the current event loop or create a new one |
| 262 | + try: |
| 263 | + loop = asyncio.get_event_loop() |
| 264 | + except RuntimeError: |
| 265 | + loop = asyncio.new_event_loop() |
| 266 | + asyncio.set_event_loop(loop) |
| 267 | + |
| 268 | + # Run the async generate method in the event loop |
| 269 | + return loop.run_until_complete( |
| 270 | + self._model.generate( |
| 271 | + prompt=prompt, |
| 272 | + num_samples=num_samples_per_prompt, |
| 273 | + max_tokens=self._cfg.max_tokens, |
| 274 | + temperature=self._cfg.temperature |
| 275 | + ) |
| 276 | + ) |
| 277 | + |
| 278 | + @property |
| 279 | + def envstep(self) -> int: |
| 280 | + """ |
| 281 | + Overview: |
| 282 | + Get the current environment step count. |
| 283 | + Returns: |
| 284 | + - count (:obj:`int`): Current environment step count |
| 285 | + """ |
| 286 | + return self._envstep |
| 287 | + |
| 288 | + @envstep.setter |
| 289 | + def envstep(self, value: int) -> None: |
| 290 | + """ |
| 291 | + Overview: |
| 292 | + Set the current environment step count. |
| 293 | + """ |
| 294 | + self._envstep = value |
| 295 | + |
| 296 | + def close(self) -> None: |
| 297 | + """ |
| 298 | + Overview: |
| 299 | + Close the collector. |
| 300 | + """ |
| 301 | + pass |
| 302 | + |
| 303 | + def __del__(self) -> None: |
| 304 | + """ |
| 305 | + Overview: |
| 306 | + Destructor for the collector. |
| 307 | + """ |
| 308 | + self.close() |
0 commit comments