|
6 | 6 | import streamlit as st
|
7 | 7 | import weave
|
8 | 8 | import os
|
| 9 | +import openai |
| 10 | +import copy |
9 | 11 | from weave.trace.weave_client import WeaveClient
|
10 | 12 |
|
11 | 13 | from programmer.weave_next.api import init_local_client
|
12 |
| -from programmer.weave_next.weave_query import calls, expand_refs |
| 14 | +from programmer.weave_next.weave_query import ( |
| 15 | + calls, |
| 16 | + expand_refs, |
| 17 | + get_call, |
| 18 | + expand_json_refs, |
| 19 | +) |
13 | 20 | from programmer.settings_manager import SettingsManager
|
14 | 21 |
|
15 | 22 | st.set_page_config(layout="wide")
|
@@ -47,7 +54,46 @@ def init_from_settings() -> WeaveClient:
|
47 | 54 | raise ValueError(f"Invalid weave_logging setting: {weave_logging_setting}")
|
48 | 55 |
|
49 | 56 |
|
50 |
| -client = init_from_settings() |
| 57 | +# Add sidebar for Weave project configuration |
| 58 | +with st.sidebar: |
| 59 | + st.header("Weave Project Configuration") |
| 60 | + |
| 61 | + # Initialize from settings |
| 62 | + initial_weave_logging = SettingsManager.get_setting("weave_logging") |
| 63 | + initial_project_type = "local" if initial_weave_logging == "local" else "cloud" |
| 64 | + initial_project_path = ( |
| 65 | + os.path.join(SettingsManager.PROGRAMMER_DIR, "weave.db") |
| 66 | + if initial_weave_logging == "local" |
| 67 | + else "" |
| 68 | + ) |
| 69 | + initial_project_name = ( |
| 70 | + f"programmer-{os.path.basename(os.path.abspath(os.curdir))}" |
| 71 | + if initial_weave_logging == "cloud" |
| 72 | + else "" |
| 73 | + ) |
| 74 | + |
| 75 | + project_type = st.radio( |
| 76 | + "Project Type", |
| 77 | + ["local", "cloud"], |
| 78 | + index=0 if initial_project_type == "local" else 1, |
| 79 | + ) |
| 80 | + |
| 81 | + if project_type == "local": |
| 82 | + project_path = st.text_input("Local DB Path", value=initial_project_path) |
| 83 | + # SettingsManager.set_setting("weave_logging", "local") |
| 84 | + # SettingsManager.set_setting("weave_db_path", project_path) |
| 85 | + client = init_local_weave(project_path) |
| 86 | + print("C2", client._project_id()) |
| 87 | + else: |
| 88 | + # SettingsManager.set_setting("weave_logging", "cloud") |
| 89 | + # SettingsManager.set_setting("weave_project_name", project_name) |
| 90 | + project_name = st.text_input("Cloud Project Name", value=initial_project_name) |
| 91 | + client = init_remote_weave(project_name) |
| 92 | + print("C3", client._project_id()) |
| 93 | + |
| 94 | +# Initialize client based on current settings |
| 95 | +# client = init_from_settings() |
| 96 | +print("CLIENT", client._project_id()) |
51 | 97 |
|
52 | 98 |
|
53 | 99 | def set_focus_step_id(call_id):
|
@@ -76,6 +122,16 @@ def cached_expand_refs(wc: WeaveClient, refs: Sequence[str]):
|
76 | 122 | return expand_refs(wc, refs).to_pandas()
|
77 | 123 |
|
78 | 124 |
|
| 125 | +@st.cache_data(hash_funcs=ST_HASH_FUNCS) |
| 126 | +def cached_get_call(wc: WeaveClient, call_id: str): |
| 127 | + return get_call(wc, call_id) |
| 128 | + |
| 129 | + |
| 130 | +@st.cache_data(hash_funcs=ST_HASH_FUNCS) |
| 131 | +def cached_expand_json_refs(wc: WeaveClient, json: dict): |
| 132 | + return expand_json_refs(wc, json) |
| 133 | + |
| 134 | + |
79 | 135 | def print_step_call(call):
|
80 | 136 | start_history = call["inputs.state.history"]
|
81 | 137 | end_history = call["output.history"]
|
@@ -174,30 +230,234 @@ def print_session_call(session_id):
|
174 | 230 | )
|
175 | 231 |
|
176 | 232 |
|
177 |
| -session_calls_df = cached_calls(client, "session", expand_refs=["inputs.agent_state"]) |
178 |
| -if len(session_calls_df) == 0: |
179 |
| - st.error("No programmer sessions found.") |
180 |
| - st.stop() |
181 |
| -session_user_message_df = session_calls_df["inputs.agent_state.history"].apply( |
182 |
| - lambda v: v[-1]["content"] |
183 |
| -) |
184 |
| - |
| 233 | +def sessions_page(): |
| 234 | + session_calls_df = cached_calls( |
| 235 | + client, "session", expand_refs=["inputs.agent_state"] |
| 236 | + ) |
| 237 | + if len(session_calls_df) == 0: |
| 238 | + st.error("No programmer sessions found.") |
| 239 | + st.stop() |
| 240 | + session_user_message_df = session_calls_df["inputs.agent_state.history"].apply( |
| 241 | + lambda v: v[-1]["content"] |
| 242 | + ) |
| 243 | + with st.sidebar: |
| 244 | + st.header("Session Selection") |
| 245 | + if st.button("Refresh"): |
| 246 | + st.cache_data.clear() |
| 247 | + st.rerun() |
| 248 | + message_ids = { |
| 249 | + f"{cid[-5:]}: {m}": cid |
| 250 | + for cid, m in reversed( |
| 251 | + list(zip(session_calls_df["id"], session_user_message_df)) |
| 252 | + ) |
| 253 | + } |
| 254 | + sel_message = st.radio("Session", options=message_ids.keys()) |
| 255 | + sel_id = None |
| 256 | + if sel_message: |
| 257 | + sel_id = message_ids.get(sel_message) |
| 258 | + if sel_id: |
| 259 | + st.header(f"Session: {sel_id}") |
| 260 | + print_session_call(sel_id) |
| 261 | + |
| 262 | + |
| 263 | +sessions_pg = st.Page(sessions_page, title="Sessions") |
| 264 | + |
| 265 | + |
| 266 | +# def write_chat_message(m, key): |
| 267 | +# with st.chat_message(m["role"]): |
| 268 | +# if "content" in m: |
| 269 | +# st.text_area( |
| 270 | +# "", value=str(m["content"]), label_visibility="collapsed", key=key |
| 271 | +# ) |
| 272 | +def write_chat_message(m, key, readonly=False): |
| 273 | + def on_change_content(): |
| 274 | + new_value = st.session_state[key] |
| 275 | + st.session_state.playground_state["editable_call"]["inputs"]["messages"][ |
| 276 | + m["original_index"] |
| 277 | + ]["content"] = new_value |
| 278 | + |
| 279 | + with st.chat_message(m["role"]): |
| 280 | + if m.get("content"): |
| 281 | + if readonly: |
| 282 | + st.code(m["content"]) |
| 283 | + else: |
| 284 | + st.text_area( |
| 285 | + "", |
| 286 | + value=m["content"], |
| 287 | + label_visibility="collapsed", |
| 288 | + key=key, |
| 289 | + on_change=on_change_content, |
| 290 | + ) |
| 291 | + if m.get("tool_calls"): |
| 292 | + for t in m["tool_calls"]: |
| 293 | + st.write(t["function"]["name"]) |
| 294 | + st.json( |
| 295 | + { |
| 296 | + "arguments": t["function"]["arguments"], |
| 297 | + "response": t.get("response", {}).get("content"), |
| 298 | + }, |
| 299 | + expanded=True, |
| 300 | + ) |
| 301 | + |
| 302 | + |
| 303 | +def attach_tool_call_responses(messages): |
| 304 | + new_messages = [] |
| 305 | + for i, m in enumerate(messages): |
| 306 | + new_m = copy.deepcopy(m) |
| 307 | + new_m["original_index"] = i |
| 308 | + if new_m["role"] == "assistant" and "tool_calls" in new_m: |
| 309 | + new_m["tool_call_responses"] = [] |
| 310 | + for t in new_m["tool_calls"]: |
| 311 | + t_id = t["id"] |
| 312 | + for j, t_response in enumerate(messages): |
| 313 | + if t_response.get("tool_call_id") == t_id: |
| 314 | + t["response"] = t_response |
| 315 | + t["response"]["original_index"] = j |
| 316 | + break |
| 317 | + if "tool_call_id" not in new_m: |
| 318 | + new_messages.append(new_m) |
| 319 | + return new_messages |
| 320 | + |
| 321 | + |
| 322 | +def playground_page(): |
| 323 | + with st.sidebar: |
| 324 | + if not st.session_state.get("playground_state"): |
| 325 | + st.session_state.playground_state = { |
| 326 | + "call_id": None, |
| 327 | + "call": None, |
| 328 | + "expanded_call": None, |
| 329 | + "editable_call": None, |
| 330 | + } |
| 331 | + playground_state = st.session_state.playground_state |
| 332 | + call_id = st.text_input("Call ID") |
| 333 | + if not call_id: |
| 334 | + st.error("Please set call ID") |
| 335 | + st.stop() |
| 336 | + |
| 337 | + # st.write(playground_state) |
| 338 | + if playground_state["expanded_call"] != playground_state["editable_call"]: |
| 339 | + st.warning("Call has been modified") |
| 340 | + if st.button("Restore original call"): |
| 341 | + st.session_state.playground_state["editable_call"] = copy.deepcopy( |
| 342 | + playground_state["expanded_call"] |
| 343 | + ) |
| 344 | + st.rerun() |
| 345 | + |
| 346 | + if call_id != st.session_state.playground_state["call_id"]: |
| 347 | + st.spinner("Loading call...") |
| 348 | + call = cached_get_call(client, call_id) |
| 349 | + editable_call = cached_expand_json_refs(client, call) |
| 350 | + st.session_state.playground_state = { |
| 351 | + "call_id": call_id, |
| 352 | + "call": call, |
| 353 | + "expanded_call": editable_call, |
| 354 | + "editable_call": copy.deepcopy(editable_call), |
| 355 | + } |
| 356 | + st.rerun() |
| 357 | + |
| 358 | + call = st.session_state.playground_state["call"] |
| 359 | + editable_call = st.session_state.playground_state["editable_call"] |
| 360 | + if call is None or editable_call is None: |
| 361 | + st.warning("call not yet loaded") |
| 362 | + st.stop() |
| 363 | + |
| 364 | + st.write(call["op_name"]) |
| 365 | + # st.json(call["inputs"]) |
| 366 | + # st.json(call["inputs"]["tools"]) |
| 367 | + |
| 368 | + def on_change_temperature(): |
| 369 | + st.session_state.playground_state["editable_call"]["inputs"][ |
| 370 | + "temperature" |
| 371 | + ] = st.session_state["temperature"] |
| 372 | + |
| 373 | + st.slider( |
| 374 | + "Temperature", |
| 375 | + min_value=0.0, |
| 376 | + max_value=1.0, |
| 377 | + value=editable_call["inputs"]["temperature"], |
| 378 | + key="temperature", |
| 379 | + on_change=on_change_temperature, |
| 380 | + ) |
185 | 381 |
|
186 |
| -with st.sidebar: |
187 |
| - if st.button("Refresh"): |
188 |
| - st.cache_data.clear() |
189 |
| - st.rerun() |
190 |
| - message_ids = { |
191 |
| - f"{cid[-5:]}: {m}": cid |
192 |
| - for cid, m in reversed( |
193 |
| - list(zip(session_calls_df["id"], session_user_message_df)) |
| 382 | + tools = call["inputs"].get("tools", []) |
| 383 | + if tools: |
| 384 | + st.write("Tools") |
| 385 | + for tool_idx, t in enumerate(tools): |
| 386 | + with st.expander(t["function"]["name"]): |
| 387 | + |
| 388 | + def on_change_tool(): |
| 389 | + st.session_state.playground_state["editable_call"]["inputs"][ |
| 390 | + "tools" |
| 391 | + ][tool_idx] = json.loads(st.session_state[f"tool-{tool_idx}"]) |
| 392 | + st.rerun() |
| 393 | + |
| 394 | + st.text_area( |
| 395 | + "json", |
| 396 | + value=json.dumps(t, indent=2), |
| 397 | + height=300, |
| 398 | + key=f"tool-{tool_idx}", |
| 399 | + on_change=on_change_tool, |
| 400 | + ) |
| 401 | + |
| 402 | + def on_change_parallel_tool_calls(): |
| 403 | + st.session_state.playground_state["editable_call"]["inputs"][ |
| 404 | + "parallel_tool_calls" |
| 405 | + ] = st.session_state["parallel_tool_calls"] |
| 406 | + |
| 407 | + st.checkbox( |
| 408 | + "Parallel tool calls", |
| 409 | + value=editable_call["inputs"].get("parallel_tool_calls", True), |
| 410 | + key="parallel_tool_calls", |
| 411 | + on_change=on_change_parallel_tool_calls, |
194 | 412 | )
|
| 413 | + |
| 414 | + inputs = editable_call["inputs"] |
| 415 | + all_input_messages = inputs["messages"] |
| 416 | + other_inputs = { |
| 417 | + k: v |
| 418 | + for k, v in inputs.items() |
| 419 | + if (k != "messages" and k != "self" and k != "stream") |
195 | 420 | }
|
196 |
| - sel_message = st.radio("Session", options=message_ids.keys()) |
197 |
| - sel_id = None |
198 |
| - if sel_message: |
199 |
| - sel_id = message_ids.get(sel_message) |
200 |
| - |
201 |
| -if sel_id: |
202 |
| - st.header(f"Session: {sel_id}") |
203 |
| - print_session_call(sel_id) |
| 421 | + |
| 422 | + tool_call_attached_messages = attach_tool_call_responses(all_input_messages) |
| 423 | + for i, m in enumerate(tool_call_attached_messages): |
| 424 | + write_chat_message(m, f"message-{i}") |
| 425 | + # output = editable_call["output"]["choices"][0]["message"] |
| 426 | + n_choices = st.number_input( |
| 427 | + "Number of choices", value=1, min_value=1, max_value=100 |
| 428 | + ) |
| 429 | + if st.button("Generate"): |
| 430 | + chat_inputs = {**editable_call["inputs"]} |
| 431 | + # st.json(chat_inputs, expanded=False) |
| 432 | + del chat_inputs["stream"] |
| 433 | + del chat_inputs["self"] |
| 434 | + chat_inputs["n"] = n_choices |
| 435 | + call_resp = openai.chat.completions.create(**chat_inputs).model_dump() |
| 436 | + |
| 437 | + editable_call["output"] = call_resp |
| 438 | + st.rerun() |
| 439 | + # st.json(response, expanded=False) |
| 440 | + # output = response["choices"][0]["message"] |
| 441 | + # st.json(output) |
| 442 | + response = editable_call["output"] |
| 443 | + st.write("full response") |
| 444 | + st.json(response, expanded=False) |
| 445 | + st.write("**system fingerprint**", response["system_fingerprint"]) |
| 446 | + st.write("**usage**", response["usage"]) |
| 447 | + for i, choice in enumerate(response["choices"]): |
| 448 | + output = choice["message"] |
| 449 | + st.write(f"Choice {i+1}") |
| 450 | + write_chat_message(output, f"output_message-{i}", readonly=True) |
| 451 | + |
| 452 | + # all_messages = [*all_input_messages, output] |
| 453 | + # st.json(st.session_state.playground_state, expanded=False) |
| 454 | + # st.json(all_messages, expanded=False) |
| 455 | + |
| 456 | + # st.write(expanded_call) |
| 457 | + |
| 458 | + |
| 459 | +playground_pg = st.Page(playground_page, title="Playground") |
| 460 | + |
| 461 | + |
| 462 | +pg = st.navigation([sessions_pg, playground_pg]) |
| 463 | +pg.run() |
0 commit comments