This repository was archived by the owner on Feb 29, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathhandlers.py
210 lines (167 loc) · 6.41 KB
/
handlers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import getpass
import json
import time
import uuid
from typing import Dict, List
from jupyter_server.base.handlers import APIHandler as BaseAPIHandler, JupyterHandler
from langchain.pydantic_v1 import ValidationError
from tornado import web, websocket
from .config_manager import WriteConflictError
from .models import (
ChatClient,
ChatHistory,
ChatMessage,
ChatRequest,
ChatUser,
ConnectionMessage,
ChatMessage,
Message,
UpdateConfigRequest,
)
class ChatHistoryHandler(BaseAPIHandler):
"""Handler to return message history"""
_messages = []
@property
def chat_history(self):
return self.settings["chat_history"]
@chat_history.setter
def _chat_history_setter(self, new_history):
self.settings["chat_history"] = new_history
@web.authenticated
async def get(self):
history = ChatHistory(messages=self.chat_history)
self.finish(history.json())
class ChatHandler(JupyterHandler, websocket.WebSocketHandler):
"""
A websocket handler for chat.
"""
@property
def root_chat_handlers(self) -> Dict[str, "ChatHandler"]:
"""Dictionary mapping client IDs to their corresponding ChatHandler
instances."""
return self.settings["root_chat_handlers"]
@property
def chat_clients(self) -> Dict[str, ChatClient]:
"""Dictionary mapping client IDs to their ChatClient objects that store
metadata."""
return self.settings["chat_clients"]
@property
def chat_client(self) -> ChatClient:
"""Returns ChatClient object associated with the current connection."""
return self.chat_clients[self.client_id]
@property
def chat_history(self) -> List[ChatMessage]:
return self.settings["chat_history"]
def initialize(self):
self.log.debug("Initializing websocket connection %s", self.request.path)
def pre_get(self):
"""Handles authentication/authorization."""
# authenticate the request before opening the websocket
user = self.current_user
if user is None:
self.log.warning("Couldn't authenticate WebSocket connection")
raise web.HTTPError(403)
# authorize the user.
if not self.authorizer.is_authorized(self, user, "execute", "events"):
raise web.HTTPError(403)
async def get(self, *args, **kwargs):
"""Get an event socket."""
self.pre_get()
res = super().get(*args, **kwargs)
await res
def get_chat_user(self) -> ChatUser:
""" Retrieves the current user synthesized from the server's current shell
environment."""
login = getpass.getuser()
initials = login[0].capitalize()
return ChatUser(
username=login,
initials=initials,
name=login,
display_name=login,
color=None,
avatar_url=None,
)
def generate_client_id(self):
"""Generates a client ID to identify the current WS connection."""
return uuid.uuid4().hex
def open(self):
"""Handles opening of a WebSocket connection. Client ID can be retrieved
from `self.client_id`."""
current_user = self.get_chat_user().dict()
client_id = self.generate_client_id()
self.root_chat_handlers[client_id] = self
self.chat_clients[client_id] = ChatClient(**current_user, id=client_id)
self.client_id = client_id
self.write_message(ConnectionMessage(client_id=client_id).dict())
self.log.info(f"Client connected. ID: {client_id}")
self.log.debug("Clients are : %s", self.root_chat_handlers.keys())
def broadcast_message(self, message: Message):
"""Broadcasts message to all connected clients.
Appends message to chat history.
"""
self.log.debug("Broadcasting message: %s to all clients...", message)
client_ids = self.root_chat_handlers.keys()
for client_id in client_ids:
client = self.root_chat_handlers[client_id]
if client:
client.write_message(message.dict())
# Only append ChatMessage instances to history, not control messages
if isinstance(message, ChatMessage):
self.chat_history.append(message)
async def on_message(self, message):
self.log.debug("Message received: %s", message)
try:
message = json.loads(message)
chat_request = ChatRequest(**message)
except ValidationError as e:
self.log.error(e)
return
# message broadcast to chat clients
if not chat_request.id:
chat_request.id = str(uuid.uuid4())
chat_message = ChatMessage(
id=chat_request.id,
time=time.time(),
body=chat_request.body,
sender=self.chat_client,
)
# broadcast the message to other clients
self.broadcast_message(message=chat_message)
def on_close(self):
self.log.debug("Disconnecting client with user %s", self.client_id)
self.root_chat_handlers.pop(self.client_id, None)
self.chat_clients.pop(self.client_id, None)
self.log.info(f"Client disconnected. ID: {self.client_id}")
self.log.debug("Chat clients: %s", self.root_chat_handlers.keys())
class GlobalConfigHandler(BaseAPIHandler):
"""API handler for fetching and setting the
model and emebddings config.
"""
@property
def config_manager(self):
return self.settings["chat_config_manager"]
@web.authenticated
def get(self):
config = self.config_manager.get_config()
if not config:
raise web.HTTPError(500, "No config found.")
self.finish(config.json())
@web.authenticated
def post(self):
try:
config = UpdateConfigRequest(**self.get_json_body())
self.config_manager.update_config(config)
self.set_status(204)
self.finish()
except (ValidationError, WriteConflictError) as e:
self.log.exception(e)
raise web.HTTPError(500, str(e)) from e
except ValueError as e:
self.log.exception(e)
raise web.HTTPError(500, str(e.cause) if hasattr(e, "cause") else str(e))
except Exception as e:
self.log.exception(e)
raise web.HTTPError(
500, "Unexpected error occurred while updating the config."
) from e