Skip to content

Commit 921c24e

Browse files
authored
assistant2: Add helper methods to Thread for dealing with tool use (#26310)
This PR adds two new helper methods to the `Thread` for dealing with tool use: - `use_pending_tools` - This uses all of the tools that are pending - The reason we aren't calling this directly in `stream_completion` is that we still might need to have a way for users to confirm that they want tools to be run, which would need to happen at the UI layer in the `ActiveThread`. - `send_tool_results_to_model` - This encapsulates inserting a new user message that contains the tool results and sending them up to the model. Release Notes: - N/A
1 parent 18f3f80 commit 921c24e

File tree

4 files changed

+61
-52
lines changed

4 files changed

+61
-52
lines changed

crates/assistant2/src/active_thread.rs

Lines changed: 5 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
use std::sync::Arc;
22

3-
use assistant_tool::ToolWorkingSet;
43
use collections::HashMap;
54
use editor::{Editor, MultiBuffer};
65
use gpui::{
76
list, AbsoluteLength, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty,
87
Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
9-
Task, TextStyleRefinement, UnderlineStyle, WeakEntity,
8+
Task, TextStyleRefinement, UnderlineStyle,
109
};
1110
use language::{Buffer, LanguageRegistry};
1211
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
1312
use markdown::{Markdown, MarkdownStyle};
14-
use project::Project;
1513
use settings::Settings as _;
1614
use theme::ThemeSettings;
1715
use ui::{prelude::*, Disclosure, KeyBinding};
@@ -23,9 +21,7 @@ use crate::tool_use::{ToolUse, ToolUseStatus};
2321
use crate::ui::ContextPill;
2422

2523
pub struct ActiveThread {
26-
project: WeakEntity<Project>,
2724
language_registry: Arc<LanguageRegistry>,
28-
tools: Arc<ToolWorkingSet>,
2925
thread_store: Entity<ThreadStore>,
3026
thread: Entity<Thread>,
3127
save_thread_task: Option<Task<()>>,
@@ -46,9 +42,7 @@ impl ActiveThread {
4642
pub fn new(
4743
thread: Entity<Thread>,
4844
thread_store: Entity<ThreadStore>,
49-
project: WeakEntity<Project>,
5045
language_registry: Arc<LanguageRegistry>,
51-
tools: Arc<ToolWorkingSet>,
5246
window: &mut Window,
5347
cx: &mut Context<Self>,
5448
) -> Self {
@@ -58,9 +52,7 @@ impl ActiveThread {
5852
];
5953

6054
let mut this = Self {
61-
project,
6255
language_registry,
63-
tools,
6456
thread_store,
6557
thread: thread.clone(),
6658
save_thread_task: None,
@@ -300,24 +292,9 @@ impl ActiveThread {
300292
cx.notify();
301293
}
302294
ThreadEvent::UsePendingTools => {
303-
let pending_tool_uses = self
304-
.thread
305-
.read(cx)
306-
.pending_tool_uses()
307-
.into_iter()
308-
.filter(|tool_use| tool_use.status.is_idle())
309-
.cloned()
310-
.collect::<Vec<_>>();
311-
312-
for tool_use in pending_tool_uses {
313-
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
314-
let task = tool.run(tool_use.input, self.project.clone(), cx);
315-
316-
self.thread.update(cx, |thread, cx| {
317-
thread.insert_tool_output(tool_use.id.clone(), task, cx);
318-
});
319-
}
320-
}
295+
self.thread.update(cx, |thread, cx| {
296+
thread.use_pending_tools(cx);
297+
});
321298
}
322299
ThreadEvent::ToolFinished { .. } => {
323300
let all_tools_finished = self
@@ -330,16 +307,7 @@ impl ActiveThread {
330307
let model_registry = LanguageModelRegistry::read_global(cx);
331308
if let Some(model) = model_registry.active_model() {
332309
self.thread.update(cx, |thread, cx| {
333-
// Insert a user message to contain the tool results.
334-
thread.insert_user_message(
335-
// TODO: Sending up a user message without any content results in the model sending back
336-
// responses that also don't have any content. We currently don't handle this case well,
337-
// so for now we provide some text to keep the model on track.
338-
"Here are the tool results.",
339-
Vec::new(),
340-
cx,
341-
);
342-
thread.send_to_model(model, RequestKind::Chat, true, cx);
310+
thread.send_tool_results_to_model(model, cx);
343311
});
344312
}
345313
}

crates/assistant2/src/assistant_panel.rs

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ pub struct AssistantPanel {
9292
context_editor: Option<Entity<ContextEditor>>,
9393
configuration: Option<Entity<AssistantConfiguration>>,
9494
configuration_subscription: Option<Subscription>,
95-
tools: Arc<ToolWorkingSet>,
9695
local_timezone: UtcOffset,
9796
active_view: ActiveView,
9897
history_store: Entity<HistoryStore>,
@@ -133,7 +132,7 @@ impl AssistantPanel {
133132
log::info!("[assistant2-debug] finished initializing ContextStore");
134133

135134
workspace.update_in(&mut cx, |workspace, window, cx| {
136-
cx.new(|cx| Self::new(workspace, thread_store, context_store, tools, window, cx))
135+
cx.new(|cx| Self::new(workspace, thread_store, context_store, window, cx))
137136
})
138137
})
139138
}
@@ -142,7 +141,6 @@ impl AssistantPanel {
142141
workspace: &Workspace,
143142
thread_store: Entity<ThreadStore>,
144143
context_store: Entity<assistant_context_editor::ContextStore>,
145-
tools: Arc<ToolWorkingSet>,
146144
window: &mut Window,
147145
cx: &mut Context<Self>,
148146
) -> Self {
@@ -179,9 +177,7 @@ impl AssistantPanel {
179177
ActiveThread::new(
180178
thread.clone(),
181179
thread_store.clone(),
182-
project.downgrade(),
183180
language_registry,
184-
tools.clone(),
185181
window,
186182
cx,
187183
)
@@ -191,7 +187,6 @@ impl AssistantPanel {
191187
context_editor: None,
192188
configuration: None,
193189
configuration_subscription: None,
194-
tools,
195190
local_timezone: UtcOffset::from_whole_seconds(
196191
chrono::Local::now().offset().local_minus_utc(),
197192
)
@@ -246,9 +241,7 @@ impl AssistantPanel {
246241
ActiveThread::new(
247242
thread.clone(),
248243
self.thread_store.clone(),
249-
self.project.downgrade(),
250244
self.language_registry.clone(),
251-
self.tools.clone(),
252245
window,
253246
cx,
254247
)
@@ -381,9 +374,7 @@ impl AssistantPanel {
381374
ActiveThread::new(
382375
thread.clone(),
383376
this.thread_store.clone(),
384-
this.project.downgrade(),
385377
this.language_registry.clone(),
386-
this.tools.clone(),
387378
window,
388379
cx,
389380
)

crates/assistant2/src/thread.rs

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@ use assistant_tool::ToolWorkingSet;
55
use chrono::{DateTime, Utc};
66
use collections::{BTreeMap, HashMap, HashSet};
77
use futures::StreamExt as _;
8-
use gpui::{App, Context, EventEmitter, SharedString, Task};
8+
use gpui::{App, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
99
use language_model::{
1010
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
1111
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
1212
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
1313
Role, StopReason,
1414
};
15+
use project::Project;
1516
use serde::{Deserialize, Serialize};
1617
use util::{post_inc, TryFutureExt as _};
1718
use uuid::Uuid;
@@ -71,12 +72,17 @@ pub struct Thread {
7172
context_by_message: HashMap<MessageId, Vec<ContextId>>,
7273
completion_count: usize,
7374
pending_completions: Vec<PendingCompletion>,
75+
project: WeakEntity<Project>,
7476
tools: Arc<ToolWorkingSet>,
7577
tool_use: ToolUseState,
7678
}
7779

7880
impl Thread {
79-
pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut Context<Self>) -> Self {
81+
pub fn new(
82+
project: Entity<Project>,
83+
tools: Arc<ToolWorkingSet>,
84+
_cx: &mut Context<Self>,
85+
) -> Self {
8086
Self {
8187
id: ThreadId::new(),
8288
updated_at: Utc::now(),
@@ -88,6 +94,7 @@ impl Thread {
8894
context_by_message: HashMap::default(),
8995
completion_count: 0,
9096
pending_completions: Vec::new(),
97+
project: project.downgrade(),
9198
tools,
9299
tool_use: ToolUseState::new(),
93100
}
@@ -96,6 +103,7 @@ impl Thread {
96103
pub fn from_saved(
97104
id: ThreadId,
98105
saved: SavedThread,
106+
project: Entity<Project>,
99107
tools: Arc<ToolWorkingSet>,
100108
_cx: &mut Context<Self>,
101109
) -> Self {
@@ -127,6 +135,7 @@ impl Thread {
127135
context_by_message: HashMap::default(),
128136
completion_count: 0,
129137
pending_completions: Vec::new(),
138+
project: project.downgrade(),
130139
tools,
131140
tool_use,
132141
}
@@ -550,6 +559,23 @@ impl Thread {
550559
});
551560
}
552561

562+
pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) {
563+
let pending_tool_uses = self
564+
.pending_tool_uses()
565+
.into_iter()
566+
.filter(|tool_use| tool_use.status.is_idle())
567+
.cloned()
568+
.collect::<Vec<_>>();
569+
570+
for tool_use in pending_tool_uses {
571+
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
572+
let task = tool.run(tool_use.input, self.project.clone(), cx);
573+
574+
self.insert_tool_output(tool_use.id.clone(), task, cx);
575+
}
576+
}
577+
}
578+
553579
pub fn insert_tool_output(
554580
&mut self,
555581
tool_use_id: LanguageModelToolUseId,
@@ -576,6 +602,23 @@ impl Thread {
576602
.run_pending_tool(tool_use_id, insert_output_task);
577603
}
578604

605+
pub fn send_tool_results_to_model(
606+
&mut self,
607+
model: Arc<dyn LanguageModel>,
608+
cx: &mut Context<Self>,
609+
) {
610+
// Insert a user message to contain the tool results.
611+
self.insert_user_message(
612+
// TODO: Sending up a user message without any content results in the model sending back
613+
// responses that also don't have any content. We currently don't handle this case well,
614+
// so for now we provide some text to keep the model on track.
615+
"Here are the tool results.",
616+
Vec::new(),
617+
cx,
618+
);
619+
self.send_to_model(model, RequestKind::Chat, true, cx);
620+
}
621+
579622
/// Cancels the last pending completion, if there are any pending.
580623
///
581624
/// Returns whether a completion was canceled.

crates/assistant2/src/thread_store.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ pub fn init(cx: &mut App) {
2626
}
2727

2828
pub struct ThreadStore {
29-
#[allow(unused)]
3029
project: Entity<Project>,
3130
tools: Arc<ToolWorkingSet>,
3231
context_server_manager: Entity<ContextServerManager>,
@@ -78,7 +77,7 @@ impl ThreadStore {
7877
}
7978

8079
pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
81-
cx.new(|cx| Thread::new(self.tools.clone(), cx))
80+
cx.new(|cx| Thread::new(self.project.clone(), self.tools.clone(), cx))
8281
}
8382

8483
pub fn open_thread(
@@ -96,7 +95,15 @@ impl ThreadStore {
9695
.ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
9796

9897
this.update(&mut cx, |this, cx| {
99-
cx.new(|cx| Thread::from_saved(id.clone(), thread, this.tools.clone(), cx))
98+
cx.new(|cx| {
99+
Thread::from_saved(
100+
id.clone(),
101+
thread,
102+
this.project.clone(),
103+
this.tools.clone(),
104+
cx,
105+
)
106+
})
100107
})
101108
})
102109
}

0 commit comments

Comments
 (0)