Skip to content

Commit 46ea642

Browse files
committed
Call cat_ranges in blockcache for async filesystems
1 parent 2fbe8de commit 46ea642

File tree

3 files changed

+90
-6
lines changed

3 files changed

+90
-6
lines changed

Diff for: fsspec/caching.py

+41-4
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,37 @@ class MMapCache(BaseCache):
5050
Ensure there is enough disc space in the temporary location.
5151
5252
This cache method might only work on posix
53+
54+
Parameters
55+
----------
56+
blocksize: int
57+
How far to read ahead in numbers of bytes
58+
fetcher: func
59+
Function of the form f(start, end) which gets bytes from remote as
60+
specified
61+
size: int
62+
How big this file is
63+
location: str
64+
Where to create the temporary file. If None, a temporary file is
65+
created using tempfile.TemporaryFile().
66+
blocks: set
67+
Set of block numbers that have already been fetched. If None, an empty
68+
set is created.
69+
multi_fetcher: func
70+
Function of the form f([(start, end)]) which gets bytes from remote
71+
as specified. This function is used to fetch multiple blocks at once.
72+
If not specified, the fetcher function is used instead.
5373
"""
5474

5575
name = "mmap"
5676

57-
def __init__(self, blocksize, fetcher, size, location=None, blocks=None):
77+
def __init__(
78+
self, blocksize, fetcher, size, location=None, blocks=None, multi_fetcher=None
79+
):
5880
super().__init__(blocksize, fetcher, size)
5981
self.blocks = set() if blocks is None else blocks
6082
self.location = location
83+
self.multi_fetcher = multi_fetcher
6184
self.cache = self._makefile()
6285

6386
def _makefile(self):
@@ -93,16 +116,30 @@ def _fetch(self, start, end):
93116
start_block = start // self.blocksize
94117
end_block = end // self.blocksize
95118
need = [i for i in range(start_block, end_block + 1) if i not in self.blocks]
119+
ranges = []
96120
while need:
97121
# TODO: not a for loop so we can consolidate blocks later to
98-
# make fewer fetch calls; this could be parallel
122+
# make fewer fetch calls
99123
i = need.pop(0)
100124
sstart = i * self.blocksize
101125
send = min(sstart + self.blocksize, self.size)
102-
logger.debug(f"MMap get block #{i} ({sstart}-{send}")
103-
self.cache[sstart:send] = self.fetcher(sstart, send)
126+
ranges.append((sstart, send))
104127
self.blocks.add(i)
105128

129+
if not ranges:
130+
return self.cache[start:end]
131+
132+
if self.multi_fetcher:
133+
logger.debug(f"MMap get blocks {ranges}")
134+
for idx, r in enumerate(self.multi_fetcher(ranges)):
135+
(sstart, send) = ranges[idx]
136+
logger.debug(f"MMap copy block ({sstart}-{send}")
137+
self.cache[sstart:send] = r
138+
else:
139+
for (sstart, send) in ranges:
140+
logger.debug(f"MMap get block ({sstart}-{send}")
141+
self.cache[sstart:send] = self.fetcher(sstart, send)
142+
106143
return self.cache[start:end]
107144

108145
def __getstate__(self):

Diff for: fsspec/implementations/cached.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,21 @@ def _open(
400400
)
401401
else:
402402
detail["blocksize"] = f.blocksize
403-
f.cache = MMapCache(f.blocksize, f._fetch_range, f.size, fn, blocks)
403+
404+
def _fetch_ranges(ranges):
405+
return self.fs.cat_ranges(
406+
[path] * len(ranges),
407+
[r[0] for r in ranges],
408+
[r[1] for r in ranges],
409+
**kwargs,
410+
)
411+
412+
multi_fetcher = (
413+
None if not self.fs.async_impl or self.compression else _fetch_ranges
414+
)
415+
f.cache = MMapCache(
416+
f.blocksize, f._fetch_range, f.size, fn, blocks, multi_fetcher=multi_fetcher
417+
)
404418
close = f.close
405419
f.close = lambda: self.close_and_update(f, close)
406420
self.save_cache()

Diff for: fsspec/tests/test_caches.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33

44
import pytest
55

6-
from fsspec.caching import BlockCache, FirstChunkCache, caches, register_cache
6+
from fsspec.caching import (
7+
BlockCache,
8+
FirstChunkCache,
9+
MMapCache,
10+
caches,
11+
register_cache,
12+
)
713

814

915
def test_cache_getitem(Cache_imp):
@@ -44,6 +50,11 @@ def letters_fetcher(start, end):
4450
return string.ascii_letters[start:end].encode()
4551

4652

53+
def multi_letters_fetcher(ranges):
54+
print(ranges)
55+
return [string.ascii_letters[start:end].encode() for start, end in ranges]
56+
57+
4758
not_parts_caches = {k: v for k, v in caches.items() if k != "parts"}
4859

4960

@@ -82,6 +93,28 @@ def test_first_cache():
8293
assert c._fetch(1, 4) == letters_fetcher(1, 4)
8394

8495

96+
def test_mmap_cache(mocker):
97+
fetcher = mocker.Mock(wraps=letters_fetcher)
98+
99+
c = MMapCache(5, fetcher, 52)
100+
assert c._fetch(12, 15) == letters_fetcher(12, 15)
101+
assert fetcher.call_count == 2
102+
assert c._fetch(3, 10) == letters_fetcher(3, 10)
103+
assert fetcher.call_count == 4
104+
assert c._fetch(1, 4) == letters_fetcher(1, 4)
105+
assert fetcher.call_count == 4
106+
107+
multi_fetcher = mocker.Mock(wraps=multi_letters_fetcher)
108+
m = MMapCache(5, fetcher, size=52, multi_fetcher=multi_fetcher)
109+
assert m._fetch(12, 15) == letters_fetcher(12, 15)
110+
assert multi_fetcher.call_count == 1
111+
assert m._fetch(3, 10) == letters_fetcher(3, 10)
112+
assert multi_fetcher.call_count == 2
113+
assert m._fetch(1, 4) == letters_fetcher(1, 4)
114+
assert multi_fetcher.call_count == 2
115+
assert fetcher.call_count == 4
116+
117+
85118
@pytest.mark.parametrize(
86119
"size_requests",
87120
[[(0, 30), (0, 35), (51, 52)], [(0, 1), (1, 11), (1, 52)], [(0, 52), (11, 15)]],

0 commit comments

Comments
 (0)