Skip to content

Commit

Permalink
feat: Allow downloading republished episodes
Browse files Browse the repository at this point in the history
  • Loading branch information
janw committed Dec 2, 2024
1 parent 8f4ae3c commit 1249a67
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 36 deletions.
1 change: 1 addition & 0 deletions hack/rich-codex.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ export DELETED_FILES="deleted.txt"
export NO_CONFIRM="true"
export SKIP_GIT_CHECKS="true"
export CLEAN_IMG_PATHS='./assets/*.svg'
export CI=1

exec poetry run rich-codex
75 changes: 64 additions & 11 deletions podcast_archiver/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import sqlite3
from abc import abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import UTC, datetime
from threading import Lock
from typing import TYPE_CHECKING, Iterator

Expand All @@ -12,22 +14,42 @@
from podcast_archiver.models import Episode


def adapt_datetime_iso(val: datetime) -> str:
if not val.tzinfo:
val = val.replace(tzinfo=UTC)
return val.isoformat()


def convert_datetime_iso(val: bytes) -> datetime:
return datetime.fromisoformat(val.decode())


sqlite3.register_adapter(datetime, adapt_datetime_iso)
sqlite3.register_converter("TIMESTAMP", convert_datetime_iso)


@dataclass(frozen=True, slots=True)
class EpisodeInDb:
length: int | None = None
published_time: datetime | None = None


class BaseDatabase:
@abstractmethod
def add(self, episode: Episode) -> None:
pass # pragma: no cover

@abstractmethod
def exists(self, episode: Episode) -> bool:
def exists(self, episode: Episode) -> EpisodeInDb | None:
pass # pragma: no cover


class DummyDatabase(BaseDatabase):
def add(self, episode: Episode) -> None:
pass

def exists(self, episode: Episode) -> bool:
return False
def exists(self, episode: Episode) -> EpisodeInDb | None:
return None


class Database(BaseDatabase):
Expand All @@ -43,7 +65,8 @@ def __init__(self, filename: str, ignore_existing: bool) -> None:

@contextmanager
def get_conn(self) -> Iterator[sqlite3.Connection]:
with self.lock, sqlite3.connect(self.filename) as conn:
with self.lock, sqlite3.connect(self.filename, detect_types=sqlite3.PARSE_DECLTYPES) as conn:
conn.row_factory = sqlite3.Row
yield conn

def migrate(self) -> None:
Expand All @@ -53,26 +76,56 @@ def migrate(self) -> None:
"""\
CREATE TABLE IF NOT EXISTS episodes(
guid TEXT UNIQUE NOT NULL,
title TEXT
title TEXT,
length UNSIGNED BIG INT,
published_time TIMESTAMP
)"""
)

self._add_column_if_missing(
"length",
"ALTER TABLE episodes ADD COLUMN length UNSIGNED BIG INT",
)
self._add_column_if_missing(
"published_time",
"ALTER TABLE episodes ADD COLUMN published_time TIMESTAMP",
)

def _add_column_if_missing(self, name: str, alter_stmt: str) -> None:
with self.get_conn() as conn:
if not self._has_column(conn, name):
logger.debug(f"Adding missing DB column {name}")
conn.execute(alter_stmt)

def _has_column(self, conn: sqlite3.Connection, name: str) -> bool:
result = conn.execute(
"SELECT EXISTS(SELECT 1 FROM pragma_table_info('episodes') WHERE name = ?)",
(name,),
)
return bool(result.fetchone()[0])

def add(self, episode: Episode) -> None:
with self.get_conn() as conn:
try:
conn.execute(
"INSERT INTO episodes(guid, title) VALUES (?, ?)",
(episode.guid, episode.title),
"INSERT OR REPLACE INTO episodes(guid, title, length, published_time) VALUES (?, ?, ?, ?)",
(
episode.guid,
episode.title,
episode.enclosure.length,
episode.published_time,
),
)
except sqlite3.IntegrityError:
logger.debug(f"Episode exists: {episode}")

def exists(self, episode: Episode) -> bool:
def exists(self, episode: Episode) -> EpisodeInDb | None:
if self.ignore_existing:
return False
return None
with self.get_conn() as conn:
result = conn.execute(
"SELECT EXISTS(SELECT 1 FROM episodes WHERE guid = ?)",
"SELECT length, published_time FROM episodes WHERE guid = ?",
(episode.guid,),
)
return bool(result.fetchone()[0])
match = result.fetchone()
return EpisodeInDb(**match) if match else None
3 changes: 0 additions & 3 deletions podcast_archiver/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ def __call__(self) -> EpisodeResult:
return EpisodeResult(self.episode, DownloadResult.FAILED)

def run(self) -> EpisodeResult:
if self.target.exists():
return EpisodeResult(self.episode, DownloadResult.ALREADY_EXISTS)

self.target.parent.mkdir(parents=True, exist_ok=True)
logger.info("Downloading: %s", self.episode)
response = session.get_and_raise(self.episode.enclosure.href, stream=True)
Expand Down
20 changes: 18 additions & 2 deletions podcast_archiver/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from contextlib import suppress
from datetime import datetime, timezone
from functools import cached_property
from http import HTTPStatus
Expand All @@ -14,6 +15,9 @@
BaseModel,
ConfigDict,
Field,
ValidationError,
ValidatorFunctionWrapHandler,
WrapValidator,
field_validator,
model_validator,
)
Expand All @@ -29,19 +33,31 @@
from requests import Response


def parse_from_struct_time(value: Any) -> Any:
def parse_from_struct_time(value: struct_time | datetime) -> datetime:
if isinstance(value, struct_time):
return datetime.fromtimestamp(mktime(value)).replace(tzinfo=timezone.utc)
value = datetime.fromtimestamp(mktime(value)).replace(tzinfo=timezone.utc)

# value.astimezone(UTC)
return value


def val_or_none(value: Any, handler: ValidatorFunctionWrapHandler) -> Any:
with suppress(ValidationError):
return handler(value)
return None


FallbackToNone = WrapValidator(val_or_none)

LenientDatetime = Annotated[datetime, BeforeValidator(parse_from_struct_time)]
LenientInt = Annotated[int | None, FallbackToNone]


class Link(BaseModel):
rel: str = ""
link_type: str = Field("", alias="type")
href: str
length: LenientInt = None


class Chapter(BaseModel):
Expand Down
49 changes: 38 additions & 11 deletions podcast_archiver/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,43 @@ def process(self, url: str) -> ProcessingResult:
rprint(f"\n[bar.finished]✔ {completion_msg} for: {result.feed}[/]")
return result

def _preflight_check(self, episode: Episode, target: Path) -> DownloadResult | None:
if self.database.exists(episode):
logger.debug("Pre-flight check on episode '%s': already in database.", episode)
return DownloadResult.ALREADY_EXISTS

if target.exists():
logger.debug("Pre-flight check on episode '%s': already on disk.", episode)
return DownloadResult.ALREADY_EXISTS
def _does_already_exist(self, episode: Episode, *, target: Path) -> bool:
if not (existing := self.database.exists(episode)):
# NOTE on backwards-compatibility: if the episode is not in the DB we'd normally
# download it again outright. This might cause a complete replacement of
# episodes on disk for existing users who either used pre-v1.4 until now or
# always have `ignore_database` enabled.
#
# To avoid that, we fall back to the on-disk check if the episode is not in
# the DB (includes `ignore_database`). If the episode is indeed in the DB,
# we do the additional checks to possibly re-download an episode if it was
# republished/changed.
if target.exists():
logger.debug("Episode '%s': not in db but on disk", episode)
return True
logger.debug("Episode '%s': not in db", episode)
return False

if existing.length and episode.enclosure.length and existing.length != episode.enclosure.length:
logger.debug(
"Episode '%s': length differs in feed: %s (%s in db)",
episode,
episode.enclosure.length,
existing.length,
)
return False

if existing.published_time and episode.published_time and episode.published_time > existing.published_time:
logger.debug(
"Episode '%s': is newer in feed: %s (by %s sec)",
episode,
episode.published_time,
(episode.published_time - existing.published_time).total_seconds(),
)
return False

return None
logger.debug("Episode '%s': already in database.", episode)
return True

def _process_episodes(self, feed: Feed) -> tuple[EpisodeResultsList, QueueCompletionType]:
results: EpisodeResultsList = []
Expand All @@ -86,9 +113,9 @@ def _process_episode(
self, episode: Episode, feed_info: FeedInfo, results: EpisodeResultsList
) -> QueueCompletionType | None:
target = self.filename_formatter.format(episode=episode, feed_info=feed_info)
if result := self._preflight_check(episode, target):
if result := self._does_already_exist(episode, target=target):
rprint(f"[bar.finished]✔ {result}: {episode}[/]")
results.append(EpisodeResult(episode, result))
results.append(EpisodeResult(episode, DownloadResult.ALREADY_EXISTS))
if self.settings.update_archive:
logger.debug("Up to date with %r", episode)
return QueueCompletionType.FOUND_EXISTING
Expand Down
7 changes: 4 additions & 3 deletions tests/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,17 @@ def test_download_job(tmp_path_cd: Path, feedobj_lautsprecher: dict[str, Any]) -
assert result == (episode, DownloadResult.COMPLETED_SUCCESSFULLY)


def test_download_already_exists(tmp_path_cd: Path, feedobj_lautsprecher_notconsumed: dict[str, Any]) -> None:
feed = FeedPage.model_validate(feedobj_lautsprecher_notconsumed)
def test_download_already_exists(tmp_path_cd: Path, feedobj_lautsprecher: dict[str, Any]) -> None:
feed = FeedPage.model_validate(feedobj_lautsprecher)
episode = feed.episodes[0]

job = download.DownloadJob(episode=episode, target=Path("file.mp3"))
job.target.parent.mkdir(exist_ok=True)
job.target.touch()
result = job()

assert result == (episode, DownloadResult.ALREADY_EXISTS)
# behavioral change: DownloadJob no longer cares if the file exists; relies on DB only.
assert result == (episode, DownloadResult.COMPLETED_SUCCESSFULLY)


def test_download_partial(
Expand Down
16 changes: 10 additions & 6 deletions tests/test_processor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

from datetime import UTC, datetime
from pathlib import Path
from typing import TYPE_CHECKING
from unittest.mock import patch

import pytest

from podcast_archiver.config import Settings
from podcast_archiver.enums import DownloadResult
from podcast_archiver.database import EpisodeInDb
from podcast_archiver.models import FeedPage
from podcast_archiver.processor import FeedProcessor, ProcessingResult

Expand All @@ -19,17 +20,20 @@
@pytest.mark.parametrize(
"file_exists,database_exists,expected_result",
[
(False, False, None),
(True, False, DownloadResult.ALREADY_EXISTS),
(False, True, DownloadResult.ALREADY_EXISTS),
(False, False, False),
(True, False, True),
(False, EpisodeInDb(), True),
(True, EpisodeInDb(length=1), False),
(True, EpisodeInDb(published_time=datetime(1970, 1, 1, tzinfo=UTC)), False),
(True, EpisodeInDb(published_time=datetime(2999, 1, 1, tzinfo=UTC)), True),
],
)
def test_preflight_check(
tmp_path_cd: Path,
feedobj_lautsprecher: Url,
file_exists: bool,
database_exists: bool,
expected_result: DownloadResult | None,
expected_result: bool,
) -> None:
settings = Settings()
feed = FeedPage.model_validate(feedobj_lautsprecher)
Expand All @@ -39,7 +43,7 @@ def test_preflight_check(
if file_exists:
target.touch()
with patch.object(proc.database, "exists", return_value=database_exists):
result = proc._preflight_check(episode, target=target)
result = proc._does_already_exist(episode, target=target)

assert result == expected_result

Expand Down

0 comments on commit 1249a67

Please sign in to comment.