Skip to content

Commit 42e4e87

Browse files
committed
Merge remote-tracking branch 'LlamaEnjoyer/read_sampling_params_from_model_config'
2 parents dd48d47 + dfe827b commit 42e4e87

File tree

7 files changed

+531
-38
lines changed

7 files changed

+531
-38
lines changed

backend/models.py

+91-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,26 @@
2424
from backend.config import config_filename
2525
from backend.util import *
2626

27+
from typing import Callable, Optional, Dict, Any
28+
29+
# Callback type for model parameter updates
30+
ModelLoadedCallback = Callable[[Dict[str, Any]], None]
31+
32+
# Global callback that will be called when model parameters are loaded/updated
33+
model_loaded_callback: Optional[ModelLoadedCallback] = None
34+
35+
def set_model_loaded_callback(callback: Optional[ModelLoadedCallback]) -> None:
36+
"""Set callback to be notified when model parameters are loaded/updated.
37+
38+
Args:
39+
callback: Function that takes model dict as argument, or None to clear
40+
"""
41+
global model_loaded_callback
42+
if callback is not None and not callable(callback):
43+
raise TypeError("Model loaded callback must be callable")
44+
model_loaded_callback = callback
45+
46+
# Reserve memory for auto-split functionality
2747
auto_split_reserve_bytes = 512 * 1024**2
2848

2949
models = {}
@@ -158,7 +178,59 @@ def prepare_draft_model(model):
158178
if "draft_rope_alpha_auto" not in model: model["draft_rope_alpha_auto"] = True
159179

160180

161-
def prepare_model(model):
181+
def prepare_model(model: Dict[str, Any]) -> None:
182+
"""Prepare model for loading by configuring parameters and resources.
183+
184+
Args:
185+
model: Dictionary containing model configuration
186+
187+
Raises:
188+
ValueError: If model directory is invalid
189+
JSONDecodeError: If generation_config.json exists but is malformed
190+
"""
191+
# Read generation_config.json if present
192+
config_path = os.path.join(expanduser(model["model_directory"]), "generation_config.json")
193+
if os.path.exists(config_path):
194+
try:
195+
with open(config_path, encoding='utf-8') as f:
196+
gen_config = json.load(f)
197+
198+
if not isinstance(gen_config, dict):
199+
raise ValueError("generation_config.json must contain a JSON object")
200+
201+
print(f"Found generation_config.json: {gen_config}")
202+
203+
# Map generation config parameters to internal names
204+
params_to_check = {
205+
"temperature": "temperature",
206+
"top_k": "top_k",
207+
"top_p": "top_p",
208+
"repetition_penalty": "repp"
209+
}
210+
211+
# Store original values for logging
212+
orig_values = {k: model.get(k) for k in params_to_check.values()}
213+
214+
# Update model with values from generation_config.json
215+
for config_name, internal_name in params_to_check.items():
216+
if config_name in gen_config:
217+
# Validate parameter types
218+
value = gen_config[config_name]
219+
if not isinstance(value, (int, float)):
220+
print(f"Warning: Invalid type for {config_name} in generation_config.json. Expected number, got {type(value)}")
221+
continue
222+
223+
model[internal_name] = value
224+
print(f"Setting {internal_name} from {orig_values.get(internal_name)} to {value}")
225+
226+
# Save updated model config
227+
save_models()
228+
except json.JSONDecodeError as e:
229+
print(f"Error parsing generation_config.json: {e}")
230+
print("Using default parameter values")
231+
except Exception as e:
232+
print(f"Unexpected error reading generation_config.json: {e}")
233+
print("Using default parameter values")
162234

163235
prep_config = ExLlamaV2Config()
164236
prep_config.fasttensors = False
@@ -194,6 +266,14 @@ def prepare_model(model):
194266
if "gpu_split" not in model: model["gpu_split"] = ""
195267
if "gpu_split_auto" not in model: model["gpu_split_auto"] = True
196268

269+
# Log final parameter state
270+
print("Final model parameters:", {
271+
"temperature": model.get("temperature", 0.8),
272+
"top_k": model.get("top_k", 50),
273+
"top_p": model.get("top_p", 0.8),
274+
"repp": model.get("repp", 1.01)
275+
})
276+
197277

198278
class ModelContainer:
199279

@@ -413,6 +493,16 @@ def load_model(data):
413493
yield json.dumps(result) + "\n"
414494
return ""
415495

496+
# Notify about model load via callback
497+
if success and model_loaded_callback is not None:
498+
print("Calling model_loaded_callback with params:", {
499+
"temperature": model.get("temperature", 0.8),
500+
"top_k": model.get("top_k", 50),
501+
"top_p": model.get("top_p", 0.8),
502+
"repp": model.get("repp", 1.01)
503+
})
504+
model_loaded_callback(model)
505+
416506
result = { "result": "ok" }
417507
# print(json.dumps(result) + "\n")
418508
yield json.dumps(result) + "\n"
@@ -430,4 +520,3 @@ def unload_model():
430520

431521
result = { "result": "ok" }
432522
return result
433-

backend/sessions.py

+52-16
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,22 @@
1717
)
1818

1919
from backend.config import set_config_dir, global_state, config_filename
20-
from backend.models import get_loaded_model
20+
from backend.models import set_model_loaded_callback
2121
from backend.prompts import prompt_formats
2222
from backend.util import MultiTimer
23+
import backend.models as models # Import as module to avoid circular dependency
2324
import threading
2425

2526
session_list: dict or None = None
2627
current_session = None
2728

29+
def handle_model_loaded(model):
30+
"""Handle model loading - only update new sessions with model params"""
31+
pass
32+
33+
# Register callback to handle model loading
34+
set_model_loaded_callback(handle_model_loaded)
35+
2836
# Cancel
2937

3038
abort_event = threading.Event()
@@ -92,9 +100,14 @@ def delete_session(d_session):
92100
current_session = None
93101

94102

95-
def get_default_session_settings():
96-
return \
97-
{
103+
def get_default_session_settings(use_model_params=False):
104+
"""Get default session settings
105+
106+
Args:
107+
use_model_params: If True and a model is loaded with custom params,
108+
apply those params instead of defaults
109+
"""
110+
settings = {
98111
"prompt_format": "Chat-RP",
99112
"roles": [ "User", "Assistant", "", "", "", "", "", "" ],
100113
"system_prompt_default": True,
@@ -119,6 +132,23 @@ def get_default_session_settings():
119132
"temperature_last": False,
120133
"skew": 0.0,
121134
}
135+
136+
if use_model_params:
137+
# If requested, try to use model parameters
138+
loaded_model = models.get_loaded_model()
139+
if loaded_model is not None:
140+
model_dict = loaded_model.model_dict
141+
# Only apply if model has custom params defined
142+
if any(param in model_dict for param in ["temperature", "top_k", "top_p", "repp"]):
143+
settings.update({
144+
"temperature": model_dict.get("temperature", settings["temperature"]),
145+
"top_k": model_dict.get("top_k", settings["top_k"]),
146+
"top_p": model_dict.get("top_p", settings["top_p"]),
147+
"repp": model_dict.get("repp", settings["repp"])
148+
})
149+
print("Updated settings with model params:", settings)
150+
151+
return settings
122152

123153
class Session:
124154

@@ -145,7 +175,8 @@ def init_new(self):
145175
self.session_uuid = str(uuid.uuid4())
146176
self.history = []
147177
# self.mode = ""
148-
self.settings = get_default_session_settings()
178+
# New sessions get app defaults
179+
self.settings = get_default_session_settings(use_model_params=False)
149180

150181

151182
def to_json(self):
@@ -163,9 +194,13 @@ def from_json(self, j):
163194
self.session_uuid = j["session_uuid"]
164195
self.history = j["history"]
165196
# self.mode = j["mode"]
166-
settings = get_default_session_settings()
167-
if "settings" in j: settings.update(j["settings"])
168-
self.settings = settings
197+
198+
# Start with hardcoded defaults (no model params)
199+
self.settings = get_default_session_settings(use_model_params=False)
200+
201+
# Apply ALL saved settings including sampling params
202+
if "settings" in j:
203+
self.settings.update(j["settings"])
169204

170205

171206
def load(self):
@@ -244,7 +279,7 @@ def create_context(self, prompt_format, max_len, min_len, uptoblock = None, pref
244279

245280
def create_context_instruct(self, prompt_format, max_len, min_len, uptoblock = None, prefix = ""):
246281

247-
tokenizer = get_loaded_model().tokenizer
282+
tokenizer = models.get_loaded_model().tokenizer
248283
prompts = []
249284
responses = []
250285

@@ -347,7 +382,7 @@ def create_context_instruct(self, prompt_format, max_len, min_len, uptoblock = N
347382

348383
def create_context_raw(self, prompt_format, max_len, min_len, uptoblock = None, prefix=""):
349384

350-
tokenizer = get_loaded_model().tokenizer
385+
tokenizer = models.get_loaded_model().tokenizer
351386
history_copy = []
352387
for h in self.history:
353388
if h["block_uuid"] == uptoblock: break
@@ -413,16 +448,17 @@ def generate(self, data):
413448
gen_prefix = data.get("prefix", "")
414449
block_id = data.get("block_id", None)
415450

416-
if get_loaded_model() is None:
451+
if models.get_loaded_model() is None:
417452
packet = { "result": "fail", "error": "No model loaded." }
418453
yield json.dumps(packet) + "\n"
419454
return packet
420455

421-
model = get_loaded_model().model
422-
generator = get_loaded_model().generator
423-
tokenizer = get_loaded_model().tokenizer
424-
cache = get_loaded_model().cache
425-
speculative_mode = get_loaded_model().speculative_mode
456+
loaded_model = models.get_loaded_model()
457+
model = loaded_model.model
458+
generator = loaded_model.generator
459+
tokenizer = loaded_model.tokenizer
460+
cache = loaded_model.cache
461+
speculative_mode = loaded_model.speculative_mode
426462

427463
prompt_format = prompt_formats[self.settings["prompt_format"]]()
428464

server.py

+91-1
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def api_get_default_settings():
125125
if verbose: print("/api/get_default_settings")
126126
with api_lock:
127127
result = { "result": "ok",
128-
"session_settings": get_default_session_settings(),
128+
"session_settings": get_default_session_settings(use_model_params=False), # Use hardcoded defaults
129129
"notepad_settings": get_default_notepad_settings(),
130130
"prompt_formats": list_prompt_formats() }
131131
return json.dumps(result) + "\n"
@@ -458,6 +458,96 @@ def api_cancel_notepad_generate():
458458
if verbose: print("->", result)
459459
return result
460460

461+
@app.route("/api/get_model_params")
462+
def api_get_model_params():
463+
global api_lock, verbose
464+
if verbose: print("/api/get_model_params")
465+
with api_lock:
466+
model = get_loaded_model()
467+
if model is None:
468+
result = { "has_params": False }
469+
else:
470+
# Check if model has any sampling params defined
471+
model_dict = model.model_dict
472+
# Track which parameters are defined in the model
473+
model_params = {
474+
"temperature": "temperature" in model_dict,
475+
"top_k": "top_k" in model_dict,
476+
"top_p": "top_p" in model_dict,
477+
"repp": "repp" in model_dict
478+
}
479+
has_params = any(model_params.values())
480+
result = {
481+
"has_params": has_params,
482+
"model_params": model_params
483+
}
484+
if verbose: print("->", result)
485+
return json.dumps(result) + "\n"
486+
487+
@app.route("/api/reset_to_app_defaults", methods=['POST'])
488+
def api_reset_to_app_defaults():
489+
global api_lock, verbose
490+
if verbose: print("/api/reset_to_app_defaults")
491+
with api_lock:
492+
session = get_session()
493+
if session is not None:
494+
# Get default settings
495+
default_settings = get_default_session_settings(use_model_params=False)
496+
497+
# Define which parameters are sampling-related
498+
sampling_params = [
499+
"temperature", "top_k", "top_p", "min_p", "tfs",
500+
"mirostat", "mirostat_tau", "mirostat_eta", "typical",
501+
"repp", "repr", "repd", "quad_sampling", "temperature_last", "skew"
502+
]
503+
504+
# Reset only sampling parameters to defaults
505+
updated_params = {}
506+
for param in sampling_params:
507+
updated_params[param] = default_settings[param]
508+
session.settings[param] = default_settings[param]
509+
510+
session.save()
511+
result = { "result": "ok", "settings": updated_params }
512+
else:
513+
result = { "result": "fail", "error": "No session loaded" }
514+
if verbose: print("->", result)
515+
return json.dumps(result) + "\n"
516+
517+
@app.route("/api/apply_model_params", methods=['POST'])
518+
def api_apply_model_params():
519+
global api_lock, verbose
520+
if verbose: print("/api/apply_model_params")
521+
with api_lock:
522+
model = get_loaded_model()
523+
session = get_session()
524+
if model is not None and session is not None:
525+
# Get model's defined parameters
526+
model_dict = model.model_dict
527+
updated_params = {}
528+
529+
# Only update parameters that are defined in the model
530+
if "temperature" in model_dict:
531+
updated_params["temperature"] = model_dict["temperature"]
532+
session.settings["temperature"] = model_dict["temperature"]
533+
if "top_k" in model_dict:
534+
updated_params["top_k"] = model_dict["top_k"]
535+
session.settings["top_k"] = model_dict["top_k"]
536+
if "top_p" in model_dict:
537+
updated_params["top_p"] = model_dict["top_p"]
538+
session.settings["top_p"] = model_dict["top_p"]
539+
if "repp" in model_dict:
540+
updated_params["repp"] = model_dict["repp"]
541+
session.settings["repp"] = model_dict["repp"]
542+
543+
session.save()
544+
# Only return the sampling parameters that were changed
545+
result = { "result": "ok", "settings": updated_params }
546+
else:
547+
result = { "result": "fail", "error": "No model or session loaded" }
548+
if verbose: print("->", result)
549+
return json.dumps(result) + "\n"
550+
461551

462552
# Prepare torch
463553

static/chat.css

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
.session-list {
32
background-color: var(--background-color-body);
43
display: flex;
@@ -522,7 +521,6 @@
522521
display: block;
523522
}
524523

525-
526524
.save-btn {
527525
background-color: var(--button-background);
528526
font-size: var(--font-size-small);
@@ -594,3 +592,7 @@
594592
.chat-popup action:last-child {
595593
margin-bottom: 0;
596594
}
595+
596+
.highlight {
597+
font-weight: bold;
598+
}

0 commit comments

Comments
 (0)