diff --git a/__templates__/driver/jumpstarter_driver/driver.py.tmpl b/__templates__/driver/jumpstarter_driver/driver.py.tmpl index 3e11ab6e..cbd3f1a6 100644 --- a/__templates__/driver/jumpstarter_driver/driver.py.tmpl +++ b/__templates__/driver/jumpstarter_driver/driver.py.tmpl @@ -10,7 +10,8 @@ class ${DRIVER_CLASS}(Driver): some_other_config: int = 69 def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() # some initialization here. @classmethod diff --git a/packages/jumpstarter-driver-can/jumpstarter_driver_can/client.py b/packages/jumpstarter-driver-can/jumpstarter_driver_can/client.py index a19738c6..c90d6fc5 100644 --- a/packages/jumpstarter-driver-can/jumpstarter_driver_can/client.py +++ b/packages/jumpstarter-driver-can/jumpstarter_driver_can/client.py @@ -33,12 +33,13 @@ class CanClient(DriverClient, can.BusABC): """ def __post_init__(self): + if hasattr(super(), "__post_init__"): + super().__post_init__() + self._periodic_tasks: List[_SelfRemovingCyclicTask] = [] self._filters = None self._is_shutdown: bool = False - super().__post_init__() - @property @validate_call(validate_return=True) def state(self) -> can.BusState: diff --git a/packages/jumpstarter-driver-can/jumpstarter_driver_can/driver.py b/packages/jumpstarter-driver-can/jumpstarter_driver_can/driver.py index fbcec8af..b8b04edc 100644 --- a/packages/jumpstarter-driver-can/jumpstarter_driver_can/driver.py +++ b/packages/jumpstarter-driver-can/jumpstarter_driver_can/driver.py @@ -45,7 +45,9 @@ def client(cls) -> str: return "jumpstarter_driver_can.client.CanClient" def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() + self.bus = can.Bus(channel=self.channel, interface=self.interface) @export @@ -195,7 +197,9 @@ def client(cls) -> str: return "jumpstarter_driver_can.client.IsoTpClient" def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() + self.bus = can.Bus(channel=self.channel, interface=self.interface) self.notifier = can.Notifier(self.bus, []) self.stack = isotp.NotifierBasedCanStack( diff --git a/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/conftest.py b/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/conftest.py new file mode 100644 index 00000000..09fc64cb --- /dev/null +++ b/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/conftest.py @@ -0,0 +1,13 @@ +import pytest +import usb + + +def pytest_runtest_call(item): + try: + item.runtest() + except FileNotFoundError: + pytest.skip("dutlink not available") + except usb.core.USBError: + pytest.skip("USB not available") + except usb.core.NoBackendError: + pytest.skip("No USB backend") diff --git a/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver.py b/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver.py index 552e0ca5..39de02db 100644 --- a/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver.py +++ b/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver.py @@ -29,6 +29,9 @@ class DutlinkConfig: tty: str | None = field(init=False, default=None) def __post_init__(self): + if hasattr(super(), "__post_init__"): + super().__post_init__() + for dev in usb.core.find(idVendor=0x2B23, idProduct=0x1012, find_all=True): serial = usb.util.get_string(dev, dev.iSerialNumber) if serial == self.serial or self.serial is None: @@ -80,15 +83,17 @@ def control(self, direction, ty, actions, action, value): @dataclass(kw_only=True) -class DutlinkSerial(DutlinkConfig, PySerial): - url: str | None = field(init=False, default=None) - +class DutlinkSerialConfig(DutlinkConfig, Driver): def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() self.url = self.tty - super(PySerial, self).__post_init__() + +@dataclass(kw_only=True) +class DutlinkSerial(PySerial, DutlinkSerialConfig): + url: str | None = field(init=False, default=None) @dataclass(kw_only=True) @@ -247,7 +252,8 @@ class Dutlink(DutlinkConfig, CompositeInterface, Driver): """ def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() self.children["power"] = DutlinkPower(serial=self.serial, timeout_s=self.timeout_s) self.children["storage"] = DutlinkStorageMux( diff --git a/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver_test.py b/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver_test.py index e4385707..9a1c4a95 100644 --- a/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver_test.py +++ b/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver_test.py @@ -1,42 +1,64 @@ -import pytest -import usb +from time import sleep + from jumpstarter_driver_network.adapters import PexpectAdapter -from jumpstarter_driver_dutlink.driver import Dutlink +from jumpstarter_driver_dutlink.driver import Dutlink, DutlinkPower, DutlinkSerial, DutlinkStorageMux from jumpstarter.common.utils import serve +STORAGE_DEVICE = "/dev/null" # MANUAL: replace with path to block device -def test_drivers_dutlink(): - try: - instance = Dutlink( - storage_device="/dev/null", - ) - except FileNotFoundError: - pytest.skip("dutlink not available") - except usb.core.USBError: - pytest.skip("USB not available") - except usb.core.NoBackendError: - pytest.skip("No USB backend") + +def power_test(power): + power.on() # MANUAL: led DUT_ON should be turned on + sleep(1) + assert next(power.read()).current != 0 + power.off() # MANUAL: led DUT_ON should be turned off + + +def storage_test(storage): + storage.write_local_file("/dev/null") + + +def serial_test(serial): + with PexpectAdapter(client=serial) as expect: + expect.send("\x02" * 5) + + expect.send("about\r\n") + expect.expect("Jumpstarter test-harness") + + expect.send("console\r\n") + expect.expect("Entering console mode") + + expect.send("hello") + expect.expect("hello") + + +def test_drivers_dutlink_power(): + instance = DutlinkPower() with serve(instance) as client: - with PexpectAdapter(client=client.console) as expect: - expect.send("\x02" * 5) + power_test(client) - expect.send("about\r\n") - expect.expect("Jumpstarter test-harness") - expect.send("console\r\n") - expect.expect("Entering console mode") +def test_drivers_dutlink_storage_mux(): + instance = DutlinkStorageMux(storage_device=STORAGE_DEVICE) - client.power.off() + with serve(instance) as client: + storage_test(client) - client.storage.write_local_file("/dev/null") - client.storage.dut() - client.power.on() +def test_drivers_dutlink_serial(): + instance = DutlinkSerial() # MANUAL: connect tx to rx - expect.send("\x02" * 5) - expect.expect("Exiting console mode") + with serve(instance) as client: + serial_test(client) - client.power.off() + +def test_drivers_dutlink(): + instance = Dutlink(storage_device=STORAGE_DEVICE) + + with serve(instance) as client: + power_test(client.power) + storage_test(client.storage) + serial_test(client.console) diff --git a/packages/jumpstarter-driver-http/jumpstarter_driver_http/driver.py b/packages/jumpstarter-driver-http/jumpstarter_driver_http/driver.py index 0b41f259..649e7bdb 100644 --- a/packages/jumpstarter-driver-http/jumpstarter_driver_http/driver.py +++ b/packages/jumpstarter-driver-http/jumpstarter_driver_http/driver.py @@ -29,7 +29,9 @@ class HttpServer(Driver): runner: Optional[web.AppRunner] = field(init=False, default=None) def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() + os.makedirs(self.root_dir, exist_ok=True) self.app.router.add_routes( [ diff --git a/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/driver.py b/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/driver.py index 3de0a3b5..44a0d757 100644 --- a/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/driver.py +++ b/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/driver.py @@ -33,7 +33,9 @@ class PySerial(Driver): baudrate: int = field(default=115200) def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() + self.device = serial_for_url(self.url, baudrate=self.baudrate) @classmethod diff --git a/packages/jumpstarter-driver-raspberrypi/jumpstarter_driver_raspberrypi/driver.py b/packages/jumpstarter-driver-raspberrypi/jumpstarter_driver_raspberrypi/driver.py index 20c4a0cb..f159d234 100644 --- a/packages/jumpstarter-driver-raspberrypi/jumpstarter_driver_raspberrypi/driver.py +++ b/packages/jumpstarter-driver-raspberrypi/jumpstarter_driver_raspberrypi/driver.py @@ -15,7 +15,8 @@ def client(cls) -> str: return "jumpstarter_driver_raspberrypi.client.DigitalOutputClient" def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() # Initialize as InputDevice first self.device = InputDevice(pin=self.pin) @@ -49,7 +50,8 @@ def client(cls) -> str: return "jumpstarter_driver_raspberrypi.client.DigitalInputClient" def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() self.device = DigitalInputDevice(pin=self.pin) @export diff --git a/packages/jumpstarter-driver-sdwire/jumpstarter_driver_sdwire/driver.py b/packages/jumpstarter-driver-sdwire/jumpstarter_driver_sdwire/driver.py index f7c6d826..bcbf9906 100644 --- a/packages/jumpstarter-driver-sdwire/jumpstarter_driver_sdwire/driver.py +++ b/packages/jumpstarter-driver-sdwire/jumpstarter_driver_sdwire/driver.py @@ -22,7 +22,9 @@ class SDWire(StorageMuxInterface, Driver): storage_device: str | None = field(default=None) def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() + for dev in usb.core.find(idVendor=0x04E8, idProduct=0x6001, find_all=True): if self.storage_device is None: context = pyudev.Context() diff --git a/packages/jumpstarter-driver-tftp/examples/tftp_test.py b/packages/jumpstarter-driver-tftp/examples/tftp_test.py index ba7b7067..735fcc14 100644 --- a/packages/jumpstarter-driver-tftp/examples/tftp_test.py +++ b/packages/jumpstarter-driver-tftp/examples/tftp_test.py @@ -1,5 +1,4 @@ import logging -import time import pytest from jumpstarter_driver_tftp.driver import FileNotFound, TftpError @@ -7,45 +6,32 @@ log = logging.getLogger(__name__) - class TestResource(JumpstarterTest): filter_labels = {"board": "rpi4"} @pytest.fixture() - def test_tftp_upload(self, client): + def setup_tftp(self, client): + # Move the setup code to a fixture + client.tftp.start() + yield client + client.tftp.stop() + + def test_tftp_operations(self, setup_tftp): + client = setup_tftp + test_file = "test.bin" + + # Create test file + with open(test_file, "wb") as f: + f.write(b"Hello from TFTP streaming test!") + try: - client.tftp.start() - print("TFTP server started") - - time.sleep(1) - - test_file = "test.bin" - with open(test_file, "wb") as f: - f.write(b"Hello from TFTP streaming test!") - - try: - client.tftp.put_local_file(test_file) - print(f"Successfully uploaded {test_file}") - - files = client.tftp.list_files() - print(f"Files in TFTP root: {files}") - - if test_file in files: - client.tftp.delete_file(test_file) - print(f"Successfully deleted {test_file}") - else: - print(f"Warning: {test_file} not found in TFTP root") - - except TftpError as e: - print(f"TFTP operation failed: {e}") - except FileNotFound as e: - print(f"File not found: {e}") - - except Exception as e: - print(f"Error: {e}") - finally: - try: - client.tftp.stop() - print("TFTP server stopped") - except Exception as e: - print(f"Error stopping server: {e}") + # Test upload + client.tftp.put_local_file(test_file) + assert test_file in client.tftp.list_files() + + # Test delete + client.tftp.delete_file(test_file) + assert test_file not in client.tftp.list_files() + + except (TftpError, FileNotFound) as e: + pytest.fail(f"Test failed: {e}") diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/__init__.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/__init__.py index e69de29b..fc318846 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/__init__.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/__init__.py @@ -0,0 +1 @@ +CHUNK_SIZE = 1024 * 1024 * 4 # 4MB diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/client.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/client.py index cc22910c..24081eea 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/client.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/client.py @@ -1,9 +1,11 @@ +import hashlib from dataclasses import dataclass from pathlib import Path from jumpstarter_driver_opendal.adapter import OpendalAdapter from opendal import Operator +from . import CHUNK_SIZE from jumpstarter.client import DriverClient @@ -46,37 +48,30 @@ def list_files(self) -> list[str]: return self.call("list_files") def put_file(self, operator: Operator, path: str): - """ - Upload a file to the TFTP server using an OpenDAL operator + filename = Path(path).name + client_checksum = self._compute_checksum(operator, path) - Args: - operator (Operator): OpenDAL operator for accessing the source storage - path (str): Path to the file in the source storage system + if self.call("check_file_checksum", filename, client_checksum): + self.logger.info(f"Skipping upload of identical file: {filename}") + return filename - Returns: - str: Name of the uploaded file - """ - filename = Path(path).name with OpendalAdapter(client=self, operator=operator, path=path, mode="rb") as handle: - return self.call("put_file", filename, handle) + return self.call("put_file", filename, handle, client_checksum) def put_local_file(self, filepath: str): - """ - Upload a file from the local filesystem to the TFTP server - Note: this doesn't use TFTP to upload. + absolute = Path(filepath).resolve() + filename = absolute.name - Args: - filepath (str): Path to the local file to upload + operator = Operator("fs", root="/") + client_checksum = self._compute_checksum(operator, str(absolute)) - Returns: - str: Name of the uploaded file + if self.call("check_file_checksum", filename, client_checksum): + self.logger.info(f"Skipping upload of identical file: {filename}") + return filename - Example: - >>> client.put_local_file("/path/to/local/file.txt") - """ - absolute = Path(filepath).resolve() - with OpendalAdapter(client=self, operator=Operator("fs", root="/"), path=str(absolute), mode="rb") as handle: - return self.call("put_file", absolute.name, handle) + self.logger.info(f"checksum: {client_checksum}") + with OpendalAdapter(client=self, operator=operator, path=str(absolute), mode="rb") as handle: + return self.call("put_file", filename, handle, client_checksum) def delete_file(self, filename: str): """ @@ -108,3 +103,10 @@ def get_port(self) -> int: int: The port number (default is 69) """ return self.call("get_port") + + def _compute_checksum(self, operator: Operator, path: str) -> str: + hasher = hashlib.sha256() + with operator.open(path, "rb") as f: + while chunk := f.read(CHUNK_SIZE): + hasher.update(chunk) + return hasher.hexdigest() diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py index f30d3a99..5ad98809 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py @@ -1,4 +1,5 @@ import asyncio +import hashlib import os import socket import threading @@ -10,33 +11,28 @@ from jumpstarter_driver_tftp.server import TftpServer +from . import CHUNK_SIZE from jumpstarter.driver import Driver, export class TftpError(Exception): """Base exception for TFTP server errors""" - pass - class ServerNotRunning(TftpError): """Server is not running""" - pass - class FileNotFound(TftpError): """File not found""" - pass - @dataclass(kw_only=True) class Tftp(Driver): """TFTP Server driver for Jumpstarter""" root_dir: str = "/var/lib/tftpboot" - host: str = field(default=None) + host: str = field(default='') port: int = 69 server: Optional["TftpServer"] = field(init=False, default=None) server_thread: Optional[threading.Thread] = field(init=False, default=None) @@ -45,9 +41,11 @@ class Tftp(Driver): _loop: Optional[asyncio.AbstractEventLoop] = field(init=False, default=None) def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() + os.makedirs(self.root_dir, exist_ok=True) - if self.host is None: + if self.host == '': self.host = self.get_default_ip() def get_default_ip(self): @@ -69,10 +67,7 @@ def _start_server(self): asyncio.set_event_loop(self._loop) self.server = TftpServer(host=self.host, port=self.port, root_dir=self.root_dir) try: - # Signal that the loop is ready self._loop_ready.set() - - # Run the server until shutdown is requested self._loop.run_until_complete(self._run_server()) except Exception as e: self.logger.error(f"Error running TFTP server: {e}") @@ -82,7 +77,6 @@ def _start_server(self): self._loop.close() except Exception as e: self.logger.error(f"Error during event loop cleanup: {e}") - self._loop = None self.logger.info("TFTP server thread completed") @@ -107,11 +101,9 @@ def start(self): self.logger.warning("TFTP server is already running") return - # Clear any previous shutdown state self._shutdown_event.clear() self._loop_ready.clear() - # Start the server thread self.server_thread = threading.Thread(target=self._start_server, daemon=True) self.server_thread.start() @@ -129,7 +121,6 @@ def stop(self): return self.logger.info("Initiating TFTP server shutdown") - self._shutdown_event.set() self.server_thread.join(timeout=10) if self.server_thread.is_alive(): @@ -143,11 +134,10 @@ def list_files(self) -> list[str]: return os.listdir(self.root_dir) @export - async def put_file(self, filename: str, src_stream): - """Handle file upload using streaming""" - try: - file_path = os.path.join(self.root_dir, filename) + async def put_file(self, filename: str, src_stream, client_checksum: str): + file_path = os.path.join(self.root_dir, filename) + try: if not Path(file_path).resolve().is_relative_to(Path(self.root_dir).resolve()): raise TftpError("Invalid target path") @@ -157,19 +147,38 @@ async def put_file(self, filename: str, src_stream): await dst.send(chunk) return filename - except Exception as e: raise TftpError(f"Failed to upload file: {str(e)}") from e @export def delete_file(self, filename: str): + file_path = os.path.join(self.root_dir, filename) + + if not os.path.exists(file_path): + raise FileNotFound(f"File {filename} not found") + try: - os.remove(os.path.join(self.root_dir, filename)) - except FileNotFoundError as err: - raise FileNotFound(f"File {filename} not found") from err + os.remove(file_path) + return filename except Exception as e: raise TftpError(f"Failed to delete {filename}") from e + @export + def check_file_checksum(self, filename: str, client_checksum: str) -> bool: + file_path = os.path.join(self.root_dir, filename) + self.logger.debug(f"checking checksum for file: {filename}") + self.logger.debug(f"file path: {file_path}") + + if not os.path.exists(file_path): + self.logger.debug(f"File {filename} does not exist") + return False + + current_checksum = self._compute_checksum(file_path) + self.logger.debug(f"Computed checksum: {current_checksum}") + self.logger.debug(f"Client checksum: {client_checksum}") + + return current_checksum == client_checksum + @export def get_host(self) -> str: return self.host @@ -182,3 +191,10 @@ def close(self): if self.server_thread is not None: self.stop() super().close() + + def _compute_checksum(self, path: str) -> str: + hasher = hashlib.sha256() + with open(path, "rb") as f: + while chunk := f.read(CHUNK_SIZE): + hasher.update(chunk) + return hasher.hexdigest() diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py index 4bd74163..3f0f6911 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py @@ -1,3 +1,4 @@ +import hashlib import os import tempfile from pathlib import Path @@ -20,18 +21,17 @@ def temp_dir(): with tempfile.TemporaryDirectory() as tmpdir: yield tmpdir - @pytest.fixture def server(temp_dir): server = Tftp(root_dir=temp_dir, host="127.0.0.1") yield server server.close() - @pytest.mark.anyio async def test_tftp_file_operations(server): filename = "test.txt" test_data = b"Hello" + client_checksum = hashlib.sha256(test_data).hexdigest() send_stream, receive_stream = create_memory_object_stream(max_buffer_size=10) @@ -46,8 +46,7 @@ async def send_data(): async with anyio.create_task_group() as tg: tg.start_soon(send_data) - - await server.put_file(filename, resource_handle) + await server.put_file(filename, resource_handle, client_checksum) files = server.list_files() assert filename in files @@ -61,20 +60,48 @@ async def send_data(): with pytest.raises(FileNotFound): server.delete_file("nonexistent.txt") - def test_tftp_host_config(temp_dir): custom_host = "192.168.1.1" server = Tftp(root_dir=temp_dir, host=custom_host) assert server.get_host() == custom_host - def test_tftp_root_directory_creation(temp_dir): new_dir = os.path.join(temp_dir, "new_tftp_root") server = Tftp(root_dir=new_dir) assert os.path.exists(new_dir) server.close() +@pytest.mark.anyio +async def test_tftp_detect_corrupted_file(server): + filename = "corrupted.txt" + original_data = b"Original Data" + client_checksum = hashlib.sha256(original_data).hexdigest() + + await _upload_file(server, filename, original_data) + + assert server.check_file_checksum(filename, client_checksum) + + file_path = Path(server.root_dir, filename) + file_path.write_bytes(b"corrupted Data") + + assert not server.check_file_checksum(filename, client_checksum) @pytest.fixture def anyio_backend(): return "asyncio" + +async def _upload_file(server, filename: str, data: bytes) -> str: + send_stream, receive_stream = create_memory_object_stream() + resource_uuid = uuid4() + server.resources[resource_uuid] = receive_stream + resource_handle = ClientStreamResource(uuid=resource_uuid).model_dump(mode="json") + + async def send_data(): + await send_stream.send(data) + await send_stream.aclose() + + async with anyio.create_task_group() as tg: + tg.start_soon(send_data) + await server.put_file(filename, resource_handle, hashlib.sha256(data).hexdigest()) + + return hashlib.sha256(data).hexdigest() diff --git a/packages/jumpstarter-driver-tftp/pyproject.toml b/packages/jumpstarter-driver-tftp/pyproject.toml index 830c47bb..50dfe17e 100644 --- a/packages/jumpstarter-driver-tftp/pyproject.toml +++ b/packages/jumpstarter-driver-tftp/pyproject.toml @@ -32,13 +32,12 @@ raw-options = { 'root' = '../../'} Homepage = "https://jumpstarter.dev" source_archive = "https://github.com/jumpstarter-dev/repo/archive/{commit_hash}.zip" -# [tool.pytest.ini_options] -# #addopts = "--cov --cov-report=html --cov-report=xml" -# log_cli = true -# log_cli_level = "INFO" -# # testpaths = ["src"] -# asyncio_mode = "auto" +[tool.pytest.ini_options] +log_cli = true +log_cli_level = "INFO" +testpaths = ["jumpstarter_driver_tftp"] +asyncio_mode = "auto" [build-system] requires = ["hatchling", "hatch-vcs"] -build-backend = "hatchling.build" \ No newline at end of file +build-backend = "hatchling.build" diff --git a/packages/jumpstarter-driver-ustreamer/jumpstarter_driver_ustreamer/driver.py b/packages/jumpstarter-driver-ustreamer/jumpstarter_driver_ustreamer/driver.py index a4de3571..4a6b3d11 100644 --- a/packages/jumpstarter-driver-ustreamer/jumpstarter_driver_ustreamer/driver.py +++ b/packages/jumpstarter-driver-ustreamer/jumpstarter_driver_ustreamer/driver.py @@ -34,7 +34,8 @@ def client(cls) -> str: return "jumpstarter_driver_ustreamer.client.UStreamerClient" def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() cmdline = [self.executable] diff --git a/packages/jumpstarter/jumpstarter/client/core.py b/packages/jumpstarter/jumpstarter/client/core.py index df413ce7..f8e81ab9 100644 --- a/packages/jumpstarter/jumpstarter/client/core.py +++ b/packages/jumpstarter/jumpstarter/client/core.py @@ -45,7 +45,8 @@ class AsyncDriverClient( logger: logging.Logger = field(init=False) def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() jumpstarter_pb2_grpc.ExporterServiceStub.__init__(self, self.channel) router_pb2_grpc.RouterServiceStub.__init__(self, self.channel) self.logger = logging.getLogger(self.__class__.__name__) diff --git a/packages/jumpstarter/jumpstarter/client/lease.py b/packages/jumpstarter/jumpstarter/client/lease.py index d7f66d49..d57aab5e 100644 --- a/packages/jumpstarter/jumpstarter/client/lease.py +++ b/packages/jumpstarter/jumpstarter/client/lease.py @@ -31,6 +31,9 @@ class Lease(AbstractContextManager, AbstractAsyncContextManager): tls_config: TLSConfigV1Alpha1 = field(default_factory=TLSConfigV1Alpha1) def __post_init__(self): + if hasattr(super(), "__post_init__"): + super().__post_init__() + self.controller = jumpstarter_pb2_grpc.ControllerServiceStub(self.channel) self.manager = self.portal.wrap_async_context_manager(self) diff --git a/packages/jumpstarter/jumpstarter/common/metadata.py b/packages/jumpstarter/jumpstarter/common/metadata.py index 4f8ec424..fbacc96e 100644 --- a/packages/jumpstarter/jumpstarter/common/metadata.py +++ b/packages/jumpstarter/jumpstarter/common/metadata.py @@ -9,9 +9,6 @@ class Metadata: uuid: UUID = field(default_factory=uuid4) labels: dict[str, str] = field(default_factory=dict) - def __post_init__(self): - pass - @property def name(self): return self.labels.get("jumpstarter.dev/name", "unknown") @@ -20,6 +17,3 @@ def name(self): @dataclass(kw_only=True, slots=True) class MetadataFilter: labels: dict[str, str] = field(default_factory=dict) - - def __post_init__(self): - pass diff --git a/packages/jumpstarter/jumpstarter/driver/base.py b/packages/jumpstarter/jumpstarter/driver/base.py index 6934bf89..d99e0e34 100644 --- a/packages/jumpstarter/jumpstarter/driver/base.py +++ b/packages/jumpstarter/jumpstarter/driver/base.py @@ -60,7 +60,9 @@ class Driver( logger: logging.Logger = field(init=False) def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() + self.logger = logging.getLogger(self.__class__.__name__) self.logger.setLevel(self.log_level)