diff --git a/src/fsspec_xrootd/xrootd.py b/src/fsspec_xrootd/xrootd.py index b9c5841..214e05f 100644 --- a/src/fsspec_xrootd/xrootd.py +++ b/src/fsspec_xrootd/xrootd.py @@ -10,7 +10,7 @@ from collections import defaultdict from dataclasses import dataclass from enum import IntEnum -from typing import Any, Callable, Coroutine, Iterable, TypeVar +from typing import Any, Callable, Coroutine, Iterable, TypeVar, cast from fsspec.asyn import AsyncFileSystem, _run_coros_in_chunks, sync, sync_wrapper from fsspec.exceptions import FSTimeoutError @@ -373,9 +373,9 @@ async def _rm_file(self, path: str, **kwargs: Any) -> None: async def _touch(self, path: str, truncate: bool = False, **kwargs: Any) -> None: if truncate or not await self._exists(path): - status, _ = await _async_wrap(self._myclient.truncate)( - path, size=0, timeout=self.timeout - ) + f = client.File() + status, _ = await _async_wrap(f.open)(path, OpenFlags.DELETE) + await _async_wrap(f.close)() if not status.ok: raise OSError(f"File not touched properly: {status.message}") else: @@ -756,9 +756,9 @@ def __init__( from fsspec.core import caches self.timeout = fs.timeout - # by this point, mode will have a "b" in it - # update "+" mode removed for now since seek() is read only - if "x" in mode: + if mode == "r+b": + self.mode = OpenFlags.UPDATE + elif "x" in mode: self.mode = OpenFlags.NEW elif "a" in mode: self.mode = OpenFlags.UPDATE @@ -834,7 +834,7 @@ def __init__( self.kwargs = kwargs - if mode not in {"ab", "rb", "wb"}: + if mode not in {"ab", "rb", "wb", "r+b"}: raise NotImplementedError("File mode not supported") if mode == "rb": if size is not None: @@ -849,6 +849,13 @@ def __init__( self.forced = False self.location = None self.offset = 0 + self.size = self._myFile.stat()[1].size + if mode == "r+b": + self.cache = caches[cache_type]( + self.blocksize, self._fetch_range, self.size, **cache_options + ) + if "a" in mode: + self.loc = self.size def _locate_sources(self, logical_filename: str) -> list[str]: """Find hosts that have the desired file. @@ -943,3 +950,81 @@ def close(self) -> None: if not status.ok: raise OSError(f"File did not close properly: {status.message}") self.closed = True + + def seek(self, loc: int, whence: int = 0) -> int: + """Set current file location + + Parameters + ---------- + loc: int + byte location + whence: {0, 1, 2} + from start of file, current location or end of file, resp. + """ + loc = int(loc) + if whence == 0: + nloc = loc + elif whence == 1: + nloc = self.loc + loc + elif whence == 2: + nloc = self.size + loc + else: + raise ValueError(f"invalid whence ({whence}, should be 0, 1 or 2)") + if nloc < 0: + raise ValueError("Seek before start of file") + self.loc = nloc + return self.loc + + def writable(self) -> bool: + """Whether opened for writing""" + return self.mode in {"wb", "ab", "xb", "r+b"} and not self.closed + + def write(self, data: bytes) -> int: + """ + Write data to buffer. + + Buffer only sent on flush() or if buffer is greater than + or equal to blocksize. + + Parameters + ---------- + data: bytes + Set of bytes to be written. + """ + if not self.writable(): + raise ValueError("File not in write mode") + if self.closed: + raise ValueError("I/O operation on closed file.") + if self.forced: + raise ValueError("This file has been force-flushed, can only close") + status, _n = self._myFile.write(data, self.loc, len(data), timeout=self.timeout) + self.loc += len(data) + self.size = max(self.size, self.loc) + if not status.ok: + raise OSError(f"File did not write properly: {status.message}") + return len(data) + + def read(self, length: int = -1) -> bytes: + """ + Return data from cache, or fetch pieces as necessary + + Parameters + ---------- + length: int (-1) + Number of bytes to read; if <0, all remaining bytes. + """ + length = int(length) + if self.mode not in {"rb", "r+b"}: + raise ValueError("File not in read mode") + if length < 0: + length = self.size - self.loc + if self.closed: + raise ValueError("I/O operation on closed file.") + if length == 0: + # don't even bother calling fetch + return b"" + # for mypy + out = cast(bytes, self.cache._fetch(self.loc, self.loc + length)) + + self.loc += len(out) + return out diff --git a/tests/test_basicio.py b/tests/test_basicio.py index db8dec8..454b3f5 100644 --- a/tests/test_basicio.py +++ b/tests/test_basicio.py @@ -183,6 +183,25 @@ def test_write_fsspec(localserver, clear_server): assert f.read() == TESTDATA1 +def test_write_rpb_fsspec(localserver, clear_server): + """Test writing with r+b as in uproot""" + remoteurl, localpath = localserver + fs, _ = fsspec.core.url_to_fs(remoteurl) + filename = "test.bin" + fs.touch(localpath + "/" + filename) + with fsspec.open(remoteurl + "/" + filename, "r+b") as f: + f.write(b"Hello, this is a test file for r+b mode.") + f.flush() + with fsspec.open(remoteurl + "/" + filename, "r+b") as f: + assert f.read() == b"Hello, this is a test file for r+b mode." + with fsspec.open(remoteurl + "/" + filename, "r+b") as f: + f.seek(len(b"Hello, this is a ")) + f.write(b"REPLACED ") + f.flush() + with fsspec.open(remoteurl + "/" + filename, "r+b") as f: + assert f.read() == b"Hello, this is a REPLACED for r+b mode." + + @pytest.mark.parametrize("start, end", [(None, None), (None, 10), (1, None), (1, 10)]) def test_read_bytes_fsspec(localserver, clear_server, start, end): remoteurl, localpath = localserver