Skip to content

Commit cef7560

Browse files
author
nya
committed
Constrained generation with json schema for ExllamaV3
1 parent fece479 commit cef7560

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

backends/exllamav3/grammar.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from typing import List
2+
import traceback
3+
4+
from exllamav3 import (
5+
Tokenizer,
6+
Filter,
7+
FormatronFilter,
8+
)
9+
from formatron.formatter import FormatterBuilder
10+
from formatron.schemas import json_schema
11+
from loguru import logger
12+
13+
14+
class ExLlamaV3Grammar:
15+
"""ExLlamaV3 class for various grammar filters/parsers."""
16+
17+
filters: List[Filter]
18+
19+
def __init__(self):
20+
self.filters = []
21+
22+
def add_json_schema_filter(
23+
self,
24+
schema: dict,
25+
tokenizer: Tokenizer,
26+
):
27+
"""Adds an ExllamaV3 filter based on a JSON schema."""
28+
29+
leading_character = "[" if schema.get("type") == "array" else "{"
30+
31+
try:
32+
# Add fields required by formatron if not present
33+
if "$id" not in schema:
34+
schema["$id"] = "https://example.com/example.json"
35+
if "$schema" not in schema:
36+
schema["$schema"] = "http://json-schema.org/draft-07/schema#"
37+
38+
# Validate schema and create formatter
39+
schema = json_schema.create_schema(schema)
40+
except Exception:
41+
traceback.print_exc()
42+
logger.error(
43+
"Skipping because the JSON schema couldn't be parsed. "
44+
"Please read the above error for more information."
45+
)
46+
return
47+
48+
f = FormatterBuilder()
49+
f.append_line(f"{f.json(schema)}")
50+
self.filters.append(FormatronFilter(tokenizer, eos_after_completed = True, formatter_builder = f))
51+
52+
# Additional constraint to force leading character
53+
f = FormatterBuilder()
54+
f.append_line(leading_character)
55+
self.filters.append(FormatronFilter(tokenizer, formatter_builder = f))

backends/exllamav3/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
Tokenizer,
2222
)
2323
from exllamav3.cache import CacheLayer_quant
24+
from backends.exllamav3.grammar import ExLlamaV3Grammar
2425
from loguru import logger
2526

2627
from backends.base_model_container import BaseModelContainer
@@ -929,6 +930,7 @@ async def generate_gen(
929930
prompts = [prompt]
930931
stop_conditions = params.stop
931932
add_bos_token = unwrap(params.add_bos_token, self.hf_model.add_bos_token())
933+
grammar_handler = ExLlamaV3Grammar()
932934

933935
# Get multimodal embeddings if present
934936
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []
@@ -970,6 +972,9 @@ async def generate_gen(
970972
request_id,
971973
)
972974

975+
if params.json_schema:
976+
grammar_handler.add_json_schema_filter(params.json_schema, self.tokenizer)
977+
973978
generation = {}
974979
job = AsyncJob(
975980
self.generator,
@@ -981,6 +986,7 @@ async def generate_gen(
981986
embeddings=mm_embeddings_content,
982987
return_top_tokens=params.logprobs,
983988
max_rq_tokens=self.max_rq_tokens,
989+
filters=grammar_handler.filters,
984990
)
985991

986992
generated_tokens = 0

0 commit comments

Comments
 (0)