diff --git a/src/config.py b/src/config.py
index 8f8a1a78..59bdd857 100644
--- a/src/config.py
+++ b/src/config.py
@@ -11,7 +11,7 @@ class Config(BaseSettings):
DATABASE_POOL_SIZE: int = 20
DATABASE_POOL_TTL: int = 60 * 20 # 20 minutes
DATABASE_POOL_PRE_PING: bool = True
- DATABASE_POOL_TIMEOUT: int = 5 # seconds to wait for a connection before raising; keep low to fail fast
+ DATABASE_POOL_TIMEOUT: int = 5 # seconds; keep low to fail fast on pool exhaustion
REDIS_URL: RedisDsn
REDIS_HEALTH_CHECK_INTERVAL: int = 30
diff --git a/src/database.py b/src/database.py
index 54238685..1cab0763 100644
--- a/src/database.py
+++ b/src/database.py
@@ -20,10 +20,12 @@
Table,
UniqueConstraint,
Update,
+ event,
func,
text,
)
from sqlalchemy.dialects.postgresql import JSONB
+from sqlalchemy.exc import DBAPIError
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.pool import NullPool
@@ -58,6 +60,30 @@
engine = create_async_engine(DATABASE_URL, **_engine_kwargs)
+
+@event.listens_for(engine.sync_engine, "handle_error")
+def _mark_stale_connection_as_disconnect(context: object) -> None:
+ """Tell SQLAlchemy to invalidate pool connections that asyncpg reports as closed.
+
+ pool_pre_ping catches most stale connections, but a race is possible: the server
+ closes the connection between the ping and the actual query. When that happens
+ asyncpg raises ConnectionDoesNotExistError (a subclass of InterfaceError).
+ Marking it as a disconnect causes the pool to discard that connection so it is
+ never handed out again.
+ """
+ original = getattr(context, "original_exception", None)
+ if original is not None and type(original).__name__ == "ConnectionDoesNotExistError":
+ context.is_disconnect = True # type: ignore[union-attr]
+
+
+def _is_stale_connection_error(exc: BaseException) -> bool:
+ return (
+ isinstance(exc, DBAPIError)
+ and exc.__cause__ is not None
+ and type(exc.__cause__).__name__ == "ConnectionDoesNotExistError"
+ )
+
+
metadata = MetaData(naming_convention=DB_NAMING_CONVENTION)
@@ -507,6 +533,16 @@ async def fetch_one(
select_query: Select | Insert | Update,
params: dict[str, Any] | None = None,
) -> dict[str, Any] | None:
+ try:
+ async with engine.begin() as conn:
+ cursor: CursorResult = await conn.execute(select_query, params or {})
+ row = cursor.first()
+ return row._asdict() if row is not None else None
+ except Exception as exc:
+ if not _is_stale_connection_error(exc):
+ raise
+ # Retry once — pool_pre_ping catches most stale connections, but a race can
+ # still return a connection that Postgres closed between the ping and the query.
async with engine.begin() as conn:
cursor: CursorResult = await conn.execute(select_query, params or {})
row = cursor.first()
@@ -517,6 +553,13 @@ async def fetch_all(
select_query: Select | Insert | Update,
params: dict[str, Any] | None = None,
) -> list[dict[str, Any]]:
+ try:
+ async with engine.begin() as conn:
+ cursor: CursorResult = await conn.execute(select_query, params or {})
+ return [r._asdict() for r in cursor.all()]
+ except Exception as exc:
+ if not _is_stale_connection_error(exc):
+ raise
async with engine.begin() as conn:
cursor: CursorResult = await conn.execute(select_query, params or {})
return [r._asdict() for r in cursor.all()]
@@ -526,5 +569,11 @@ async def execute(
select_query: Insert | Update,
params: dict[str, Any] | None = None,
) -> CursorResult:
+ try:
+ async with engine.begin() as conn:
+ return await conn.execute(select_query, params or {})
+ except Exception as exc:
+ if not _is_stale_connection_error(exc):
+ raise
async with engine.begin() as conn:
return await conn.execute(select_query, params or {})
diff --git a/src/flows/storage/describe_memes.py b/src/flows/storage/describe_memes.py
index 2e0d8214..8e0456d3 100644
--- a/src/flows/storage/describe_memes.py
+++ b/src/flows/storage/describe_memes.py
@@ -38,7 +38,7 @@
VISION_MODELS = [
"google/gemma-3-27b-it:free", # best quality, 140+ languages, 131k context
"google/gemma-3-12b-it:free", # good fallback, smaller, 32k context
- "google/gemma-3-4b-it:free", # smallest Gemma, lowest rate limits hit
+ "google/gemma-3-4b-it:free", # smallest Gemma, lowest rate limits hit
"nvidia/nemotron-nano-12b-v2-vl:free", # last resort: still listed but unreliable JSON
]
diff --git a/src/tgbot/handlers/chat/agent/runner.py b/src/tgbot/handlers/chat/agent/runner.py
index e224bd8b..e9187972 100644
--- a/src/tgbot/handlers/chat/agent/runner.py
+++ b/src/tgbot/handlers/chat/agent/runner.py
@@ -110,8 +110,16 @@ async def run_chat_agent(
tool_calls_count = 0
for raw in result.raw_responses:
if hasattr(raw, "usage") and raw.usage:
- total_prompt_tokens += getattr(raw.usage, "input_tokens", None) or getattr(raw.usage, "prompt_tokens", None) or 0
- total_completion_tokens += getattr(raw.usage, "output_tokens", None) or getattr(raw.usage, "completion_tokens", None) or 0
+ total_prompt_tokens += (
+ getattr(raw.usage, "input_tokens", None)
+ or getattr(raw.usage, "prompt_tokens", None)
+ or 0
+ )
+ total_completion_tokens += (
+ getattr(raw.usage, "output_tokens", None)
+ or getattr(raw.usage, "completion_tokens", None)
+ or 0
+ )
for choice in getattr(raw, "choices", []):
if hasattr(choice, "message") and choice.message and choice.message.tool_calls:
tool_calls_count += len(choice.message.tool_calls)
diff --git a/src/tgbot/handlers/chat/feedback.py b/src/tgbot/handlers/chat/feedback.py
index cfd5a858..b1fd7390 100644
--- a/src/tgbot/handlers/chat/feedback.py
+++ b/src/tgbot/handlers/chat/feedback.py
@@ -37,6 +37,8 @@ async def handle_feedback_message(update: Update, context: ContextTypes.DEFAULT_
async def handle_feedback_reply(update: Update, context: ContextTypes.DEFAULT_TYPE):
reply_text = update.message.text
+ if not reply_text:
+ return # Admin sent a non-text message (sticker, voice, etc.) — can't forward
header, _ = update.message.reply_to_message.text.split("\n", 1)
user_id, message_id = header.split(":")
diff --git a/src/tgbot/handlers/stats/wrapped.py b/src/tgbot/handlers/stats/wrapped.py
index 73c00a8f..6e5eba8c 100644
--- a/src/tgbot/handlers/stats/wrapped.py
+++ b/src/tgbot/handlers/stats/wrapped.py
@@ -1008,9 +1008,7 @@ async def generate_wrapped_data(
# Pick oneliner meme (avoid your_meme)
oneliner_meme_id = None
if liked:
- oneliner_candidates = [
- m for m in liked[:10] if m["meme_id"] not in global_used_memes
- ]
+ oneliner_candidates = [m for m in liked[:10] if m["meme_id"] not in global_used_memes]
if oneliner_candidates:
oneliner_meme_id = random.choice(oneliner_candidates)["meme_id"]
else:
@@ -1187,9 +1185,7 @@ def _build_zodiac_slide(p: dict, is_ru: bool = True) -> str:
return f"{header}\n\n" f"{html_escape(sign)}\n\n" f"{html_escape(why)}"
-def _attach_memes_to_absurd(
- p: dict, liked: list, used_ids: set | None = None
-) -> list:
+def _attach_memes_to_absurd(p: dict, liked: list, used_ids: set | None = None) -> list:
"""Attach meme IDs to each absurd comparison, ensuring no duplicates."""
comparisons = p.get("absurd_comparisons", [])
result = []