diff --git a/datalad/runner/runnertools.py b/datalad/runner/runnertools.py new file mode 100644 index 0000000000..9de90c87fa --- /dev/null +++ b/datalad/runner/runnertools.py @@ -0,0 +1,61 @@ +"""Helper that extend protocols and runner with new functionality""" +from __future__ import annotations + +from .coreprotocols import WitlessProtocol + + +def wrap_logger(klass: type[WitlessProtocol], + stdout_io, + stderr_io + ) -> type[WitlessProtocol]: + """ + This function modifies the class given in ``klass`` in such a way that it + copies all data received in ``pipe_data_received`` to ``stdout_io`` or + ``stderr_io`` depending on the file descriptor, and then passes it on to + the original ``pipe_data_received`` method. + + Parameters + ---------- + klass: type[WitlessProtocol] + The class to modify + stdout_io: + The io object to write stdout data to, it has to support the method + ``write``, which is called with ``bytes`` data. + stderr_io + The io object to write stderr data to, it has to support the method + ``write``, which is called with ``bytes`` data. + + Returns + ------- + type[WitlessProtocol] + The modified class + """ + def pipe_data_received(self, fd: int, data: bytes) -> None: + if fd == 1: + stdout_io.write(data) + elif fd == 2: + stderr_io.write(data) + original_pipe_data_received(self, fd, data) + + original_pipe_data_received = klass.pipe_data_received + klass.pipe_data_received = pipe_data_received + return klass + + +class WrapLogger: + """A decorator for classes that allows logging of stdout and stderr data""" + def __init__(self, stdout_io, stderr_io): + self.stdout_io = stdout_io + self.stderr_io = stderr_io + + def __call__(self, klass): + def pipe_data_received(obj, fd: int, data: bytes) -> None: + if fd == 1: + self.stdout_io.write(data) + elif fd == 2: + self.stderr_io.write(data) + original_pipe_data_received(obj, fd, data) + + original_pipe_data_received = klass.pipe_data_received + klass.pipe_data_received = pipe_data_received + return klass diff --git a/datalad/runner/tests/test_loggingwrapper.py b/datalad/runner/tests/test_loggingwrapper.py new file mode 100644 index 0000000000..0b8fbdf11b --- /dev/null +++ b/datalad/runner/tests/test_loggingwrapper.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from io import BytesIO +from typing import ( + Any, + Optional, +) + +from datalad.runner.coreprotocols import StdOutErrCapture +from datalad.runner.runnertools import ( + WrapLogger, + wrap_logger, +) + + +class TestProtocol(StdOutErrCapture): + + __test__ = False # class is not a class of tests + + def __init__(self, + output_list: list, + done_future: Any = None, + encoding: Optional[str] = None, + ) -> None: + + StdOutErrCapture.__init__( + self, + done_future=done_future, + encoding=encoding) + self.output_list = output_list + + def pipe_data_received(self, fd: int, data: bytes) -> None: + self.output_list.append((fd, data)) + + +def test_wrapping_function_on_class() -> None: + output_list = [] + stdout_io = BytesIO() + stderr_io = BytesIO() + + wrap_logger(TestProtocol, stdout_io, stderr_io) + _test_protocol_instance(TestProtocol( + output_list=output_list), + output_list, + stdout_io, + stderr_io + ) + + +def test_decorator(): + output = [] + stdout_io = BytesIO() + stderr_io = BytesIO() + + @WrapLogger(stdout_io, stderr_io) + class TestProtocol2(StdOutErrCapture): + def pipe_data_received(self, fd: int, data: bytes) -> None: + output.append((fd, data)) + + _test_protocol_instance(TestProtocol2(), output, stdout_io, stderr_io) + + +def _test_protocol_instance(protocol_instance, output, stdout_io, stderr_io): + protocol_instance.pipe_data_received(1, b"stdout") + protocol_instance.pipe_data_received(2, b"stderr") + + assert output == [(1, b"stdout"), (2, b"stderr")] + assert stdout_io.getvalue() == b"stdout" + assert stderr_io.getvalue() == b"stderr"