diff --git a/src/mopidy_spotify/playlists.py b/src/mopidy_spotify/playlists.py index be4545d6..9168f17c 100644 --- a/src/mopidy_spotify/playlists.py +++ b/src/mopidy_spotify/playlists.py @@ -13,22 +13,17 @@ class SpotifyPlaylistsProvider(backend.PlaylistsProvider): def __init__(self, backend): self._backend = backend self._timeout = self._backend._config["spotify"]["timeout"] - self._loaded = False - - self._refreshing = False + self._refresh_mutex = threading.Lock() def as_list(self): with utils.time_logger("playlists.as_list()", logging.DEBUG): - if not self._loaded: - return [] - return list(self._get_flattened_playlist_refs()) - def _get_flattened_playlist_refs(self): + def _get_flattened_playlist_refs(self, *, refresh=False): if not self._backend._web_client.logged_in: return [] - user_playlists = self._backend._web_client.get_user_playlists() + user_playlists = self._backend._web_client.get_user_playlists(refresh=refresh) return translator.to_playlist_refs( user_playlists, self._backend._web_client.user_id ) @@ -50,33 +45,38 @@ def _get_playlist(self, uri, *, as_items=False): ) def refresh(self): - if not self._backend._web_client.logged_in or self._refreshing: + if not self._backend._web_client.logged_in: return - - self._refreshing = True - - logger.info("Refreshing Spotify playlists") - - def refresher(): - try: - with utils.time_logger("playlists.refresh()", logging.DEBUG): - self._backend._web_client.clear_cache() - count = 0 - for playlist_ref in self._get_flattened_playlist_refs(): - self._get_playlist(playlist_ref.uri) - count += 1 - logger.info(f"Refreshed {count} Spotify playlists") - - listener.CoreListener.send("playlists_loaded") - self._loaded = True - except Exception: - logger.exception("An error occurred while refreshing Spotify playlists") - finally: - self._refreshing = False - - thread = threading.Thread(target=refresher) - thread.daemon = True - thread.start() + if not self._refresh_mutex.acquire(blocking=False): + logger.info("Refreshing Spotify playlists already in progress") + return + try: + uris = [ref.uri for ref in self._get_flattened_playlist_refs(refresh=True)] + logger.info(f"Refreshing {len(uris)} Spotify playlists in background") + threading.Thread( + target=self._refresh_tracks, + args=(uris,), + daemon=True, + ).start() + except Exception: + logger.exception("Error occurred while refreshing Spotify playlists") + + def _refresh_tracks(self, playlist_uris): + if not self._refresh_mutex.locked(): + logger.error("Lock must be held before calling this method") + return [] + try: + with utils.time_logger("playlists._refresh_tracks()", logging.DEBUG): + refreshed = [uri for uri in playlist_uris if self.lookup(uri)] + logger.info(f"Refreshed {len(refreshed)} Spotify playlists") + + listener.CoreListener.send("playlists_loaded") + except Exception: + logger.exception("Error occurred while refreshing Spotify playlists tracks") + else: + return refreshed # For test + finally: + self._refresh_mutex.release() def create(self, name): pass # TODO diff --git a/src/mopidy_spotify/web.py b/src/mopidy_spotify/web.py index 7c93faa0..dbdc1c8d 100644 --- a/src/mopidy_spotify/web.py +++ b/src/mopidy_spotify/web.py @@ -2,6 +2,7 @@ import logging import os import re +import threading import time import urllib.parse from dataclasses import dataclass @@ -9,7 +10,6 @@ from email.utils import parsedate_to_datetime from enum import Enum, unique from http import HTTPStatus -from threading import Lock import requests @@ -65,8 +65,9 @@ def __init__( # noqa: PLR0913 self._headers = {"Content-Type": "application/json"} self._session = utils.get_requests_session(proxy_config or {}) - self._cache_mutex = Lock() - self._refresh_mutex = Lock() + # TODO: Move _cache_mutex to the object it actually protects. + self._cache_mutex = threading.Lock() # Protects get() cache param. + self._refresh_mutex = threading.Lock() # Protects _headers and _expires. def get(self, path, cache=None, *args, **kwargs): if self._authorization_failed: @@ -78,10 +79,10 @@ def get(self, path, cache=None, *args, **kwargs): _trace(f"Get '{path}'") - ignore_expiry = kwargs.pop("ignore_expiry", False) + expiry_strategy = kwargs.pop("expiry_strategy", None) if cache is not None and path in cache: cached_result = cache.get(path) - if cached_result.still_valid(ignore_expiry=ignore_expiry): + if cached_result.still_valid(expiry_strategy=expiry_strategy): return cached_result kwargs.setdefault("headers", {}).update(cached_result.etag_headers) @@ -120,11 +121,16 @@ def _should_cache_response(self, cache, response): def _should_refresh_token(self): # TODO: Add jitter to margin? + if not self._refresh_mutex.locked(): + raise OAuthTokenRefreshError("Lock must be held before calling.") return not self._auth or time.time() > self._expires - self._margin def _refresh_token(self): logger.debug(f"Fetching OAuth token from {self._refresh_url}") + if not self._refresh_mutex.locked(): + raise OAuthTokenRefreshError("Lock must be held before calling.") + data = {"grant_type": "client_credentials"} result = self._request_with_retries( "POST", self._refresh_url, auth=self._auth, data=data @@ -264,6 +270,12 @@ def _parse_retry_after(self, response): return max(0, seconds) +@unique +class ExpiryStrategy(Enum): + FORCE_FRESH = "force-fresh" + FORCE_EXPIRED = "force-expired" + + class WebResponse(dict): def __init__( # noqa: PLR0913 self, @@ -335,19 +347,20 @@ def _parse_etag(response): return None - def still_valid(self, *, ignore_expiry=False): - if ignore_expiry: - result = True - status = "forced" - elif self._expires >= time.time(): - result = True - status = "fresh" + def still_valid(self, *, expiry_strategy=None): + if expiry_strategy is None: + if self._expires >= time.time(): + valid = True + status = "fresh" + else: + valid = False + status = "expired" else: - result = False - status = "expired" - self._from_cache = result + valid = expiry_strategy is ExpiryStrategy.FORCE_FRESH + status = expiry_strategy.value + self._from_cache = valid _trace("Cached data %s for %s", status, self) - return result + return valid @property def status_unchanged(self): @@ -439,8 +452,13 @@ def login(self): def logged_in(self): return self.user_id is not None - def get_user_playlists(self): - pages = self.get_all(f"users/{self.user_id}/playlists", params={"limit": 50}) + def get_user_playlists(self, *, refresh=False): + expiry_strategy = ExpiryStrategy.FORCE_EXPIRED if refresh else None + pages = self.get_all( + f"users/{self.user_id}/playlists", + params={"limit": 50}, + expiry_strategy=expiry_strategy, + ) for page in pages: yield from page.get("items", []) @@ -451,7 +469,9 @@ def _with_all_tracks(self, obj, params=None): track_pages = self.get_all( tracks_path, params=params, - ignore_expiry=obj.status_unchanged, + expiry_strategy=( + ExpiryStrategy.FORCE_FRESH if obj.status_unchanged else None + ), ) more_tracks = [] @@ -532,12 +552,6 @@ def get_track(self, web_link): return self.get_one(f"tracks/{web_link.id}", params={"market": "from_token"}) - def clear_cache( - self, - extra_expiry=None, # noqa: ARG002 - ): - self._cache.clear() - @unique class LinkType(Enum): diff --git a/tests/__init__.py b/tests/__init__.py index 44b6d90b..30953daa 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -2,7 +2,7 @@ class ThreadJoiner: - def __init__(self, timeout: int): + def __init__(self, timeout: int = 1): self.timeout = timeout def __enter__(self): diff --git a/tests/test_backend.py b/tests/test_backend.py index 262e65db..e0321e41 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -60,7 +60,7 @@ def test_on_start_configures_proxy(web_mock, config): "password": "s3cret", } backend = get_backend(config) - with ThreadJoiner(timeout=1.0): + with ThreadJoiner(): backend.on_start() assert True @@ -77,7 +77,7 @@ def test_on_start_configures_web_client(web_mock, config): config["spotify"]["client_secret"] = "AbCdEfG" backend = get_backend(config) - with ThreadJoiner(timeout=1.0): + with ThreadJoiner(): backend.on_start() web_mock.SpotifyOAuthClient.assert_called_once_with( @@ -96,13 +96,13 @@ def test_on_start_logs_in(web_mock, config): def test_on_start_refreshes_playlists(web_mock, config, caplog): backend = get_backend(config) - with ThreadJoiner(timeout=1.0): + with ThreadJoiner(): backend.on_start() client_mock = web_mock.SpotifyOAuthClient.return_value - client_mock.get_user_playlists.assert_called_once() + client_mock.get_user_playlists.assert_called_once_with(refresh=True) + assert "Refreshing 0 Spotify playlists in background" in caplog.text assert "Refreshed 0 Spotify playlists" in caplog.text - assert backend.playlists._loaded def test_on_start_doesnt_refresh_playlists_if_not_allowed(web_mock, config, caplog): diff --git a/tests/test_playlists.py b/tests/test_playlists.py index 7e0a8ee8..31c62e49 100644 --- a/tests/test_playlists.py +++ b/tests/test_playlists.py @@ -1,3 +1,4 @@ +import logging from unittest import mock import pytest @@ -44,9 +45,7 @@ def get_playlist(*args, **kwargs): @pytest.fixture() def provider(backend_mock, web_client_mock): backend_mock._web_client = web_client_mock - provider = playlists.SpotifyPlaylistsProvider(backend_mock) - provider._loaded = True - return provider + return playlists.SpotifyPlaylistsProvider(backend_mock) def test_is_a_playlists_provider(provider): @@ -69,14 +68,6 @@ def test_as_list_when_offline(web_client_mock, provider): assert len(result) == 0 -def test_as_list_when_not_loaded(provider): - provider._loaded = False - - result = provider.as_list() - - assert len(result) == 0 - - def test_as_list_when_playlist_wont_translate(provider): result = provider.as_list() @@ -119,15 +110,6 @@ def test_get_items_when_offline(web_client_mock, provider, caplog): ) -def test_get_items_when_not_loaded(provider): - provider._loaded = False - - result = provider.get_items("spotify:user:alice:playlist:foo") - - assert len(result) == 1 - assert result[0] == Ref.track(uri="spotify:track:abc", name="ABC 123") - - def test_get_items_when_playlist_wont_translate(provider): assert provider.get_items("spotify:user:alice:playlist:malformed") is None @@ -141,7 +123,7 @@ def test_get_items_when_playlist_is_unknown(provider, caplog): def test_refresh_loads_all_playlists(provider, web_client_mock): - with ThreadJoiner(timeout=1.0): + with ThreadJoiner(): provider.refresh() web_client_mock.get_user_playlists.assert_called_once() @@ -154,40 +136,69 @@ def test_refresh_loads_all_playlists(provider, web_client_mock): def test_refresh_when_not_logged_in(provider, web_client_mock): - provider._loaded = False web_client_mock.logged_in = False - with ThreadJoiner(timeout=1.0): + with ThreadJoiner(): provider.refresh() web_client_mock.get_user_playlists.assert_not_called() web_client_mock.get_playlist.assert_not_called() - assert not provider._loaded -def test_refresh_sets_loaded(provider, web_client_mock): - provider._loaded = False +def test_refresh_in_progress(provider, web_client_mock, caplog): + assert provider._refresh_mutex.acquire(blocking=False) - with ThreadJoiner(timeout=1.0): + with ThreadJoiner(): provider.refresh() - web_client_mock.get_user_playlists.assert_called_once() - web_client_mock.get_playlist.assert_called() - assert provider._loaded + web_client_mock.get_user_playlists.assert_not_called() + web_client_mock.get_playlist.assert_not_called() + assert provider._refresh_mutex.locked() + assert "Refreshing Spotify playlists already in progress" in caplog.text -def test_refresh_counts_playlists(provider, caplog): - with ThreadJoiner(timeout=1.0): +def test_refresh_counts_valid_playlists(provider, caplog): + caplog.set_level(logging.INFO) # To avoid log corruption from debug logging. + with ThreadJoiner(): provider.refresh() + assert "Refreshing 2 Spotify playlists in background" in caplog.text assert "Refreshed 2 Spotify playlists" in caplog.text -def test_refresh_clears_caches(provider, web_client_mock): - with ThreadJoiner(timeout=1.0): +@mock.patch("mopidy.core.listener.CoreListener.send") +def test_refresh_triggers_playlists_loaded_event(send, provider): + with ThreadJoiner(): + provider.refresh() + + send.assert_called_once_with("playlists_loaded") + + +def test_refresh_with_refresh_true_arg(provider, web_client_mock): + with ThreadJoiner(): provider.refresh() - web_client_mock.clear_cache.assert_called_once() + web_client_mock.get_user_playlists.assert_called_once_with(refresh=True) + + +def test_refresh_tracks_needs_lock(provider, web_client_mock, caplog): + assert provider._refresh_tracks("foo") == [] + + assert "Lock must be held before calling this method" in caplog.text + web_client_mock.get_playlist.assert_not_called() + + +def test_refresh_tracks(provider, web_client_mock, caplog): + uris = ["spotify:user:alice:playlist:foo", "spotify:user:bob:playlist:baz"] + + assert provider._refresh_mutex.acquire(blocking=False) + assert provider._refresh_tracks(uris) == uris + + expected_calls = [ + mock.call("spotify:user:alice:playlist:foo"), + mock.call("spotify:user:bob:playlist:baz"), + ] + web_client_mock.get_playlist.assert_has_calls(expected_calls) def test_lookup(provider): @@ -206,15 +217,6 @@ def test_lookup_when_not_logged_in(web_client_mock, provider): assert playlist is None -def test_lookup_when_not_loaded(provider): - provider._loaded = False - - playlist = provider.lookup("spotify:user:alice:playlist:foo") - - assert playlist.uri == "spotify:user:alice:playlist:foo" - assert playlist.name == "Foo" - - def test_lookup_when_playlist_is_empty(provider, caplog): assert provider.lookup("nothing") is None assert "Failed to lookup Spotify playlist URI 'nothing'" in caplog.text diff --git a/tests/test_web.py b/tests/test_web.py index 6839dc0a..a151835d 100644 --- a/tests/test_web.py +++ b/tests/test_web.py @@ -49,20 +49,33 @@ def skip_refresh_token(): patcher.stop() +def test_should_refresh_token_requires_lock(oauth_client): + with pytest.raises(web.OAuthTokenRefreshError): + oauth_client._should_refresh_token() + + +def test_refresh_token_requires_lock(oauth_client): + with pytest.raises(web.OAuthTokenRefreshError): + oauth_client._refresh_token() + + def test_initial_refresh_token(oauth_client): - assert oauth_client._should_refresh_token() + with oauth_client._refresh_mutex: + assert oauth_client._should_refresh_token() def test_expired_refresh_token(oauth_client, mock_time): oauth_client._expires = 1060 mock_time.return_value = 1001 - assert oauth_client._should_refresh_token() + with oauth_client._refresh_mutex: + assert oauth_client._should_refresh_token() def test_still_valid_refresh_token(oauth_client, mock_time): oauth_client._expires = 1060 mock_time.return_value = 1000 - assert not oauth_client._should_refresh_token() + with oauth_client._refresh_mutex: + assert not oauth_client._should_refresh_token() def test_user_agent(oauth_client): @@ -347,7 +360,7 @@ def test_web_response_status_unchanged_from_cache(): assert not response.status_unchanged - response.still_valid(ignore_expiry=True) + response.still_valid(expiry_strategy=web.ExpiryStrategy.FORCE_FRESH) assert response.status_unchanged @@ -499,8 +512,20 @@ def test_cache_response_expired( assert result["uri"] == "new" +def test_cache_response_still_valid_strategy(mock_time): + response = web.WebResponse("foo", {}, expires=9999 + 1) + mock_time.return_value = 9999 + + assert response.still_valid() is True + assert response.still_valid(expiry_strategy=None) is True + assert response.still_valid(expiry_strategy=web.ExpiryStrategy.FORCE_FRESH) is True + assert ( + response.still_valid(expiry_strategy=web.ExpiryStrategy.FORCE_EXPIRED) is False + ) + + @responses.activate -def test_cache_response_ignore_expiry( +def test_cache_response_force_fresh( web_response_mock, skip_refresh_token, oauth_client, mock_time, caplog ): cache = {"https://api.spotify.com/v1/tracks/abc": web_response_mock} @@ -512,11 +537,15 @@ def test_cache_response_ignore_expiry( mock_time.return_value = 9999 assert not web_response_mock.still_valid() - assert web_response_mock.still_valid(ignore_expiry=True) - assert "Cached data forced for" in caplog.text + assert "Cached data expired for" in caplog.text + + assert web_response_mock.still_valid(expiry_strategy=web.ExpiryStrategy.FORCE_FRESH) + assert "Cached data force-fresh for" in caplog.text result = oauth_client.get( - "https://api.spotify.com/v1/tracks/abc", cache, ignore_expiry=True + "https://api.spotify.com/v1/tracks/abc", + cache, + expiry_strategy=web.ExpiryStrategy.FORCE_FRESH, ) assert len(responses.calls) == 0 assert result["uri"] == "spotify:track:abc" @@ -928,6 +957,25 @@ def test_get_user_playlists_empty(self, spotify_client): assert len(responses.calls) == 1 assert len(result) == 0 + @pytest.mark.parametrize( + ("refresh", "strategy"), + [ + (True, web.ExpiryStrategy.FORCE_EXPIRED), + (False, None), + ], + ) + def test_get_user_playlists_get_all(self, spotify_client, refresh, strategy): + spotify_client.get_all = mock.Mock(return_value=[]) + + result = list(spotify_client.get_user_playlists(refresh=refresh)) + + spotify_client.get_all.assert_called_once_with( + "users/alice/playlists", + params={"limit": 50}, + expiry_strategy=strategy, + ) + assert len(result) == 0 + @responses.activate def test_get_user_playlists_sets_params(self, spotify_client): responses.add(responses.GET, url("users/alice/playlists"), json={}) @@ -1131,13 +1179,6 @@ def test_get_playlist_error_msg(self, spotify_client, caplog, uri, msg): assert spotify_client.get_playlist(uri) == {} assert f"Could not parse {uri!r} as a {msg} URI" in caplog.text - def test_clear_cache(self, spotify_client): - spotify_client._cache = {"foo": "bar"} - - spotify_client.clear_cache() - - assert {} == spotify_client._cache - @pytest.mark.parametrize(("user_id", "expected"), [("alice", True), (None, False)]) def test_logged_in(self, spotify_client, user_id, expected): spotify_client.user_id = user_id