diff --git a/tests/rl/test_evaluator.py b/tests/rl/test_evaluator.py deleted file mode 100644 index ffd52930a..000000000 --- a/tests/rl/test_evaluator.py +++ /dev/null @@ -1,121 +0,0 @@ -import os -import unittest -import ray -import tempfile -from transformers import AutoTokenizer - -from xtuner.v1.rl.rollout.worker import RolloutConfig -try: - from xtuner.v1.ray.judger.controller import JudgerConfig -except Exception: - class JudgerConfig: - def __init__(self, *args, **kwargs): - self.__dict__.update(kwargs) -from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers -try: - from xtuner.v1.ray.environment import SingleTurnEnvironment -except Exception: - SingleTurnEnvironment = None -from xtuner.v1.rl.evaluator import Evaluator, EvaluatorConfig -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig, DatasetConfig, OpenaiTokenizeFunctionConfig - - -MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] -TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"] - - -@unittest.skipIf(SingleTurnEnvironment is None, "ray environment unavailable") -class TestEvaluator(unittest.TestCase): - - @classmethod - def setUpClass(cls) -> None: - os.environ["XTUNER_USE_FA3"] = "1" - os.environ["LMD_SKIP_WARMUP"] = "1" - - @classmethod - def tearDownClass(cls) -> None: - del os.environ["XTUNER_USE_FA3"] - del os.environ["LMD_SKIP_WARMUP"] - - def init_config(self): - self.resources_cfg = AcceleratorResourcesConfig( - accelerator="GPU", - num_workers=8, - num_cpus_per_worker=8, - cpu_memory_per_worker=16 * 1024**3, # 16 GB - ) - self.max_prompt_length = 512 - self.max_response_length = 1024 - self.rollout_cfg = RolloutConfig( - env="test_rollout", - model_path=MODEL_PATH, - model_name=os.path.basename(MODEL_PATH).lower(), - tokenizer_path=MODEL_PATH, - tensor_parallel_size=8, - context_length=self.max_prompt_length + self.max_response_length, - worker_log_dir=self.worker_log_dir - ) - from xtuner.v1.rl.judger.gsm8k import GSM8KJudgerConfig - gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", judger_type="router") - self.judger_cfg = JudgerConfig( - reward_judger_configs=[gsm8k_judger_config], - worker_log_dir=self.worker_log_dir - ) - self.eval_dataset_cfg = [ - { - "dataset": DatasetConfig(name="gsm8k", - anno_path=TEST_DATA_PATH, - sample_ratio=1.0), - "tokenize_fn": RLTokenizeFnConfig(max_length=self.max_prompt_length) - }, - ] - self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) - self.pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) - self.test_env = SingleTurnEnvironment.remote( - "test_env", - self.pg, - self.rollout_cfg, - None, - self.judger_cfg - ) - self.sample_params = SampleParams( - top_p=1.0, - temperature=0.0, - max_tokens=self.max_response_length, - top_k=1 - ) - - def setUp(self): - ray.init(num_cpus=80) - self.model_path = MODEL_PATH - self.temp_dir = tempfile.TemporaryDirectory() - self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") - self.init_config() - - def tearDown(self): - ray.shutdown() - self.temp_dir.cleanup() - - @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_lmdeploy_evaluator(self): - def custom_compute_metric(samples): - return {"custom_accuracy": sum(s.env.judger.reward["score"] > 0 for s in samples) / len(samples)} - - evaluator_cfg = EvaluatorConfig( - dataset_cfg=self.eval_dataset_cfg, - tokenizer=self.tokenizer, - max_concurrent=16, - eval_sample_ratio=0.004, # generate 5 samples - compute_metric_func=custom_compute_metric, - sample_params=self.sample_params, - worker_log_dir=self.worker_log_dir - ) - evaluator = Evaluator.remote(evaluator_cfg, self.test_env) - try: - ray.get(evaluator.run.remote()) - except Exception as e: - self.fail(f"evaluator.run.remote() raised an exception: {e}") - -if __name__ == '__main__': - unittest.main() diff --git a/tests/rl/test_rl_train_with_sft.py b/tests/rl/test_rl_train_with_sft.py deleted file mode 100644 index e0476de71..000000000 --- a/tests/rl/test_rl_train_with_sft.py +++ /dev/null @@ -1,180 +0,0 @@ -import os -import unittest -from transformers import AutoTokenizer -import shutil -import tempfile -import json -import torch -from xtuner.v1.data_proto.sequence_context import SequenceContext -import ray -from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.rl.trainer import WorkerConfig, TrainingController, TrainingWorker as BaseTrainingWorker -from xtuner.v1.rl.loss import GRPOLossConfig as LossConfig -from xtuner.v1.model.dense.qwen3 import Qwen3Dense8BConfig -from xtuner.v1.datasets.sft_tokenize_fn import OpenaiTokenizeFunctionConfig -from xtuner.v1.loss import CELossConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.train.trainer import LoadCheckpointConfig - -QWEN3_PATH = os.environ["QWEN3_PATH"] -ALPACA_PATH = os.environ["ALPACA_PATH"] - - -class TestRLTrainWithSFT(unittest.TestCase): - def setUp(self): - ray.init(num_cpus=80, ignore_reinit_error=True) - - resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_accelerators_per_worker=1, - num_cpus_per_worker=8, - num_workers=8, - cpu_memory_per_worker=16 * 1024 ** 3, # 16 GB - ) - - pg = AutoAcceleratorWorkers.build_placement_group(resources) - self.pg = pg - - self.temp_dir = tempfile.mkdtemp() - tokenizer = AutoTokenizer.from_pretrained(QWEN3_PATH, trust_remote_code=True) - self.tokenizer = tokenizer - self.prompt_repeat_k = 8 - file = './tests/ray/rollout_output.jsonl' - with open(file, 'r') as f: - data = [json.loads(line) for line in f] - data_groups = [data[i:i + self.prompt_repeat_k] for i in range(0, len(data), self.prompt_repeat_k)] - data_groups = data_groups[:8] - data_batches = [] - for group in data_groups: - prompt_ids = tokenizer(group[0]['prompt'], return_tensors='pt')['input_ids'].flatten().tolist() - rewards = [item['reward'] for item in group] - rewards = torch.tensor(rewards, dtype=torch.float32) - advantages = (rewards - rewards.mean(0)) / (rewards.std(0) + 1e-8) - - for i in range(self.prompt_repeat_k): - item = group[i] - response_ids = tokenizer(item['response'], return_tensors='pt')['input_ids'].flatten().tolist() - input_ids = prompt_ids + response_ids - shifted_labels = [-100] * (len(prompt_ids) - 1) + response_ids + [-100] - input_ids = torch.tensor(input_ids, dtype=torch.int64).unsqueeze(0) - shifted_labels = torch.tensor(shifted_labels, dtype=torch.int64).unsqueeze(0) - data_batches.append( - dict( - seq_ctx=SequenceContext.from_input_ids((input_ids,), device="cpu"), - shifted_labels=shifted_labels, - advantage=advantages[i].item(), - ) - ) - self.data_batches = data_batches - - def tearDown(self): - shutil.rmtree(self.temp_dir) - ray.shutdown() - - def build_train_controller(self): - model_cfg = Qwen3Dense8BConfig() - optim_cfg: AdamWConfig = AdamWConfig(lr=5e-7, foreach=False) - fsdp_cfg: FSDPConfig = FSDPConfig( - torch_compile=True, - cpu_offload=False, - ep_size=1, - ) - lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=5e-7) - - dataset_config = [] - _data_cfg = {"dataset": DatasetConfig(name='apach', - anno_path=ALPACA_PATH), - "tokenize_fn": OpenaiTokenizeFunctionConfig( - chat_template='qwen3', - max_length=32768 - ) - } - dataset_config.append(_data_cfg) - - sft_dataloader_cfg = DataloaderConfig( - dataset_config_list=dataset_config, - pack_max_length=32768, - pack_to_max_length=True, - num_workers=0, - ) - sft_global_batch_size = 8 - loss_reduction = "square" - sft_loss_cfg = CELossConfig(mode="chunk", chunk_size=1024, loss_reduction=loss_reduction) - - worker_cfg: WorkerConfig = WorkerConfig( - sft_dataloader_cfg=sft_dataloader_cfg, - sft_global_batch_size=sft_global_batch_size, - sft_loss_cfg=sft_loss_cfg, - seed=42, - model_cfg=model_cfg, - optim_cfg=optim_cfg, - loss_cfg=LossConfig( - policy_loss_cfg=dict( - cliprange_high=0.28, - cliprange_low=0.2, - loss_type="vanilla", - ), - ignore_idx=-100, - use_kl_loss=True, - kl_loss_coef=0.001, - kl_loss_type="low_var_kl", - mode="eager"), - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - load_from=QWEN3_PATH, - sp_size=1, - pack_max_length=8192, - ) - - TrainingWorker = ray.remote( - runtime_env={ - "env_vars": { - "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", - "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1", - } - }, - )(BaseTrainingWorker) - train_workers, _ = AutoAcceleratorWorkers.from_placement_group( - TrainingWorker, worker_cfg, self.pg - ) - futures = [worker.test_all_reduce.remote() for worker in train_workers] - print(ray.get(futures)) - train_controller = TrainingController.remote( - workers=train_workers, - ) - ray.get(train_controller.__ray_ready__.remote()) - return train_controller - - def test_rl_train_with_sft(self): - train_controller = self.build_train_controller() - - ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=1024, rollout_idx=0)) - ray.get(train_controller.save.remote(os.path.join(self.temp_dir, "save_test"), no_save_optimizer=True)) - - log_infos = ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=1024, rollout_idx=1)) - efficient_attn_ratio_list = [] - for log_info in log_infos: - efficient_attn_ratio_list.append(log_info['sft_train_metrics']['efficient_attn_ratio']) - assert all([efficient_attn_ratio > 0 for efficient_attn_ratio in efficient_attn_ratio_list]) - - ray.kill(train_controller) - train_controller = self.build_train_controller() - load_checkpoint_cfg = LoadCheckpointConfig(checkpoint_path=os.path.join(self.temp_dir, "save_test"), - load_optimizer_states=False, - load_optimizer_args=False - ) - ray.get(train_controller.resume.remote(load_checkpoint_cfg)) - - log_infos = ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=1024, rollout_idx=1)) - new_efficient_attn_ratio_list = [] - for log_info in log_infos: - new_efficient_attn_ratio_list.append(log_info['sft_train_metrics']['efficient_attn_ratio']) - - efficient_attn_ratio_list.sort() - new_efficient_attn_ratio_list.sort() - self.assertEqual(efficient_attn_ratio_list, new_efficient_attn_ratio_list) diff --git a/tests/rl/test_rl_trainer.py b/tests/rl/test_rl_trainer.py deleted file mode 100644 index 82688151d..000000000 --- a/tests/rl/test_rl_trainer.py +++ /dev/null @@ -1,266 +0,0 @@ -import os -import tempfile -import unittest -from pathlib import Path - -import ray -import torch - -from transformers import AutoTokenizer -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.data_proto.rl_data import SampleParams -from xtuner.v1.datasets import RLTokenizeFnConfig -from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig -from xtuner.v1.model import get_model_config_from_hf -from xtuner.v1.rl.utils import AcceleratorResourcesConfig, CPUResourcesConfig -from xtuner.v1.rl.rollout.worker import RolloutConfig -try: - from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig -except Exception: - class DataFlowConfig: - def __init__(self, *args, **kwargs): - self.__dict__.update(kwargs) - - class ReplayBufferConfig: - def __init__(self, *args, **kwargs): - self.__dict__.update(kwargs) -try: - from xtuner.v1.ray.judger.controller import JudgerConfig -except Exception: - class JudgerConfig: - def __init__(self, *args, **kwargs): - self.__dict__.update(kwargs) -from xtuner.v1.rl.trainer.worker import WorkerConfig -from xtuner.v1.rl.loss import GRPOLossConfig -from xtuner.v1.train.rl_trainer import RLTrainer, RLTrainerConfig - - -MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] -TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] -TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"] -resource_map = { - "npu": "NPU", - "cuda": "GPU", -} - - -class TestRLTrainer(unittest.TestCase): - @classmethod - def setUpClass(cls): - os.environ["XTUNER_USE_FA3"] = "1" - os.environ["LMD_SKIP_WARMUP"] = "1" - - @classmethod - def tearDownClass(cls): - del os.environ["XTUNER_USE_FA3"] - del os.environ["LMD_SKIP_WARMUP"] - - def init_traine_worker_config(self, train_optimizer_steps, pack_max_length): - model_cfg = get_model_config_from_hf(Path(MODEL_PATH)) - optim_cfg = AdamWConfig(lr=1e-6, betas=(0.9, 0.999), max_grad_norm=1.0, weight_decay=0.1, foreach=False) - loss_cfg = GRPOLossConfig( - policy_loss_cfg=dict( - cliprange_high=0.28, - cliprange_low=0.2, - loss_type="vanilla", - clip_ratio_c=10.0, - log_prob_diff_min=-20.0, - log_prob_diff_max=20.0, - ), - ignore_idx=-100, - use_kl_loss=False, - kl_loss_coef=0.0, - kl_loss_type="low_var_kl", - mode="chunk", - chunk_size=512, - ) - lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) - fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) - train_worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - load_from=MODEL_PATH, - optim_cfg=optim_cfg, - loss_cfg=loss_cfg, - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - sp_size=1, - optimizer_steps=train_optimizer_steps, - pack_max_length=pack_max_length, - ) - return train_worker_cfg - - def init_replay_buffer_config(self, max_prompt_length): - train_dataset_cfg = [ - { - "dataset": DatasetConfig(name="gsm8k", anno_path=TRAIN_DATA_PATH, sample_ratio=1.0), - "tokenize_fn": RLTokenizeFnConfig(max_length=max_prompt_length), - }, - ] - dataloader_cfg = DataloaderConfig( - collator="fake_collator", - pack_level="none", - group_by_length=False, - ) - tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) - replay_buffer_cfg = ReplayBufferConfig( - dataset_cfg=train_dataset_cfg, - dataloader_cfg=dataloader_cfg, - tokenizer=tokenizer, - worker_log_dir=self.worker_log_dir, - ) - return replay_buffer_cfg - - def init_resources_config(self, num_workers, num_cpus_per_worker, cpu_memory_per_worker): - resources = AcceleratorResourcesConfig( - accelerator=resource_map[torch.accelerator.current_accelerator().type], - num_workers=num_workers, - num_cpus_per_worker=num_cpus_per_worker, - cpu_memory_per_worker=cpu_memory_per_worker, - ) - return resources - - def init_cpu_resources_config(self, num_cpus_per_worker, cpu_memory_per_worker): - cpu_resources = CPUResourcesConfig( - num_cpus_per_worker=num_cpus_per_worker, - cpu_memory_per_worker=cpu_memory_per_worker, - ) - return cpu_resources - - def init_rollout_config(self, max_prompt_length, max_response_length): - rollout_config = RolloutConfig( - env="test_rl_trainer", - model_path=MODEL_PATH, - worker_log_dir=self.worker_log_dir, - rollout_max_batch_size_per_instance=1024, - context_length=max_response_length + max_prompt_length, - ) - return rollout_config - - def init_dataflow_config(self, max_response_length, global_batch_size, prompt_repeat_k, enable_partial_rollout): - sample_params = SampleParams( - max_tokens=max_response_length, - ) - dataflow_config = DataFlowConfig( - env="test_rl_trainer", - global_batch_size=global_batch_size, - prompt_repeat_k=prompt_repeat_k, - worker_log_dir=self.worker_log_dir, - sample_params=sample_params, - enable_partial_rollout=enable_partial_rollout, - max_concurrent=1024, - ) - return dataflow_config - - def init_judger_config(self): - from xtuner.v1.rl.judger.gsm8k import GSM8KJudgerConfig - - gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k", judger_type="router") - judger_cfg = JudgerConfig(reward_judger_configs=[gsm8k_judger_config], worker_log_dir=self.worker_log_dir) - return judger_cfg - - def init_multi_judger_config(self): - from xtuner.v1.rl.judger.gsm8k import GSM8KJudgerConfig - - # 支持一个GSM8KJudgerConfig创建多个实例 - gsm8k_judger_config_1 = GSM8KJudgerConfig(judger_name="openai/gsm8k-1", judger_type="router") - gsm8k_judger_config_2 = GSM8KJudgerConfig(judger_name="openai/gsm8k-2", judger_type="router") - judger_cfg = JudgerConfig( - reward_judger_configs=[gsm8k_judger_config_1, gsm8k_judger_config_2], - worker_log_dir=self.worker_log_dir, - ) - return judger_cfg - - def setUp(self): - ray.init(num_cpus=80, ignore_reinit_error=True) - self.temp_dir = tempfile.TemporaryDirectory() - self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") - - train_optimizer_steps = 2 - pack_max_length = 32768 - max_prompt_length = 2048 - max_response_length = 1024 - global_batch_size = 4 - prompt_repeat_k = 4 - enable_partial_rollout = False - - self.train_worker_cfg = self.init_traine_worker_config(train_optimizer_steps, pack_max_length) - self.replay_buffer_cfg = self.init_replay_buffer_config(max_prompt_length) - self.resources_cfg = self.init_resources_config( - num_workers=8, num_cpus_per_worker=8, cpu_memory_per_worker=8 * 1024**3 - ) - self.cpu_resources = self.init_cpu_resources_config(num_cpus_per_worker=1, cpu_memory_per_worker=1 * 1024**3) - self.rollout_config = self.init_rollout_config( - max_response_length=max_response_length, max_prompt_length=max_prompt_length - ) - self.dataflow_config = self.init_dataflow_config( - max_response_length=max_response_length, - global_batch_size=global_batch_size, - prompt_repeat_k=prompt_repeat_k, - enable_partial_rollout=enable_partial_rollout, - ) - self.judger_config = self.init_judger_config() - - def tearDown(self): - self.temp_dir.cleanup() - ray.shutdown() - - def test_rl_trainer(self): - multi_judger_config = self.init_multi_judger_config() - cpu_resources = self.init_cpu_resources_config(num_cpus_per_worker=2, cpu_memory_per_worker=2 * 1024**3) - trainer_config = RLTrainerConfig( - load_from=MODEL_PATH, - resources=self.resources_cfg, - cpu_resources=cpu_resources, - rollout_config=self.rollout_config, - dataflow_config=self.dataflow_config, - judger_config=multi_judger_config, - replay_buffer_config=self.replay_buffer_cfg, - train_worker_config=self.train_worker_cfg, - work_dir=self.worker_log_dir, - tokenizer_path=MODEL_PATH, - total_epochs=1, - rollout_steps=1, - ) - trainer = RLTrainer.from_config(trainer_config) - self.assertIsNotNone(trainer, "Trainer should be created successfully") - try: - trainer.fit() - except Exception as e: - self.fail(f"trainer.fit() raised unexpected exception: {e}") - # assure all writers are closed before checking log files - del trainer - log_files = list(Path(self.worker_log_dir).rglob("*.log")) - self.assertGreater(len(log_files), 0, "Should generate log files") - trajectory_files = list(Path(self.worker_log_dir).rglob("*_trajectory.jsonl")) - self.assertGreater(len(trajectory_files), 0, "Should generate trajectory files") - - def test_judger_cpu_pg_creation_with_error(self): - """Test RLTrainer judger_cpu_pg creation.""" - multi_judger_config = self.init_multi_judger_config() - # error resource with multi-judger - cpu_resources = self.init_cpu_resources_config(num_cpus_per_worker=1, cpu_memory_per_worker=1 * 1024**3) - trainer_config = RLTrainerConfig( - load_from=MODEL_PATH, - resources=self.resources_cfg, - cpu_resources=cpu_resources, - rollout_config=self.rollout_config, - dataflow_config=self.dataflow_config, - judger_config=multi_judger_config, - replay_buffer_config=self.replay_buffer_cfg, - train_worker_config=self.train_worker_cfg, - work_dir=self.worker_log_dir, - tokenizer_path=MODEL_PATH, - total_epochs=1, - rollout_steps=1, - ) - with self.assertRaises(AssertionError) as cm: - trainer = RLTrainer.from_config(trainer_config) - - print(f"Expected AssertionError caught: {cm.exception}") - -if __name__ == "__main__": - unittest.main() diff --git a/tests/rl/test_rollout_api_server.py b/tests/rl/test_rollout_api_server.py new file mode 100644 index 000000000..b197f25d9 --- /dev/null +++ b/tests/rl/test_rollout_api_server.py @@ -0,0 +1,314 @@ +import os +import subprocess +import tempfile +import time +import unittest + +import httpx +import ray +import torch +from transformers import AutoTokenizer + +from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams +from xtuner.v1.rl.rollout import RolloutController +from xtuner.v1.rl.rollout.worker import RolloutConfig +from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers + + +TEST_TEXT_MESSAGES = [{"role": "user", "content": "Hello!"}] +MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] +MOE_MODEL_PATH = os.environ.get("QWEN3_MOE_PATH") or os.environ["QWEN30B_MODEL_PATH"] +RESOURCE_MAP = { + "npu": "NPU", + "cuda": "GPU", +} + + +@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") +class TestRolloutAPIServer(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + os.environ["XTUNER_USE_FA3"] = "1" + os.environ["LMD_SKIP_WARMUP"] = "1" + + @classmethod + def tearDownClass(cls) -> None: + del os.environ["XTUNER_USE_FA3"] + del os.environ["LMD_SKIP_WARMUP"] + + def init_config(self): + self.resources_cfg = AcceleratorResourcesConfig( + accelerator=RESOURCE_MAP[torch.accelerator.current_accelerator().type], + num_workers=4, + num_cpus_per_worker=8, + cpu_memory_per_worker=16 * 1024**3, + ) + self.max_prompt_length = 512 + self.max_response_length = 1024 + self.context_length = self.max_prompt_length + self.max_response_length + self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + + def setUp(self): + ray.init(num_cpus=80, ignore_reinit_error=True) + self.temp_dir = tempfile.TemporaryDirectory() + self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") + self.init_config() + + def tearDown(self): + ray.shutdown() + self._cleanup_lmdeploy_ray_worker_wrapper() + self.temp_dir.cleanup() + + def _cleanup_lmdeploy_ray_worker_wrapper(self): + try: + result = subprocess.run( + ["pkill", "-f", "ray::RayWorkerWrapper*"], + capture_output=True, + text=True, + timeout=10, + check=False, + ) + if result.returncode != 0: + print( + f"pkill command failed with return code {result.returncode}: {result.stderr}." + " Maybe no lmdeploy ray::RayWorkerWrapper processes found." + ) + except Exception as exc: + print(f"Error stopping ray::RayWorkerWrapper cluster: {exc}") + + def _wait_until_ready(self, base_url: str): + deadline = time.time() + 1800 + last_error = None + while time.time() < deadline: + try: + response = httpx.get(f"{base_url}/healthz", timeout=10.0) + if response.status_code == 200: + return + last_error = f"healthz returned {response.status_code}: {response.text}" + except httpx.HTTPError as exc: + last_error = repr(exc) + time.sleep(5) + raise RuntimeError(f"API server at {base_url} did not become ready in time: {last_error}") + + def test_dense_model(self): + resource_config = AcceleratorResourcesConfig( + accelerator=RESOURCE_MAP[torch.accelerator.current_accelerator().type], + num_workers=4, + num_cpus_per_worker=16, + cpu_memory_per_worker=8 * 1024**3, + ) + pg = AutoAcceleratorWorkers.build_placement_group(resource_config, name="dense_api_pg") + dense_worker_log_dir = os.path.join(self.worker_log_dir, "dense") + rollout_config = RolloutConfig( + env="test_rollout_api_server_dense", + model_path=MODEL_PATH, + model_name=os.path.basename(MODEL_PATH).lower(), + tokenizer_path=MODEL_PATH, + tensor_parallel_size=4, + expert_parallel_size=1, + context_length=self.context_length, + worker_log_dir=dense_worker_log_dir, + dist_port_base=38000, + api_host="127.0.0.1", + api_port=28000, + ) + rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) + try: + metadata = ray.get(rollout_controller.get_rollout_metadata.remote(), timeout=1800) + base_url = metadata["api_server_url"] + self._wait_until_ready(base_url) + + text_prompt = self.tokenizer.apply_chat_template( + TEST_TEXT_MESSAGES, + tokenize=False, + add_generation_prompt=True, + ) + test_input_ids = self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"] + + with httpx.Client(timeout=300.0) as client: + generate = client.post( + f"{base_url}/generate", + json={ + "message": TEST_TEXT_MESSAGES, + "tokens": test_input_ids, + "sample_params": { + "return_token_ids": True, + "temperature": 0.0, + "top_k": 1, + "max_tokens": 16, + }, + }, + ) + self.assertEqual(generate.status_code, 200, generate.text) + generate_body = generate.json() + self.assertEqual(generate_body["status"], "completed") + self.assertIn(generate_body["finish_reason"], {"stop", "length"}) + self.assertTrue(generate_body["extra_fields"]["request_id"]) + self.assertGreater(len(generate_body["response_ids"]), 0) + self.assertIsInstance(generate_body["response"], str) + + chat = client.post( + f"{base_url}/v1/chat/completions", + json={ + "model": rollout_config.model_name, + "messages": TEST_TEXT_MESSAGES, + "temperature": 0.0, + "top_p": 1.0, + "max_tokens": 16, + }, + ) + self.assertEqual(chat.status_code, 200, chat.text) + chat_body = chat.json() + print("chat_body: ", chat_body) + self.assertEqual(chat_body["object"], "chat.completion") + self.assertEqual(chat_body["model"], rollout_config.model_name) + self.assertTrue(chat_body["id"].startswith("chatcmpl-")) + self.assertEqual(chat_body["choices"][0]["message"]["role"], "assistant") + self.assertTrue(chat_body["choices"][0]["message"]["content"]) + self.assertIn(chat_body["choices"][0]["finish_reason"], {"stop", "length"}) + self.assertGreater(chat_body["usage"]["prompt_tokens"], 0) + self.assertGreater(chat_body["usage"]["total_tokens"], chat_body["usage"]["completion_tokens"]) + + anthropic = client.post( + f"{base_url}/v1/messages", + json={ + "model": rollout_config.model_name, + "system": "You are helpful.", + "messages": TEST_TEXT_MESSAGES, + "max_tokens": 16, + "temperature": 0.0, + "top_p": 1.0, + }, + ) + self.assertEqual(anthropic.status_code, 200, anthropic.text) + anthropic_body = anthropic.json() + self.assertEqual(anthropic_body["type"], "message") + self.assertEqual(anthropic_body["role"], "assistant") + self.assertEqual(anthropic_body["model"], rollout_config.model_name) + self.assertTrue(anthropic_body["id"].startswith("msg_")) + self.assertTrue(anthropic_body["content"][0]["text"]) + self.assertIn(anthropic_body["stop_reason"], {"stop", "length"}) + self.assertGreater(anthropic_body["usage"]["input_tokens"], 0) + self.assertGreaterEqual(anthropic_body["usage"]["output_tokens"], 1) + + invalid_block = client.post( + f"{base_url}/v1/messages", + json={ + "model": rollout_config.model_name, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "look"}, + {"type": "image", "text": ""}, + ], + } + ], + "max_tokens": 8, + }, + timeout=30.0, + ) + self.assertEqual(invalid_block.status_code, 400, invalid_block.text) + self.assertEqual(invalid_block.json()["type"], "error") + self.assertEqual(invalid_block.json()["error"]["type"], "invalid_request_error") + + health = client.get(f"{base_url}/healthz", timeout=30.0) + meta = client.get(f"{base_url}/metadata", timeout=30.0) + self.assertEqual(health.status_code, 200, health.text) + self.assertEqual(health.json()["status"], "ok") + self.assertGreaterEqual(health.json()["active_workers"], 1) + self.assertEqual(meta.status_code, 200, meta.text) + self.assertEqual(meta.json()["api_server_url"], base_url) + self.assertEqual(metadata["api_server_url"], base_url) + self.assertEqual(meta.json()["api_server_url"].rsplit(":", 1)[-1], str(rollout_config.api_port)) + self.assertTrue(all(meta.json()["worker_server_urls_status"].values())) + + offload = client.post(f"{base_url}/offload", timeout=120.0) + self.assertEqual(offload.status_code, 200, offload.text) + self.assertEqual(offload.json()["action"], "offload") + + onload = client.post(f"{base_url}/onload", timeout=120.0) + self.assertEqual(onload.status_code, 200, onload.text) + self.assertEqual(onload.json()["action"], "onload") + + regenerated = client.post( + f"{base_url}/generate", + json={ + "message": TEST_TEXT_MESSAGES, + "sample_params": { + "return_token_ids": True, + "temperature": 0.0, + "top_k": 1, + "max_tokens": 8, + }, + }, + ) + self.assertEqual(regenerated.status_code, 200, regenerated.text) + self.assertEqual(regenerated.json()["status"], "completed") + finally: + try: + ray.get(rollout_controller.shutdown.remote(), timeout=300) + finally: + ray.util.remove_placement_group(pg) + + def test_moe_model(self): + resource_config = AcceleratorResourcesConfig( + accelerator=RESOURCE_MAP[torch.accelerator.current_accelerator().type], + num_workers=4, + num_cpus_per_worker=16, + cpu_memory_per_worker=8 * 1024**3, + ) + pg = AutoAcceleratorWorkers.build_placement_group(resource_config, name="moe_api_pg") + moe_worker_log_dir = os.path.join(self.worker_log_dir, "moe") + rollout_config = RolloutConfig( + env="test_rollout_api_server_moe", + model_path=MOE_MODEL_PATH, + model_name=os.path.basename(MOE_MODEL_PATH).lower(), + tokenizer_path=MOE_MODEL_PATH, + tensor_parallel_size=1, + expert_parallel_size=4, + context_length=self.context_length, + worker_log_dir=moe_worker_log_dir, + dist_port_base=38000 + 1024 * 4, + api_host="127.0.0.1", + api_port=29000, + enable_return_routed_experts=True, + ) + rollout_controller = ray.remote(RolloutController).remote(rollout_config, pg) + try: + metadata = ray.get(rollout_controller.get_rollout_metadata.remote(), timeout=1800) + base_url = metadata["api_server_url"] + self._wait_until_ready(base_url) + + request = RolloutState( + message=[{"role": "user", "content": "Briefly explain what mixture of experts means."}], + sample_params=SampleParams( + return_token_ids=True, + return_logprob=False, + temperature=0.0, + top_k=1, + max_tokens=32, + ), + ) + with httpx.Client(timeout=300.0) as client: + response = client.post( + f"{base_url}/generate", + json=request.model_dump(mode="json"), + ) + meta = client.get(f"{base_url}/metadata", timeout=30.0) + + self.assertEqual(response.status_code, 200, response.text) + rollout_state = RolloutState.model_validate_json(response.text) + self.assertIsNotNone(rollout_state.routed_experts) + self.assertEqual(meta.status_code, 200, meta.text) + self.assertEqual(meta.json()["api_server_url"], base_url) + self.assertEqual(metadata["api_server_url"], base_url) + self.assertEqual(meta.json()["api_server_url"].rsplit(":", 1)[-1], str(rollout_config.api_port)) + finally: + try: + ray.get(rollout_controller.shutdown.remote(), timeout=300) + finally: + ray.util.remove_placement_group(pg) + +if __name__ == "__main__": + unittest.main() diff --git a/xtuner/v1/data_proto/rl_data.py b/xtuner/v1/data_proto/rl_data.py index e8009eb34..2a7f93163 100644 --- a/xtuner/v1/data_proto/rl_data.py +++ b/xtuner/v1/data_proto/rl_data.py @@ -1,10 +1,11 @@ from __future__ import annotations +import base64 from enum import Enum from typing import TYPE_CHECKING, Any, TypeAlias import torch -from pydantic import BaseModel, ConfigDict, field_serializer +from pydantic import BaseModel, ConfigDict, field_serializer, field_validator from typing_extensions import NotRequired, TypedDict # ==================================== @@ -76,7 +77,7 @@ class RolloutState(CacheObj, BaseModel): reward_model: dict[str, Any] | None = None num_tokens: int | None = None # 用于 cache 管理 - # --- InferEngine 输入 --- + # --- InferEngine 输入 ---å session_uid: int | None = None tokens: list[int] | None = None # 每一次推理引擎的实际输入 tools: list | None = None @@ -104,20 +105,44 @@ class RolloutState(CacheObj, BaseModel): extra_fields: dict[str, Any] = {} @field_serializer("routed_experts") - def _serialize_routed_experts(self, value: list[int] | RayObjectRef | None) -> list[int] | None: - """Dump 时跳过 ray.ObjectRef,序列化为 None,避免 PydanticSerializationError。""" + def _serialize_routed_experts(self, value: list[int] | RayObjectRef | None) -> list[int] | str | None: + """序列化 routed_experts 字段: + + - None -> None + - list[int] -> list[int](原样保留) + - RayObjectRef -> base64 编码的字符串(通过 ray.cloudpickle 序列化) + """ + import ray + if value is None: return None - try: - import ray - - if isinstance(value, ray.ObjectRef): - return None - except ImportError: - pass - if type(value).__name__ == "ObjectRef" and "ray" in getattr(type(value), "__module__", ""): + if isinstance(value, ray.ObjectRef): + data = ray.cloudpickle.dumps(value) + return base64.b64encode(data).decode("utf-8") + return value + + @field_validator("routed_experts", mode="before") + @classmethod + def _deserialize_routed_experts(cls, value: Any) -> list[int] | RayObjectRef | None: + """反序列化 routed_experts 字段: + + - None -> None + - list[int] -> list[int](原样保留) + - str(base64 编码)-> RayObjectRef(通过 ray.cloudpickle 反序列化) + - RayObjectRef -> RayObjectRef(原样保留) + """ + import ray + + if value is None: return None - return value # list[int] + if isinstance(value, ray.ObjectRef): + return value + if isinstance(value, str): + data = base64.b64decode(value) + return ray.cloudpickle.loads(data) + if isinstance(value, list): + return value + return value def update_status_from_finish_reason(finish_reason: str | None) -> Status: diff --git a/xtuner/v1/rl/rollout/__init__.py b/xtuner/v1/rl/rollout/__init__.py index 349cd2fad..2cdf4b892 100644 --- a/xtuner/v1/rl/rollout/__init__.py +++ b/xtuner/v1/rl/rollout/__init__.py @@ -1,6 +1,8 @@ import os +from .anthropic_chat import AnthropicChatAdapter from .controller import RolloutController +from .openai_chat import OpenAIChatAdapter from .worker import RolloutWorker diff --git a/xtuner/v1/rl/rollout/anthropic_chat.py b/xtuner/v1/rl/rollout/anthropic_chat.py new file mode 100644 index 000000000..060d0e27d --- /dev/null +++ b/xtuner/v1/rl/rollout/anthropic_chat.py @@ -0,0 +1,220 @@ +from collections.abc import Awaitable, Callable +from typing import Any, Literal + +from pydantic import BaseModel, ConfigDict + +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status +from xtuner.v1.utils import get_logger + +from .utils import ensure_rollout_request_id + + +logger = get_logger(__name__) +GenerateHandler = Callable[[RolloutState], Awaitable[RolloutState]] + + +class AnthropicTextContent(BaseModel): + model_config = ConfigDict(extra="allow") + + type: str = "text" + text: str + + +AnthropicContentBlock = AnthropicTextContent + + +class AnthropicMessage(BaseModel): + model_config = ConfigDict(extra="allow") + + role: Literal["user", "assistant"] + content: str | list[AnthropicContentBlock] + + +class AnthropicMessagesRequest(BaseModel): + model_config = ConfigDict(extra="allow") + + model: str | None = None + system: str | list[AnthropicTextContent] | None = None + messages: list[AnthropicMessage] + max_tokens: int + stream: bool = False + temperature: float | None = None + top_p: float | None = None + stop_sequences: list[str] | None = None + + +class AnthropicUsage(BaseModel): + model_config = ConfigDict(extra="allow") + + input_tokens: int + output_tokens: int + + +class AnthropicMessagesResponse(BaseModel): + model_config = ConfigDict(extra="allow") + + id: str + type: Literal["message"] = "message" + role: Literal["assistant"] = "assistant" + content: list[AnthropicTextContent] + model: str + stop_reason: str | None = None + stop_sequence: str | None = None + usage: AnthropicUsage + + +class AnthropicChatAdapterError(RuntimeError): + def __init__(self, message: str, error_type: str, request_id: str | None = None): + super().__init__(message) + self.message = message + self.error_type = error_type + self.request_id = request_id + + +class AnthropicChatAdapter: + def __init__( + self, + generate_handler: GenerateHandler, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | str, + default_model_name: str | None = None, + ): + self._generate_handler = generate_handler + self._default_model_name = default_model_name + if isinstance(tokenizer, str): + self._tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) + else: + self._tokenizer = tokenizer + + async def messages(self, request: AnthropicMessagesRequest) -> AnthropicMessagesResponse: + if request.stream: + raise AnthropicChatAdapterError( + "stream=true is not supported yet", + "invalid_request_error", + ) + + rollout_state = self._build_rollout_state(request) + request_id = ensure_rollout_request_id(rollout_state) + response = await self._generate_handler(rollout_state) + + if not response.extra_fields.get("request_id"): + response.extra_fields["request_id"] = request_id + + if response.status == Status.FAILED: + raise AnthropicChatAdapterError( + response.error_msg or "Rollout generation failed", + "api_error", + request_id, + ) + + return self._build_messages_response(response, request) + + def _build_rollout_state(self, request: AnthropicMessagesRequest) -> RolloutState: + messages = self._build_internal_messages(request) + rollout_state = RolloutState( + message=messages, + sample_params=self._build_sample_params(request), + ) + logger.info(f"rollout_state built for request: {rollout_state}") + ensure_rollout_request_id(rollout_state) + return rollout_state + + def _build_internal_messages(self, request: AnthropicMessagesRequest) -> list[dict[str, str]]: + messages: list[dict[str, str]] = [] + + if request.system: + if isinstance(request.system, str): + system_text = request.system + else: + system_text = self._join_text_blocks(request.system, context="system") + messages.append({"role": "system", "content": system_text}) + + for message in request.messages: + if isinstance(message.content, str): + content = message.content + else: + content = self._join_text_blocks(message.content, context=f"messages[{message.role}]") + messages.append({"role": message.role, "content": content}) + + return messages + + def _join_text_blocks(self, blocks: list[AnthropicContentBlock], context: str) -> str: + unsupported_types = [block.type for block in blocks if block.type != "text"] + if unsupported_types: + unsupported_str = ", ".join(sorted(set(unsupported_types))) + raise AnthropicChatAdapterError( + f"Unsupported Anthropic content block type(s) in {context}: {unsupported_str}", + "invalid_request_error", + ) + return "\n".join(block.text for block in blocks) + + def _build_sample_params(self, request: AnthropicMessagesRequest) -> SampleParams: + kwargs = { + "return_token_ids": False, + "return_logprob": False, + "stream": request.stream, + "max_tokens": request.max_tokens, + "stops": request.stop_sequences or [], + } + if request.temperature is not None: + kwargs["temperature"] = request.temperature + if request.top_p is not None: + kwargs["top_p"] = request.top_p + return SampleParams(**kwargs) + + def _build_messages_response( + self, + rollout_state: RolloutState, + request: AnthropicMessagesRequest, + ) -> AnthropicMessagesResponse: + request_id = ensure_rollout_request_id(rollout_state) + model_name = request.model or self._default_model_name or "rollout-controller" + prompt_tokens = self._count_prompt_tokens(rollout_state) + completion_tokens = self._count_completion_tokens(rollout_state) + + return AnthropicMessagesResponse( + id=f"msg_{request_id}", + content=[AnthropicTextContent(text=rollout_state.response or "")], + model=model_name, + stop_reason=rollout_state.finish_reason, + usage=AnthropicUsage( + input_tokens=prompt_tokens, + output_tokens=completion_tokens, + ), + ) + + def _count_prompt_tokens(self, rollout_state: RolloutState) -> int: + if rollout_state.tokens is not None: + return len(rollout_state.tokens) + if rollout_state.prompt_ids is not None: + return len(rollout_state.prompt_ids) + if self._tokenizer is not None and rollout_state.message: + text_prompt = self._tokenizer.apply_chat_template( + rollout_state.message, + tokenize=False, + add_generation_prompt=True, + ) + return len(self._tokenizer(text_prompt, add_special_tokens=False)["input_ids"]) + return 0 + + def _count_completion_tokens(self, rollout_state: RolloutState) -> int: + if rollout_state.response_ids is not None: + return len(rollout_state.response_ids) + if self._tokenizer is not None and rollout_state.response: + return len(self._tokenizer(rollout_state.response, add_special_tokens=False)["input_ids"]) + return 0 + + +def bind_anthropic_chat_interface( + rollout_controller: Any, + default_model_name: str | None = None, + tokenizer: Any | None = None, +) -> Any: + if getattr(rollout_controller, "anthropic_chat_adapter", None) is None: + rollout_controller.anthropic_chat_adapter = AnthropicChatAdapter( + rollout_controller.generate, + default_model_name=default_model_name, + tokenizer=tokenizer, + ) + rollout_controller.anthropic_messages = rollout_controller.anthropic_chat_adapter.messages + return rollout_controller diff --git a/xtuner/v1/rl/rollout/api_server.py b/xtuner/v1/rl/rollout/api_server.py new file mode 100644 index 000000000..fddfc0c90 --- /dev/null +++ b/xtuner/v1/rl/rollout/api_server.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse + +from xtuner.v1.data_proto.rl_data import RolloutState, Status + +from .anthropic_chat import AnthropicChatAdapterError, AnthropicMessagesRequest, AnthropicMessagesResponse +from .openai_chat import ( + ChatCompletionRequest, + ChatCompletionResponse, + OpenAIChatAdapterError, +) +from .utils import ensure_rollout_request_id + + +if TYPE_CHECKING: + from .controller import RolloutControllerProxy + + +def _build_error_response( + status_code: int, + message: str, + error_type: str, + code: str | None = None, + request_id: str | None = None, + protocol: str = "openai", +) -> JSONResponse: + if protocol == "anthropic": + payload = { + "type": "error", + "error": { + "type": error_type, + "message": message, + }, + } + if request_id is not None: + payload["request_id"] = request_id + else: + payload = { + "error": { + "message": message, + "type": error_type, + "code": code, + "request_id": request_id, + } + } + return JSONResponse(status_code=status_code, content=payload) + + +def create_rollout_api_app( + rollout_controller: RolloutControllerProxy, + logger: Any, +) -> FastAPI: + """Build the rollout API app around the provided rollout controller.""" + app = FastAPI() + + @app.exception_handler(HTTPException) + async def handle_http_exception(request: Request, exc: HTTPException) -> JSONResponse: + request_id = request.headers.get("X-Request-Id") + if isinstance(exc.detail, dict) and "error" in exc.detail: + return JSONResponse(status_code=exc.status_code, content=exc.detail) + return _build_error_response( + status_code=exc.status_code, + message=str(exc.detail), + error_type="invalid_request_error" if exc.status_code < 500 else "server_error", + code="http_error", + request_id=request_id, + ) + + @app.post("/generate") + async def generate(request: RolloutState) -> RolloutState: + request_id = ensure_rollout_request_id(request) + try: + response = await rollout_controller.generate(request) + if not response.extra_fields.get("request_id"): + response.extra_fields["request_id"] = request_id + return response + except Exception as exc: + logger.error(f"Generate failed in API server for request_id={request_id}: {exc}") + request.status = Status.FAILED + request.error_msg = f"Generate failed in API server with error: {str(exc)}" + return request + + @app.post("/v1/chat/completions") + async def chat_completions(request: ChatCompletionRequest, http_request: Request) -> ChatCompletionResponse: + try: + return await rollout_controller.chat(request) + except OpenAIChatAdapterError as exc: + status_code = 400 if exc.error_type == "invalid_request_error" else 500 + raise HTTPException( + status_code=status_code, + detail={ + "error": { + "message": exc.message, + "type": exc.error_type, + "code": exc.code, + "request_id": exc.request_id, + } + }, + ) + + @app.post("/v1/messages") + async def anthropic_messages( + request: AnthropicMessagesRequest, http_request: Request + ) -> AnthropicMessagesResponse: + try: + return await rollout_controller.anthropic_messages(request) + except AnthropicChatAdapterError as exc: + status_code = 400 if exc.error_type == "invalid_request_error" else 500 + return _build_error_response( + status_code=status_code, + message=exc.message, + error_type=exc.error_type, + request_id=exc.request_id, + protocol="anthropic", + ) + + @app.get("/healthz") + async def healthz(): + is_ready, payload = rollout_controller.get_ready_status() + if is_ready: + return {"status": "ok", **payload} + return JSONResponse(status_code=503, content={"status": "not_ready", **payload}) + + @app.get("/metadata") + async def metadata(): + return rollout_controller.get_rollout_metadata() + + @app.post("/pause") + async def pause(): + rollout_controller.pause_generation() + return {"status": "ok", "action": "pause"} + + @app.post("/continue") + async def continue_generation(): + rollout_controller.continue_generation() + return {"status": "ok", "action": "continue"} + + @app.post("/offload") + async def offload(): + rollout_controller.offload() + return {"status": "ok", "action": "offload"} + + @app.post("/onload") + async def onload(): + rollout_controller.onload() + return {"status": "ok", "action": "onload"} + + @app.post("/shutdown") + async def shutdown(): + rollout_controller.shutdown() + return {"status": "ok", "action": "shutdown"} + + return app diff --git a/xtuner/v1/rl/rollout/controller.py b/xtuner/v1/rl/rollout/controller.py index b82d6428b..f1a7875c6 100644 --- a/xtuner/v1/rl/rollout/controller.py +++ b/xtuner/v1/rl/rollout/controller.py @@ -1,11 +1,13 @@ import asyncio import os +import socket import threading from dataclasses import dataclass -from typing import Dict, List, Optional, TypeAlias, TypedDict +from typing import Any, Dict, List, Optional, TypeAlias, TypedDict from uuid import uuid4 import ray +import uvicorn from ray.actor import ActorProxy from ray.util.placement_group import PlacementGroup @@ -13,6 +15,9 @@ from xtuner.v1.rl.utils import AutoAcceleratorWorkers from xtuner.v1.utils import get_logger +from .anthropic_chat import bind_anthropic_chat_interface +from .api_server import create_rollout_api_app +from .openai_chat import bind_openai_chat_interface from .utils import ROLLOUT_RAY_GET_TIMEOUT, RolloutHealthChecker, SessionRouter from .worker import RolloutConfig, RolloutWorker @@ -53,6 +58,9 @@ class RolloutWorkerMetadata(TypedDict): # 值:布尔值,True 表示该 worker 处于活跃状态,False 表示已失效或停用 worker_server_urls_status: Dict[str, bool] + # Rollout Controller API 服务器的 URL 地址, + api_server_url: str + class RolloutController: """Controller for managing and coordinating multiple RolloutWorker @@ -77,12 +85,22 @@ def __init__( else self.config.tensor_parallel_size ) self.logger = get_logger(log_dir=infer_config.worker_log_dir, tag="RolloutController") + self.api_server_url = "" self.engine_rank_mesh_array: List[List[int]] = [] self.worker_server_urls_map: dict[str, List[str]] = {} self.rank2info: dict[int, WorkerInfo] = {} self.engine_rank_mesh_array, self.worker_server_urls_map, self.rank2info = self._init_workers(placement_group) self.num_active_workers = len(self.rank2info) self.worker_info_lock = threading.RLock() + bind_openai_chat_interface( + self, default_model_name=self.config.model_name, tokenizer=self.config.tokenizer_path + ) + bind_anthropic_chat_interface( + self, + default_model_name=self.config.model_name, + tokenizer=self.config.tokenizer_path, + ) + self._start_api_server() # The timeout for the environment to wait for the rollout controller's response. # This should be longer than the controller's internal timeout (`rollout_timeout`) # to account for potential queuing delays and other overheads. @@ -109,9 +127,19 @@ def get_rollout_metadata(self) -> RolloutWorkerMetadata: "server_url_dict": self.worker_server_urls_map, "rollout_config": self.config, "worker_server_urls_status": worker_server_urls_status, + "api_server_url": self.api_server_url, } return rollout_metadata + def get_ready_status(self) -> tuple[bool, dict[str, Any]]: + with self.worker_info_lock: + active_workers = sum(1 for info in self.rank2info.values() if info.is_active) + total_workers = len(self.rank2info) + return active_workers > 0, { + "active_workers": active_workers, + "total_workers": total_workers, + } + async def generate(self, rollout_state: RolloutState) -> RolloutState: session_id = rollout_state.session_uid if rollout_state.session_uid else uuid4().int worker = await self.router.get_worker(session_id) @@ -322,6 +350,35 @@ def _update_active_workers_and_urls_map(self, active_rollout_workers, worker_ser active_worker_server_urls = list(worker_server_urls_map.values())[::active_worker_interval] return active_rollout_workers[::active_worker_interval], dict(zip(active_rank, active_worker_server_urls)) + @staticmethod + def _is_port_in_use(host: str, port: int) -> bool: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.settimeout(0.2) + return sock.connect_ex((host, port)) == 0 + + def _start_api_server(self, host: str | None = None, port: int | None = None): + """Starts the API server to expose the rollout functionality.""" + host = host or self.config.api_host + port = self.config.api_port if self.config.api_port else (port or 8000) + + original_port = port + while self._is_port_in_use(host, port): + self.logger.warning(f"Port {port} is in use, trying port {port + 1}") + port += 1 + + if original_port != port: + self.logger.info(f"API server will use port {port} instead of the originally configured {original_port}.") + + app = create_rollout_api_app(self, self.logger) + + config = uvicorn.Config(app, host=host, port=port) + server = uvicorn.Server(config) + server_thread = threading.Thread(target=server.run, daemon=True) + server_thread.start() + self.config.api_port = port + self.api_server_url = f"http://{host}:{port}" + self.logger.info(f"Rollout API server started at {self.api_server_url}") + def _init_workers(self, placement_group: PlacementGroup): """Initializes and configures the pool of RolloutWorker actors. diff --git a/xtuner/v1/rl/rollout/lmdeploy.py b/xtuner/v1/rl/rollout/lmdeploy.py index 65f6fc3d8..c8453865f 100644 --- a/xtuner/v1/rl/rollout/lmdeploy.py +++ b/xtuner/v1/rl/rollout/lmdeploy.py @@ -93,12 +93,14 @@ def _get_request_payload(self, rollout_state: RolloutState) -> dict: message = rollout_state.message input_tokens = rollout_state.tokens + optional_fields: dict[str, object] = {} + if tools is not None: + optional_fields["tools"] = tools + if tool_choice is not None: + optional_fields["tool_choice"] = tool_choice + if sample_params.return_token_ids: - payload = { - "model": self.model_name, - "tools": tools, - "tool_choice": tool_choice, - } + payload = {"model": self.model_name, **optional_fields} if input_tokens is not None: payload["input_ids"] = input_tokens else: @@ -107,19 +109,27 @@ def _get_request_payload(self, rollout_state: RolloutState) -> dict: payload["input_ids"] = prompt_token_ids sample_params.return_routed_experts = True if self.enable_return_routed_experts else False lmdeploy_sample_params = self._transform_sample_params(sample_params) - payload.update(sample_params) + payload.update(lmdeploy_sample_params) else: payload = { "model": self.model_name, "messages": rollout_state.message, - "tools": tools, - "tool_choice": tool_choice, + **optional_fields, } - lmdeploy_sample_params = self._transform_sample_params(sample_params) - lmdeploy_sample_params.pop("no_stop_trim", None) - lmdeploy_sample_params.pop("return_logprob", None) - lmdeploy_sample_params.pop("stop_token_ids", None) - lmdeploy_sample_params["min_new_tokens"] = sample_params.min_tokens + lmdeploy_sample_params = { + "temperature": sample_params.temperature, + "top_p": sample_params.top_p, + "n": sample_params.n, + "stream": sample_params.stream, + "max_tokens": sample_params.max_tokens, + "repetition_penalty": sample_params.repetition_penalty, + "top_k": sample_params.top_k, + "skip_special_tokens": sample_params.skip_special_tokens, + } + if sample_params.stops: + lmdeploy_sample_params["stop"] = sample_params.stops + if sample_params.min_tokens > 0: + lmdeploy_sample_params["min_new_tokens"] = sample_params.min_tokens payload.update(lmdeploy_sample_params) return payload diff --git a/xtuner/v1/rl/rollout/openai_chat.py b/xtuner/v1/rl/rollout/openai_chat.py new file mode 100644 index 000000000..dacf31c34 --- /dev/null +++ b/xtuner/v1/rl/rollout/openai_chat.py @@ -0,0 +1,215 @@ +import time +from collections.abc import Awaitable, Callable +from typing import Any + +from pydantic import BaseModel + +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +from xtuner.v1.data_proto.rl_data import RolloutState, SampleParams, Status +from xtuner.v1.utils import get_logger + +from .utils import ensure_rollout_request_id + + +logger = get_logger(__name__) +GenerateHandler = Callable[[RolloutState], Awaitable[RolloutState]] + + +class ChatCompletionRequest(BaseModel): + messages: list[dict[str, Any]] + model: str | None = None + tools: list[dict[str, Any]] | None = None + tool_choice: str | dict | None = None + stream: bool = False + temperature: float | None = None + top_p: float | None = None + n: int | None = None + max_tokens: int | None = None + stop: str | list[str] | None = None + presence_penalty: float | None = None + frequency_penalty: float | None = None + + +class ChatCompletionMessage(BaseModel): + role: str = "assistant" + content: str | None = None + + +class ChatCompletionChoice(BaseModel): + index: int + message: ChatCompletionMessage + finish_reason: str | None = None + + +class ChatCompletionUsage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +class ChatCompletionResponse(BaseModel): + id: str + object: str = "chat.completion" + created: int + model: str + choices: list[ChatCompletionChoice] + usage: ChatCompletionUsage + + +class OpenAIChatAdapterError(RuntimeError): + def __init__( + self, + message: str, + error_type: str, + code: str, + request_id: str | None = None, + ): + super().__init__(message) + self.message = message + self.error_type = error_type + self.code = code + self.request_id = request_id + + +class OpenAIChatAdapter: + def __init__( + self, + generate_handler: GenerateHandler, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast | str, + default_model_name: str | None = None, + ): + self._generate_handler = generate_handler + self._default_model_name = default_model_name + if isinstance(tokenizer, str): + self._tokenizer = AutoTokenizer.from_pretrained(tokenizer, trust_remote_code=True) + else: + self._tokenizer = tokenizer + + async def chat(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + if request.stream: + raise OpenAIChatAdapterError( + "stream=true is not supported yet", + "invalid_request_error", + "stream_not_supported", + ) + rollout_state = self._build_rollout_state(request) + request_id = ensure_rollout_request_id(rollout_state) + response = await self._generate_handler(rollout_state) + response.extra_fields.setdefault("request_id", request_id) + + if response.status == Status.FAILED: + raise OpenAIChatAdapterError( + response.error_msg or "Rollout generation failed", + "server_error", + "rollout_failed", + request_id, + ) + + return self._build_chat_completion_response(response, request) + + def _build_rollout_state(self, request: ChatCompletionRequest) -> RolloutState: + if request.tool_choice is not None and not isinstance(request.tool_choice, str): + raise OpenAIChatAdapterError( + "tool_choice object form is not supported yet", + "invalid_request_error", + "unsupported_tool_choice", + ) + rollout_state = RolloutState( + message=request.messages, + tools=request.tools, + tool_choice=request.tool_choice, + sample_params=self._build_sample_params(request), + ) + return rollout_state + + def _build_sample_params(self, request: ChatCompletionRequest) -> SampleParams: + stops: list[str] + if request.stop is None: + stops = [] + elif isinstance(request.stop, str): + stops = [request.stop] + else: + stops = request.stop + + kwargs = { + "return_token_ids": False, + "return_logprob": False, + "stream": request.stream, + "stops": stops, + } + if request.temperature is not None: + kwargs["temperature"] = request.temperature + if request.top_p is not None: + kwargs["top_p"] = request.top_p + if request.n is not None: + kwargs["n"] = request.n + if request.max_tokens is not None: + kwargs["max_tokens"] = request.max_tokens + if request.presence_penalty is not None: + kwargs["presence_penalty"] = request.presence_penalty + if request.frequency_penalty is not None: + kwargs["frequency_penalty"] = request.frequency_penalty + return SampleParams(**kwargs) + + def _build_chat_completion_response( + self, + rollout_state: RolloutState, + request: ChatCompletionRequest, + ) -> ChatCompletionResponse: + request_id = ensure_rollout_request_id(rollout_state) + response_id = f"chatcmpl-{request_id}" + model_name = request.model or self._default_model_name or "rollout-controller" + prompt_tokens = self._count_prompt_tokens(rollout_state) + completion_tokens = self._count_completion_tokens(rollout_state) + usage = ChatCompletionUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + choice = ChatCompletionChoice( + index=0, + message=ChatCompletionMessage(content=rollout_state.response), + finish_reason=rollout_state.finish_reason, + ) + return ChatCompletionResponse( + id=response_id, + created=int(time.time()), + model=model_name, + choices=[choice], + usage=usage, + ) + + def _count_prompt_tokens(self, rollout_state: RolloutState) -> int: + if rollout_state.tokens is not None: + return len(rollout_state.tokens) + if rollout_state.prompt_ids is not None: + return len(rollout_state.prompt_ids) + if self._tokenizer is not None and rollout_state.message: + text_prompt = self._tokenizer.apply_chat_template( + rollout_state.message, + tokenize=False, + add_generation_prompt=True, + ) + return len(self._tokenizer(text_prompt, add_special_tokens=False)["input_ids"]) + return 0 + + def _count_completion_tokens(self, rollout_state: RolloutState) -> int: + if rollout_state.response_ids is not None: + return len(rollout_state.response_ids) + if self._tokenizer is not None and rollout_state.response: + return len(self._tokenizer(rollout_state.response, add_special_tokens=False)["input_ids"]) + return 0 + + +def bind_openai_chat_interface( + rollout_controller: Any, + tokenizer: str | PreTrainedTokenizer | PreTrainedTokenizerFast, + default_model_name: str | None = None, +) -> Any: + rollout_controller.openai_chat_adapter = OpenAIChatAdapter( + rollout_controller.generate, + tokenizer=tokenizer, + default_model_name=default_model_name, + ) + rollout_controller.chat = rollout_controller.openai_chat_adapter.chat + return rollout_controller diff --git a/xtuner/v1/rl/rollout/utils.py b/xtuner/v1/rl/rollout/utils.py index e19bb2864..2afe87113 100644 --- a/xtuner/v1/rl/rollout/utils.py +++ b/xtuner/v1/rl/rollout/utils.py @@ -5,10 +5,12 @@ from collections import OrderedDict from itertools import cycle from typing import TYPE_CHECKING, Any, Optional +from uuid import uuid4 import httpx import ray +from xtuner.v1.data_proto.rl_data import RolloutState from xtuner.v1.rl.utils import asyncio_run from xtuner.v1.utils import get_logger @@ -277,3 +279,13 @@ async def check_worker_health( f"Exception during health check for worker {rank} at {url}: {e}. Failure count: {failing_count}" ) return False + + +def ensure_rollout_request_id(rollout_state: RolloutState) -> str: + request_id = str(rollout_state.extra_fields.get("request_id", "")) + if request_id: + return request_id + + request_id = uuid4().hex + rollout_state.extra_fields["request_id"] = request_id + return request_id diff --git a/xtuner/v1/rl/rollout/worker.py b/xtuner/v1/rl/rollout/worker.py index 156c9a054..f5fecb1aa 100644 --- a/xtuner/v1/rl/rollout/worker.py +++ b/xtuner/v1/rl/rollout/worker.py @@ -117,6 +117,10 @@ class RolloutConfig(BaseModel): int, Parameter(group=infer_group, help="Port number for the rollout API server. If not set, 8000 will be used."), ] = 8000 + api_host: Annotated[ + str, + Parameter(group=infer_group, help="Host for the rollout API server."), + ] = "0.0.0.0" gpus_per_node: Annotated[int, Parameter(group=infer_group, help="Number of GPUs allocated per node.")] = 8 dtype: Annotated[ str, @@ -313,7 +317,7 @@ def model_post_init(self, __context: Any) -> None: while True: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: try: - s.bind(("localhost", port)) + s.bind((self.api_host if self.api_host != "0.0.0.0" else "localhost", port)) break except OSError: port += 1 diff --git a/xtuner/v1/rl/utils/ray_utils.py b/xtuner/v1/rl/utils/ray_utils.py index 987ba700f..14d94323d 100644 --- a/xtuner/v1/rl/utils/ray_utils.py +++ b/xtuner/v1/rl/utils/ray_utils.py @@ -180,6 +180,6 @@ def bind_train_rollout( train_workers: A list of training worker actors. rollout_controller: The rollout controller actor. """ - info_dict = ray.get(rollout_controller.get_rollout_info.remote()) # type: ignore[attr-defined] + info_dict = ray.get(rollout_controller.get_rollout_metadata.remote()) # type: ignore[attr-defined] ray.get([worker.update_rollout_info.remote(**info_dict) for worker in train_workers]) # type: ignore[attr-defined] return