Skip to content

Commit 9715046

Browse files
authored
Text editor + Playground UI (#21)
* Working on text editor * Fix failing tests * Move ToolContext -> io_context.py * rename ToolContext -> IOContext * Text editor agent * Refactor to encapsulate TextEditorState * Separate tools out of AgentTextEditor * AgentTextEditor eval fixes/improvements - system message at end - fix: include all prior messages when returning response - return error if replace_lines buffer is not open, instead of trying to open it - disable file closing for now * Configurable project in programmer ui * Parallel file range writing * Start of playground for "programmer ui" * Better editor state prompt, show message before and after * Fix tests * Stable text editing! sandwich message, temp0, fix docstrings, 1-index * Major editor eval improvements + playground UI * Lint
1 parent f47cff1 commit 9715046

17 files changed

+1533
-261
lines changed

programmer-ui/ui.py

+286-26
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,17 @@
66
import streamlit as st
77
import weave
88
import os
9+
import openai
10+
import copy
911
from weave.trace.weave_client import WeaveClient
1012

1113
from programmer.weave_next.api import init_local_client
12-
from programmer.weave_next.weave_query import calls, expand_refs
14+
from programmer.weave_next.weave_query import (
15+
calls,
16+
expand_refs,
17+
get_call,
18+
expand_json_refs,
19+
)
1320
from programmer.settings_manager import SettingsManager
1421

1522
st.set_page_config(layout="wide")
@@ -47,7 +54,46 @@ def init_from_settings() -> WeaveClient:
4754
raise ValueError(f"Invalid weave_logging setting: {weave_logging_setting}")
4855

4956

50-
client = init_from_settings()
57+
# Add sidebar for Weave project configuration
58+
with st.sidebar:
59+
st.header("Weave Project Configuration")
60+
61+
# Initialize from settings
62+
initial_weave_logging = SettingsManager.get_setting("weave_logging")
63+
initial_project_type = "local" if initial_weave_logging == "local" else "cloud"
64+
initial_project_path = (
65+
os.path.join(SettingsManager.PROGRAMMER_DIR, "weave.db")
66+
if initial_weave_logging == "local"
67+
else ""
68+
)
69+
initial_project_name = (
70+
f"programmer-{os.path.basename(os.path.abspath(os.curdir))}"
71+
if initial_weave_logging == "cloud"
72+
else ""
73+
)
74+
75+
project_type = st.radio(
76+
"Project Type",
77+
["local", "cloud"],
78+
index=0 if initial_project_type == "local" else 1,
79+
)
80+
81+
if project_type == "local":
82+
project_path = st.text_input("Local DB Path", value=initial_project_path)
83+
# SettingsManager.set_setting("weave_logging", "local")
84+
# SettingsManager.set_setting("weave_db_path", project_path)
85+
client = init_local_weave(project_path)
86+
print("C2", client._project_id())
87+
else:
88+
# SettingsManager.set_setting("weave_logging", "cloud")
89+
# SettingsManager.set_setting("weave_project_name", project_name)
90+
project_name = st.text_input("Cloud Project Name", value=initial_project_name)
91+
client = init_remote_weave(project_name)
92+
print("C3", client._project_id())
93+
94+
# Initialize client based on current settings
95+
# client = init_from_settings()
96+
print("CLIENT", client._project_id())
5197

5298

5399
def set_focus_step_id(call_id):
@@ -76,6 +122,16 @@ def cached_expand_refs(wc: WeaveClient, refs: Sequence[str]):
76122
return expand_refs(wc, refs).to_pandas()
77123

78124

125+
@st.cache_data(hash_funcs=ST_HASH_FUNCS)
126+
def cached_get_call(wc: WeaveClient, call_id: str):
127+
return get_call(wc, call_id)
128+
129+
130+
@st.cache_data(hash_funcs=ST_HASH_FUNCS)
131+
def cached_expand_json_refs(wc: WeaveClient, json: dict):
132+
return expand_json_refs(wc, json)
133+
134+
79135
def print_step_call(call):
80136
start_history = call["inputs.state.history"]
81137
end_history = call["output.history"]
@@ -174,30 +230,234 @@ def print_session_call(session_id):
174230
)
175231

176232

177-
session_calls_df = cached_calls(client, "session", expand_refs=["inputs.agent_state"])
178-
if len(session_calls_df) == 0:
179-
st.error("No programmer sessions found.")
180-
st.stop()
181-
session_user_message_df = session_calls_df["inputs.agent_state.history"].apply(
182-
lambda v: v[-1]["content"]
183-
)
184-
233+
def sessions_page():
234+
session_calls_df = cached_calls(
235+
client, "session", expand_refs=["inputs.agent_state"]
236+
)
237+
if len(session_calls_df) == 0:
238+
st.error("No programmer sessions found.")
239+
st.stop()
240+
session_user_message_df = session_calls_df["inputs.agent_state.history"].apply(
241+
lambda v: v[-1]["content"]
242+
)
243+
with st.sidebar:
244+
st.header("Session Selection")
245+
if st.button("Refresh"):
246+
st.cache_data.clear()
247+
st.rerun()
248+
message_ids = {
249+
f"{cid[-5:]}: {m}": cid
250+
for cid, m in reversed(
251+
list(zip(session_calls_df["id"], session_user_message_df))
252+
)
253+
}
254+
sel_message = st.radio("Session", options=message_ids.keys())
255+
sel_id = None
256+
if sel_message:
257+
sel_id = message_ids.get(sel_message)
258+
if sel_id:
259+
st.header(f"Session: {sel_id}")
260+
print_session_call(sel_id)
261+
262+
263+
sessions_pg = st.Page(sessions_page, title="Sessions")
264+
265+
266+
# def write_chat_message(m, key):
267+
# with st.chat_message(m["role"]):
268+
# if "content" in m:
269+
# st.text_area(
270+
# "", value=str(m["content"]), label_visibility="collapsed", key=key
271+
# )
272+
def write_chat_message(m, key, readonly=False):
273+
def on_change_content():
274+
new_value = st.session_state[key]
275+
st.session_state.playground_state["editable_call"]["inputs"]["messages"][
276+
m["original_index"]
277+
]["content"] = new_value
278+
279+
with st.chat_message(m["role"]):
280+
if m.get("content"):
281+
if readonly:
282+
st.code(m["content"])
283+
else:
284+
st.text_area(
285+
"",
286+
value=m["content"],
287+
label_visibility="collapsed",
288+
key=key,
289+
on_change=on_change_content,
290+
)
291+
if m.get("tool_calls"):
292+
for t in m["tool_calls"]:
293+
st.write(t["function"]["name"])
294+
st.json(
295+
{
296+
"arguments": t["function"]["arguments"],
297+
"response": t.get("response", {}).get("content"),
298+
},
299+
expanded=True,
300+
)
301+
302+
303+
def attach_tool_call_responses(messages):
304+
new_messages = []
305+
for i, m in enumerate(messages):
306+
new_m = copy.deepcopy(m)
307+
new_m["original_index"] = i
308+
if new_m["role"] == "assistant" and "tool_calls" in new_m:
309+
new_m["tool_call_responses"] = []
310+
for t in new_m["tool_calls"]:
311+
t_id = t["id"]
312+
for j, t_response in enumerate(messages):
313+
if t_response.get("tool_call_id") == t_id:
314+
t["response"] = t_response
315+
t["response"]["original_index"] = j
316+
break
317+
if "tool_call_id" not in new_m:
318+
new_messages.append(new_m)
319+
return new_messages
320+
321+
322+
def playground_page():
323+
with st.sidebar:
324+
if not st.session_state.get("playground_state"):
325+
st.session_state.playground_state = {
326+
"call_id": None,
327+
"call": None,
328+
"expanded_call": None,
329+
"editable_call": None,
330+
}
331+
playground_state = st.session_state.playground_state
332+
call_id = st.text_input("Call ID")
333+
if not call_id:
334+
st.error("Please set call ID")
335+
st.stop()
336+
337+
# st.write(playground_state)
338+
if playground_state["expanded_call"] != playground_state["editable_call"]:
339+
st.warning("Call has been modified")
340+
if st.button("Restore original call"):
341+
st.session_state.playground_state["editable_call"] = copy.deepcopy(
342+
playground_state["expanded_call"]
343+
)
344+
st.rerun()
345+
346+
if call_id != st.session_state.playground_state["call_id"]:
347+
st.spinner("Loading call...")
348+
call = cached_get_call(client, call_id)
349+
editable_call = cached_expand_json_refs(client, call)
350+
st.session_state.playground_state = {
351+
"call_id": call_id,
352+
"call": call,
353+
"expanded_call": editable_call,
354+
"editable_call": copy.deepcopy(editable_call),
355+
}
356+
st.rerun()
357+
358+
call = st.session_state.playground_state["call"]
359+
editable_call = st.session_state.playground_state["editable_call"]
360+
if call is None or editable_call is None:
361+
st.warning("call not yet loaded")
362+
st.stop()
363+
364+
st.write(call["op_name"])
365+
# st.json(call["inputs"])
366+
# st.json(call["inputs"]["tools"])
367+
368+
def on_change_temperature():
369+
st.session_state.playground_state["editable_call"]["inputs"][
370+
"temperature"
371+
] = st.session_state["temperature"]
372+
373+
st.slider(
374+
"Temperature",
375+
min_value=0.0,
376+
max_value=1.0,
377+
value=editable_call["inputs"]["temperature"],
378+
key="temperature",
379+
on_change=on_change_temperature,
380+
)
185381

186-
with st.sidebar:
187-
if st.button("Refresh"):
188-
st.cache_data.clear()
189-
st.rerun()
190-
message_ids = {
191-
f"{cid[-5:]}: {m}": cid
192-
for cid, m in reversed(
193-
list(zip(session_calls_df["id"], session_user_message_df))
382+
tools = call["inputs"].get("tools", [])
383+
if tools:
384+
st.write("Tools")
385+
for tool_idx, t in enumerate(tools):
386+
with st.expander(t["function"]["name"]):
387+
388+
def on_change_tool():
389+
st.session_state.playground_state["editable_call"]["inputs"][
390+
"tools"
391+
][tool_idx] = json.loads(st.session_state[f"tool-{tool_idx}"])
392+
st.rerun()
393+
394+
st.text_area(
395+
"json",
396+
value=json.dumps(t, indent=2),
397+
height=300,
398+
key=f"tool-{tool_idx}",
399+
on_change=on_change_tool,
400+
)
401+
402+
def on_change_parallel_tool_calls():
403+
st.session_state.playground_state["editable_call"]["inputs"][
404+
"parallel_tool_calls"
405+
] = st.session_state["parallel_tool_calls"]
406+
407+
st.checkbox(
408+
"Parallel tool calls",
409+
value=editable_call["inputs"].get("parallel_tool_calls", True),
410+
key="parallel_tool_calls",
411+
on_change=on_change_parallel_tool_calls,
194412
)
413+
414+
inputs = editable_call["inputs"]
415+
all_input_messages = inputs["messages"]
416+
other_inputs = {
417+
k: v
418+
for k, v in inputs.items()
419+
if (k != "messages" and k != "self" and k != "stream")
195420
}
196-
sel_message = st.radio("Session", options=message_ids.keys())
197-
sel_id = None
198-
if sel_message:
199-
sel_id = message_ids.get(sel_message)
200-
201-
if sel_id:
202-
st.header(f"Session: {sel_id}")
203-
print_session_call(sel_id)
421+
422+
tool_call_attached_messages = attach_tool_call_responses(all_input_messages)
423+
for i, m in enumerate(tool_call_attached_messages):
424+
write_chat_message(m, f"message-{i}")
425+
# output = editable_call["output"]["choices"][0]["message"]
426+
n_choices = st.number_input(
427+
"Number of choices", value=1, min_value=1, max_value=100
428+
)
429+
if st.button("Generate"):
430+
chat_inputs = {**editable_call["inputs"]}
431+
# st.json(chat_inputs, expanded=False)
432+
del chat_inputs["stream"]
433+
del chat_inputs["self"]
434+
chat_inputs["n"] = n_choices
435+
call_resp = openai.chat.completions.create(**chat_inputs).model_dump()
436+
437+
editable_call["output"] = call_resp
438+
st.rerun()
439+
# st.json(response, expanded=False)
440+
# output = response["choices"][0]["message"]
441+
# st.json(output)
442+
response = editable_call["output"]
443+
st.write("full response")
444+
st.json(response, expanded=False)
445+
st.write("**system fingerprint**", response["system_fingerprint"])
446+
st.write("**usage**", response["usage"])
447+
for i, choice in enumerate(response["choices"]):
448+
output = choice["message"]
449+
st.write(f"Choice {i+1}")
450+
write_chat_message(output, f"output_message-{i}", readonly=True)
451+
452+
# all_messages = [*all_input_messages, output]
453+
# st.json(st.session_state.playground_state, expanded=False)
454+
# st.json(all_messages, expanded=False)
455+
456+
# st.write(expanded_call)
457+
458+
459+
playground_pg = st.Page(playground_page, title="Playground")
460+
461+
462+
pg = st.navigation([sessions_pg, playground_pg])
463+
pg.run()

programmer/agent.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ class AgentState(weave.Object):
3939
history: list[Any] = Field(default_factory=list)
4040
env_snapshot_key: Optional[EnvironmentSnapshotKey] = None
4141

42+
def with_history(self, history: list[Any]) -> "AgentState":
43+
environment = get_current_environment()
44+
msg = get_commit_message(history)
45+
snapshot_key = environment.make_snapshot(msg)
46+
return self.__class__(history=history, env_snapshot_key=snapshot_key)
47+
4248

4349
def unweavify(v: Any) -> Any:
4450
if isinstance(v, list):
@@ -55,6 +61,9 @@ class Agent(weave.Object):
5561
system_message: str
5662
tools: list[Any] = Field(default_factory=list)
5763

64+
def initial_state(self, history: list[Any]) -> AgentState:
65+
return AgentState().with_history(history)
66+
5867
@weave.op()
5968
def step(self, state: AgentState) -> AgentState:
6069
"""Run a step of the agent.
@@ -118,12 +127,8 @@ def step(self, state: AgentState) -> AgentState:
118127

119128
# new_history = state.history + new_messages
120129
new_history = weavelist_add(state.history, new_messages)
121-
msg = get_commit_message(new_history)
122-
123-
environment = get_current_environment()
124-
snapshot_key = environment.make_snapshot(msg)
125130

126-
return AgentState(history=new_history, env_snapshot_key=snapshot_key)
131+
return state.with_history(new_history)
127132

128133
@weave.op()
129134
def run(self, state: AgentState, max_runtime_seconds: int = -1):

0 commit comments

Comments
 (0)