Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Backend option to disable xgrammar any_whitespace #12744

Merged
merged 25 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
2756335
[Bugfix] Env var to to disable xgrammar any_whitespace
wallashss Feb 4, 2025
f01cfc7
Merge branch 'main' into fix-xgrammar-whitespace
wallashss Feb 5, 2025
a24fdab
Merge branch 'main' into fix-xgrammar-whitespace
wallashss Feb 6, 2025
7a3927e
Merge branch 'main' into fix-xgrammar-whitespace
wallashss Feb 12, 2025
669b097
Merge branch 'main' into fix-xgrammar-whitespace
wallashss Feb 12, 2025
435695a
Merge branch 'main' into fix-xgrammar-whitespace
wallashss Feb 18, 2025
65cff01
:wrench: Add env var to disable guided decoding fallbacks
joerunde Feb 19, 2025
c971da7
:rewind: revert envs change
joerunde Feb 19, 2025
15cac0c
:sparkles: add guided decoding backend options
joerunde Feb 19, 2025
f9d0e9d
:bug: handle missing backend name
joerunde Feb 19, 2025
c64df44
:bug: fixup options
joerunde Feb 19, 2025
a8e73c3
:memo: add docs and example
joerunde Feb 19, 2025
85b1558
:sparkles: add CLI support
joerunde Feb 20, 2025
1d9f721
Merge remote-tracking branch 'joe/no-gd-fallback' into fix-xgrammar-w…
wallashss Feb 20, 2025
2c29f4e
updated to use backend options from #13505
wallashss Feb 20, 2025
a5e0189
Merge branch 'main' into fix-xgrammar-whitespace
wallashss Feb 20, 2025
a2aa6d3
Updated docs and removed code from env
wallashss Feb 20, 2025
44aeb64
rewrite disable_any_whitespace in args to disable-any-whitespace
wallashss Feb 20, 2025
602c75f
added info to tell users how to use disable-any-whitespace
wallashss Feb 21, 2025
2e3eb0c
Merge branch 'main' into fix-xgrammar-whitespace
wallashss Feb 21, 2025
3bfbd2d
added test for disable any whitespace
wallashss Feb 21, 2025
7a10476
minor test change
wallashss Feb 24, 2025
5f55f81
Merge branch 'main' into fix-xgrammar-whitespace
wallashss Feb 24, 2025
c04a031
fix pre-commit
wallashss Feb 24, 2025
75bc73c
Merge branch 'main' into fix-xgrammar-whitespace
wallashss Feb 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions tests/entrypoints/llm/test_guided_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import jsonschema
import pytest
from pydantic import BaseModel

from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM
Expand Down Expand Up @@ -322,3 +323,56 @@ def test_guided_json_object(llm, guided_decoding_backend: str):
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)


@pytest.mark.skip_global_cleanup
def test_json_with_any_whitespace_disabled(llm):

class ResponseSchema(BaseModel):
clarifying_question: str
cost_per_serving: str
calories: str
type_dish_ids: str
type_meal_ids: str
product_ids: list[str]
exclude_product_ids: list[str]
allergen_ids: list[str]
total_cooking_time: str
kitchen_ids: str
holiday_ids: str

# Note: Without this setting, the response is sometimes full of `\n`
# for some models. This option prevents that.
guided_decoding_backend = 'xgrammar:disable-any-whitespace'

schema = ResponseSchema.model_json_schema()
guided_params = GuidedDecodingParams(json=schema,
backend=\
guided_decoding_backend)
sampling_params = SamplingParams(max_tokens=2000,
frequency_penalty=0,
presence_penalty=-1.1,
repetition_penalty=1.3,
guided_decoding=guided_params)

prompt = ("<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You"
"are a helpful assistant.<|im_end|>\n<|im_start|>user\nI want a "
"quick launch fast with $10.<|im_end|>\n<|im_start|>assistant\n")
outputs = llm.generate(prompts=prompt,
sampling_params=sampling_params,
use_tqdm=True)

assert outputs is not None

for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)

generated_text = output.outputs[0].text
assert generated_text is not None
assert "\n" not in generated_text

# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
jsonschema.validate(instance=parsed_json, schema=schema)
1 change: 1 addition & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'Backend-specific options can be supplied in a comma-separated '
'list following a colon after the backend name. Valid backends and '
'all available options are: [xgrammar:no-fallback, '
'xgrammar:disable-any-whitespace, '
'outlines:no-fallback, lm-format-enforcer:no-fallback]')
parser.add_argument(
'--logits-processor-pattern',
Expand Down
36 changes: 33 additions & 3 deletions vllm/model_executor/guided_decoding/xgrammar_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
xgr_installed = False
pass

from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf,
grammar_is_likely_lark)
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
Expand All @@ -30,6 +31,8 @@
from vllm.config import ModelConfig
from vllm.sampling_params import GuidedDecodingParams

logger = init_logger(__name__)


# TODO: passing batch size to max threads here
def get_local_xgrammar_guided_decoding_logits_processor(
Expand Down Expand Up @@ -162,6 +165,7 @@ class GrammarConfig:
json_str: str | None = None
grammar_str: str | None = None
json_object: bool | None = None
any_whitespace: bool = True
max_threads: int = 8
tokenizer_data: TokenizerData | None = None

Expand All @@ -181,19 +185,42 @@ def from_guided_params(cls,
else:
json_str = guided_params.json

any_whitespace = 'disable-any-whitespace' not in \
guided_params.backend_options()

# Check and log if model with xgrammar and whitespace have history
# of runaway generation of whitespaces.
# References:
# https://github.com/vllm-project/vllm/pull/12744
# https://github.com/mlc-ai/xgrammar/issues/212
model_with_warn = None

if 'Mistral' in model_config.model:
model_with_warn = 'Mistral'
elif 'Qwen' in model_config.model:
model_with_warn = 'Qwen'

if model_with_warn is not None and any_whitespace:
msg = (f"{model_with_warn} "
f"model detected, consider set "
f"`guided_backend=xgrammar:disable-any-whitespace` "
f"to prevent runaway generation of whitespaces.")
logger.info_once(msg)
# Validate the schema and raise ValueError here if it is invalid.
# This is to avoid exceptions in model execution, which will crash
# the engine worker process.
try:
xgr.Grammar.from_json_schema(json_str)
xgr.Grammar.from_json_schema(json_str,
any_whitespace=any_whitespace)
except RuntimeError as err:
raise ValueError(str(err)) from err

return cls(json_str=json_str,
vocab_size=model_config.hf_text_config.vocab_size,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads,
tokenizer_data=tokenizer_data)
tokenizer_data=tokenizer_data,
any_whitespace=any_whitespace)
elif guided_params.grammar:
# XGrammar only supports GBNF grammars, so we must convert Lark
if grammar_is_likely_lark(guided_params.grammar):
Expand Down Expand Up @@ -291,7 +318,10 @@ def _ensure_ctx(self):
if self.ctx is None:
compiler = GrammarCompilerCache.get_compiler(self.config)
if self.config.json_str is not None:
self.ctx = compiler.compile_json_schema(self.config.json_str)
any_whitespace = self.config.any_whitespace
self.ctx = compiler\
.compile_json_schema(self.config.json_str,
any_whitespace=any_whitespace)
elif self.config.grammar_str is not None:
self.ctx = compiler.compile_grammar(self.config.grammar_str)
elif self.config.json_object:
Expand Down