Skip to content

Commit c749971

Browse files
maxgallipre-commit-ci[bot]nsmith-
authored
fix writing via XRootD (called from uproot) (#76)
* fix writing via XRootD (called from uproot) * make calls to file inside _touch async * remove outdated comment * add type annotations * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix mypy errors * fix for test test_basicio.py::test_read_fsspec * update self.loc in case of mode==a, otherwise the write function overwrites * add test for r+b * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicholas Smith <[email protected]>
1 parent 4f30f18 commit c749971

File tree

2 files changed

+112
-8
lines changed

2 files changed

+112
-8
lines changed

src/fsspec_xrootd/xrootd.py

Lines changed: 93 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from collections import defaultdict
1111
from dataclasses import dataclass
1212
from enum import IntEnum
13-
from typing import Any, Callable, Coroutine, Iterable, TypeVar
13+
from typing import Any, Callable, Coroutine, Iterable, TypeVar, cast
1414

1515
from fsspec.asyn import AsyncFileSystem, _run_coros_in_chunks, sync, sync_wrapper
1616
from fsspec.exceptions import FSTimeoutError
@@ -373,9 +373,9 @@ async def _rm_file(self, path: str, **kwargs: Any) -> None:
373373

374374
async def _touch(self, path: str, truncate: bool = False, **kwargs: Any) -> None:
375375
if truncate or not await self._exists(path):
376-
status, _ = await _async_wrap(self._myclient.truncate)(
377-
path, size=0, timeout=self.timeout
378-
)
376+
f = client.File()
377+
status, _ = await _async_wrap(f.open)(path, OpenFlags.DELETE)
378+
await _async_wrap(f.close)()
379379
if not status.ok:
380380
raise OSError(f"File not touched properly: {status.message}")
381381
else:
@@ -756,9 +756,9 @@ def __init__(
756756
from fsspec.core import caches
757757

758758
self.timeout = fs.timeout
759-
# by this point, mode will have a "b" in it
760-
# update "+" mode removed for now since seek() is read only
761-
if "x" in mode:
759+
if mode == "r+b":
760+
self.mode = OpenFlags.UPDATE
761+
elif "x" in mode:
762762
self.mode = OpenFlags.NEW
763763
elif "a" in mode:
764764
self.mode = OpenFlags.UPDATE
@@ -834,7 +834,7 @@ def __init__(
834834

835835
self.kwargs = kwargs
836836

837-
if mode not in {"ab", "rb", "wb"}:
837+
if mode not in {"ab", "rb", "wb", "r+b"}:
838838
raise NotImplementedError("File mode not supported")
839839
if mode == "rb":
840840
if size is not None:
@@ -849,6 +849,13 @@ def __init__(
849849
self.forced = False
850850
self.location = None
851851
self.offset = 0
852+
self.size = self._myFile.stat()[1].size
853+
if mode == "r+b":
854+
self.cache = caches[cache_type](
855+
self.blocksize, self._fetch_range, self.size, **cache_options
856+
)
857+
if "a" in mode:
858+
self.loc = self.size
852859

853860
def _locate_sources(self, logical_filename: str) -> list[str]:
854861
"""Find hosts that have the desired file.
@@ -943,3 +950,81 @@ def close(self) -> None:
943950
if not status.ok:
944951
raise OSError(f"File did not close properly: {status.message}")
945952
self.closed = True
953+
954+
def seek(self, loc: int, whence: int = 0) -> int:
955+
"""Set current file location
956+
957+
Parameters
958+
----------
959+
loc: int
960+
byte location
961+
whence: {0, 1, 2}
962+
from start of file, current location or end of file, resp.
963+
"""
964+
loc = int(loc)
965+
if whence == 0:
966+
nloc = loc
967+
elif whence == 1:
968+
nloc = self.loc + loc
969+
elif whence == 2:
970+
nloc = self.size + loc
971+
else:
972+
raise ValueError(f"invalid whence ({whence}, should be 0, 1 or 2)")
973+
if nloc < 0:
974+
raise ValueError("Seek before start of file")
975+
self.loc = nloc
976+
return self.loc
977+
978+
def writable(self) -> bool:
979+
"""Whether opened for writing"""
980+
return self.mode in {"wb", "ab", "xb", "r+b"} and not self.closed
981+
982+
def write(self, data: bytes) -> int:
983+
"""
984+
Write data to buffer.
985+
986+
Buffer only sent on flush() or if buffer is greater than
987+
or equal to blocksize.
988+
989+
Parameters
990+
----------
991+
data: bytes
992+
Set of bytes to be written.
993+
"""
994+
if not self.writable():
995+
raise ValueError("File not in write mode")
996+
if self.closed:
997+
raise ValueError("I/O operation on closed file.")
998+
if self.forced:
999+
raise ValueError("This file has been force-flushed, can only close")
1000+
status, _n = self._myFile.write(data, self.loc, len(data), timeout=self.timeout)
1001+
self.loc += len(data)
1002+
self.size = max(self.size, self.loc)
1003+
if not status.ok:
1004+
raise OSError(f"File did not write properly: {status.message}")
1005+
return len(data)
1006+
1007+
def read(self, length: int = -1) -> bytes:
1008+
"""
1009+
Return data from cache, or fetch pieces as necessary
1010+
1011+
Parameters
1012+
----------
1013+
length: int (-1)
1014+
Number of bytes to read; if <0, all remaining bytes.
1015+
"""
1016+
length = int(length)
1017+
if self.mode not in {"rb", "r+b"}:
1018+
raise ValueError("File not in read mode")
1019+
if length < 0:
1020+
length = self.size - self.loc
1021+
if self.closed:
1022+
raise ValueError("I/O operation on closed file.")
1023+
if length == 0:
1024+
# don't even bother calling fetch
1025+
return b""
1026+
# for mypy
1027+
out = cast(bytes, self.cache._fetch(self.loc, self.loc + length))
1028+
1029+
self.loc += len(out)
1030+
return out

tests/test_basicio.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,25 @@ def test_write_fsspec(localserver, clear_server):
183183
assert f.read() == TESTDATA1
184184

185185

186+
def test_write_rpb_fsspec(localserver, clear_server):
187+
"""Test writing with r+b as in uproot"""
188+
remoteurl, localpath = localserver
189+
fs, _ = fsspec.core.url_to_fs(remoteurl)
190+
filename = "test.bin"
191+
fs.touch(localpath + "/" + filename)
192+
with fsspec.open(remoteurl + "/" + filename, "r+b") as f:
193+
f.write(b"Hello, this is a test file for r+b mode.")
194+
f.flush()
195+
with fsspec.open(remoteurl + "/" + filename, "r+b") as f:
196+
assert f.read() == b"Hello, this is a test file for r+b mode."
197+
with fsspec.open(remoteurl + "/" + filename, "r+b") as f:
198+
f.seek(len(b"Hello, this is a "))
199+
f.write(b"REPLACED ")
200+
f.flush()
201+
with fsspec.open(remoteurl + "/" + filename, "r+b") as f:
202+
assert f.read() == b"Hello, this is a REPLACED for r+b mode."
203+
204+
186205
@pytest.mark.parametrize("start, end", [(None, None), (None, 10), (1, None), (1, 10)])
187206
def test_read_bytes_fsspec(localserver, clear_server, start, end):
188207
remoteurl, localpath = localserver

0 commit comments

Comments
 (0)