diff --git a/nanapi/utils/anilist.py b/nanapi/utils/anilist.py index b14d0ff..33988a0 100644 --- a/nanapi/utils/anilist.py +++ b/nanapi/utils/anilist.py @@ -2,9 +2,10 @@ import logging import time from abc import ABC, abstractmethod +from collections.abc import Generator from dataclasses import dataclass, field from itertools import batched, count, filterfalse -from typing import Any, Generator, Generic, Optional, Self, Type, TypeVar +from typing import Any, Generic, Self, TypeVar, override import aiohttp import orjson @@ -172,7 +173,7 @@ class ALAPI: def __init__(self): self.low_priority_ready = asyncio.Event() self.low_priority_ready.set() - self.reset_task: Optional[asyncio.Task] = None + self.reset_task: asyncio.Task | None = None self._reset_at = 0 self.last_request_time = 0 self._remaining = ALAPI.RATE_LIMIT @@ -326,6 +327,16 @@ def load(self, media_ids: list[MediaSelectAllIdsResult]): ######### # Lists # ######### + + +@dataclass(frozen=True) +class ListEntry: + id_al: int + status: str + progress: int + score: int + + class Userlist: service: AnilistService @@ -334,11 +345,11 @@ def __init__(self, username: str): async def refresh( self, media_type: MediaType, al_low_priority: bool = False - ) -> tuple[list[dict], set[ALMedia]]: - return [], set() + ) -> tuple[set[ListEntry], set[ALMedia]]: + return set(), set() def to_edgedb( - self, media_type: MediaType, entries: list[dict], medias: set[ALMedia] + self, media_type: MediaType, entries: set[ListEntry], medias: set[ALMedia] ) -> dict[str, Any]: nodes = [m.to_edgedb() for m in medias] edgedb_data = dict( @@ -346,7 +357,7 @@ def to_edgedb( service=self.service.value, username=self.username, type=media_type.value, - entries=entries, + entries=list(entries), ) return edgedb_data @@ -359,18 +370,21 @@ def __str__(self): class ALUserlist(Userlist): service = AnilistService.ANILIST - async def refresh(self, media_type: MediaType, al_low_priority=False): + @override + async def refresh( + self, media_type: MediaType, al_low_priority=False + ) -> tuple[set[ListEntry], set[ALMedia]]: await super().refresh(media_type) # fetch updated list entries = await self.fetch_entries(media_type, al_low_priority) - new_entries = [] - new_medias = set() + new_entries = set[ListEntry]() + new_medias = set[ALMedia]() for entry in entries: almedia = ALMedia.model_validate(entry['media']) - new_entries.append( - dict( + new_entries.add( + ListEntry( id_al=almedia.id, status=entry['status'], progress=entry['progress'], @@ -438,7 +452,10 @@ class MALUserlist(Userlist): refresh_lock = asyncio.Lock() - async def refresh(self, media_type: MediaType, al_low_priority=False): + @override + async def refresh( + self, media_type: MediaType, al_low_priority=False + ) -> tuple[set[ListEntry], set[ALMedia]]: await super().refresh(media_type) for _ in range(3): @@ -450,13 +467,13 @@ async def refresh(self, media_type: MediaType, al_low_priority=False): await asyncio.sleep(1) else: logger.error(f'MALUserlist: refresh failed for {self.username}') - return [], set() + return set(), set() al_ids, new_medias = await self.get_al_ids( media_type, set(entry.node.id for entry in userlist), low_priority=al_low_priority ) - new_entries = [] + new_entries = set[ListEntry]() for entry in userlist: if entry.node.id is not None: repeating = ( @@ -473,11 +490,11 @@ async def refresh(self, media_type: MediaType, al_low_priority=False): ) if id_al := al_ids.get(entry.node.id, None): - new_entries.append( - dict( + new_entries.add( + ListEntry( id_al=id_al, status=status, - progress=progress, + progress=progress or 0, # FIXME: can be None? score=entry.list_status.score, ) ) @@ -556,7 +573,7 @@ def __str__(self): return f'https://myanimelist.net/profile/{self.username}' -SERVICE_USER_LIST: dict[AnilistService, Type[Userlist]] = { +SERVICE_USER_LIST: dict[AnilistService, type[Userlist]] = { AnilistService.ANILIST: ALUserlist, AnilistService.MYANIMELIST: MALUserlist, }