Skip to content

Commit 4dbb984

Browse files
committed
Call cat_ranges in blockcache for async filesystems
1 parent 961412d commit 4dbb984

File tree

3 files changed

+91
-2
lines changed

3 files changed

+91
-2
lines changed

fsspec/caching.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
logger = logging.getLogger("fsspec")
3838

3939
Fetcher = Callable[[int, int], bytes] # Maps (start, end) to bytes
40+
MultiFetcher = Callable[list[[int, int]], bytes] # Maps [(start, end)] to bytes
4041

4142

4243
class BaseCache:
@@ -109,6 +110,26 @@ class MMapCache(BaseCache):
109110
Ensure there is enough disc space in the temporary location.
110111
111112
This cache method might only work on posix
113+
114+
Parameters
115+
----------
116+
blocksize: int
117+
How far to read ahead in numbers of bytes
118+
fetcher: Fetcher
119+
Function of the form f(start, end) which gets bytes from remote as
120+
specified
121+
size: int
122+
How big this file is
123+
location: str
124+
Where to create the temporary file. If None, a temporary file is
125+
created using tempfile.TemporaryFile().
126+
blocks: set[int]
127+
Set of block numbers that have already been fetched. If None, an empty
128+
set is created.
129+
multi_fetcher: MultiFetcher
130+
Function of the form f([(start, end)]) which gets bytes from remote
131+
as specified. This function is used to fetch multiple blocks at once.
132+
If not specified, the fetcher function is used instead.
112133
"""
113134

114135
name = "mmap"
@@ -120,10 +141,12 @@ def __init__(
120141
size: int,
121142
location: str | None = None,
122143
blocks: set[int] | None = None,
144+
multi_fetcher: MultiFetcher | None = None,
123145
) -> None:
124146
super().__init__(blocksize, fetcher, size)
125147
self.blocks = set() if blocks is None else blocks
126148
self.location = location
149+
self.multi_fetcher = multi_fetcher
127150
self.cache = self._makefile()
128151

129152
def _makefile(self) -> mmap.mmap | bytearray:
@@ -164,6 +187,8 @@ def _fetch(self, start: int | None, end: int | None) -> bytes:
164187
# Count the number of blocks already cached
165188
self.hit_count += sum(1 for i in block_range if i in self.blocks)
166189

190+
ranges = []
191+
167192
# Consolidate needed blocks.
168193
# Algorithm adapted from Python 2.x itertools documentation.
169194
# We are grouping an enumerated sequence of blocks. By comparing when the difference
@@ -185,13 +210,27 @@ def _fetch(self, start: int | None, end: int | None) -> bytes:
185210
logger.debug(
186211
f"MMap get blocks {_blocks[0]}-{_blocks[-1]} ({sstart}-{send})"
187212
)
188-
self.cache[sstart:send] = self.fetcher(sstart, send)
213+
ranges.append((sstart, send))
189214

190215
# Update set of cached blocks
191216
self.blocks.update(_blocks)
192217
# Update cache statistics with number of blocks we had to cache
193218
self.miss_count += len(_blocks)
194219

220+
if not ranges:
221+
return self.cache[start:end]
222+
223+
if self.multi_fetcher:
224+
logger.debug(f"MMap get blocks {ranges}")
225+
for idx, r in enumerate(self.multi_fetcher(ranges)):
226+
(sstart, send) = ranges[idx]
227+
logger.debug(f"MMap copy block ({sstart}-{send}")
228+
self.cache[sstart:send] = r
229+
else:
230+
for (sstart, send) in ranges:
231+
logger.debug(f"MMap get block ({sstart}-{send}")
232+
self.cache[sstart:send] = self.fetcher(sstart, send)
233+
195234
return self.cache[start:end]
196235

197236
def __getstate__(self) -> dict[str, Any]:

fsspec/implementations/cached.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,19 @@ def _open(
362362
)
363363
else:
364364
detail["blocksize"] = f.blocksize
365-
f.cache = MMapCache(f.blocksize, f._fetch_range, f.size, fn, blocks)
365+
366+
def _fetch_ranges(ranges):
367+
return self.fs.cat_ranges(
368+
[path] * len(ranges),
369+
[r[0] for r in ranges],
370+
[r[1] for r in ranges],
371+
**kwargs,
372+
)
373+
374+
multi_fetcher = None if self.compression else _fetch_ranges
375+
f.cache = MMapCache(
376+
f.blocksize, f._fetch_range, f.size, fn, blocks, multi_fetcher=multi_fetcher
377+
)
366378
close = f.close
367379
f.close = lambda: self.close_and_update(f, close)
368380
self.save_cache()

fsspec/tests/test_caches.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from fsspec.caching import (
77
BlockCache,
88
FirstChunkCache,
9+
MMapCache,
910
ReadAheadCache,
1011
caches,
1112
register_cache,
@@ -144,9 +145,14 @@ def _fetcher(start, end):
144145

145146

146147
def letters_fetcher(start, end):
148+
print(start, end)
147149
return string.ascii_letters[start:end].encode()
148150

149151

152+
def multi_letters_fetcher(ranges):
153+
return [string.ascii_letters[start:end].encode() for start, end in ranges]
154+
155+
150156
not_parts_caches = {k: v for k, v in caches.items() if k != "parts"}
151157

152158

@@ -174,6 +180,38 @@ def test_cache_pickleable(Cache_imp):
174180
assert unpickled._fetch(0, 10) == b"0" * 10
175181

176182

183+
def test_first_cache():
184+
c = FirstChunkCache(5, letters_fetcher, 52)
185+
assert c.cache is None
186+
assert c._fetch(12, 15) == letters_fetcher(12, 15)
187+
assert c.cache is None
188+
assert c._fetch(3, 10) == letters_fetcher(3, 10)
189+
assert c.cache == letters_fetcher(0, 5)
190+
c.fetcher = None
191+
assert c._fetch(1, 4) == letters_fetcher(1, 4)
192+
193+
194+
def test_mmap_cache(mocker):
195+
fetcher = mocker.Mock(wraps=letters_fetcher)
196+
c = MMapCache(5, fetcher, 52)
197+
assert c._fetch(6, 8) == letters_fetcher(6, 8)
198+
assert fetcher.call_count == 1
199+
assert c._fetch(17, 22) == letters_fetcher(17, 22)
200+
assert fetcher.call_count == 2
201+
assert c._fetch(1, 38) == letters_fetcher(1, 38)
202+
assert fetcher.call_count == 5
203+
204+
multi_fetcher = mocker.Mock(wraps=multi_letters_fetcher)
205+
m = MMapCache(5, fetcher, size=52, multi_fetcher=multi_fetcher)
206+
assert m._fetch(6, 8) == letters_fetcher(6, 8)
207+
assert multi_fetcher.call_count == 1
208+
assert m._fetch(17, 22) == letters_fetcher(17, 22)
209+
assert multi_fetcher.call_count == 2
210+
assert m._fetch(1, 38) == letters_fetcher(1, 38)
211+
assert multi_fetcher.call_count == 3
212+
assert fetcher.call_count == 5
213+
214+
177215
@pytest.mark.parametrize(
178216
"size_requests",
179217
[[(0, 30), (0, 35), (51, 52)], [(0, 1), (1, 11), (1, 52)], [(0, 52), (11, 15)]],

0 commit comments

Comments
 (0)