diff --git a/src/config.py b/src/config.py index 8f8a1a78..59bdd857 100644 --- a/src/config.py +++ b/src/config.py @@ -11,7 +11,7 @@ class Config(BaseSettings): DATABASE_POOL_SIZE: int = 20 DATABASE_POOL_TTL: int = 60 * 20 # 20 minutes DATABASE_POOL_PRE_PING: bool = True - DATABASE_POOL_TIMEOUT: int = 5 # seconds to wait for a connection before raising; keep low to fail fast + DATABASE_POOL_TIMEOUT: int = 5 # seconds; keep low to fail fast on pool exhaustion REDIS_URL: RedisDsn REDIS_HEALTH_CHECK_INTERVAL: int = 30 diff --git a/src/database.py b/src/database.py index 54238685..1cab0763 100644 --- a/src/database.py +++ b/src/database.py @@ -20,10 +20,12 @@ Table, UniqueConstraint, Update, + event, func, text, ) from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.exc import DBAPIError from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.pool import NullPool @@ -58,6 +60,30 @@ engine = create_async_engine(DATABASE_URL, **_engine_kwargs) + +@event.listens_for(engine.sync_engine, "handle_error") +def _mark_stale_connection_as_disconnect(context: object) -> None: + """Tell SQLAlchemy to invalidate pool connections that asyncpg reports as closed. + + pool_pre_ping catches most stale connections, but a race is possible: the server + closes the connection between the ping and the actual query. When that happens + asyncpg raises ConnectionDoesNotExistError (a subclass of InterfaceError). + Marking it as a disconnect causes the pool to discard that connection so it is + never handed out again. + """ + original = getattr(context, "original_exception", None) + if original is not None and type(original).__name__ == "ConnectionDoesNotExistError": + context.is_disconnect = True # type: ignore[union-attr] + + +def _is_stale_connection_error(exc: BaseException) -> bool: + return ( + isinstance(exc, DBAPIError) + and exc.__cause__ is not None + and type(exc.__cause__).__name__ == "ConnectionDoesNotExistError" + ) + + metadata = MetaData(naming_convention=DB_NAMING_CONVENTION) @@ -507,6 +533,16 @@ async def fetch_one( select_query: Select | Insert | Update, params: dict[str, Any] | None = None, ) -> dict[str, Any] | None: + try: + async with engine.begin() as conn: + cursor: CursorResult = await conn.execute(select_query, params or {}) + row = cursor.first() + return row._asdict() if row is not None else None + except Exception as exc: + if not _is_stale_connection_error(exc): + raise + # Retry once — pool_pre_ping catches most stale connections, but a race can + # still return a connection that Postgres closed between the ping and the query. async with engine.begin() as conn: cursor: CursorResult = await conn.execute(select_query, params or {}) row = cursor.first() @@ -517,6 +553,13 @@ async def fetch_all( select_query: Select | Insert | Update, params: dict[str, Any] | None = None, ) -> list[dict[str, Any]]: + try: + async with engine.begin() as conn: + cursor: CursorResult = await conn.execute(select_query, params or {}) + return [r._asdict() for r in cursor.all()] + except Exception as exc: + if not _is_stale_connection_error(exc): + raise async with engine.begin() as conn: cursor: CursorResult = await conn.execute(select_query, params or {}) return [r._asdict() for r in cursor.all()] @@ -526,5 +569,11 @@ async def execute( select_query: Insert | Update, params: dict[str, Any] | None = None, ) -> CursorResult: + try: + async with engine.begin() as conn: + return await conn.execute(select_query, params or {}) + except Exception as exc: + if not _is_stale_connection_error(exc): + raise async with engine.begin() as conn: return await conn.execute(select_query, params or {}) diff --git a/src/flows/storage/describe_memes.py b/src/flows/storage/describe_memes.py index 2e0d8214..8e0456d3 100644 --- a/src/flows/storage/describe_memes.py +++ b/src/flows/storage/describe_memes.py @@ -38,7 +38,7 @@ VISION_MODELS = [ "google/gemma-3-27b-it:free", # best quality, 140+ languages, 131k context "google/gemma-3-12b-it:free", # good fallback, smaller, 32k context - "google/gemma-3-4b-it:free", # smallest Gemma, lowest rate limits hit + "google/gemma-3-4b-it:free", # smallest Gemma, lowest rate limits hit "nvidia/nemotron-nano-12b-v2-vl:free", # last resort: still listed but unreliable JSON ] diff --git a/src/tgbot/handlers/chat/agent/runner.py b/src/tgbot/handlers/chat/agent/runner.py index e224bd8b..e9187972 100644 --- a/src/tgbot/handlers/chat/agent/runner.py +++ b/src/tgbot/handlers/chat/agent/runner.py @@ -110,8 +110,16 @@ async def run_chat_agent( tool_calls_count = 0 for raw in result.raw_responses: if hasattr(raw, "usage") and raw.usage: - total_prompt_tokens += getattr(raw.usage, "input_tokens", None) or getattr(raw.usage, "prompt_tokens", None) or 0 - total_completion_tokens += getattr(raw.usage, "output_tokens", None) or getattr(raw.usage, "completion_tokens", None) or 0 + total_prompt_tokens += ( + getattr(raw.usage, "input_tokens", None) + or getattr(raw.usage, "prompt_tokens", None) + or 0 + ) + total_completion_tokens += ( + getattr(raw.usage, "output_tokens", None) + or getattr(raw.usage, "completion_tokens", None) + or 0 + ) for choice in getattr(raw, "choices", []): if hasattr(choice, "message") and choice.message and choice.message.tool_calls: tool_calls_count += len(choice.message.tool_calls) diff --git a/src/tgbot/handlers/chat/feedback.py b/src/tgbot/handlers/chat/feedback.py index cfd5a858..b1fd7390 100644 --- a/src/tgbot/handlers/chat/feedback.py +++ b/src/tgbot/handlers/chat/feedback.py @@ -37,6 +37,8 @@ async def handle_feedback_message(update: Update, context: ContextTypes.DEFAULT_ async def handle_feedback_reply(update: Update, context: ContextTypes.DEFAULT_TYPE): reply_text = update.message.text + if not reply_text: + return # Admin sent a non-text message (sticker, voice, etc.) — can't forward header, _ = update.message.reply_to_message.text.split("\n", 1) user_id, message_id = header.split(":") diff --git a/src/tgbot/handlers/stats/wrapped.py b/src/tgbot/handlers/stats/wrapped.py index 73c00a8f..6e5eba8c 100644 --- a/src/tgbot/handlers/stats/wrapped.py +++ b/src/tgbot/handlers/stats/wrapped.py @@ -1008,9 +1008,7 @@ async def generate_wrapped_data( # Pick oneliner meme (avoid your_meme) oneliner_meme_id = None if liked: - oneliner_candidates = [ - m for m in liked[:10] if m["meme_id"] not in global_used_memes - ] + oneliner_candidates = [m for m in liked[:10] if m["meme_id"] not in global_used_memes] if oneliner_candidates: oneliner_meme_id = random.choice(oneliner_candidates)["meme_id"] else: @@ -1187,9 +1185,7 @@ def _build_zodiac_slide(p: dict, is_ru: bool = True) -> str: return f"{header}\n\n" f"{html_escape(sign)}\n\n" f"{html_escape(why)}" -def _attach_memes_to_absurd( - p: dict, liked: list, used_ids: set | None = None -) -> list: +def _attach_memes_to_absurd(p: dict, liked: list, used_ids: set | None = None) -> list: """Attach meme IDs to each absurd comparison, ensuring no duplicates.""" comparisons = p.get("absurd_comparisons", []) result = []