Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

assistant2: Add support for using tools provided by context servers #21418

Merged
merged 2 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions crates/assistant2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@ anyhow.workspace = true
assistant_tool.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
context_server.workspace = true
editor.workspace = true
feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
language_model.workspace = true
language_model_selector.workspace = true
log.workspace = true
project.workspace = true
proto.workspace = true
settings.workspace = true
serde.workspace = true
Expand Down
1 change: 1 addition & 0 deletions crates/assistant2/src/assistant.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod assistant_panel;
mod message_editor;
mod thread;
mod thread_store;

use command_palette_hooks::CommandPaletteFilter;
use feature_flags::{Assistant2FeatureFlag, FeatureFlagAppExt};
Expand Down
20 changes: 18 additions & 2 deletions crates/assistant2/src/assistant_panel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use workspace::Workspace;

use crate::message_editor::MessageEditor;
use crate::thread::{Message, Thread, ThreadEvent};
use crate::thread_store::ThreadStore;
use crate::{NewThread, ToggleFocus, ToggleModelSelector};

pub fn init(cx: &mut AppContext) {
Expand All @@ -29,6 +30,8 @@ pub fn init(cx: &mut AppContext) {

pub struct AssistantPanel {
workspace: WeakView<Workspace>,
#[allow(unused)]
thread_store: Model<ThreadStore>,
thread: Model<Thread>,
message_editor: View<MessageEditor>,
tools: Arc<ToolWorkingSet>,
Expand All @@ -42,13 +45,25 @@ impl AssistantPanel {
) -> Task<Result<View<Self>>> {
cx.spawn(|mut cx| async move {
let tools = Arc::new(ToolWorkingSet::default());
let thread_store = workspace
.update(&mut cx, |workspace, cx| {
let project = workspace.project().clone();
ThreadStore::new(project, tools.clone(), cx)
})?
.await?;

workspace.update(&mut cx, |workspace, cx| {
cx.new_view(|cx| Self::new(workspace, tools, cx))
cx.new_view(|cx| Self::new(workspace, thread_store, tools, cx))
})
})
}

fn new(workspace: &Workspace, tools: Arc<ToolWorkingSet>, cx: &mut ViewContext<Self>) -> Self {
fn new(
workspace: &Workspace,
thread_store: Model<ThreadStore>,
tools: Arc<ToolWorkingSet>,
cx: &mut ViewContext<Self>,
) -> Self {
let thread = cx.new_model(|cx| Thread::new(tools.clone(), cx));
let subscriptions = vec![
cx.observe(&thread, |_, _, cx| cx.notify()),
Expand All @@ -57,6 +72,7 @@ impl AssistantPanel {

Self {
workspace: workspace.weak_handle(),
thread_store,
thread: thread.clone(),
message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)),
tools,
Expand Down
114 changes: 114 additions & 0 deletions crates/assistant2/src/thread_store.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
use std::sync::Arc;

use anyhow::Result;
use assistant_tool::{ToolId, ToolWorkingSet};
use collections::HashMap;
use context_server::manager::ContextServerManager;
use context_server::{ContextServerFactoryRegistry, ContextServerTool};
use gpui::{prelude::*, AppContext, Model, ModelContext, Task};
use project::Project;
use util::ResultExt as _;

pub struct ThreadStore {
#[allow(unused)]
project: Model<Project>,
tools: Arc<ToolWorkingSet>,
context_server_manager: Model<ContextServerManager>,
context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
}

impl ThreadStore {
pub fn new(
project: Model<Project>,
tools: Arc<ToolWorkingSet>,
cx: &mut AppContext,
) -> Task<Result<Model<Self>>> {
cx.spawn(|mut cx| async move {
let this = cx.new_model(|cx: &mut ModelContext<Self>| {
let context_server_factory_registry =
ContextServerFactoryRegistry::default_global(cx);
let context_server_manager = cx.new_model(|cx| {
ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
});

let this = Self {
project,
tools,
context_server_manager,
context_server_tool_ids: HashMap::default(),
};
this.register_context_server_handlers(cx);

this
})?;

Ok(this)
})
}

fn register_context_server_handlers(&self, cx: &mut ModelContext<Self>) {
cx.subscribe(
&self.context_server_manager.clone(),
Self::handle_context_server_event,
)
.detach();
}

fn handle_context_server_event(
&mut self,
context_server_manager: Model<ContextServerManager>,
event: &context_server::manager::Event,
cx: &mut ModelContext<Self>,
) {
let tool_working_set = self.tools.clone();
match event {
context_server::manager::Event::ServerStarted { server_id } => {
if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
let context_server_manager = context_server_manager.clone();
cx.spawn({
let server = server.clone();
let server_id = server_id.clone();
|this, mut cx| async move {
let Some(protocol) = server.client() else {
return;
};

if protocol.capable(context_server::protocol::ServerCapability::Tools) {
if let Some(tools) = protocol.list_tools().await.log_err() {
let tool_ids = tools
.tools
.into_iter()
.map(|tool| {
log::info!(
"registering context server tool: {:?}",
tool.name
);
tool_working_set.insert(Arc::new(
ContextServerTool::new(
context_server_manager.clone(),
server.id(),
tool,
),
))
})
.collect::<Vec<_>>();

this.update(&mut cx, |this, _cx| {
this.context_server_tool_ids.insert(server_id, tool_ids);
})
.log_err();
}
}
}
})
.detach();
}
}
context_server::manager::Event::ServerStopped { server_id } => {
if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
tool_working_set.remove(&tool_ids);
}
}
}
}
}
Loading