Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Config(BaseSettings):
DATABASE_POOL_SIZE: int = 20
DATABASE_POOL_TTL: int = 60 * 20 # 20 minutes
DATABASE_POOL_PRE_PING: bool = True
DATABASE_POOL_TIMEOUT: int = 5 # seconds to wait for a connection before raising; keep low to fail fast
DATABASE_POOL_TIMEOUT: int = 5 # seconds; keep low to fail fast on pool exhaustion

REDIS_URL: RedisDsn
REDIS_HEALTH_CHECK_INTERVAL: int = 30
Expand Down
49 changes: 49 additions & 0 deletions src/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
Table,
UniqueConstraint,
Update,
event,
func,
text,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.exc import DBAPIError
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.pool import NullPool

Expand Down Expand Up @@ -58,6 +60,30 @@

engine = create_async_engine(DATABASE_URL, **_engine_kwargs)


@event.listens_for(engine.sync_engine, "handle_error")
def _mark_stale_connection_as_disconnect(context: object) -> None:
"""Tell SQLAlchemy to invalidate pool connections that asyncpg reports as closed.

pool_pre_ping catches most stale connections, but a race is possible: the server
closes the connection between the ping and the actual query. When that happens
asyncpg raises ConnectionDoesNotExistError (a subclass of InterfaceError).
Marking it as a disconnect causes the pool to discard that connection so it is
never handed out again.
"""
original = getattr(context, "original_exception", None)
if original is not None and type(original).__name__ == "ConnectionDoesNotExistError":
context.is_disconnect = True # type: ignore[union-attr]


def _is_stale_connection_error(exc: BaseException) -> bool:
return (
isinstance(exc, DBAPIError)
and exc.__cause__ is not None
and type(exc.__cause__).__name__ == "ConnectionDoesNotExistError"
)


metadata = MetaData(naming_convention=DB_NAMING_CONVENTION)


Expand Down Expand Up @@ -507,6 +533,16 @@ async def fetch_one(
select_query: Select | Insert | Update,
params: dict[str, Any] | None = None,
) -> dict[str, Any] | None:
try:
async with engine.begin() as conn:
cursor: CursorResult = await conn.execute(select_query, params or {})
row = cursor.first()
return row._asdict() if row is not None else None
except Exception as exc:
if not _is_stale_connection_error(exc):
raise
# Retry once — pool_pre_ping catches most stale connections, but a race can
# still return a connection that Postgres closed between the ping and the query.
async with engine.begin() as conn:
cursor: CursorResult = await conn.execute(select_query, params or {})
row = cursor.first()
Expand All @@ -517,6 +553,13 @@ async def fetch_all(
select_query: Select | Insert | Update,
params: dict[str, Any] | None = None,
) -> list[dict[str, Any]]:
try:
async with engine.begin() as conn:
cursor: CursorResult = await conn.execute(select_query, params or {})
return [r._asdict() for r in cursor.all()]
except Exception as exc:
if not _is_stale_connection_error(exc):
raise
async with engine.begin() as conn:
cursor: CursorResult = await conn.execute(select_query, params or {})
return [r._asdict() for r in cursor.all()]
Expand All @@ -526,5 +569,11 @@ async def execute(
select_query: Insert | Update,
params: dict[str, Any] | None = None,
) -> CursorResult:
try:
async with engine.begin() as conn:
return await conn.execute(select_query, params or {})
except Exception as exc:
if not _is_stale_connection_error(exc):
raise
async with engine.begin() as conn:
return await conn.execute(select_query, params or {})
2 changes: 1 addition & 1 deletion src/flows/storage/describe_memes.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
VISION_MODELS = [
"google/gemma-3-27b-it:free", # best quality, 140+ languages, 131k context
"google/gemma-3-12b-it:free", # good fallback, smaller, 32k context
"google/gemma-3-4b-it:free", # smallest Gemma, lowest rate limits hit
"google/gemma-3-4b-it:free", # smallest Gemma, lowest rate limits hit
"nvidia/nemotron-nano-12b-v2-vl:free", # last resort: still listed but unreliable JSON
]

Expand Down
12 changes: 10 additions & 2 deletions src/tgbot/handlers/chat/agent/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,16 @@ async def run_chat_agent(
tool_calls_count = 0
for raw in result.raw_responses:
if hasattr(raw, "usage") and raw.usage:
total_prompt_tokens += getattr(raw.usage, "input_tokens", None) or getattr(raw.usage, "prompt_tokens", None) or 0
total_completion_tokens += getattr(raw.usage, "output_tokens", None) or getattr(raw.usage, "completion_tokens", None) or 0
total_prompt_tokens += (
getattr(raw.usage, "input_tokens", None)
or getattr(raw.usage, "prompt_tokens", None)
or 0
)
total_completion_tokens += (
getattr(raw.usage, "output_tokens", None)
or getattr(raw.usage, "completion_tokens", None)
or 0
)
for choice in getattr(raw, "choices", []):
if hasattr(choice, "message") and choice.message and choice.message.tool_calls:
tool_calls_count += len(choice.message.tool_calls)
Expand Down
2 changes: 2 additions & 0 deletions src/tgbot/handlers/chat/feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ async def handle_feedback_message(update: Update, context: ContextTypes.DEFAULT_

async def handle_feedback_reply(update: Update, context: ContextTypes.DEFAULT_TYPE):
reply_text = update.message.text
if not reply_text:
return # Admin sent a non-text message (sticker, voice, etc.) — can't forward
header, _ = update.message.reply_to_message.text.split("\n", 1)
user_id, message_id = header.split(":")

Expand Down
8 changes: 2 additions & 6 deletions src/tgbot/handlers/stats/wrapped.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,9 +1008,7 @@ async def generate_wrapped_data(
# Pick oneliner meme (avoid your_meme)
oneliner_meme_id = None
if liked:
oneliner_candidates = [
m for m in liked[:10] if m["meme_id"] not in global_used_memes
]
oneliner_candidates = [m for m in liked[:10] if m["meme_id"] not in global_used_memes]
if oneliner_candidates:
oneliner_meme_id = random.choice(oneliner_candidates)["meme_id"]
else:
Expand Down Expand Up @@ -1187,9 +1185,7 @@ def _build_zodiac_slide(p: dict, is_ru: bool = True) -> str:
return f"{header}\n\n" f"<b>{html_escape(sign)}</b>\n\n" f"<i>{html_escape(why)}</i>"


def _attach_memes_to_absurd(
p: dict, liked: list, used_ids: set | None = None
) -> list:
def _attach_memes_to_absurd(p: dict, liked: list, used_ids: set | None = None) -> list:
"""Attach meme IDs to each absurd comparison, ensuring no duplicates."""
comparisons = p.get("absurd_comparisons", [])
result = []
Expand Down
Loading