Skip to content

Commit 8524c8c

Browse files
committed
feat: Add model parameter management and UI controls
Add support for reading generation_config.json for model parameters Implement model_loaded_callback for parameter updates Add UI controls to switch between model and app defaults Add visual highlighting for modified parameters Add new API endpoints for parameter management Improve session handling of model parameters Update UI styling for parameter controls
1 parent f4d9478 commit 8524c8c

File tree

7 files changed

+531
-39
lines changed

7 files changed

+531
-39
lines changed

backend/models.py

+91-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,26 @@
2424
from backend.config import config_filename
2525
from backend.util import *
2626

27+
from typing import Callable, Optional, Dict, Any
28+
29+
# Callback type for model parameter updates
30+
ModelLoadedCallback = Callable[[Dict[str, Any]], None]
31+
32+
# Global callback that will be called when model parameters are loaded/updated
33+
model_loaded_callback: Optional[ModelLoadedCallback] = None
34+
35+
def set_model_loaded_callback(callback: Optional[ModelLoadedCallback]) -> None:
36+
"""Set callback to be notified when model parameters are loaded/updated.
37+
38+
Args:
39+
callback: Function that takes model dict as argument, or None to clear
40+
"""
41+
global model_loaded_callback
42+
if callback is not None and not callable(callback):
43+
raise TypeError("Model loaded callback must be callable")
44+
model_loaded_callback = callback
45+
46+
# Reserve memory for auto-split functionality
2747
auto_split_reserve_bytes = 512 * 1024**2
2848

2949
models = {}
@@ -158,7 +178,59 @@ def prepare_draft_model(model):
158178
if "draft_rope_alpha_auto" not in model: model["draft_rope_alpha_auto"] = True
159179

160180

161-
def prepare_model(model):
181+
def prepare_model(model: Dict[str, Any]) -> None:
182+
"""Prepare model for loading by configuring parameters and resources.
183+
184+
Args:
185+
model: Dictionary containing model configuration
186+
187+
Raises:
188+
ValueError: If model directory is invalid
189+
JSONDecodeError: If generation_config.json exists but is malformed
190+
"""
191+
# Read generation_config.json if present
192+
config_path = os.path.join(expanduser(model["model_directory"]), "generation_config.json")
193+
if os.path.exists(config_path):
194+
try:
195+
with open(config_path, encoding='utf-8') as f:
196+
gen_config = json.load(f)
197+
198+
if not isinstance(gen_config, dict):
199+
raise ValueError("generation_config.json must contain a JSON object")
200+
201+
print(f"Found generation_config.json: {gen_config}")
202+
203+
# Map generation config parameters to internal names
204+
params_to_check = {
205+
"temperature": "temperature",
206+
"top_k": "top_k",
207+
"top_p": "top_p",
208+
"repetition_penalty": "repp"
209+
}
210+
211+
# Store original values for logging
212+
orig_values = {k: model.get(k) for k in params_to_check.values()}
213+
214+
# Update model with values from generation_config.json
215+
for config_name, internal_name in params_to_check.items():
216+
if config_name in gen_config:
217+
# Validate parameter types
218+
value = gen_config[config_name]
219+
if not isinstance(value, (int, float)):
220+
print(f"Warning: Invalid type for {config_name} in generation_config.json. Expected number, got {type(value)}")
221+
continue
222+
223+
model[internal_name] = value
224+
print(f"Setting {internal_name} from {orig_values.get(internal_name)} to {value}")
225+
226+
# Save updated model config
227+
save_models()
228+
except json.JSONDecodeError as e:
229+
print(f"Error parsing generation_config.json: {e}")
230+
print("Using default parameter values")
231+
except Exception as e:
232+
print(f"Unexpected error reading generation_config.json: {e}")
233+
print("Using default parameter values")
162234

163235
prep_config = ExLlamaV2Config()
164236
prep_config.fasttensors = False
@@ -194,6 +266,14 @@ def prepare_model(model):
194266
if "gpu_split" not in model: model["gpu_split"] = ""
195267
if "gpu_split_auto" not in model: model["gpu_split_auto"] = True
196268

269+
# Log final parameter state
270+
print("Final model parameters:", {
271+
"temperature": model.get("temperature", 0.8),
272+
"top_k": model.get("top_k", 50),
273+
"top_p": model.get("top_p", 0.8),
274+
"repp": model.get("repp", 1.01)
275+
})
276+
197277

198278
class ModelContainer:
199279

@@ -413,6 +493,16 @@ def load_model(data):
413493
yield json.dumps(result) + "\n"
414494
return ""
415495

496+
# Notify about model load via callback
497+
if success and model_loaded_callback is not None:
498+
print("Calling model_loaded_callback with params:", {
499+
"temperature": model.get("temperature", 0.8),
500+
"top_k": model.get("top_k", 50),
501+
"top_p": model.get("top_p", 0.8),
502+
"repp": model.get("repp", 1.01)
503+
})
504+
model_loaded_callback(model)
505+
416506
result = { "result": "ok" }
417507
# print(json.dumps(result) + "\n")
418508
yield json.dumps(result) + "\n"
@@ -430,4 +520,3 @@ def unload_model():
430520

431521
result = { "result": "ok" }
432522
return result
433-

backend/sessions.py

+52-16
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,22 @@
1717
)
1818

1919
from backend.config import set_config_dir, global_state, config_filename
20-
from backend.models import get_loaded_model
20+
from backend.models import set_model_loaded_callback
2121
from backend.prompts import prompt_formats
2222
from backend.util import MultiTimer
23+
import backend.models as models # Import as module to avoid circular dependency
2324
import threading
2425

2526
session_list: dict or None = None
2627
current_session = None
2728

29+
def handle_model_loaded(model):
30+
"""Handle model loading - only update new sessions with model params"""
31+
pass
32+
33+
# Register callback to handle model loading
34+
set_model_loaded_callback(handle_model_loaded)
35+
2836
# Cancel
2937

3038
abort_event = threading.Event()
@@ -92,9 +100,14 @@ def delete_session(d_session):
92100
current_session = None
93101

94102

95-
def get_default_session_settings():
96-
return \
97-
{
103+
def get_default_session_settings(use_model_params=False):
104+
"""Get default session settings
105+
106+
Args:
107+
use_model_params: If True and a model is loaded with custom params,
108+
apply those params instead of defaults
109+
"""
110+
settings = {
98111
"prompt_format": "Chat-RP",
99112
"roles": [ "User", "Assistant", "", "", "", "", "", "" ],
100113
"system_prompt_default": True,
@@ -119,6 +132,23 @@ def get_default_session_settings():
119132
"temperature_last": False,
120133
"skew": 0.0,
121134
}
135+
136+
if use_model_params:
137+
# If requested, try to use model parameters
138+
loaded_model = models.get_loaded_model()
139+
if loaded_model is not None:
140+
model_dict = loaded_model.model_dict
141+
# Only apply if model has custom params defined
142+
if any(param in model_dict for param in ["temperature", "top_k", "top_p", "repp"]):
143+
settings.update({
144+
"temperature": model_dict.get("temperature", settings["temperature"]),
145+
"top_k": model_dict.get("top_k", settings["top_k"]),
146+
"top_p": model_dict.get("top_p", settings["top_p"]),
147+
"repp": model_dict.get("repp", settings["repp"])
148+
})
149+
print("Updated settings with model params:", settings)
150+
151+
return settings
122152

123153
class Session:
124154

@@ -145,7 +175,8 @@ def init_new(self):
145175
self.session_uuid = str(uuid.uuid4())
146176
self.history = []
147177
# self.mode = ""
148-
self.settings = get_default_session_settings()
178+
# New sessions get app defaults
179+
self.settings = get_default_session_settings(use_model_params=False)
149180

150181

151182
def to_json(self):
@@ -163,9 +194,13 @@ def from_json(self, j):
163194
self.session_uuid = j["session_uuid"]
164195
self.history = j["history"]
165196
# self.mode = j["mode"]
166-
settings = get_default_session_settings()
167-
if "settings" in j: settings.update(j["settings"])
168-
self.settings = settings
197+
198+
# Start with hardcoded defaults (no model params)
199+
self.settings = get_default_session_settings(use_model_params=False)
200+
201+
# Apply ALL saved settings including sampling params
202+
if "settings" in j:
203+
self.settings.update(j["settings"])
169204

170205

171206
def load(self):
@@ -244,7 +279,7 @@ def create_context(self, prompt_format, max_len, min_len, uptoblock = None, pref
244279

245280
def create_context_instruct(self, prompt_format, max_len, min_len, uptoblock = None, prefix = ""):
246281

247-
tokenizer = get_loaded_model().tokenizer
282+
tokenizer = models.get_loaded_model().tokenizer
248283
prompts = []
249284
responses = []
250285

@@ -347,7 +382,7 @@ def create_context_instruct(self, prompt_format, max_len, min_len, uptoblock = N
347382

348383
def create_context_raw(self, prompt_format, max_len, min_len, uptoblock = None, prefix=""):
349384

350-
tokenizer = get_loaded_model().tokenizer
385+
tokenizer = models.get_loaded_model().tokenizer
351386
history_copy = []
352387
for h in self.history:
353388
if h["block_uuid"] == uptoblock: break
@@ -413,16 +448,17 @@ def generate(self, data):
413448
gen_prefix = data.get("prefix", "")
414449
block_id = data.get("block_id", None)
415450

416-
if get_loaded_model() is None:
451+
if models.get_loaded_model() is None:
417452
packet = { "result": "fail", "error": "No model loaded." }
418453
yield json.dumps(packet) + "\n"
419454
return packet
420455

421-
model = get_loaded_model().model
422-
generator = get_loaded_model().generator
423-
tokenizer = get_loaded_model().tokenizer
424-
cache = get_loaded_model().cache
425-
speculative_mode = get_loaded_model().speculative_mode
456+
loaded_model = models.get_loaded_model()
457+
model = loaded_model.model
458+
generator = loaded_model.generator
459+
tokenizer = loaded_model.tokenizer
460+
cache = loaded_model.cache
461+
speculative_mode = loaded_model.speculative_mode
426462

427463
prompt_format = prompt_formats[self.settings["prompt_format"]]()
428464

server.py

+91-2
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def api_get_default_settings():
125125
if verbose: print("/api/get_default_settings")
126126
with api_lock:
127127
result = { "result": "ok",
128-
"session_settings": get_default_session_settings(),
128+
"session_settings": get_default_session_settings(use_model_params=False), # Use hardcoded defaults
129129
"notepad_settings": get_default_notepad_settings(),
130130
"prompt_formats": list_prompt_formats() }
131131
return json.dumps(result) + "\n"
@@ -439,6 +439,96 @@ def api_cancel_notepad_generate():
439439
if verbose: print("->", result)
440440
return result
441441

442+
@app.route("/api/get_model_params")
443+
def api_get_model_params():
444+
global api_lock, verbose
445+
if verbose: print("/api/get_model_params")
446+
with api_lock:
447+
model = get_loaded_model()
448+
if model is None:
449+
result = { "has_params": False }
450+
else:
451+
# Check if model has any sampling params defined
452+
model_dict = model.model_dict
453+
# Track which parameters are defined in the model
454+
model_params = {
455+
"temperature": "temperature" in model_dict,
456+
"top_k": "top_k" in model_dict,
457+
"top_p": "top_p" in model_dict,
458+
"repp": "repp" in model_dict
459+
}
460+
has_params = any(model_params.values())
461+
result = {
462+
"has_params": has_params,
463+
"model_params": model_params
464+
}
465+
if verbose: print("->", result)
466+
return json.dumps(result) + "\n"
467+
468+
@app.route("/api/reset_to_app_defaults", methods=['POST'])
469+
def api_reset_to_app_defaults():
470+
global api_lock, verbose
471+
if verbose: print("/api/reset_to_app_defaults")
472+
with api_lock:
473+
session = get_session()
474+
if session is not None:
475+
# Get default settings
476+
default_settings = get_default_session_settings(use_model_params=False)
477+
478+
# Define which parameters are sampling-related
479+
sampling_params = [
480+
"temperature", "top_k", "top_p", "min_p", "tfs",
481+
"mirostat", "mirostat_tau", "mirostat_eta", "typical",
482+
"repp", "repr", "repd", "quad_sampling", "temperature_last", "skew"
483+
]
484+
485+
# Reset only sampling parameters to defaults
486+
updated_params = {}
487+
for param in sampling_params:
488+
updated_params[param] = default_settings[param]
489+
session.settings[param] = default_settings[param]
490+
491+
session.save()
492+
result = { "result": "ok", "settings": updated_params }
493+
else:
494+
result = { "result": "fail", "error": "No session loaded" }
495+
if verbose: print("->", result)
496+
return json.dumps(result) + "\n"
497+
498+
@app.route("/api/apply_model_params", methods=['POST'])
499+
def api_apply_model_params():
500+
global api_lock, verbose
501+
if verbose: print("/api/apply_model_params")
502+
with api_lock:
503+
model = get_loaded_model()
504+
session = get_session()
505+
if model is not None and session is not None:
506+
# Get model's defined parameters
507+
model_dict = model.model_dict
508+
updated_params = {}
509+
510+
# Only update parameters that are defined in the model
511+
if "temperature" in model_dict:
512+
updated_params["temperature"] = model_dict["temperature"]
513+
session.settings["temperature"] = model_dict["temperature"]
514+
if "top_k" in model_dict:
515+
updated_params["top_k"] = model_dict["top_k"]
516+
session.settings["top_k"] = model_dict["top_k"]
517+
if "top_p" in model_dict:
518+
updated_params["top_p"] = model_dict["top_p"]
519+
session.settings["top_p"] = model_dict["top_p"]
520+
if "repp" in model_dict:
521+
updated_params["repp"] = model_dict["repp"]
522+
session.settings["repp"] = model_dict["repp"]
523+
524+
session.save()
525+
# Only return the sampling parameters that were changed
526+
result = { "result": "ok", "settings": updated_params }
527+
else:
528+
result = { "result": "fail", "error": "No model or session loaded" }
529+
if verbose: print("->", result)
530+
return json.dumps(result) + "\n"
531+
442532

443533
# Prepare torch
444534

@@ -467,4 +557,3 @@ def api_cancel_notepad_generate():
467557
print(f" -- Opening UI in default web browser")
468558

469559
serve(app, host = host, port = port, threads = 8)
470-

static/chat.css

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
.session-list {
32
background-color: var(--background-color-body);
43
display: flex;
@@ -492,7 +491,6 @@
492491
display: block;
493492
}
494493

495-
496494
.save-btn {
497495
background-color: var(--button-background);
498496
font-size: var(--font-size-small);
@@ -564,3 +562,7 @@
564562
.chat-popup action:last-child {
565563
margin-bottom: 0;
566564
}
565+
566+
.highlight {
567+
font-weight: bold;
568+
}

0 commit comments

Comments
 (0)