Skip to content

Commit

Permalink
Refactor: change row_factory from tuple to dict
Browse files Browse the repository at this point in the history
Make arguments in class initialisation not order-sensitive
Adjust existing calls to match the new format
Assign some default values based on existing calls
Fix inconsistent indentation
  • Loading branch information
Manitary committed Jun 23, 2023
1 parent 3950b59 commit b08843b
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 82 deletions.
97 changes: 51 additions & 46 deletions src/data/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ def living_in(the_database):
return None
return DatabaseDatabase(db)

def dict_factory(cursor, row) -> dict:
fields = [column[0] for column in cursor.description]
return dict(zip(fields, row))

# Database

def db_error(f):
Expand Down Expand Up @@ -50,9 +54,10 @@ def protected(*args, **kwargs):
return decorate

class DatabaseDatabase:
def __init__(self, db):
def __init__(self, db: sqlite3.Connection):
self._db = db
self.q = db.cursor()
self._db.row_factory = dict_factory
self.q = self._db.cursor()

# Set up collations
self._db.create_collation("alphanum", _collate_alphanum)
Expand All @@ -63,7 +68,7 @@ def __getattr__(self, attr):
return getattr(self._db, attr)

def get_count(self):
return self.q.fetchone()[0]
return self.q.fetchone()['count(*)']

def save(self):
self.commit()
Expand Down Expand Up @@ -217,19 +222,19 @@ def get_service(self, id=None, key=None) -> Optional[Service]:
error("ID or key required to get service")
return None
service = self.q.fetchone()
return Service(*service)
return Service(**service)

@db_error_default(list())
def get_services(self, enabled=True, disabled=False) -> List[Service]:
services = list()
if enabled:
self.q.execute("SELECT id, key, name, enabled, use_in_post FROM Services WHERE enabled = 1")
for service in self.q.fetchall():
services.append(Service(*service))
services.append(Service(**service))
if disabled:
self.q.execute("SELECT id, key, name, enabled, use_in_post FROM Services WHERE enabled = 0")
for service in self.q.fetchall():
services.append(Service(*service))
services.append(Service(**service))
return services

@db_error_default(None)
Expand All @@ -242,7 +247,7 @@ def get_stream(self, id=None, service_tuple=None) -> Optional[Stream]:
if stream is None:
error("Stream {} not found".format(id))
return None
stream = Stream(*stream)
stream = Stream(**stream)
elif service_tuple is not None:
service, show_key = service_tuple
debug("Getting stream for {}/{}".format(service, show_key))
Expand All @@ -252,7 +257,7 @@ def get_stream(self, id=None, service_tuple=None) -> Optional[Stream]:
if stream is None:
error("Stream {} not found".format(id))
return None
stream = Stream(*stream)
stream = Stream(**stream)
else:
error("Nothing provided to get stream")
return None
Expand Down Expand Up @@ -299,7 +304,7 @@ def get_streams(self, service=None, show=None, active=True, unmatched=False, mis
return list()

streams = self.q.fetchall()
streams = [Stream(*stream) for stream in streams]
streams = [Stream(**stream) for stream in streams]
for stream in streams:
stream.show = self.get_show(id=stream.show) # convert show id to show model
return streams
Expand Down Expand Up @@ -359,7 +364,7 @@ def get_lite_streams(self, service=None, show=None, missing_link=False) -> List[
return list()

lite_streams = self.q.fetchall()
lite_streams = [LiteStream(*lite_stream) for lite_stream in lite_streams]
lite_streams = [LiteStream(**lite_stream) for lite_stream in lite_streams]
return lite_streams

@db_error
Expand All @@ -381,19 +386,19 @@ def get_link_site(self, id:str=None, key:str=None) -> Optional[LinkSite]:
site = self.q.fetchone()
if site is None:
return None
return LinkSite(*site)
return LinkSite(**site)

@db_error_default(list())
def get_link_sites(self, enabled=True, disabled=False) -> List[LinkSite]:
sites = list()
if enabled:
self.q.execute("SELECT id, key, name, enabled FROM LinkSites WHERE enabled = 1")
for link in self.q.fetchall():
sites.append(LinkSite(*link))
sites.append(LinkSite(**link))
if disabled:
self.q.execute("SELECT id, key, name, enabled FROM LinkSites WHERE enabled = 0")
for link in self.q.fetchall():
sites.append(LinkSite(*link))
sites.append(LinkSite(**link))
return sites

@db_error_default(list())
Expand All @@ -404,7 +409,7 @@ def get_links(self, show:Show=None) -> List[Link]:
# Get all streams with show ID
self.q.execute("SELECT site, show, site_key FROM Links WHERE show = ?", (show.id,))
links = self.q.fetchall()
links = [Link(*link) for link in links]
links = [Link(**link) for link in links]
return links
else:
error("A show must be provided to get links")
Expand All @@ -418,7 +423,7 @@ def get_link(self, show: Show, link_site: LinkSite) -> Optional[Link]:
link = self.q.fetchone()
if link is None:
return None
link = Link(*link)
link = Link(**link)
return link

@db_error_default(False)
Expand Down Expand Up @@ -449,15 +454,15 @@ def add_link(self, raw_show: UnprocessedShow, show_id, commit=True):

# Shows
@db_error_default(list())
def get_shows(self, missing_length=False, missing_stream=False, enabled=True, delayed=False) -> [Show]:
def get_shows(self, missing_length=False, missing_stream=False, enabled=True, delayed=False) -> list[Show]:
shows = list()
if missing_length:
self.q.execute(
"SELECT id, name, name_en, length, type, has_source, is_nsfw, enabled, delayed FROM Shows \
"SELECT id, name, name_en, length, type AS show_type, has_source, is_nsfw, enabled, delayed FROM Shows \
WHERE (length IS NULL OR length = '' OR length = 0) AND enabled = ?", (enabled,))
elif missing_stream:
self.q.execute(
"SELECT id, name, name_en, length, type, has_source, is_nsfw, enabled, delayed FROM Shows show\
"SELECT id, name, name_en, length, type AS show_type, has_source, is_nsfw, enabled, delayed FROM Shows show\
WHERE (SELECT count(*) FROM Streams stream, Services service \
WHERE stream.show = show.id \
AND stream.active = 1 \
Expand All @@ -467,14 +472,14 @@ def get_shows(self, missing_length=False, missing_stream=False, enabled=True, de
(enabled,))
elif delayed:
self.q.execute(
"SELECT id, name, name_en, length, type, has_source, is_nsfw, enabled, delayed FROM Shows \
"SELECT id, name, name_en, length, type AS show_type, has_source, is_nsfw, enabled, delayed FROM Shows \
WHERE delayed = 1 AND enabled = ?", (enabled,))
else:
self.q.execute(
"SELECT id, name, name_en, length, type, has_source, is_nsfw, enabled, delayed FROM Shows \
"SELECT id, name, name_en, length, type AS show_type, has_source, is_nsfw, enabled, delayed FROM Shows \
WHERE enabled = ?", (enabled,))
for show in self.q.fetchall():
show = Show(*show)
show = Show(**show)
show.aliases = self.get_aliases(show)
shows.append(show)
return shows
Expand All @@ -492,12 +497,12 @@ def get_show(self, id=None, stream=None) -> Optional[Show]:
error("Show ID not provided to get_show")
return None
self.q.execute(
"SELECT id, name, name_en, length, type, has_source, is_nsfw, enabled, delayed FROM Shows \
"SELECT id, name, name_en, length, type AS show_type, has_source, is_nsfw, enabled, delayed FROM Shows \
WHERE id = ?", (id,))
show = self.q.fetchone()
if show is None:
return None
show = Show(*show)
show = Show(**show)
show.aliases = self.get_aliases(show)
return show

Expand All @@ -506,19 +511,19 @@ def get_show_by_name(self, name) -> Optional[Show]:
#debug("Getting show from database")

self.q.execute(
"SELECT id, name, name_en, length, type, has_source, is_nsfw, enabled, delayed FROM Shows \
"SELECT id, name, name_en, length, type AS show_type, has_source, is_nsfw, enabled, delayed FROM Shows \
WHERE name = ?", (name,))
show = self.q.fetchone()
if show is None:
return None
show = Show(*show)
show = Show(**show)
show.aliases = self.get_aliases(show)
return show

@db_error_default(list())
def get_aliases(self, show: Show) -> [str]:
def get_aliases(self, show: Show) -> list[str]:
self.q.execute("SELECT alias FROM Aliases where show = ?", (show.id,))
return [s for s, in self.q.fetchall()]
return [s["alias"] for s in self.q.fetchall()]

@db_error_default(None)
def add_show(self, raw_show: UnprocessedShow, commit=True) -> int:
Expand Down Expand Up @@ -556,7 +561,7 @@ def update_show(self, show_id: str, raw_show: UnprocessedShow, commit=True):
is_nsfw = raw_show.is_nsfw

if name_en:
self.q.execute("UPDATE Shows SET name_en = ? WHERE id = ?", (name_en, show_id))
self.q.execute("UPDATE Shows SET name_en = ? WHERE id = ?", (name_en, show_id))
if length != 0:
self.q.execute("UPDATE Shows SET length = ? WHERE id = ?", (length, show_id))
self.q.execute("UPDATE Shows SET type = ?, has_source = ?, is_nsfw = ? WHERE id = ?", (show_type, has_source, is_nsfw, show_id))
Expand Down Expand Up @@ -599,10 +604,10 @@ def stream_has_episode(self, stream: Stream, episode_num) -> bool:

@db_error_default(None)
def get_latest_episode(self, show: Show) -> Optional[Episode]:
self.q.execute("SELECT episode, post_url FROM Episodes WHERE show = ? ORDER BY episode DESC LIMIT 1", (show.id,))
self.q.execute("SELECT episode AS number, post_url AS link FROM Episodes WHERE show = ? ORDER BY episode DESC LIMIT 1", (show.id,))
data = self.q.fetchone()
if data is not None:
return Episode(data[0], None, data[1], None)
return Episode(**data)
return None

@db_error
Expand All @@ -614,9 +619,9 @@ def add_episode(self, show, episode_num, post_url):
@db_error_default(list())
def get_episodes(self, show, ensure_sorted=True) -> List[Episode]:
episodes = list()
self.q.execute("SELECT episode, post_url FROM Episodes WHERE show = ?", (show.id,))
self.q.execute("SELECT episode AS number, post_url AS link FROM Episodes WHERE show = ?", (show.id,))
for data in self.q.fetchall():
episodes.append(Episode(data[0], None, data[1], None))
episodes.append(Episode(**data))

if ensure_sorted:
episodes = sorted(episodes, key=lambda e: e.number)
Expand All @@ -625,23 +630,23 @@ def get_episodes(self, show, ensure_sorted=True) -> List[Episode]:
# Scores
@db_error_default(list())
def get_show_scores(self, show: Show) -> List[EpisodeScore]:
self.q.execute("SELECT episode, site, score FROM Scores WHERE show=?", (show.id,))
return [EpisodeScore(show.id, *s) for s in self.q.fetchall()]
self.q.execute("SELECT episode, site AS site_id, score FROM Scores WHERE show=?", (show.id,))
return [EpisodeScore(show_id=show.id, **s) for s in self.q.fetchall()]

@db_error_default(list())
def get_episode_scores(self, show: Show, episode: Episode) -> List[EpisodeScore]:
self.q.execute("SELECT site, score FROM Scores WHERE show=? AND episode=?", (show.id, episode.number))
return [EpisodeScore(show.id, episode.number, *s) for s in self.q.fetchall()]
self.q.execute("SELECT site AS site_id, score FROM Scores WHERE show=? AND episode=?", (show.id, episode.number))
return [EpisodeScore(show_id=show.id, episode=episode.number, **s) for s in self.q.fetchall()]

@db_error_default(None)
def get_episode_score_avg(self, show: Show, episode: Episode) -> Optional[EpisodeScore]:
debug("Calculating avg score for {} ({})".format(show.name, show.id))
self.q.execute("SELECT score FROM Scores WHERE show=? AND episode=?", (show.id, episode.number))
scores = [s[0] for s in self.q.fetchall()]
scores = [s["score"] for s in self.q.fetchall()]
if len(scores) > 0:
score = sum(scores)/len(scores)
debug(" Score: {} (from {} scores)".format(score, len(scores)))
return EpisodeScore(show.id, episode.number, None, score)
return EpisodeScore(show_id=show.id, episode=episode.number, score=score)
return None

@db_error
Expand All @@ -664,7 +669,7 @@ def get_poll_site(self, id:str=None, key:str=None) -> Optional[PollSite]:
site = self.q.fetchone()
if site is None:
return None
return PollSite(*site)
return PollSite(**site)

@db_error
def add_poll(self, show: Show, episode: Episode, site: PollSite, poll_id, commit=True):
Expand All @@ -681,24 +686,24 @@ def update_poll_score(self, poll: Poll, score, commit=True):

@db_error_default(None)
def get_poll(self, show: Show, episode: Episode):
self.q.execute("SELECT show, episode, poll_service, poll_id, timestamp, score FROM Polls WHERE show = ? AND episode = ?", (show.id, episode.number))
self.q.execute("SELECT show AS show_id, episode, poll_service AS service, poll_id AS id, timestamp AS date, score FROM Polls WHERE show = ? AND episode = ?", (show.id, episode.number))
poll = self.q.fetchone()
if poll is None:
return None
return Poll(*poll)
return Poll(**poll)

@db_error_default(list())
def get_polls(self, show: Show=None, missing_score=False):
polls = list()
if show is not None:
self.q.execute("SELECT show, episode, poll_service, poll_id, timestamp, score FROM Polls WHERE show = ?", (show.id,))
self.q.execute("SELECT show AS show_id, episode, poll_service AS service, poll_id AS id, timestamp AS date, score FROM Polls WHERE show = ?", (show.id,))
elif missing_score:
self.q.execute("SELECT show, episode, poll_service, poll_id, timestamp, score FROM Polls WHERE score is NULL AND show IN (SELECT id FROM Shows where enabled = 1)")
self.q.execute("SELECT show AS show_id, episode, poll_service AS service, poll_id AS id, timestamp AS date, score FROM Polls WHERE score is NULL AND show IN (SELECT id FROM Shows where enabled = 1)")
else:
error("Need to select a show to get polls")
return list()
for poll in self.q.fetchall():
polls.append(Poll(*poll))
polls.append(Poll(**poll))
return polls

# Searching
Expand All @@ -713,8 +718,8 @@ def search_show_ids_by_names(self, *names, exact=False) -> Set[Show]:
self.q.execute("SELECT show, name FROM ShowNames WHERE name = ? COLLATE alphanum", (name,))
matched = self.q.fetchall()
for match in matched:
debug(" Found match: {} | {}".format(match[0], match[1]))
shows.add(match[0])
debug(" Found match: {} | {}".format(match['show'], match['name']))
shows.add(match['show'])
return shows

# Helper methods
Expand Down
Loading

0 comments on commit b08843b

Please sign in to comment.