File tree Expand file tree Collapse file tree 2 files changed +61
-0
lines changed Expand file tree Collapse file tree 2 files changed +61
-0
lines changed Original file line number Diff line number Diff line change 1+ from typing import List
2+ import traceback
3+
4+ from exllamav3 import (
5+ Tokenizer ,
6+ Filter ,
7+ FormatronFilter ,
8+ )
9+ from formatron .formatter import FormatterBuilder
10+ from formatron .schemas import json_schema
11+ from loguru import logger
12+
13+
14+ class ExLlamaV3Grammar :
15+ """ExLlamaV3 class for various grammar filters/parsers."""
16+
17+ filters : List [Filter ]
18+
19+ def __init__ (self ):
20+ self .filters = []
21+
22+ def add_json_schema_filter (
23+ self ,
24+ schema : dict ,
25+ tokenizer : Tokenizer ,
26+ ):
27+ """Adds an ExllamaV3 filter based on a JSON schema."""
28+
29+ leading_character = "[" if schema .get ("type" ) == "array" else "{"
30+
31+ try :
32+ # Add fields required by formatron if not present
33+ if "$id" not in schema :
34+ schema ["$id" ] = "https://example.com/example.json"
35+ if "$schema" not in schema :
36+ schema ["$schema" ] = "http://json-schema.org/draft-07/schema#"
37+
38+ # Validate schema and create formatter
39+ schema = json_schema .create_schema (schema )
40+ except Exception :
41+ traceback .print_exc ()
42+ logger .error (
43+ "Skipping because the JSON schema couldn't be parsed. "
44+ "Please read the above error for more information."
45+ )
46+ return
47+
48+ f = FormatterBuilder ()
49+ f .append_line (f"{ f .json (schema )} " )
50+ self .filters .append (FormatronFilter (tokenizer , eos_after_completed = True , formatter_builder = f ))
51+
52+ # Additional constraint to force leading character
53+ f = FormatterBuilder ()
54+ f .append_line (leading_character )
55+ self .filters .append (FormatronFilter (tokenizer , formatter_builder = f ))
Original file line number Diff line number Diff line change 2121 Tokenizer ,
2222)
2323from exllamav3 .cache import CacheLayer_quant
24+ from backends .exllamav3 .grammar import ExLlamaV3Grammar
2425from loguru import logger
2526
2627from backends .base_model_container import BaseModelContainer
@@ -929,6 +930,7 @@ async def generate_gen(
929930 prompts = [prompt ]
930931 stop_conditions = params .stop
931932 add_bos_token = unwrap (params .add_bos_token , self .hf_model .add_bos_token ())
933+ grammar_handler = ExLlamaV3Grammar ()
932934
933935 # Get multimodal embeddings if present
934936 mm_embeddings_content = mm_embeddings .content if mm_embeddings else []
@@ -970,6 +972,9 @@ async def generate_gen(
970972 request_id ,
971973 )
972974
975+ if params .json_schema :
976+ grammar_handler .add_json_schema_filter (params .json_schema , self .tokenizer )
977+
973978 generation = {}
974979 job = AsyncJob (
975980 self .generator ,
@@ -981,6 +986,7 @@ async def generate_gen(
981986 embeddings = mm_embeddings_content ,
982987 return_top_tokens = params .logprobs ,
983988 max_rq_tokens = self .max_rq_tokens ,
989+ filters = grammar_handler .filters ,
984990 )
985991
986992 generated_tokens = 0
You can’t perform that action at this time.
0 commit comments