diff --git a/Cargo.lock b/Cargo.lock index 7768dac710feff..41f71c330c9312 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -458,12 +458,15 @@ dependencies = [ "assistant_tool", "collections", "command_palette_hooks", + "context_server", "editor", "feature_flags", "futures 0.3.31", "gpui", "language_model", "language_model_selector", + "log", + "project", "proto", "serde", "serde_json", diff --git a/crates/assistant2/Cargo.toml b/crates/assistant2/Cargo.toml index ca563b05c8d469..ff49801c46af4f 100644 --- a/crates/assistant2/Cargo.toml +++ b/crates/assistant2/Cargo.toml @@ -17,12 +17,15 @@ anyhow.workspace = true assistant_tool.workspace = true collections.workspace = true command_palette_hooks.workspace = true +context_server.workspace = true editor.workspace = true feature_flags.workspace = true futures.workspace = true gpui.workspace = true language_model.workspace = true language_model_selector.workspace = true +log.workspace = true +project.workspace = true proto.workspace = true settings.workspace = true serde.workspace = true diff --git a/crates/assistant2/src/assistant.rs b/crates/assistant2/src/assistant.rs index 1b33e27928a609..8ef4a1d9dcf057 100644 --- a/crates/assistant2/src/assistant.rs +++ b/crates/assistant2/src/assistant.rs @@ -1,6 +1,7 @@ mod assistant_panel; mod message_editor; mod thread; +mod thread_store; use command_palette_hooks::CommandPaletteFilter; use feature_flags::{Assistant2FeatureFlag, FeatureFlagAppExt}; diff --git a/crates/assistant2/src/assistant_panel.rs b/crates/assistant2/src/assistant_panel.rs index bf457d6c71826e..7d8405dc78722c 100644 --- a/crates/assistant2/src/assistant_panel.rs +++ b/crates/assistant2/src/assistant_panel.rs @@ -14,6 +14,7 @@ use workspace::Workspace; use crate::message_editor::MessageEditor; use crate::thread::{Message, Thread, ThreadEvent}; +use crate::thread_store::ThreadStore; use crate::{NewThread, ToggleFocus, ToggleModelSelector}; pub fn init(cx: &mut AppContext) { @@ -29,6 +30,8 @@ pub fn init(cx: &mut AppContext) { pub struct AssistantPanel { workspace: WeakView, + #[allow(unused)] + thread_store: Model, thread: Model, message_editor: View, tools: Arc, @@ -42,13 +45,25 @@ impl AssistantPanel { ) -> Task>> { cx.spawn(|mut cx| async move { let tools = Arc::new(ToolWorkingSet::default()); + let thread_store = workspace + .update(&mut cx, |workspace, cx| { + let project = workspace.project().clone(); + ThreadStore::new(project, tools.clone(), cx) + })? + .await?; + workspace.update(&mut cx, |workspace, cx| { - cx.new_view(|cx| Self::new(workspace, tools, cx)) + cx.new_view(|cx| Self::new(workspace, thread_store, tools, cx)) }) }) } - fn new(workspace: &Workspace, tools: Arc, cx: &mut ViewContext) -> Self { + fn new( + workspace: &Workspace, + thread_store: Model, + tools: Arc, + cx: &mut ViewContext, + ) -> Self { let thread = cx.new_model(|cx| Thread::new(tools.clone(), cx)); let subscriptions = vec![ cx.observe(&thread, |_, _, cx| cx.notify()), @@ -57,6 +72,7 @@ impl AssistantPanel { Self { workspace: workspace.weak_handle(), + thread_store, thread: thread.clone(), message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)), tools, diff --git a/crates/assistant2/src/thread_store.rs b/crates/assistant2/src/thread_store.rs new file mode 100644 index 00000000000000..99f90eace8304e --- /dev/null +++ b/crates/assistant2/src/thread_store.rs @@ -0,0 +1,114 @@ +use std::sync::Arc; + +use anyhow::Result; +use assistant_tool::{ToolId, ToolWorkingSet}; +use collections::HashMap; +use context_server::manager::ContextServerManager; +use context_server::{ContextServerFactoryRegistry, ContextServerTool}; +use gpui::{prelude::*, AppContext, Model, ModelContext, Task}; +use project::Project; +use util::ResultExt as _; + +pub struct ThreadStore { + #[allow(unused)] + project: Model, + tools: Arc, + context_server_manager: Model, + context_server_tool_ids: HashMap, Vec>, +} + +impl ThreadStore { + pub fn new( + project: Model, + tools: Arc, + cx: &mut AppContext, + ) -> Task>> { + cx.spawn(|mut cx| async move { + let this = cx.new_model(|cx: &mut ModelContext| { + let context_server_factory_registry = + ContextServerFactoryRegistry::default_global(cx); + let context_server_manager = cx.new_model(|cx| { + ContextServerManager::new(context_server_factory_registry, project.clone(), cx) + }); + + let this = Self { + project, + tools, + context_server_manager, + context_server_tool_ids: HashMap::default(), + }; + this.register_context_server_handlers(cx); + + this + })?; + + Ok(this) + }) + } + + fn register_context_server_handlers(&self, cx: &mut ModelContext) { + cx.subscribe( + &self.context_server_manager.clone(), + Self::handle_context_server_event, + ) + .detach(); + } + + fn handle_context_server_event( + &mut self, + context_server_manager: Model, + event: &context_server::manager::Event, + cx: &mut ModelContext, + ) { + let tool_working_set = self.tools.clone(); + match event { + context_server::manager::Event::ServerStarted { server_id } => { + if let Some(server) = context_server_manager.read(cx).get_server(server_id) { + let context_server_manager = context_server_manager.clone(); + cx.spawn({ + let server = server.clone(); + let server_id = server_id.clone(); + |this, mut cx| async move { + let Some(protocol) = server.client() else { + return; + }; + + if protocol.capable(context_server::protocol::ServerCapability::Tools) { + if let Some(tools) = protocol.list_tools().await.log_err() { + let tool_ids = tools + .tools + .into_iter() + .map(|tool| { + log::info!( + "registering context server tool: {:?}", + tool.name + ); + tool_working_set.insert(Arc::new( + ContextServerTool::new( + context_server_manager.clone(), + server.id(), + tool, + ), + )) + }) + .collect::>(); + + this.update(&mut cx, |this, _cx| { + this.context_server_tool_ids.insert(server_id, tool_ids); + }) + .log_err(); + } + } + } + }) + .detach(); + } + } + context_server::manager::Event::ServerStopped { server_id } => { + if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) { + tool_working_set.remove(&tool_ids); + } + } + } + } +}