diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 314dc59328c..fce581c7828 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -6,6 +6,7 @@ import jsonschema import pytest +from pydantic import BaseModel from vllm.distributed import cleanup_dist_env_and_memory from vllm.entrypoints.llm import LLM @@ -322,3 +323,56 @@ def test_guided_json_object(llm, guided_decoding_backend: str): # Parse to verify it is valid JSON parsed_json = json.loads(generated_text) assert isinstance(parsed_json, dict) + + +@pytest.mark.skip_global_cleanup +def test_json_with_any_whitespace_disabled(llm): + + class ResponseSchema(BaseModel): + clarifying_question: str + cost_per_serving: str + calories: str + type_dish_ids: str + type_meal_ids: str + product_ids: list[str] + exclude_product_ids: list[str] + allergen_ids: list[str] + total_cooking_time: str + kitchen_ids: str + holiday_ids: str + + # Note: Without this setting, the response is sometimes full of `\n` + # for some models. This option prevents that. + guided_decoding_backend = 'xgrammar:disable-any-whitespace' + + schema = ResponseSchema.model_json_schema() + guided_params = GuidedDecodingParams(json=schema, + backend=\ + guided_decoding_backend) + sampling_params = SamplingParams(max_tokens=2000, + frequency_penalty=0, + presence_penalty=-1.1, + repetition_penalty=1.3, + guided_decoding=guided_params) + + prompt = ("<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You" + "are a helpful assistant.<|im_end|>\n<|im_start|>user\nI want a " + "quick launch fast with $10.<|im_end|>\n<|im_start|>assistant\n") + outputs = llm.generate(prompts=prompt, + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + + generated_text = output.outputs[0].text + assert generated_text is not None + assert "\n" not in generated_text + + # Parse to verify it is valid JSON + parsed_json = json.loads(generated_text) + assert isinstance(parsed_json, dict) + jsonschema.validate(instance=parsed_json, schema=schema) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 663ea1ef8af..26d4a84b841 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -385,6 +385,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'Backend-specific options can be supplied in a comma-separated ' 'list following a colon after the backend name. Valid backends and ' 'all available options are: [xgrammar:no-fallback, ' + 'xgrammar:disable-any-whitespace, ' 'outlines:no-fallback, lm-format-enforcer:no-fallback]') parser.add_argument( '--logits-processor-pattern', diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index 329b03a573d..d7b4970015b 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -20,6 +20,7 @@ xgr_installed = False pass +from vllm.logger import init_logger from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf, grammar_is_likely_lark) from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer @@ -30,6 +31,8 @@ from vllm.config import ModelConfig from vllm.sampling_params import GuidedDecodingParams +logger = init_logger(__name__) + # TODO: passing batch size to max threads here def get_local_xgrammar_guided_decoding_logits_processor( @@ -162,6 +165,7 @@ class GrammarConfig: json_str: str | None = None grammar_str: str | None = None json_object: bool | None = None + any_whitespace: bool = True max_threads: int = 8 tokenizer_data: TokenizerData | None = None @@ -181,11 +185,33 @@ def from_guided_params(cls, else: json_str = guided_params.json + any_whitespace = 'disable-any-whitespace' not in \ + guided_params.backend_options() + + # Check and log if model with xgrammar and whitespace have history + # of runaway generation of whitespaces. + # References: + # https://github.com/vllm-project/vllm/pull/12744 + # https://github.com/mlc-ai/xgrammar/issues/212 + model_with_warn = None + + if 'Mistral' in model_config.model: + model_with_warn = 'Mistral' + elif 'Qwen' in model_config.model: + model_with_warn = 'Qwen' + + if model_with_warn is not None and any_whitespace: + msg = (f"{model_with_warn} " + f"model detected, consider set " + f"`guided_backend=xgrammar:disable-any-whitespace` " + f"to prevent runaway generation of whitespaces.") + logger.info_once(msg) # Validate the schema and raise ValueError here if it is invalid. # This is to avoid exceptions in model execution, which will crash # the engine worker process. try: - xgr.Grammar.from_json_schema(json_str) + xgr.Grammar.from_json_schema(json_str, + any_whitespace=any_whitespace) except RuntimeError as err: raise ValueError(str(err)) from err @@ -193,7 +219,8 @@ def from_guided_params(cls, vocab_size=model_config.hf_text_config.vocab_size, tokenizer_hash=tokenizer_hash, max_threads=max_threads, - tokenizer_data=tokenizer_data) + tokenizer_data=tokenizer_data, + any_whitespace=any_whitespace) elif guided_params.grammar: # XGrammar only supports GBNF grammars, so we must convert Lark if grammar_is_likely_lark(guided_params.grammar): @@ -291,7 +318,10 @@ def _ensure_ctx(self): if self.ctx is None: compiler = GrammarCompilerCache.get_compiler(self.config) if self.config.json_str is not None: - self.ctx = compiler.compile_json_schema(self.config.json_str) + any_whitespace = self.config.any_whitespace + self.ctx = compiler\ + .compile_json_schema(self.config.json_str, + any_whitespace=any_whitespace) elif self.config.grammar_str is not None: self.ctx = compiler.compile_grammar(self.config.grammar_str) elif self.config.json_object: