|
11 | 11 |
|
12 | 12 | import requests
|
13 | 13 |
|
| 14 | +_LASTMOD_FMT = "%a, %d %b %Y %H:%M:%S %Z" |
| 15 | + |
| 16 | + |
| 17 | +def _base_cache_dir() -> str | None: |
| 18 | + sysname = platform.system() |
| 19 | + |
| 20 | + # on windows, try to get the appdata env var |
| 21 | + # this *could* result in cache_dir=None, which is fine, just skip caching in |
| 22 | + # that case |
| 23 | + if sysname == "Windows": |
| 24 | + cache_dir = os.getenv("LOCALAPPDATA", os.getenv("APPDATA")) |
| 25 | + # macOS -> app support dir |
| 26 | + elif sysname == "Darwin": |
| 27 | + cache_dir = os.path.expanduser("~/Library/Caches") |
| 28 | + # default for unknown platforms, namely linux behavior |
| 29 | + # use XDG env var and default to ~/.cache/ |
| 30 | + else: |
| 31 | + cache_dir = os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) |
| 32 | + |
| 33 | + return cache_dir |
| 34 | + |
| 35 | + |
| 36 | +def _resolve_cache_dir(dirname: str = "downloads") -> str | None: |
| 37 | + cache_dir = _base_cache_dir() |
| 38 | + if cache_dir: |
| 39 | + cache_dir = os.path.join(cache_dir, "check_jsonschema", dirname) |
| 40 | + return cache_dir |
| 41 | + |
| 42 | + |
| 43 | +def _lastmod_from_response(response: requests.Response) -> float: |
| 44 | + try: |
| 45 | + return time.mktime( |
| 46 | + time.strptime(response.headers["last-modified"], _LASTMOD_FMT) |
| 47 | + ) |
| 48 | + # OverflowError: time outside of platform-specific bounds |
| 49 | + # ValueError: malformed/unparseable |
| 50 | + # LookupError: no such header |
| 51 | + except (OverflowError, ValueError, LookupError): |
| 52 | + return 0.0 |
| 53 | + |
| 54 | + |
| 55 | +def _get_request( |
| 56 | + file_url: str, *, response_ok: t.Callable[[requests.Response], bool] |
| 57 | +) -> requests.Response: |
| 58 | + num_retries = 2 |
| 59 | + r: requests.Response | None = None |
| 60 | + for _attempt in range(num_retries + 1): |
| 61 | + try: |
| 62 | + r = requests.get(file_url, stream=True) |
| 63 | + except requests.RequestException as e: |
| 64 | + if _attempt == num_retries: |
| 65 | + raise FailedDownloadError("encountered error during download") from e |
| 66 | + continue |
| 67 | + if r.ok and response_ok(r): |
| 68 | + return r |
| 69 | + assert r is not None |
| 70 | + raise FailedDownloadError( |
| 71 | + f"got response with status={r.status_code}, retries exhausted" |
| 72 | + ) |
| 73 | + |
| 74 | + |
| 75 | +def _atomic_write(dest: str, content: bytes) -> None: |
| 76 | + # download to a temp file and then move to the dest |
| 77 | + # this makes the download safe if run in parallel (parallel runs |
| 78 | + # won't create a new empty file for writing and cause failures) |
| 79 | + fp = tempfile.NamedTemporaryFile(mode="wb", delete=False) |
| 80 | + fp.write(content) |
| 81 | + fp.close() |
| 82 | + shutil.copy(fp.name, dest) |
| 83 | + os.remove(fp.name) |
| 84 | + |
| 85 | + |
| 86 | +def _cache_hit(cachefile: str, response: requests.Response) -> bool: |
| 87 | + # no file? miss |
| 88 | + if not os.path.exists(cachefile): |
| 89 | + return False |
| 90 | + |
| 91 | + # compare mtime on any cached file against the remote last-modified time |
| 92 | + # it is considered a hit if the local file is at least as new as the remote file |
| 93 | + local_mtime = os.path.getmtime(cachefile) |
| 94 | + remote_mtime = _lastmod_from_response(response) |
| 95 | + return local_mtime >= remote_mtime |
| 96 | + |
14 | 97 |
|
15 | 98 | class FailedDownloadError(Exception):
|
16 | 99 | pass
|
17 | 100 |
|
18 | 101 |
|
19 | 102 | class CacheDownloader:
|
20 |
| - _LASTMOD_FMT = "%a, %d %b %Y %H:%M:%S %Z" |
21 |
| - |
22 |
| - # changed in v0.5.0 |
23 |
| - # original cache dir was "jsonschema_validate" |
24 |
| - # this will let us do any other caching we might need in the future in the same |
25 |
| - # cache dir (adjacent to "downloads") |
26 |
| - _CACHEDIR_NAME = os.path.join("check_jsonschema", "downloads") |
| 103 | + def __init__(self, cache_dir: str | None = None, disable_cache: bool = False): |
| 104 | + if cache_dir is None: |
| 105 | + self._cache_dir = _resolve_cache_dir() |
| 106 | + else: |
| 107 | + self._cache_dir = _resolve_cache_dir(cache_dir) |
| 108 | + self._disable_cache = disable_cache |
27 | 109 |
|
28 |
| - def __init__( |
| 110 | + def _download( |
29 | 111 | self,
|
30 | 112 | file_url: str,
|
31 |
| - filename: str | None = None, |
32 |
| - cache_dir: str | None = None, |
33 |
| - disable_cache: bool = False, |
34 |
| - validation_callback: t.Callable[[bytes], t.Any] | None = None, |
35 |
| - ): |
36 |
| - self._file_url = file_url |
37 |
| - self._filename = filename or file_url.split("/")[-1] |
38 |
| - self._cache_dir = cache_dir or self._compute_default_cache_dir() |
39 |
| - self._disable_cache = disable_cache |
40 |
| - self._validation_callback = validation_callback |
41 |
| - |
42 |
| - def _compute_default_cache_dir(self) -> str | None: |
43 |
| - sysname = platform.system() |
44 |
| - |
45 |
| - # on windows, try to get the appdata env var |
46 |
| - # this *could* result in cache_dir=None, which is fine, just skip caching in |
47 |
| - # that case |
48 |
| - if sysname == "Windows": |
49 |
| - cache_dir = os.getenv("LOCALAPPDATA", os.getenv("APPDATA")) |
50 |
| - # macOS -> app support dir |
51 |
| - elif sysname == "Darwin": |
52 |
| - cache_dir = os.path.expanduser("~/Library/Caches") |
53 |
| - # default for unknown platforms, namely linux behavior |
54 |
| - # use XDG env var and default to ~/.cache/ |
55 |
| - else: |
56 |
| - cache_dir = os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) |
57 |
| - |
58 |
| - if cache_dir: |
59 |
| - cache_dir = os.path.join(cache_dir, self._CACHEDIR_NAME) |
60 |
| - |
61 |
| - return cache_dir |
62 |
| - |
63 |
| - def _get_request( |
64 |
| - self, *, response_ok: t.Callable[[requests.Response], bool] |
65 |
| - ) -> requests.Response: |
66 |
| - try: |
67 |
| - r: requests.Response | None = None |
68 |
| - for _attempt in range(3): |
69 |
| - r = requests.get(self._file_url, stream=True) |
70 |
| - if r.ok and response_ok(r): |
71 |
| - return r |
72 |
| - assert r is not None |
73 |
| - raise FailedDownloadError( |
74 |
| - f"got response with status={r.status_code}, retries exhausted" |
75 |
| - ) |
76 |
| - except requests.RequestException as e: |
77 |
| - raise FailedDownloadError("encountered error during download") from e |
78 |
| - |
79 |
| - def _lastmod_from_response(self, response: requests.Response) -> float: |
80 |
| - try: |
81 |
| - return time.mktime( |
82 |
| - time.strptime(response.headers["last-modified"], self._LASTMOD_FMT) |
83 |
| - ) |
84 |
| - # OverflowError: time outside of platform-specific bounds |
85 |
| - # ValueError: malformed/unparseable |
86 |
| - # LookupError: no such header |
87 |
| - except (OverflowError, ValueError, LookupError): |
88 |
| - return 0.0 |
89 |
| - |
90 |
| - def _cache_hit(self, cachefile: str, response: requests.Response) -> bool: |
91 |
| - # no file? miss |
92 |
| - if not os.path.exists(cachefile): |
93 |
| - return False |
94 |
| - |
95 |
| - # compare mtime on any cached file against the remote last-modified time |
96 |
| - # it is considered a hit if the local file is at least as new as the remote file |
97 |
| - local_mtime = os.path.getmtime(cachefile) |
98 |
| - remote_mtime = self._lastmod_from_response(response) |
99 |
| - return local_mtime >= remote_mtime |
100 |
| - |
101 |
| - def _write(self, dest: str, response: requests.Response) -> None: |
102 |
| - # download to a temp file and then move to the dest |
103 |
| - # this makes the download safe if run in parallel (parallel runs |
104 |
| - # won't create a new empty file for writing and cause failures) |
105 |
| - fp = tempfile.NamedTemporaryFile(mode="wb", delete=False) |
106 |
| - fp.write(response.content) |
107 |
| - fp.close() |
108 |
| - shutil.copy(fp.name, dest) |
109 |
| - os.remove(fp.name) |
110 |
| - |
111 |
| - def _validate(self, response: requests.Response) -> bool: |
112 |
| - if not self._validation_callback: |
113 |
| - return True |
114 |
| - |
115 |
| - try: |
116 |
| - self._validation_callback(response.content) |
117 |
| - return True |
118 |
| - except ValueError: |
119 |
| - return False |
120 |
| - |
121 |
| - def _download(self) -> str: |
122 |
| - assert self._cache_dir |
| 113 | + filename: str, |
| 114 | + response_ok: t.Callable[[requests.Response], bool], |
| 115 | + ) -> str: |
| 116 | + assert self._cache_dir is not None |
123 | 117 | os.makedirs(self._cache_dir, exist_ok=True)
|
124 |
| - dest = os.path.join(self._cache_dir, self._filename) |
| 118 | + dest = os.path.join(self._cache_dir, filename) |
125 | 119 |
|
126 | 120 | def check_response_for_download(r: requests.Response) -> bool:
|
127 | 121 | # if the response indicates a cache hit, treat it as valid
|
128 | 122 | # this ensures that we short-circuit any further evaluation immediately on
|
129 | 123 | # a hit
|
130 |
| - if self._cache_hit(dest, r): |
| 124 | + if _cache_hit(dest, r): |
131 | 125 | return True
|
132 | 126 | # we now know it's not a hit, so validate the content (forces download)
|
133 |
| - return self._validate(r) |
| 127 | + return response_ok(r) |
134 | 128 |
|
135 |
| - response = self._get_request(response_ok=check_response_for_download) |
| 129 | + response = _get_request(file_url, response_ok=check_response_for_download) |
136 | 130 | # check to see if we have a file which matches the connection
|
137 | 131 | # only download if we do not (cache miss, vs hit)
|
138 |
| - if not self._cache_hit(dest, response): |
139 |
| - self._write(dest, response) |
| 132 | + if not _cache_hit(dest, response): |
| 133 | + _atomic_write(dest, response.content) |
140 | 134 |
|
141 | 135 | return dest
|
142 | 136 |
|
143 | 137 | @contextlib.contextmanager
|
144 |
| - def open(self) -> t.Iterator[t.IO[bytes]]: |
| 138 | + def open( |
| 139 | + self, |
| 140 | + file_url: str, |
| 141 | + filename: str, |
| 142 | + validate_response: t.Callable[[requests.Response], bool], |
| 143 | + ) -> t.Iterator[t.IO[bytes]]: |
145 | 144 | if (not self._cache_dir) or self._disable_cache:
|
146 |
| - yield io.BytesIO(self._get_request(response_ok=self._validate).content) |
| 145 | + yield io.BytesIO( |
| 146 | + _get_request(file_url, response_ok=validate_response).content |
| 147 | + ) |
147 | 148 | else:
|
148 |
| - with open(self._download(), "rb") as fp: |
| 149 | + with open( |
| 150 | + self._download(file_url, filename, response_ok=validate_response), "rb" |
| 151 | + ) as fp: |
149 | 152 | yield fp
|
| 153 | + |
| 154 | + def bind( |
| 155 | + self, |
| 156 | + file_url: str, |
| 157 | + filename: str | None = None, |
| 158 | + validation_callback: t.Callable[[bytes], t.Any] | None = None, |
| 159 | + ) -> BoundCacheDownloader: |
| 160 | + return BoundCacheDownloader( |
| 161 | + file_url, filename, self, validation_callback=validation_callback |
| 162 | + ) |
| 163 | + |
| 164 | + |
| 165 | +class BoundCacheDownloader: |
| 166 | + def __init__( |
| 167 | + self, |
| 168 | + file_url: str, |
| 169 | + filename: str | None, |
| 170 | + downloader: CacheDownloader, |
| 171 | + *, |
| 172 | + validation_callback: t.Callable[[bytes], t.Any] | None = None, |
| 173 | + ): |
| 174 | + self._file_url = file_url |
| 175 | + self._filename = filename or file_url.split("/")[-1] |
| 176 | + self._downloader = downloader |
| 177 | + self._validation_callback = validation_callback |
| 178 | + |
| 179 | + @contextlib.contextmanager |
| 180 | + def open(self) -> t.Iterator[t.IO[bytes]]: |
| 181 | + with self._downloader.open( |
| 182 | + self._file_url, |
| 183 | + self._filename, |
| 184 | + validate_response=self._validate_response, |
| 185 | + ) as fp: |
| 186 | + yield fp |
| 187 | + |
| 188 | + def _validate_response(self, response: requests.Response) -> bool: |
| 189 | + if not self._validation_callback: |
| 190 | + return True |
| 191 | + |
| 192 | + try: |
| 193 | + self._validation_callback(response.content) |
| 194 | + return True |
| 195 | + except ValueError: |
| 196 | + return False |
0 commit comments