Skip to content

Commit

Permalink
assistant_tool: Decouple Tool from Workspace (#26309)
Browse files Browse the repository at this point in the history
This PR decouples the `Tool` trait from the `Workspace` (and from the
UI, in general).

`Tool::run` now takes a `WeakEntity<Project>` instead of a
`WeakEntity<Workspace>` and a `Window`.

Release Notes:

- N/A
  • Loading branch information
maxdeviant authored Mar 7, 2025
1 parent 4f6682c commit 18f3f80
Show file tree
Hide file tree
Showing 13 changed files with 35 additions and 52 deletions.
5 changes: 1 addition & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions crates/assistant2/src/active_thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@ use gpui::{
use language::{Buffer, LanguageRegistry};
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
use markdown::{Markdown, MarkdownStyle};
use project::Project;
use settings::Settings as _;
use theme::ThemeSettings;
use ui::{prelude::*, Disclosure, KeyBinding};
use util::ResultExt as _;
use workspace::Workspace;

use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
use crate::thread_store::ThreadStore;
use crate::tool_use::{ToolUse, ToolUseStatus};
use crate::ui::ContextPill;

pub struct ActiveThread {
workspace: WeakEntity<Workspace>,
project: WeakEntity<Project>,
language_registry: Arc<LanguageRegistry>,
tools: Arc<ToolWorkingSet>,
thread_store: Entity<ThreadStore>,
Expand All @@ -46,7 +46,7 @@ impl ActiveThread {
pub fn new(
thread: Entity<Thread>,
thread_store: Entity<ThreadStore>,
workspace: WeakEntity<Workspace>,
project: WeakEntity<Project>,
language_registry: Arc<LanguageRegistry>,
tools: Arc<ToolWorkingSet>,
window: &mut Window,
Expand All @@ -58,7 +58,7 @@ impl ActiveThread {
];

let mut this = Self {
workspace,
project,
language_registry,
tools,
thread_store,
Expand Down Expand Up @@ -311,7 +311,7 @@ impl ActiveThread {

for tool_use in pending_tool_uses {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
let task = tool.run(tool_use.input, self.workspace.clone(), window, cx);
let task = tool.run(tool_use.input, self.project.clone(), cx);

self.thread.update(cx, |thread, cx| {
thread.insert_tool_output(tool_use.id.clone(), task, cx);
Expand Down
10 changes: 5 additions & 5 deletions crates/assistant2/src/assistant_panel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,16 @@ impl AssistantPanel {

Self {
active_view: ActiveView::Thread,
workspace: workspace.clone(),
project,
workspace,
project: project.clone(),
fs: fs.clone(),
language_registry: language_registry.clone(),
thread_store: thread_store.clone(),
thread: cx.new(|cx| {
ActiveThread::new(
thread.clone(),
thread_store.clone(),
workspace,
project.downgrade(),
language_registry,
tools.clone(),
window,
Expand Down Expand Up @@ -246,7 +246,7 @@ impl AssistantPanel {
ActiveThread::new(
thread.clone(),
self.thread_store.clone(),
self.workspace.clone(),
self.project.downgrade(),
self.language_registry.clone(),
self.tools.clone(),
window,
Expand Down Expand Up @@ -381,7 +381,7 @@ impl AssistantPanel {
ActiveThread::new(
thread.clone(),
this.thread_store.clone(),
this.workspace.clone(),
this.project.downgrade(),
this.language_registry.clone(),
this.tools.clone(),
window,
Expand Down
2 changes: 1 addition & 1 deletion crates/assistant_tool/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ collections.workspace = true
derive_more.workspace = true
gpui.workspace = true
parking_lot.workspace = true
project.workspace = true
serde.workspace = true
serde_json.workspace = true
workspace.workspace = true
7 changes: 3 additions & 4 deletions crates/assistant_tool/src/assistant_tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ mod tool_working_set;
use std::sync::Arc;

use anyhow::Result;
use gpui::{App, Task, WeakEntity, Window};
use workspace::Workspace;
use gpui::{App, Task, WeakEntity};
use project::Project;

pub use crate::tool_registry::*;
pub use crate::tool_working_set::*;
Expand All @@ -31,8 +31,7 @@ pub trait Tool: 'static + Send + Sync {
fn run(
self: Arc<Self>,
input: serde_json::Value,
workspace: WeakEntity<Workspace>,
window: &mut Window,
project: WeakEntity<Project>,
cx: &mut App,
) -> Task<Result<String>>;
}
1 change: 0 additions & 1 deletion crates/assistant_tools/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,3 @@ project.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
workspace.workspace = true
13 changes: 5 additions & 8 deletions crates/assistant_tools/src/list_worktrees_tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use std::sync::Arc;

use anyhow::{anyhow, Result};
use assistant_tool::Tool;
use gpui::{App, Task, WeakEntity, Window};
use gpui::{App, Task, WeakEntity};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use workspace::Workspace;

#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct ListWorktreesToolInput {}
Expand Down Expand Up @@ -34,16 +34,13 @@ impl Tool for ListWorktreesTool {
fn run(
self: Arc<Self>,
_input: serde_json::Value,
workspace: WeakEntity<Workspace>,
_window: &mut Window,
project: WeakEntity<Project>,
cx: &mut App,
) -> Task<Result<String>> {
let Some(workspace) = workspace.upgrade() else {
return Task::ready(Err(anyhow!("workspace dropped")));
let Some(project) = project.upgrade() else {
return Task::ready(Err(anyhow!("project dropped")));
};

let project = workspace.read(cx).project().clone();

cx.spawn(|cx| async move {
cx.update(|cx| {
#[derive(Debug, Serialize)]
Expand Down
6 changes: 3 additions & 3 deletions crates/assistant_tools/src/now_tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use std::sync::Arc;
use anyhow::{anyhow, Result};
use assistant_tool::Tool;
use chrono::{Local, Utc};
use gpui::{App, Task, WeakEntity, Window};
use gpui::{App, Task, WeakEntity};
use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -41,8 +42,7 @@ impl Tool for NowTool {
fn run(
self: Arc<Self>,
input: serde_json::Value,
_workspace: WeakEntity<workspace::Workspace>,
_window: &mut Window,
_project: WeakEntity<Project>,
_cx: &mut App,
) -> Task<Result<String>> {
let input: NowToolInput = match serde_json::from_value(input) {
Expand Down
13 changes: 5 additions & 8 deletions crates/assistant_tools/src/read_file_tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@ use std::sync::Arc;

use anyhow::{anyhow, Result};
use assistant_tool::Tool;
use gpui::{App, Task, WeakEntity, Window};
use project::{ProjectPath, WorktreeId};
use gpui::{App, Task, WeakEntity};
use project::{Project, ProjectPath, WorktreeId};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use workspace::Workspace;

#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct ReadFileToolInput {
Expand Down Expand Up @@ -38,20 +37,18 @@ impl Tool for ReadFileTool {
fn run(
self: Arc<Self>,
input: serde_json::Value,
workspace: WeakEntity<Workspace>,
_window: &mut Window,
project: WeakEntity<Project>,
cx: &mut App,
) -> Task<Result<String>> {
let Some(workspace) = workspace.upgrade() else {
return Task::ready(Err(anyhow!("workspace dropped")));
let Some(project) = project.upgrade() else {
return Task::ready(Err(anyhow!("project dropped")));
};

let input = match serde_json::from_value::<ReadFileToolInput>(input) {
Ok(input) => input,
Err(err) => return Task::ready(Err(anyhow!(err))),
};

let project = workspace.read(cx).project().clone();
let project_path = ProjectPath {
worktree_id: WorktreeId::from_usize(input.worktree_id),
path: input.path,
Expand Down
1 change: 0 additions & 1 deletion crates/context_server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,3 @@ settings.workspace = true
smol.workspace = true
url = { workspace = true, features = ["serde"] }
util.workspace = true
workspace.workspace = true
5 changes: 2 additions & 3 deletions crates/context_server/src/context_server_tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::Arc;

use anyhow::{anyhow, bail};
use assistant_tool::Tool;
use gpui::{App, Entity, Task, Window};
use gpui::{App, Entity, Task};

use crate::manager::ContextServerManager;
use crate::types;
Expand Down Expand Up @@ -51,8 +51,7 @@ impl Tool for ContextServerTool {
fn run(
self: std::sync::Arc<Self>,
input: serde_json::Value,
_workspace: gpui::WeakEntity<workspace::Workspace>,
_: &mut Window,
_project: gpui::WeakEntity<project::Project>,
cx: &mut App,
) -> gpui::Task<gpui::Result<String>> {
if let Some(server) = self.server_manager.read(cx).get_server(&self.server_id) {
Expand Down
2 changes: 0 additions & 2 deletions crates/scripting_tool/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,9 @@ serde.workspace = true
serde_json.workspace = true
settings.workspace = true
util.workspace = true
workspace.workspace = true

[dev-dependencies]
collections = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
workspace = { workspace = true, features = ["test-support"] }
12 changes: 5 additions & 7 deletions crates/scripting_tool/src/scripting_tool.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
mod session;

use project::Project;
pub(crate) use session::*;

use assistant_tool::{Tool, ToolRegistry};
use gpui::{App, AppContext as _, Task, WeakEntity, Window};
use gpui::{App, AppContext as _, Task, WeakEntity};
use schemars::JsonSchema;
use serde::Deserialize;
use std::sync::Arc;
use workspace::Workspace;

pub fn init(cx: &App) {
let registry = ToolRegistry::global(cx);
Expand Down Expand Up @@ -38,17 +38,15 @@ impl Tool for ScriptingTool {
fn run(
self: Arc<Self>,
input: serde_json::Value,
workspace: WeakEntity<Workspace>,
_window: &mut Window,
project: WeakEntity<Project>,
cx: &mut App,
) -> Task<anyhow::Result<String>> {
let input = match serde_json::from_value::<ScriptingToolInput>(input) {
Err(err) => return Task::ready(Err(err.into())),
Ok(input) => input,
};
let Ok(project) = workspace.read_with(cx, |workspace, _cx| workspace.project().clone())
else {
return Task::ready(Err(anyhow::anyhow!("No project found")));
let Some(project) = project.upgrade() else {
return Task::ready(Err(anyhow::anyhow!("project dropped")));
};

let session = cx.new(|cx| Session::new(project, cx));
Expand Down

0 comments on commit 18f3f80

Please sign in to comment.