diff --git a/Cargo.lock b/Cargo.lock index 9004719ae87873..25b1b33cb2af53 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -450,6 +450,7 @@ version = "0.1.0" dependencies = [ "anyhow", "assistant_context_editor", + "assistant_scripting", "assistant_settings", "assistant_slash_command", "assistant_tool", @@ -563,6 +564,25 @@ dependencies = [ "workspace", ] +[[package]] +name = "assistant_scripting" +version = "0.1.0" +dependencies = [ + "anyhow", + "collections", + "futures 0.3.31", + "gpui", + "mlua", + "parking_lot", + "project", + "rand 0.8.5", + "regex", + "serde", + "serde_json", + "settings", + "util", +] + [[package]] name = "assistant_settings" version = "0.1.0" @@ -11910,26 +11930,6 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3cf7c11c38cb994f3d40e8a8cde3bbd1f72a435e4c49e85d6553d8312306152" -[[package]] -name = "scripting_tool" -version = "0.1.0" -dependencies = [ - "anyhow", - "assistant_tool", - "collections", - "futures 0.3.31", - "gpui", - "mlua", - "parking_lot", - "project", - "regex", - "schemars", - "serde", - "serde_json", - "settings", - "util", -] - [[package]] name = "scrypt" version = "0.11.0" @@ -16984,7 +16984,6 @@ dependencies = [ "repl", "reqwest_client", "rope", - "scripting_tool", "search", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 8c18627262cde3..7f2824c038a216 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -118,7 +118,7 @@ members = [ "crates/rope", "crates/rpc", "crates/schema_generator", - "crates/scripting_tool", + "crates/assistant_scripting", "crates/search", "crates/semantic_index", "crates/semantic_version", @@ -318,7 +318,7 @@ reqwest_client = { path = "crates/reqwest_client" } rich_text = { path = "crates/rich_text" } rope = { path = "crates/rope" } rpc = { path = "crates/rpc" } -scripting_tool = { path = "crates/scripting_tool" } +assistant_scripting = { path = "crates/assistant_scripting" } search = { path = "crates/search" } semantic_index = { path = "crates/semantic_index" } semantic_version = { path = "crates/semantic_version" } diff --git a/crates/assistant2/Cargo.toml b/crates/assistant2/Cargo.toml index b59b2f2e7b5043..b33bd37de21211 100644 --- a/crates/assistant2/Cargo.toml +++ b/crates/assistant2/Cargo.toml @@ -63,6 +63,7 @@ serde.workspace = true serde_json.workspace = true settings.workspace = true smol.workspace = true +assistant_scripting.workspace = true streaming_diff.workspace = true telemetry_events.workspace = true terminal.workspace = true diff --git a/crates/assistant2/src/active_thread.rs b/crates/assistant2/src/active_thread.rs index f50689ec4cd5ac..8f3677b6d0faf6 100644 --- a/crates/assistant2/src/active_thread.rs +++ b/crates/assistant2/src/active_thread.rs @@ -1,11 +1,12 @@ use std::sync::Arc; -use collections::HashMap; +use assistant_scripting::{ScriptId, ScriptState}; +use collections::{HashMap, HashSet}; use editor::{Editor, MultiBuffer}; use gpui::{ list, AbsoluteLength, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty, Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription, - Task, TextStyleRefinement, UnderlineStyle, + Task, TextStyleRefinement, UnderlineStyle, WeakEntity, }; use language::{Buffer, LanguageRegistry}; use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role}; @@ -14,6 +15,7 @@ use settings::Settings as _; use theme::ThemeSettings; use ui::{prelude::*, Disclosure, KeyBinding}; use util::ResultExt as _; +use workspace::Workspace; use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent}; use crate::thread_store::ThreadStore; @@ -21,6 +23,7 @@ use crate::tool_use::{ToolUse, ToolUseStatus}; use crate::ui::ContextPill; pub struct ActiveThread { + workspace: WeakEntity, language_registry: Arc, thread_store: Entity, thread: Entity, @@ -30,6 +33,7 @@ pub struct ActiveThread { rendered_messages_by_id: HashMap>, editing_message: Option<(MessageId, EditMessageState)>, expanded_tool_uses: HashMap, + expanded_scripts: HashSet, last_error: Option, _subscriptions: Vec, } @@ -40,6 +44,7 @@ struct EditMessageState { impl ActiveThread { pub fn new( + workspace: WeakEntity, thread: Entity, thread_store: Entity, language_registry: Arc, @@ -52,6 +57,7 @@ impl ActiveThread { ]; let mut this = Self { + workspace, language_registry, thread_store, thread: thread.clone(), @@ -59,6 +65,7 @@ impl ActiveThread { messages: Vec::new(), rendered_messages_by_id: HashMap::default(), expanded_tool_uses: HashMap::default(), + expanded_scripts: HashSet::default(), list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), { let this = cx.entity().downgrade(); move |ix, window: &mut Window, cx: &mut App| { @@ -241,7 +248,7 @@ impl ActiveThread { fn handle_thread_event( &mut self, - _: &Entity, + _thread: &Entity, event: &ThreadEvent, window: &mut Window, cx: &mut Context, @@ -306,6 +313,14 @@ impl ActiveThread { } } } + ThreadEvent::ScriptFinished => { + let model_registry = LanguageModelRegistry::read_global(cx); + if let Some(model) = model_registry.active_model() { + self.thread.update(cx, |thread, cx| { + thread.send_to_model(model, RequestKind::Chat, false, cx); + }); + } + } } } @@ -445,12 +460,16 @@ impl ActiveThread { return Empty.into_any(); }; - let context = self.thread.read(cx).context_for_message(message_id); - let tool_uses = self.thread.read(cx).tool_uses_for_message(message_id); - let colors = cx.theme().colors(); + let thread = self.thread.read(cx); + + let context = thread.context_for_message(message_id); + let tool_uses = thread.tool_uses_for_message(message_id); // Don't render user messages that are just there for returning tool results. - if message.role == Role::User && self.thread.read(cx).message_has_tool_results(message_id) { + if message.role == Role::User + && (thread.message_has_tool_results(message_id) + || thread.message_has_script_output(message_id)) + { return Empty.into_any(); } @@ -463,6 +482,8 @@ impl ActiveThread { .filter(|(id, _)| *id == message_id) .map(|(_, state)| state.editor.clone()); + let colors = cx.theme().colors(); + let message_content = v_flex() .child( if let Some(edit_message_editor) = edit_message_editor.clone() { @@ -597,6 +618,7 @@ impl ActiveThread { Role::Assistant => div() .id(("message-container", ix)) .child(message_content) + .children(self.render_script(message_id, cx)) .map(|parent| { if tool_uses.is_empty() { return parent; @@ -716,6 +738,139 @@ impl ActiveThread { }), ) } + + fn render_script(&self, message_id: MessageId, cx: &mut Context) -> Option { + let script = self.thread.read(cx).script_for_message(message_id, cx)?; + + let is_open = self.expanded_scripts.contains(&script.id); + let colors = cx.theme().colors(); + + let element = div().px_2p5().child( + v_flex() + .gap_1() + .rounded_lg() + .border_1() + .border_color(colors.border) + .child( + h_flex() + .justify_between() + .py_0p5() + .pl_1() + .pr_2() + .bg(colors.editor_foreground.opacity(0.02)) + .when(is_open, |element| element.border_b_1().rounded_t(px(6.))) + .when(!is_open, |element| element.rounded_md()) + .border_color(colors.border) + .child( + h_flex() + .gap_1() + .child(Disclosure::new("script-disclosure", is_open).on_click( + cx.listener({ + let script_id = script.id; + move |this, _event, _window, _cx| { + if this.expanded_scripts.contains(&script_id) { + this.expanded_scripts.remove(&script_id); + } else { + this.expanded_scripts.insert(script_id); + } + } + }), + )) + // TODO: Generate script description + .child(Label::new("Script")), + ) + .child( + h_flex() + .gap_1() + .child( + Label::new(match script.state { + ScriptState::Generating => "Generating", + ScriptState::Running { .. } => "Running", + ScriptState::Succeeded { .. } => "Finished", + ScriptState::Failed { .. } => "Error", + }) + .size(LabelSize::XSmall) + .buffer_font(cx), + ) + .child( + IconButton::new("view-source", IconName::Eye) + .icon_color(Color::Muted) + .disabled(matches!(script.state, ScriptState::Generating)) + .on_click(cx.listener({ + let source = script.source.clone(); + move |this, _event, window, cx| { + this.open_script_source(source.clone(), window, cx); + } + })), + ), + ), + ) + .when(is_open, |parent| { + let stdout = script.stdout_snapshot(); + let error = script.error(); + + parent.child( + v_flex() + .p_2() + .bg(colors.editor_background) + .gap_2() + .child(if stdout.is_empty() && error.is_none() { + Label::new("No output yet") + .size(LabelSize::Small) + .color(Color::Muted) + } else { + Label::new(stdout).size(LabelSize::Small).buffer_font(cx) + }) + .children(script.error().map(|err| { + Label::new(err.to_string()) + .size(LabelSize::Small) + .color(Color::Error) + })), + ) + }), + ); + + Some(element.into_any()) + } + + fn open_script_source( + &mut self, + source: SharedString, + window: &mut Window, + cx: &mut Context<'_, ActiveThread>, + ) { + let language_registry = self.language_registry.clone(); + let workspace = self.workspace.clone(); + let source = source.clone(); + + cx.spawn_in(window, |_, mut cx| async move { + let lua = language_registry.language_for_name("Lua").await.log_err(); + + workspace.update_in(&mut cx, |workspace, window, cx| { + let project = workspace.project().clone(); + + let buffer = project.update(cx, |project, cx| { + project.create_local_buffer(&source.trim(), lua, cx) + }); + + let buffer = cx.new(|cx| { + MultiBuffer::singleton(buffer, cx) + // TODO: Generate script description + .with_title("Assistant script".into()) + }); + + let editor = cx.new(|cx| { + let mut editor = + Editor::for_multibuffer(buffer, Some(project), true, window, cx); + editor.set_read_only(true); + editor + }); + + workspace.add_item_to_active_pane(Box::new(editor), None, true, window, cx); + }) + }) + .detach_and_log_err(cx); + } } impl Render for ActiveThread { diff --git a/crates/assistant2/src/assistant_panel.rs b/crates/assistant2/src/assistant_panel.rs index 301466ae5f8363..123ae8f45943b7 100644 --- a/crates/assistant2/src/assistant_panel.rs +++ b/crates/assistant2/src/assistant_panel.rs @@ -166,22 +166,25 @@ impl AssistantPanel { let history_store = cx.new(|cx| HistoryStore::new(thread_store.clone(), context_store.clone(), cx)); + let thread = cx.new(|cx| { + ActiveThread::new( + workspace.clone(), + thread.clone(), + thread_store.clone(), + language_registry.clone(), + window, + cx, + ) + }); + Self { active_view: ActiveView::Thread, workspace, project: project.clone(), fs: fs.clone(), - language_registry: language_registry.clone(), + language_registry, thread_store: thread_store.clone(), - thread: cx.new(|cx| { - ActiveThread::new( - thread.clone(), - thread_store.clone(), - language_registry, - window, - cx, - ) - }), + thread, message_editor, context_store, context_editor: None, @@ -239,6 +242,7 @@ impl AssistantPanel { self.active_view = ActiveView::Thread; self.thread = cx.new(|cx| { ActiveThread::new( + self.workspace.clone(), thread.clone(), self.thread_store.clone(), self.language_registry.clone(), @@ -372,6 +376,7 @@ impl AssistantPanel { this.active_view = ActiveView::Thread; this.thread = cx.new(|cx| { ActiveThread::new( + this.workspace.clone(), thread.clone(), this.thread_store.clone(), this.language_registry.clone(), diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index 95f42f75453d55..4dc194e0e93fe4 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -1,11 +1,14 @@ use std::sync::Arc; use anyhow::Result; +use assistant_scripting::{ + Script, ScriptEvent, ScriptId, ScriptSession, ScriptTagParser, SCRIPTING_PROMPT, +}; use assistant_tool::ToolWorkingSet; use chrono::{DateTime, Utc}; use collections::{BTreeMap, HashMap, HashSet}; use futures::StreamExt as _; -use gpui::{App, Context, Entity, EventEmitter, SharedString, Task}; +use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Subscription, Task}; use language_model::{ LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, @@ -75,14 +78,21 @@ pub struct Thread { project: Entity, tools: Arc, tool_use: ToolUseState, + scripts_by_assistant_message: HashMap, + script_output_messages: HashSet, + script_session: Entity, + _script_session_subscription: Subscription, } impl Thread { pub fn new( project: Entity, tools: Arc, - _cx: &mut Context, + cx: &mut Context, ) -> Self { + let script_session = cx.new(|cx| ScriptSession::new(project.clone(), cx)); + let script_session_subscription = cx.subscribe(&script_session, Self::handle_script_event); + Self { id: ThreadId::new(), updated_at: Utc::now(), @@ -97,6 +107,10 @@ impl Thread { project, tools, tool_use: ToolUseState::new(), + scripts_by_assistant_message: HashMap::default(), + script_output_messages: HashSet::default(), + script_session, + _script_session_subscription: script_session_subscription, } } @@ -105,7 +119,7 @@ impl Thread { saved: SavedThread, project: Entity, tools: Arc, - _cx: &mut Context, + cx: &mut Context, ) -> Self { let next_message_id = MessageId( saved @@ -115,6 +129,8 @@ impl Thread { .unwrap_or(0), ); let tool_use = ToolUseState::from_saved_messages(&saved.messages); + let script_session = cx.new(|cx| ScriptSession::new(project.clone(), cx)); + let script_session_subscription = cx.subscribe(&script_session, Self::handle_script_event); Self { id, @@ -138,6 +154,10 @@ impl Thread { project, tools, tool_use, + scripts_by_assistant_message: HashMap::default(), + script_output_messages: HashSet::default(), + script_session, + _script_session_subscription: script_session_subscription, } } @@ -223,17 +243,22 @@ impl Thread { self.tool_use.message_has_tool_results(message_id) } + pub fn message_has_script_output(&self, message_id: MessageId) -> bool { + self.script_output_messages.contains(&message_id) + } + pub fn insert_user_message( &mut self, text: impl Into, context: Vec, cx: &mut Context, - ) { + ) -> MessageId { let message_id = self.insert_message(Role::User, text, cx); let context_ids = context.iter().map(|context| context.id).collect::>(); self.context .extend(context.into_iter().map(|context| (context.id, context))); self.context_by_message.insert(message_id, context_ids); + message_id } pub fn insert_message( @@ -302,6 +327,39 @@ impl Thread { text } + pub fn script_for_message<'a>( + &'a self, + message_id: MessageId, + cx: &'a App, + ) -> Option<&'a Script> { + self.scripts_by_assistant_message + .get(&message_id) + .map(|script_id| self.script_session.read(cx).get(*script_id)) + } + + fn handle_script_event( + &mut self, + _script_session: Entity, + event: &ScriptEvent, + cx: &mut Context, + ) { + match event { + ScriptEvent::Spawned(_) => {} + ScriptEvent::Exited(script_id) => { + if let Some(output_message) = self + .script_session + .read(cx) + .get(*script_id) + .output_message_for_llm() + { + let message_id = self.insert_user_message(output_message, vec![], cx); + self.script_output_messages.insert(message_id); + cx.emit(ThreadEvent::ScriptFinished) + } + } + } + } + pub fn send_to_model( &mut self, model: Arc, @@ -330,7 +388,7 @@ impl Thread { pub fn to_completion_request( &self, request_kind: RequestKind, - _cx: &App, + cx: &App, ) -> LanguageModelRequest { let mut request = LanguageModelRequest { messages: vec![], @@ -339,6 +397,12 @@ impl Thread { temperature: None, }; + request.messages.push(LanguageModelRequestMessage { + role: Role::System, + content: vec![SCRIPTING_PROMPT.to_string().into()], + cache: true, + }); + let mut referenced_context_ids = HashSet::default(); for message in &self.messages { @@ -351,6 +415,7 @@ impl Thread { content: Vec::new(), cache: false, }; + match request_kind { RequestKind::Chat => { self.tool_use @@ -371,11 +436,20 @@ impl Thread { RequestKind::Chat => { self.tool_use .attach_tool_uses(message.id, &mut request_message); + + if matches!(message.role, Role::Assistant) { + if let Some(script_id) = self.scripts_by_assistant_message.get(&message.id) + { + let script = self.script_session.read(cx).get(*script_id); + + request_message.content.push(script.source_tag().into()); + } + } } RequestKind::Summarize => { // We don't care about tool use during summarization. } - } + }; request.messages.push(request_message); } @@ -412,6 +486,8 @@ impl Thread { let stream_completion = async { let mut events = stream.await?; let mut stop_reason = StopReason::EndTurn; + let mut script_tag_parser = ScriptTagParser::new(); + let mut script_id = None; while let Some(event) = events.next().await { let event = event?; @@ -426,19 +502,43 @@ impl Thread { } LanguageModelCompletionEvent::Text(chunk) => { if let Some(last_message) = thread.messages.last_mut() { - if last_message.role == Role::Assistant { - last_message.text.push_str(&chunk); + let chunk = script_tag_parser.parse_chunk(&chunk); + + let message_id = if last_message.role == Role::Assistant { + last_message.text.push_str(&chunk.content); cx.emit(ThreadEvent::StreamedAssistantText( last_message.id, - chunk, + chunk.content, )); + last_message.id } else { // If we won't have an Assistant message yet, assume this chunk marks the beginning // of a new Assistant response. // // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it // will result in duplicating the text of the chunk in the rendered Markdown. - thread.insert_message(Role::Assistant, chunk, cx); + thread.insert_message(Role::Assistant, chunk.content, cx) + }; + + if script_id.is_none() && script_tag_parser.found_script() { + let id = thread + .script_session + .update(cx, |session, _cx| session.new_script()); + thread.scripts_by_assistant_message.insert(message_id, id); + + script_id = Some(id); + } + + if let (Some(script_source), Some(script_id)) = + (chunk.script_source, script_id) + { + // TODO: move buffer to script and run as it streams + thread + .script_session + .update(cx, |this, cx| { + this.run_script(script_id, script_source, cx) + }) + .detach_and_log_err(cx); } } } @@ -661,6 +761,7 @@ pub enum ThreadEvent { #[allow(unused)] tool_use_id: LanguageModelToolUseId, }, + ScriptFinished, } impl EventEmitter for Thread {} diff --git a/crates/scripting_tool/Cargo.toml b/crates/assistant_scripting/Cargo.toml similarity index 86% rename from crates/scripting_tool/Cargo.toml rename to crates/assistant_scripting/Cargo.toml index ab80d96fe05509..d1e06b0506c06f 100644 --- a/crates/scripting_tool/Cargo.toml +++ b/crates/assistant_scripting/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "scripting_tool" +name = "assistant_scripting" version = "0.1.0" edition.workspace = true publish.workspace = true @@ -9,12 +9,11 @@ license = "GPL-3.0-or-later" workspace = true [lib] -path = "src/scripting_tool.rs" +path = "src/assistant_scripting.rs" doctest = false [dependencies] anyhow.workspace = true -assistant_tool.workspace = true collections.workspace = true futures.workspace = true gpui.workspace = true @@ -22,7 +21,6 @@ mlua.workspace = true parking_lot.workspace = true project.workspace = true regex.workspace = true -schemars.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true @@ -32,4 +30,5 @@ util.workspace = true collections = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } project = { workspace = true, features = ["test-support"] } +rand.workspace = true settings = { workspace = true, features = ["test-support"] } diff --git a/crates/scripting_tool/LICENSE-GPL b/crates/assistant_scripting/LICENSE-GPL similarity index 100% rename from crates/scripting_tool/LICENSE-GPL rename to crates/assistant_scripting/LICENSE-GPL diff --git a/crates/assistant_scripting/src/assistant_scripting.rs b/crates/assistant_scripting/src/assistant_scripting.rs new file mode 100644 index 00000000000000..fbe335d7fbd2fb --- /dev/null +++ b/crates/assistant_scripting/src/assistant_scripting.rs @@ -0,0 +1,7 @@ +mod session; +mod tag; + +pub use session::*; +pub use tag::*; + +pub const SCRIPTING_PROMPT: &str = include_str!("./system_prompt.txt"); diff --git a/crates/scripting_tool/src/sandbox_preamble.lua b/crates/assistant_scripting/src/sandbox_preamble.lua similarity index 100% rename from crates/scripting_tool/src/sandbox_preamble.lua rename to crates/assistant_scripting/src/sandbox_preamble.lua diff --git a/crates/scripting_tool/src/session.rs b/crates/assistant_scripting/src/session.rs similarity index 80% rename from crates/scripting_tool/src/session.rs rename to crates/assistant_scripting/src/session.rs index 36bd395fd2bdc4..59769b5cfb2918 100644 --- a/crates/scripting_tool/src/session.rs +++ b/crates/assistant_scripting/src/session.rs @@ -1,10 +1,9 @@ -use anyhow::Result; use collections::{HashMap, HashSet}; use futures::{ channel::{mpsc, oneshot}, pin_mut, SinkExt, StreamExt, }; -use gpui::{AppContext, AsyncApp, Context, Entity, Task, WeakEntity}; +use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity}; use mlua::{Lua, MultiValue, Table, UserData, UserDataMethods}; use parking_lot::Mutex; use project::{search::SearchQuery, Fs, Project}; @@ -16,24 +15,23 @@ use std::{ }; use util::{paths::PathMatcher, ResultExt}; -pub struct ScriptOutput { - pub stdout: String, -} +use crate::{SCRIPT_END_TAG, SCRIPT_START_TAG}; -struct ForegroundFn(Box, AsyncApp) + Send>); +struct ForegroundFn(Box, AsyncApp) + Send>); -pub struct Session { +pub struct ScriptSession { project: Entity, // TODO Remove this fs_changes: Arc>>>, foreground_fns_tx: mpsc::Sender, _invoke_foreground_fns: Task<()>, + scripts: Vec