Skip to content

Commit

Permalink
assistant2: Add ability to open past threads (#21548)
Browse files Browse the repository at this point in the history
This PR adds the ability to open past threads in Assistant 2.

There are also some mocked threads in the history for testing purposes.

Release Notes:

- N/A
  • Loading branch information
maxdeviant authored Dec 4, 2024
1 parent 44264ff commit 0bde0f8
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 60 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions crates/assistant2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,7 @@ settings.workspace = true
smol.workspace = true
theme.workspace = true
ui.workspace = true
unindent.workspace = true
util.workspace = true
uuid.workspace = true
workspace.workspace = true
149 changes: 91 additions & 58 deletions crates/assistant2/src/assistant_panel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use workspace::dock::{DockPosition, Panel, PanelEvent};
use workspace::Workspace;

use crate::message_editor::MessageEditor;
use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent, ThreadId};
use crate::thread_store::ThreadStore;
use crate::{NewThread, OpenHistory, ToggleFocus, ToggleModelSelector};

Expand Down Expand Up @@ -77,7 +77,7 @@ impl AssistantPanel {
tools: Arc<ToolWorkingSet>,
cx: &mut ViewContext<Self>,
) -> Self {
let thread = cx.new_model(|cx| Thread::new(tools.clone(), cx));
let thread = thread_store.update(cx, |this, cx| this.create_thread(cx));
let subscriptions = vec![
cx.observe(&thread, |_, _, cx| cx.notify()),
cx.subscribe(&thread, Self::handle_thread_event),
Expand Down Expand Up @@ -105,8 +105,27 @@ impl AssistantPanel {
}

fn new_thread(&mut self, cx: &mut ViewContext<Self>) {
let tools = self.thread.read(cx).tools().clone();
let thread = cx.new_model(|cx| Thread::new(tools, cx));
let thread = self
.thread_store
.update(cx, |this, cx| this.create_thread(cx));
self.reset_thread(thread, cx);
}

fn open_thread(&mut self, thread_id: &ThreadId, cx: &mut ViewContext<Self>) {
let Some(thread) = self
.thread_store
.update(cx, |this, cx| this.open_thread(thread_id, cx))
else {
return;
};
self.reset_thread(thread.clone(), cx);

for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
self.push_message(&message.id, message.text.clone(), cx);
}
}

fn reset_thread(&mut self, thread: Model<Thread>, cx: &mut ViewContext<Self>) {
let subscriptions = vec![
cx.observe(&thread, |_, _, cx| cx.notify()),
cx.subscribe(&thread, Self::handle_thread_event),
Expand All @@ -122,6 +141,56 @@ impl AssistantPanel {
self.message_editor.focus_handle(cx).focus(cx);
}

fn push_message(&mut self, id: &MessageId, text: String, cx: &mut ViewContext<Self>) {
let old_len = self.thread_messages.len();
self.thread_messages.push(*id);
self.thread_list_state.splice(old_len..old_len, 1);

let theme_settings = ThemeSettings::get_global(cx);
let ui_font_size = TextSize::Default.rems(cx);
let buffer_font_size = theme_settings.buffer_font_size;

let mut text_style = cx.text_style();
text_style.refine(&TextStyleRefinement {
font_family: Some(theme_settings.ui_font.family.clone()),
font_size: Some(ui_font_size.into()),
color: Some(cx.theme().colors().text),
..Default::default()
});

let markdown_style = MarkdownStyle {
base_text_style: text_style,
syntax: cx.theme().syntax().clone(),
selection_background_color: cx.theme().players().local().selection,
code_block: StyleRefinement {
text: Some(TextStyleRefinement {
font_family: Some(theme_settings.buffer_font.family.clone()),
font_size: Some(buffer_font_size.into()),
..Default::default()
}),
..Default::default()
},
inline_code: TextStyleRefinement {
font_family: Some(theme_settings.buffer_font.family.clone()),
font_size: Some(ui_font_size.into()),
background_color: Some(cx.theme().colors().editor_background),
..Default::default()
},
..Default::default()
};

let markdown = cx.new_view(|cx| {
Markdown::new(
text,
markdown_style,
Some(self.language_registry.clone()),
None,
cx,
)
});
self.rendered_messages_by_id.insert(*id, markdown);
}

fn handle_thread_event(
&mut self,
_: Model<Thread>,
Expand All @@ -141,59 +210,13 @@ impl AssistantPanel {
}
}
ThreadEvent::MessageAdded(message_id) => {
let old_len = self.thread_messages.len();
self.thread_messages.push(*message_id);
self.thread_list_state.splice(old_len..old_len, 1);

if let Some(message_text) = self
.thread
.read(cx)
.message(*message_id)
.map(|message| message.text.clone())
{
let theme_settings = ThemeSettings::get_global(cx);
let ui_font_size = TextSize::Default.rems(cx);
let buffer_font_size = theme_settings.buffer_font_size;

let mut text_style = cx.text_style();
text_style.refine(&TextStyleRefinement {
font_family: Some(theme_settings.ui_font.family.clone()),
font_size: Some(ui_font_size.into()),
color: Some(cx.theme().colors().text),
..Default::default()
});

let markdown_style = MarkdownStyle {
base_text_style: text_style,
syntax: cx.theme().syntax().clone(),
selection_background_color: cx.theme().players().local().selection,
code_block: StyleRefinement {
text: Some(TextStyleRefinement {
font_family: Some(theme_settings.buffer_font.family.clone()),
font_size: Some(buffer_font_size.into()),
..Default::default()
}),
..Default::default()
},
inline_code: TextStyleRefinement {
font_family: Some(theme_settings.buffer_font.family.clone()),
font_size: Some(ui_font_size.into()),
background_color: Some(cx.theme().colors().editor_background),
..Default::default()
},
..Default::default()
};

let markdown = cx.new_view(|cx| {
Markdown::new(
message_text,
markdown_style,
Some(self.language_registry.clone()),
None,
cx,
)
});
self.rendered_messages_by_id.insert(*message_id, markdown);
self.push_message(message_id, message_text, cx);
}

cx.notify();
Expand Down Expand Up @@ -401,8 +424,9 @@ impl AssistantPanel {

fn render_message_list(&self, cx: &mut ViewContext<Self>) -> AnyElement {
if self.thread_messages.is_empty() {
#[allow(clippy::useless_vec)]
let recent_threads = vec![1, 2, 3];
let recent_threads = self
.thread_store
.update(cx, |this, cx| this.recent_threads(3, cx));

return v_flex()
.gap_2()
Expand Down Expand Up @@ -467,8 +491,8 @@ impl AssistantPanel {
.child(
v_flex().gap_2().children(
recent_threads
.iter()
.map(|_thread| self.render_past_thread(cx)),
.into_iter()
.map(|thread| self.render_past_thread(thread, cx)),
),
)
.child(
Expand Down Expand Up @@ -534,10 +558,16 @@ impl AssistantPanel {
.into_any()
}

fn render_past_thread(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
ListItem::new("temp")
fn render_past_thread(
&self,
thread: Model<Thread>,
cx: &mut ViewContext<Self>,
) -> impl IntoElement {
let id = thread.read(cx).id().clone();

ListItem::new(("past-thread", thread.entity_id()))
.start_slot(Icon::new(IconName::MessageBubbles))
.child(Label::new("Some Thread Title"))
.child(Label::new(format!("Thread {id}")))
.end_slot(
h_flex()
.gap_2()
Expand All @@ -548,6 +578,9 @@ impl AssistantPanel {
.icon_size(IconSize::Small),
),
)
.on_click(cx.listener(move |this, _event, cx| {
this.open_thread(&id, cx);
}))
}

fn render_last_error(&self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
Expand Down
37 changes: 36 additions & 1 deletion crates/assistant2/src/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,28 @@ use language_model::{
use language_models::provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError};
use serde::{Deserialize, Serialize};
use util::post_inc;
use uuid::Uuid;

#[derive(Debug, Clone, Copy)]
pub enum RequestKind {
Chat,
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
pub struct ThreadId(Arc<str>);

impl ThreadId {
pub fn new() -> Self {
Self(Uuid::new_v4().to_string().into())
}
}

impl std::fmt::Display for ThreadId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
pub struct MessageId(usize);

Expand All @@ -39,6 +55,7 @@ pub struct Message {

/// A thread of conversation with the LLM.
pub struct Thread {
id: ThreadId,
messages: Vec<Message>,
next_message_id: MessageId,
completion_count: usize,
Expand All @@ -52,6 +69,7 @@ pub struct Thread {
impl Thread {
pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut ModelContext<Self>) -> Self {
Self {
id: ThreadId::new(),
messages: Vec::new(),
next_message_id: MessageId(0),
completion_count: 0,
Expand All @@ -63,10 +81,18 @@ impl Thread {
}
}

pub fn id(&self) -> &ThreadId {
&self.id
}

pub fn message(&self, id: MessageId) -> Option<&Message> {
self.messages.iter().find(|message| message.id == id)
}

pub fn messages(&self) -> impl Iterator<Item = &Message> {
self.messages.iter()
}

pub fn tools(&self) -> &Arc<ToolWorkingSet> {
&self.tools
}
Expand All @@ -76,10 +102,19 @@ impl Thread {
}

pub fn insert_user_message(&mut self, text: impl Into<String>, cx: &mut ModelContext<Self>) {
self.insert_message(Role::User, text, cx)
}

pub fn insert_message(
&mut self,
role: Role,
text: impl Into<String>,
cx: &mut ModelContext<Self>,
) {
let id = self.next_message_id.post_inc();
self.messages.push(Message {
id,
role: Role::User,
role,
text: text.into(),
});
cx.emit(ThreadEvent::MessageAdded(id));
Expand Down
Loading

0 comments on commit 0bde0f8

Please sign in to comment.