Skip to content

Commit 1c95789

Browse files
committed
Call cat_ranges in blockcache for async filesystems
1 parent 2fbe8de commit 1c95789

File tree

3 files changed

+75
-6
lines changed

3 files changed

+75
-6
lines changed

Diff for: fsspec/caching.py

+39-4
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,35 @@ class MMapCache(BaseCache):
5050
Ensure there is enough disc space in the temporary location.
5151
5252
This cache method might only work on posix
53+
54+
Parameters
55+
----------
56+
blocksize: int
57+
How far to read ahead in numbers of bytes
58+
fetcher: func
59+
Function of the form f(start, end) which gets bytes from remote as
60+
specified
61+
size: int
62+
How big this file is
63+
location: str
64+
Where to create the temporary file. If None, a temporary file is
65+
created using tempfile.TemporaryFile().
66+
blocks: set
67+
Set of block numbers that have already been fetched. If None, an empty
68+
set is created.
69+
multi_fetcher: func
70+
Function of the form f([(start, end)]) which gets bytes from remote
71+
as specified. This function is used to fetch multiple blocks at once.
72+
If not specified, the fetcher function is used instead.
5373
"""
5474

5575
name = "mmap"
5676

57-
def __init__(self, blocksize, fetcher, size, location=None, blocks=None):
77+
def __init__(self, blocksize, fetcher, size, location=None, blocks=None, multi_fetcher=None):
5878
super().__init__(blocksize, fetcher, size)
5979
self.blocks = set() if blocks is None else blocks
6080
self.location = location
81+
self.multi_fetcher = multi_fetcher
6182
self.cache = self._makefile()
6283

6384
def _makefile(self):
@@ -93,16 +114,30 @@ def _fetch(self, start, end):
93114
start_block = start // self.blocksize
94115
end_block = end // self.blocksize
95116
need = [i for i in range(start_block, end_block + 1) if i not in self.blocks]
117+
ranges = []
96118
while need:
97119
# TODO: not a for loop so we can consolidate blocks later to
98-
# make fewer fetch calls; this could be parallel
120+
# make fewer fetch calls
99121
i = need.pop(0)
100122
sstart = i * self.blocksize
101123
send = min(sstart + self.blocksize, self.size)
102-
logger.debug(f"MMap get block #{i} ({sstart}-{send}")
103-
self.cache[sstart:send] = self.fetcher(sstart, send)
124+
ranges.append((sstart, send))
104125
self.blocks.add(i)
105126

127+
if not ranges:
128+
return self.cache[start:end]
129+
130+
if self.multi_fetcher:
131+
logger.debug(f"MMap get blocks {ranges}")
132+
for idx, r in enumerate(self.multi_fetcher(ranges)):
133+
(sstart, send) = ranges[idx]
134+
logger.debug(f"MMap get block ({sstart}-{send}")
135+
self.cache[sstart:send] = r
136+
else:
137+
for (sstart, send) in ranges:
138+
logger.debug(f"MMap get block ({sstart}-{send}")
139+
self.cache[sstart:send] = self.fetcher(sstart, send)
140+
106141
return self.cache[start:end]
107142

108143
def __getstate__(self):

Diff for: fsspec/implementations/cached.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,17 @@ def _open(
400400
)
401401
else:
402402
detail["blocksize"] = f.blocksize
403-
f.cache = MMapCache(f.blocksize, f._fetch_range, f.size, fn, blocks)
403+
404+
def _fetch_ranges(ranges):
405+
return self.fs.cat_ranges(
406+
[path] * len(ranges),
407+
[r[0] for r in ranges],
408+
[r[1] for r in ranges],
409+
**kwargs,
410+
)
411+
412+
multi_fetcher = None if not self.fs.async_impl or self.compression else _fetch_ranges
413+
f.cache = MMapCache(f.blocksize, f._fetch_range, f.size, fn, blocks, multi_fetcher=multi_fetcher)
404414
close = f.close
405415
f.close = lambda: self.close_and_update(f, close)
406416
self.save_cache()

Diff for: fsspec/tests/test_caches.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytest
55

6-
from fsspec.caching import BlockCache, FirstChunkCache, caches, register_cache
6+
from fsspec.caching import BlockCache, FirstChunkCache, MMapCache, caches, register_cache
77

88

99
def test_cache_getitem(Cache_imp):
@@ -43,6 +43,10 @@ def _fetcher(start, end):
4343
def letters_fetcher(start, end):
4444
return string.ascii_letters[start:end].encode()
4545

46+
def multi_letters_fetcher(ranges):
47+
print(ranges)
48+
return [string.ascii_letters[start:end].encode() for start, end in ranges]
49+
4650

4751
not_parts_caches = {k: v for k, v in caches.items() if k != "parts"}
4852

@@ -81,6 +85,26 @@ def test_first_cache():
8185
c.fetcher = None
8286
assert c._fetch(1, 4) == letters_fetcher(1, 4)
8387

88+
def test_mmap_cache(mocker):
89+
fetcher = mocker.Mock(wraps=letters_fetcher)
90+
91+
c = MMapCache(5, fetcher, 52)
92+
assert c._fetch(12, 15) == letters_fetcher(12, 15)
93+
assert fetcher.call_count == 2
94+
assert c._fetch(3, 10) == letters_fetcher(3, 10)
95+
assert fetcher.call_count == 4
96+
assert c._fetch(1, 4) == letters_fetcher(1, 4)
97+
assert fetcher.call_count == 4
98+
99+
multi_fetcher = mocker.Mock(wraps=multi_letters_fetcher)
100+
m = MMapCache(5, fetcher, size=52, multi_fetcher=multi_fetcher)
101+
assert m._fetch(12, 15) == letters_fetcher(12, 15)
102+
assert multi_fetcher.call_count == 1
103+
assert m._fetch(3, 10) == letters_fetcher(3, 10)
104+
assert multi_fetcher.call_count == 2
105+
assert m._fetch(1, 4) == letters_fetcher(1, 4)
106+
assert multi_fetcher.call_count == 2
107+
assert fetcher.call_count == 4
84108

85109
@pytest.mark.parametrize(
86110
"size_requests",

0 commit comments

Comments
 (0)