diff --git a/jupyterlab_chat/handlers.py b/jupyterlab_chat/handlers.py index 46a7efd..4a7ef55 100644 --- a/jupyterlab_chat/handlers.py +++ b/jupyterlab_chat/handlers.py @@ -2,12 +2,9 @@ import json import time import uuid -from asyncio import AbstractEventLoop -from dataclasses import asdict -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import Dict, List from jupyter_server.base.handlers import APIHandler as BaseAPIHandler, JupyterHandler -from jupyter_server.utils import url_path_join from langchain.pydantic_v1 import ValidationError from tornado import web, websocket @@ -153,11 +150,13 @@ async def on_message(self, message): return # message broadcast to chat clients - chat_message_id = str(uuid.uuid4()) + if not chat_request.id: + chat_request.id = str(uuid.uuid4()) + chat_message = ChatMessage( - id=chat_message_id, + id=chat_request.id, time=time.time(), - body=chat_request.prompt, + body=chat_request.body, sender=self.chat_client, ) diff --git a/jupyterlab_chat/models.py b/jupyterlab_chat/models.py index 64ea2ab..95cff48 100644 --- a/jupyterlab_chat/models.py +++ b/jupyterlab_chat/models.py @@ -8,7 +8,8 @@ # the type of message used to chat with the agent class ChatRequest(BaseModel): - prompt: str + body: str + id: str class ChatUser(BaseModel): diff --git a/package.json b/package.json index 798e0d6..74b6801 100644 --- a/package.json +++ b/package.json @@ -60,6 +60,9 @@ "@jupyterlab/rendermime": "^4.0.5", "@jupyterlab/services": "^7.0.5", "@jupyterlab/ui-components": "^4.0.5", + "@lumino/coreutils": "2.1.2", + "@lumino/disposable": "2.1.2", + "@lumino/signaling": "2.1.2", "@mui/icons-material": "5.11.0", "@mui/material": "^5.11.0", "react": "^18.2.0", diff --git a/src/components/chat-messages.tsx b/src/components/chat-messages.tsx index 0082a76..02953dc 100644 --- a/src/components/chat-messages.tsx +++ b/src/components/chat-messages.tsx @@ -1,9 +1,8 @@ -import React, { useState, useEffect } from 'react'; - +import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { Avatar, Box, Typography } from '@mui/material'; import type { SxProps, Theme } from '@mui/material'; +import React, { useState, useEffect } from 'react'; -import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { RendermimeMarkdown } from './rendermime-markdown'; import { ChatService } from '../services'; diff --git a/src/components/chat-settings.tsx b/src/components/chat-settings.tsx index 2cd87a4..0949882 100644 --- a/src/components/chat-settings.tsx +++ b/src/components/chat-settings.tsx @@ -11,14 +11,14 @@ import { } from '@mui/material'; import React, { useEffect, useState } from 'react'; -import { ChatService } from '../services'; +import { useStackingAlert } from './mui-extras/stacking-alert'; import { ServerInfoState, useServerInfo } from './settings/use-server-info'; import { minifyUpdate } from './settings/minify'; -import { useStackingAlert } from './mui-extras/stacking-alert'; +import { ChatService } from '../services'; -// /** -// * Component that returns the settings view in the chat panel. -// */ +/** + * Component that returns the settings view in the chat panel. + */ export function ChatSettings(): JSX.Element { // state fetched on initial render const server = useServerInfo(); diff --git a/src/components/chat.tsx b/src/components/chat.tsx index 06434c2..eaaac71 100644 --- a/src/components/chat.tsx +++ b/src/components/chat.tsx @@ -10,17 +10,17 @@ import { JlThemeProvider } from './jl-theme-provider'; import { ChatMessages } from './chat-messages'; import { ChatInput } from './chat-input'; import { ChatSettings } from './chat-settings'; -import { ChatHandler } from '../chat-handler'; import { ScrollContainer } from './scroll-container'; +import { IChatModel } from '../model'; import { ChatService } from '../services'; type ChatBodyProps = { - chatHandler: ChatHandler; + chatModel: IChatModel; rmRegistry: IRenderMimeRegistry; }; function ChatBody({ - chatHandler, + chatModel, rmRegistry: renderMimeRegistry }: ChatBodyProps): JSX.Element { const [messages, setMessages] = useState([]); @@ -34,7 +34,8 @@ function ChatBody({ async function fetchHistory() { try { const [history, config] = await Promise.all([ - chatHandler.getHistory(), + chatModel.getHistory?.() ?? + new Promise(r => r({ messages: [] })), ChatService.getConfig() ]); setSendWithShiftEnter(config.send_with_shift_enter ?? false); @@ -45,13 +46,13 @@ function ChatBody({ } fetchHistory(); - }, [chatHandler]); + }, [chatModel]); /** * Effect: listen to chat messages */ useEffect(() => { - function handleChatEvents(message: ChatService.IMessage) { + function handleChatEvents(_: IChatModel, message: ChatService.IMessage) { if (message.type === 'connection') { return; } else if (message.type === 'clear') { @@ -62,11 +63,11 @@ function ChatBody({ setMessages(messageGroups => [...messageGroups, message]); } - chatHandler.addListener(handleChatEvents); + chatModel.incomingMessage.connect(handleChatEvents); return function cleanup() { - chatHandler.removeListener(handleChatEvents); + chatModel.incomingMessage.disconnect(handleChatEvents); }; - }, [chatHandler]); + }, [chatModel]); // no need to append to messageGroups imperatively here. all of that is // handled by the listeners registered in the effect hooks above. @@ -74,7 +75,7 @@ function ChatBody({ setInput(''); // send message to backend - chatHandler.sendMessage({ prompt: input }); + chatModel.sendMessage({ body: input }); }; return ( @@ -100,7 +101,7 @@ function ChatBody({ } export type ChatProps = { - chatHandler: ChatHandler; + chatModel: IChatModel; themeManager: IThemeManager | null; rmRegistry: IRenderMimeRegistry; chatView?: ChatView; @@ -147,10 +148,7 @@ export function Chat(props: ChatProps): JSX.Element { {/* body */} {view === ChatView.Chat && ( - + )} {view === ChatView.Settings && } diff --git a/src/handler.ts b/src/handlers/handler.ts similarity index 99% rename from src/handler.ts rename to src/handlers/handler.ts index d5a1f57..5ffb554 100644 --- a/src/handler.ts +++ b/src/handlers/handler.ts @@ -1,5 +1,4 @@ import { URLExt } from '@jupyterlab/coreutils'; - import { ServerConnection } from '@jupyterlab/services'; const API_NAMESPACE = 'api/chat'; diff --git a/src/chat-handler.ts b/src/handlers/websocket-handler.ts similarity index 60% rename from src/chat-handler.ts rename to src/handlers/websocket-handler.ts index 62b5433..0cc0403 100644 --- a/src/chat-handler.ts +++ b/src/handlers/websocket-handler.ts @@ -1,26 +1,27 @@ -import { IDisposable } from '@lumino/disposable'; -import { ServerConnection } from '@jupyterlab/services'; import { URLExt } from '@jupyterlab/coreutils'; +import { ServerConnection } from '@jupyterlab/services'; +import { UUID } from '@lumino/coreutils'; + import { requestAPI } from './handler'; -import { ChatService } from './services'; +import { ChatModel, IChatModel } from '../model'; +import { ChatService } from '../services'; const CHAT_SERVICE_URL = 'api/chat'; -export class ChatHandler implements IDisposable { +/** + * An implementation of the chat model based on websocket handler. + */ +export class WebSocketHandler extends ChatModel { /** * The server settings used to make API requests. */ readonly serverSettings: ServerConnection.ISettings; - /** - * ID of the connection. Requires `await initialize()`. - */ - id = ''; - /** * Create a new chat handler. */ - constructor(options: ChatHandler.IOptions = {}) { + constructor(options: WebSocketHandler.IOptions = {}) { + super(options); this.serverSettings = options.serverSettings ?? ServerConnection.makeSettings(); } @@ -30,7 +31,7 @@ export class ChatHandler implements IDisposable { * resolved when server acknowledges connection and sends the client ID. This * must be awaited before calling any other method. */ - public async initialize(): Promise { + async initialize(): Promise { await this._initialize(); } @@ -38,27 +39,15 @@ export class ChatHandler implements IDisposable { * Sends a message across the WebSocket. Promise resolves to the message ID * when the server sends the same message back, acknowledging receipt. */ - public sendMessage(message: ChatService.ChatRequest): Promise { + sendMessage(message: ChatService.ChatRequest): Promise { + message.id = UUID.uuid4(); return new Promise(resolve => { this._socket?.send(JSON.stringify(message)); - this._sendResolverQueue.push(resolve); + this._sendResolverQueue.set(message.id!, resolve); }); } - public addListener(handler: (message: ChatService.IMessage) => void): void { - this._listeners.push(handler); - } - - public removeListener( - handler: (message: ChatService.IMessage) => void - ): void { - const index = this._listeners.indexOf(handler); - if (index > -1) { - this._listeners.splice(index, 1); - } - } - - public async getHistory(): Promise { + async getHistory(): Promise { let data: ChatService.ChatHistory = { messages: [] }; try { data = await requestAPI('history', { @@ -70,22 +59,11 @@ export class ChatHandler implements IDisposable { return data; } - /** - * Whether the chat handler is disposed. - */ - get isDisposed(): boolean { - return this._isDisposed; - } - /** * Dispose the chat handler. */ dispose(): void { - if (this.isDisposed) { - return; - } - this._isDisposed = true; - this._listeners = []; + super.dispose(); // Clean up socket. const socket = this._socket; @@ -99,35 +77,15 @@ export class ChatHandler implements IDisposable { } } - /** - * A function called before transferring the message to the panel(s). - * Can be useful if some actions are required on the message. - */ - protected formatChatMessage( - message: ChatService.IChatMessage - ): ChatService.IChatMessage { - return message; - } - - private _onMessage(message: ChatService.IMessage): void { + onMessage(message: ChatService.IMessage): void { // resolve promise from `sendMessage()` if (message.type === 'msg' && message.sender.id === this.id) { - this._sendResolverQueue.shift()?.(message.id); - } - - if (message.type === 'msg') { - message = this.formatChatMessage(message as ChatService.IChatMessage); + this._sendResolverQueue.get(message.id)?.(true); } - // call listeners in serial - this._listeners.forEach(listener => listener(message)); + super.onMessage(message); } - /** - * Queue of Promise resolvers pushed onto by `send()` - */ - private _sendResolverQueue: ((value: string) => void)[] = []; - private _onClose(e: CloseEvent, reject: any) { reject(new Error('Chat UI websocket disconnected')); console.error('Chat UI websocket disconnected'); @@ -155,31 +113,39 @@ export class ChatHandler implements IDisposable { socket.onclose = e => this._onClose(e, reject); socket.onerror = e => reject(e); socket.onmessage = msg => - msg.data && this._onMessage(JSON.parse(msg.data)); + msg.data && this.onMessage(JSON.parse(msg.data)); - const listenForConnection = (message: ChatService.IMessage) => { + const listenForConnection = ( + _: IChatModel, + message: ChatService.IMessage + ) => { if (message.type !== 'connection') { return; } this.id = message.client_id; resolve(); - this.removeListener(listenForConnection); + this.incomingMessage.disconnect(listenForConnection); }; - this.addListener(listenForConnection); + this.incomingMessage.connect(listenForConnection); }); } - private _isDisposed = false; private _socket: WebSocket | null = null; - private _listeners: ((msg: any) => void)[] = []; + /** + * Queue of Promise resolvers pushed onto by `send()` + */ + private _sendResolverQueue = new Map void>(); } -export namespace ChatHandler { +/** + * The websocket namespace. + */ +export namespace WebSocketHandler { /** * The instantiation options for a data registry handler. */ - export interface IOptions { + export interface IOptions extends ChatModel.IOptions { serverSettings?: ServerConnection.ISettings; } } diff --git a/src/index.ts b/src/index.ts index aeb0822..6127c09 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,4 +1,5 @@ -export * from './chat-handler'; +export * from './handlers/websocket-handler'; +export * from './model'; +export * from './services'; export * from './widgets/chat-error'; export * from './widgets/chat-sidebar'; -export * from './services'; diff --git a/src/model.ts b/src/model.ts new file mode 100644 index 0000000..7e6e978 --- /dev/null +++ b/src/model.ts @@ -0,0 +1,191 @@ +import { IDisposable } from '@lumino/disposable'; +import { ISignal, Signal } from '@lumino/signaling'; + +import { ChatService } from './services'; + +export interface IChatModel extends IDisposable { + /** + * The chat model ID. + */ + id: string; + + /** + * The signal emitted when a new message is received. + */ + get incomingMessage(): ISignal; + + /** + * The signal emitted when a message is updated. + */ + get messageUpdated(): ISignal; + + /** + * The signal emitted when a message is updated. + */ + get messageDeleted(): ISignal; + + /** + * Send a message, to be defined depending on the chosen technology. + * Default to no-op. + * + * @param message - the message to send. + * @returns whether the message has been sent or not, or nothing if not needed. + */ + sendMessage( + message: ChatService.ChatRequest + ): Promise | boolean | void; + + /** + * Optional, to update a message from the chat. + * + * @param id - the unique ID of the message. + * @param message - the message to update. + */ + updateMessage?( + id: string, + message: ChatService.ChatRequest + ): Promise | boolean | void; + + /** + * Optional, to get messages history. + */ + getHistory?(): Promise; + + /** + * Dispose the chat model. + */ + dispose(): void; + + /** + * Whether the chat handler is disposed. + */ + isDisposed: boolean; + + /** + * Function to call when a message is received. + * + * @param message - the new message, containing user information and body. + */ + onMessage(message: ChatService.IMessage): void; + + /** + * Function to call when a message is updated. + * + * @param message - the message updated, containing user information and body. + */ + onMessageUpdated?(message: ChatService.IMessage): void; +} + +/** + * The default chat model implementation. + * It is not able to send or update a message by itself, since it depends on the + * chosen technology. + */ +export class ChatModel implements IChatModel { + /** + * Create a new chat model. + */ + constructor(options: ChatModel.IOptions) {} + + /** + * The chat model ID. + */ + get id(): string { + return this._id; + } + set id(value: string) { + this._id = value; + } + + /** + * The signal emitted when a new message is received. + */ + get incomingMessage(): ISignal { + return this._newMessage; + } + + /** + * The signal emitted when a message is updated. + */ + get messageUpdated(): ISignal { + return this._messageUpdated; + } + + /** + * The signal emitted when a message is updated. + */ + get messageDeleted(): ISignal { + return this._messageDeleted; + } + + /** + * Send a message, to be defined depending on the chosen technology. + * Default to no-op. + * + * @param message - the message to send. + * @returns whether the message has been sent or not. + */ + sendMessage( + message: ChatService.ChatRequest + ): Promise | boolean | void {} + + /** + * Dispose the chat model. + */ + dispose(): void { + if (this.isDisposed) { + return; + } + this._isDisposed = true; + } + + /** + * Whether the chat handler is disposed. + */ + get isDisposed(): boolean { + return this._isDisposed; + } + + /** + * A function called before transferring the message to the panel(s). + * Can be useful if some actions are required on the message. + */ + protected formatChatMessage( + message: ChatService.IChatMessage + ): ChatService.IChatMessage { + return message; + } + + /** + * Function to call when a message is received. + * + * @param message - the message with user information and body. + */ + onMessage(message: ChatService.IMessage): void { + if (message.type === 'msg') { + message = this.formatChatMessage(message as ChatService.IChatMessage); + } + + this._newMessage.emit(message); + } + + private _id: string = ''; + private _isDisposed = false; + private _newMessage = new Signal(this); + private _messageUpdated = new Signal( + this + ); + private _messageDeleted = new Signal( + this + ); +} + +/** + * The chat model namespace. + */ +export namespace ChatModel { + /** + * The instantiation options for a ChatModel. + */ + export interface IOptions {} +} diff --git a/src/services.ts b/src/services.ts index f92212f..d63381a 100644 --- a/src/services.ts +++ b/src/services.ts @@ -1,4 +1,4 @@ -import { requestAPI } from './handler'; +import { requestAPI } from './handlers/handler'; export namespace ChatService { export interface IUser { @@ -35,7 +35,8 @@ export namespace ChatService { }; export type ChatRequest = { - prompt: string; + body: string; + id?: string; }; export type DescribeConfigResponse = { diff --git a/src/widgets/chat-error.tsx b/src/widgets/chat-error.tsx index 8ae9cbb..c31adb2 100644 --- a/src/widgets/chat-error.tsx +++ b/src/widgets/chat-error.tsx @@ -1,7 +1,6 @@ -import React from 'react'; -import { ReactWidget } from '@jupyterlab/apputils'; -import type { IThemeManager } from '@jupyterlab/apputils'; +import { IThemeManager, ReactWidget } from '@jupyterlab/apputils'; import { Alert, Box } from '@mui/material'; +import React from 'react'; import { chatIcon } from '../icons'; import { JlThemeProvider } from '../components/jl-theme-provider'; diff --git a/src/widgets/chat-sidebar.tsx b/src/widgets/chat-sidebar.tsx index 1ddb21b..61ca44e 100644 --- a/src/widgets/chat-sidebar.tsx +++ b/src/widgets/chat-sidebar.tsx @@ -1,20 +1,19 @@ +import { IThemeManager, ReactWidget } from '@jupyterlab/apputils'; +import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import React from 'react'; -import { ReactWidget } from '@jupyterlab/apputils'; -import type { IThemeManager } from '@jupyterlab/apputils'; import { Chat } from '../components/chat'; +import { WebSocketHandler } from '../handlers/websocket-handler'; import { chatIcon } from '../icons'; -import { ChatHandler } from '../chat-handler'; -import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; export function buildChatSidebar( - chatHandler: ChatHandler, + chatHandler: WebSocketHandler, themeManager: IThemeManager | null, rmRegistry: IRenderMimeRegistry ): ReactWidget { const ChatWidget = ReactWidget.create( diff --git a/yarn.lock b/yarn.lock index 7f9772f..9990beb 100644 --- a/yarn.lock +++ b/yarn.lock @@ -2467,6 +2467,9 @@ __metadata: "@jupyterlab/services": ^7.0.5 "@jupyterlab/testutils": ^4.0.0 "@jupyterlab/ui-components": ^4.0.5 + "@lumino/coreutils": 2.1.2 + "@lumino/disposable": 2.1.2 + "@lumino/signaling": 2.1.2 "@mui/icons-material": 5.11.0 "@mui/material": ^5.11.0 "@types/jest": ^29.2.0 @@ -3382,14 +3385,14 @@ __metadata: languageName: node linkType: hard -"@lumino/coreutils@npm:^1.11.0 || ^2.0.0, @lumino/coreutils@npm:^1.11.0 || ^2.1.2, @lumino/coreutils@npm:^2.1.2": +"@lumino/coreutils@npm:2.1.2, @lumino/coreutils@npm:^1.11.0 || ^2.0.0, @lumino/coreutils@npm:^1.11.0 || ^2.1.2, @lumino/coreutils@npm:^2.1.2": version: 2.1.2 resolution: "@lumino/coreutils@npm:2.1.2" checksum: 7865317ac0676b448d108eb57ab5d8b2a17c101995c0f7a7106662d9fe6c859570104525f83ee3cda12ae2e326803372206d6f4c1f415a5b59e4158a7b81066f languageName: node linkType: hard -"@lumino/disposable@npm:^1.10.0 || ^2.0.0, @lumino/disposable@npm:^2.1.2": +"@lumino/disposable@npm:2.1.2, @lumino/disposable@npm:^1.10.0 || ^2.0.0, @lumino/disposable@npm:^2.1.2": version: 2.1.2 resolution: "@lumino/disposable@npm:2.1.2" dependencies: @@ -3450,7 +3453,7 @@ __metadata: languageName: node linkType: hard -"@lumino/signaling@npm:^1.10.0 || ^2.0.0, @lumino/signaling@npm:^2.1.2": +"@lumino/signaling@npm:2.1.2, @lumino/signaling@npm:^1.10.0 || ^2.0.0, @lumino/signaling@npm:^2.1.2": version: 2.1.2 resolution: "@lumino/signaling@npm:2.1.2" dependencies: