Skip to content

Commit 5be168b

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 2cf6fb8 commit 5be168b

22 files changed

Lines changed: 134 additions & 139 deletions

ai_chatbots/api.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import json
44
import logging
5-
from typing import Any, Optional, Union
5+
from typing import Any, Union
66
from uuid import uuid4
77

88
from channels.db import database_sync_to_async
@@ -96,14 +96,14 @@ def serialize_tool_calls(tool_calls: list[dict]) -> list[dict]:
9696

9797

9898
@database_sync_to_async
99-
def query_tutorbot_output(thread_id: str) -> Optional[TutorBotOutput]:
99+
def query_tutorbot_output(thread_id: str) -> TutorBotOutput | None:
100100
"""Return the latest TutorBotOutput for a given thread_id"""
101101
return TutorBotOutput.objects.filter(thread_id=thread_id).last()
102102

103103

104104
@database_sync_to_async
105105
def create_tutorbot_output_and_checkpoints(
106-
thread_id: str, chat_json: Union[str, dict], edx_module_id: Optional[str]
106+
thread_id: str, chat_json: Union[str, dict], edx_module_id: str | None
107107
) -> tuple[TutorBotOutput, list[DjangoCheckpoint]]:
108108
"""Atomically create both TutorBotOutput and DjangoCheckpoint objects"""
109109
with transaction.atomic():
@@ -133,7 +133,7 @@ def _should_create_checkpoint(msg: dict) -> bool:
133133

134134

135135
def _identify_new_messages(
136-
filtered_messages: list[dict], previous_chat_json: Optional[Union[str, dict]]
136+
filtered_messages: list[dict], previous_chat_json: Union[str, dict] | None
137137
) -> list[dict]:
138138
"""Identify which messages are new by comparing with previous chat data."""
139139
if not previous_chat_json:
@@ -222,7 +222,7 @@ def _create_checkpoint_metadata(
222222
def create_tutor_checkpoints(
223223
thread_id: str,
224224
chat_json: Union[str, dict],
225-
previous_chat_json: Optional[Union[str, dict]] = None,
225+
previous_chat_json: Union[str, dict] | None = None,
226226
) -> list[DjangoCheckpoint]:
227227
"""Create DjangoCheckpoint records from tutor chat data (synchronous)"""
228228
# Get the associated session

ai_chatbots/chatbots.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from abc import ABC, abstractmethod
77
from collections.abc import AsyncGenerator
88
from operator import add
9-
from typing import Annotated, Any, Optional
9+
from typing import Annotated, Any
1010
from uuid import uuid4
1111

1212
import posthog
@@ -73,10 +73,10 @@ def __init__( # noqa: PLR0913
7373
checkpointer: BaseCheckpointSaver,
7474
*,
7575
name: str = "MIT Open Learning Chatbot",
76-
model: Optional[str] = None,
77-
temperature: Optional[float] = None,
78-
instructions: Optional[str] = None,
79-
thread_id: Optional[str] = None,
76+
model: str | None = None,
77+
temperature: float | None = None,
78+
instructions: str | None = None,
79+
thread_id: str | None = None,
8080
):
8181
"""Initialize the AI chat agent service"""
8282
self.bot_name = name
@@ -224,7 +224,7 @@ async def validate_and_clean_checkpoint(self) -> None:
224224
except Exception:
225225
log.exception("Error while cleaning checkpoint")
226226

227-
async def _get_latest_checkpoint_id(self) -> Optional[str]:
227+
async def _get_latest_checkpoint_id(self) -> str | None:
228228
"""Get the most recent assistant response checkpoint"""
229229
checkpoint = (
230230
await DjangoCheckpoint.objects.prefetch_related("session", "session__user")
@@ -237,7 +237,7 @@ async def _get_latest_checkpoint_id(self) -> Optional[str]:
237237
return checkpoint.id if checkpoint else None
238238

239239
async def set_callbacks(
240-
self, properties: Optional[dict] = None
240+
self, properties: dict | None = None
241241
) -> list[CallbackHandler]:
242242
"""Set callbacks for the agent LLM"""
243243
if settings.POSTHOG_PROJECT_API_KEY and settings.POSTHOG_API_HOST:
@@ -282,7 +282,7 @@ async def get_completion(
282282
self,
283283
message: str,
284284
*,
285-
extra_state: Optional[dict[str, Any]] = None,
285+
extra_state: dict[str, Any] | None = None,
286286
debug: bool = settings.AI_DEBUG,
287287
) -> AsyncGenerator[str, None]:
288288
"""
@@ -424,13 +424,13 @@ class ResourceRecommendationBot(TruncatingChatbot):
424424
def __init__( # noqa: PLR0913
425425
self,
426426
user_id: str,
427-
checkpointer: Optional[BaseCheckpointSaver] = None,
427+
checkpointer: BaseCheckpointSaver | None = None,
428428
*,
429429
name: str = "MIT Open Learning Chatbot",
430-
model: Optional[str] = None,
431-
temperature: Optional[float] = None,
432-
instructions: Optional[str] = None,
433-
thread_id: Optional[str] = None,
430+
model: str | None = None,
431+
temperature: float | None = None,
432+
instructions: str | None = None,
433+
thread_id: str | None = None,
434434
):
435435
"""Initialize the AI search agent service"""
436436
super().__init__(
@@ -466,7 +466,7 @@ class SyllabusAgentState(SummaryState):
466466
related_courses: Annotated[list[str], add]
467467
# str representation of a boolean value, because the
468468
# langgraph JsonPlusSerializer can't handle booleans
469-
exclude_canvas: Annotated[Optional[list[str]], add]
469+
exclude_canvas: Annotated[list[str] | None, add]
470470

471471

472472
class SyllabusBot(TruncatingChatbot):
@@ -483,11 +483,11 @@ def __init__( # noqa: PLR0913
483483
checkpointer: BaseCheckpointSaver,
484484
*,
485485
name: str = "MIT Open Learning Syllabus Chatbot",
486-
model: Optional[str] = None,
487-
temperature: Optional[float] = None,
488-
instructions: Optional[str] = None,
489-
thread_id: Optional[str] = None,
490-
enable_related_courses: Optional[bool] = False,
486+
model: str | None = None,
487+
temperature: float | None = None,
488+
instructions: str | None = None,
489+
thread_id: str | None = None,
490+
enable_related_courses: bool | None = False,
491491
):
492492
self.enable_related_courses = enable_related_courses
493493
super().__init__(
@@ -546,16 +546,16 @@ class TutorBot(BaseChatbot):
546546
def __init__( # noqa: PLR0913
547547
self,
548548
user_id: str,
549-
checkpointer: Optional[BaseCheckpointSaver] = BaseCheckpointSaver,
549+
checkpointer: BaseCheckpointSaver | None = BaseCheckpointSaver,
550550
*,
551551
name: str = "MIT Open Learning Tutor Chatbot",
552-
model: Optional[str] = None,
553-
temperature: Optional[float] = None,
554-
thread_id: Optional[str] = None,
555-
block_siblings: Optional[list[str]] = None,
556-
edx_module_id: Optional[str] = None,
557-
run_readable_id: Optional[str] = None,
558-
problem_set_title: Optional[str] = None,
552+
model: str | None = None,
553+
temperature: float | None = None,
554+
thread_id: str | None = None,
555+
block_siblings: list[str] | None = None,
556+
edx_module_id: str | None = None,
557+
run_readable_id: str | None = None,
558+
problem_set_title: str | None = None,
559559
):
560560
super().__init__(
561561
user_id,
@@ -600,7 +600,7 @@ async def get_completion(
600600
self,
601601
message: str,
602602
*,
603-
extra_state: Optional[dict[str, Any]] = None, # noqa: ARG002
603+
extra_state: dict[str, Any] | None = None, # noqa: ARG002
604604
debug: bool = settings.AI_DEBUG,
605605
) -> AsyncGenerator[str, None]:
606606
"""Call message_tutor with the user query and return the response"""
@@ -810,10 +810,10 @@ def __init__( # noqa: PLR0913
810810
checkpointer: BaseCheckpointSaver,
811811
*,
812812
name: str = "MIT Open Learning VideoGPT Chatbot",
813-
model: Optional[str] = None,
814-
temperature: Optional[float] = None,
815-
instructions: Optional[str] = None,
816-
thread_id: Optional[str] = None,
813+
model: str | None = None,
814+
temperature: float | None = None,
815+
instructions: str | None = None,
816+
thread_id: str | None = None,
817817
):
818818
super().__init__(
819819
user_id,

ai_chatbots/chatbots_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -975,9 +975,11 @@ async def test_tutor_get_completion(posthog_settings, mocker, variant):
975975
assert "Let's start by thinking about the problem. " in results
976976

977977
checkpoint = await database_sync_to_async(
978-
lambda: DjangoCheckpoint.objects.select_related("session")
979-
.filter(thread_id=thread_id)
980-
.last()
978+
lambda: (
979+
DjangoCheckpoint.objects.select_related("session")
980+
.filter(thread_id=thread_id)
981+
.last()
982+
)
981983
)()
982984
history = await database_sync_to_async(
983985
lambda: TutorBotOutput.objects.filter(thread_id=thread_id).last()

ai_chatbots/checkpointers.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from collections.abc import AsyncGenerator
55
from typing import (
66
Any,
7-
Optional,
87
)
98

109
from django.conf import settings
@@ -106,8 +105,8 @@ def _load_writes(
106105
def _parse_checkpoint_data(
107106
serde: JsonPlusSerializer,
108107
data: DjangoCheckpoint,
109-
pending_writes: Optional[list[PendingWrite]] = None,
110-
) -> Optional[CheckpointTuple]:
108+
pending_writes: list[PendingWrite] | None = None,
109+
) -> CheckpointTuple | None:
111110
"""
112111
Parse checkpoint data retrieved from the database.
113112
"""
@@ -163,9 +162,9 @@ async def create_with_session( # noqa: PLR0913
163162
thread_id: str,
164163
message: str,
165164
agent: str,
166-
user: Optional[USER_MODEL] = None,
167-
dj_session_key: Optional[str] = "",
168-
object_id: Optional[str] = "",
165+
user: USER_MODEL | None = None,
166+
dj_session_key: str | None = "",
167+
object_id: str | None = "",
169168
):
170169
"""
171170
Initialize the DjangoSaver and create a UserChatSession if applicable.
@@ -317,7 +316,7 @@ async def aput_writes(
317316
},
318317
)
319318

320-
async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
319+
async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
321320
"""Get a checkpoint tuple from the database asynchronously.
322321
323322
This method retrieves a checkpoint tuple from the database based on the
@@ -362,11 +361,11 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
362361

363362
async def alist(
364363
self,
365-
config: Optional[RunnableConfig],
364+
config: RunnableConfig | None,
366365
*,
367-
filter: Optional[dict[str, Any]] = None, # noqa: ARG002, A002
368-
before: Optional[RunnableConfig] = None,
369-
limit: Optional[int] = None,
366+
filter: dict[str, Any] | None = None, # noqa: ARG002, A002
367+
before: RunnableConfig | None = None,
368+
limit: int | None = None,
370369
) -> AsyncGenerator[CheckpointTuple, None]:
371370
"""List checkpoints from the database asynchronously.
372371

ai_chatbots/constants.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import dataclasses
44
import datetime
5-
from typing import Optional
65

76
from named_enum import ExtendedEnum
87

@@ -57,7 +56,7 @@ class ChatbotCookie:
5756
name: str
5857
value: str
5958
path: str = "/"
60-
max_age: Optional[datetime.datetime] = None
59+
max_age: datetime.datetime | None = None
6160

6261
def __str__(self) -> str:
6362
"""

ai_chatbots/consumers.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import logging
33
from abc import ABC, abstractmethod
44
from http.cookies import SimpleCookie
5-
from typing import Optional
65
from uuid import uuid4
76

87
import litellm
@@ -104,9 +103,9 @@ async def assign_thread_cookies(
104103
self,
105104
user: User,
106105
*,
107-
clear_history: Optional[bool] = False,
108-
thread_id: Optional[str] = None,
109-
object_id: Optional[str] = None,
106+
clear_history: bool | None = False,
107+
thread_id: str | None = None,
108+
object_id: str | None = None,
110109
) -> tuple[str, list[str]]:
111110
"""
112111
Extract and update separate cookie values for logged in vs anonymous users.
@@ -220,7 +219,7 @@ async def assign_thread_cookies(
220219
return current_thread_id, cookies
221220

222221
async def prepare_response(
223-
self, serializer: ChatRequestSerializer, object_id_field: Optional[str] = None
222+
self, serializer: ChatRequestSerializer, object_id_field: str | None = None
224223
) -> tuple[str, list[str]]:
225224
"""Prepare consumer for the API response"""
226225
if object_id_field:
@@ -249,9 +248,9 @@ def process_extra_state(self, data: dict) -> dict: # noqa: ARG002
249248

250249
async def start_response(
251250
self,
252-
thread_id: Optional[str] = None,
253-
status: Optional[int] = HTTP_200_OK,
254-
cookies: Optional[list[str]] = None,
251+
thread_id: str | None = None,
252+
status: int | None = HTTP_200_OK,
253+
cookies: list[str] | None = None,
255254
):
256255
headers = (
257256
[
@@ -507,7 +506,7 @@ def process_extra_state(self, data: dict) -> dict:
507506
def prepare_response(
508507
self,
509508
serializer: SyllabusChatRequestSerializer,
510-
object_id_field: Optional[str] = None,
509+
object_id_field: str | None = None,
511510
) -> tuple[str, list[str]]:
512511
"""Set the course id as the default object id field"""
513512
object_id_field = object_id_field or "course_id"
@@ -620,7 +619,7 @@ def create_chatbot(
620619
def prepare_response(
621620
self,
622621
serializer: TutorChatRequestSerializer,
623-
object_id_field: Optional[str] = None,
622+
object_id_field: str | None = None,
624623
) -> tuple[str, list[str]]:
625624
"""Set the edx_module_id as the default object id field"""
626625
object_id_field = object_id_field or "edx_module_id"
@@ -673,7 +672,7 @@ def create_chatbot(
673672
def prepare_response(
674673
self,
675674
serializer: TutorChatRequestSerializer,
676-
object_id_field: Optional[str] = None,
675+
object_id_field: str | None = None,
677676
) -> tuple[str, list[str]]:
678677
"""Set the edx_module_id as the default object id field"""
679678
object_id_field = "object_id"
@@ -752,7 +751,7 @@ def process_extra_state(self, data: dict) -> dict:
752751
def prepare_response(
753752
self,
754753
serializer: VideoGPTRequestSerializer,
755-
object_id_field: Optional[str] = None,
754+
object_id_field: str | None = None,
756755
) -> tuple[str, list[str]]:
757756
"""Set the problem code as the default object id field"""
758757
object_id_field = object_id_field or "transcript_asset_id"

ai_chatbots/consumers_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1017,7 +1017,7 @@ async def test_anonymous_user_login_session_association( # noqa: PLR0913
10171017
for i in range(3):
10181018
anon_consumer = anonymous_consumer_setup(test_session_key)
10191019
payload = {
1020-
"message": f"Anonymous question {i+1}",
1020+
"message": f"Anonymous question {i + 1}",
10211021
"course_id": "MITx+6.00.1x",
10221022
}
10231023
await anon_consumer.handle(json.dumps(payload))

0 commit comments

Comments
 (0)