From a53786ef424f234954cd8c25eb5b77f9c35558ed Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Wed, 29 Jan 2025 10:26:49 -0500 Subject: [PATCH 1/2] Treewide: harmonize __post_init__ --- .../driver/jumpstarter_driver/driver.py.tmpl | 3 ++- .../jumpstarter_driver_can/client.py | 5 +++-- .../jumpstarter_driver_can/driver.py | 8 ++++++-- .../jumpstarter_driver_dutlink/driver.py | 11 +++++++---- .../jumpstarter_driver_http/driver.py | 4 +++- .../jumpstarter_driver_pyserial/driver.py | 4 +++- .../jumpstarter_driver_raspberrypi/driver.py | 6 ++++-- .../jumpstarter_driver_sdwire/driver.py | 4 +++- .../jumpstarter_driver_tftp/driver.py | 4 +++- .../jumpstarter_driver_ustreamer/driver.py | 3 ++- packages/jumpstarter/jumpstarter/client/core.py | 3 ++- packages/jumpstarter/jumpstarter/client/lease.py | 3 +++ packages/jumpstarter/jumpstarter/common/metadata.py | 6 ------ packages/jumpstarter/jumpstarter/driver/base.py | 4 +++- 14 files changed, 44 insertions(+), 24 deletions(-) diff --git a/__templates__/driver/jumpstarter_driver/driver.py.tmpl b/__templates__/driver/jumpstarter_driver/driver.py.tmpl index 3e11ab6e..cbd3f1a6 100644 --- a/__templates__/driver/jumpstarter_driver/driver.py.tmpl +++ b/__templates__/driver/jumpstarter_driver/driver.py.tmpl @@ -10,7 +10,8 @@ class ${DRIVER_CLASS}(Driver): some_other_config: int = 69 def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() # some initialization here. @classmethod diff --git a/packages/jumpstarter-driver-can/jumpstarter_driver_can/client.py b/packages/jumpstarter-driver-can/jumpstarter_driver_can/client.py index a19738c6..c90d6fc5 100644 --- a/packages/jumpstarter-driver-can/jumpstarter_driver_can/client.py +++ b/packages/jumpstarter-driver-can/jumpstarter_driver_can/client.py @@ -33,12 +33,13 @@ class CanClient(DriverClient, can.BusABC): """ def __post_init__(self): + if hasattr(super(), "__post_init__"): + super().__post_init__() + self._periodic_tasks: List[_SelfRemovingCyclicTask] = [] self._filters = None self._is_shutdown: bool = False - super().__post_init__() - @property @validate_call(validate_return=True) def state(self) -> can.BusState: diff --git a/packages/jumpstarter-driver-can/jumpstarter_driver_can/driver.py b/packages/jumpstarter-driver-can/jumpstarter_driver_can/driver.py index fbcec8af..b8b04edc 100644 --- a/packages/jumpstarter-driver-can/jumpstarter_driver_can/driver.py +++ b/packages/jumpstarter-driver-can/jumpstarter_driver_can/driver.py @@ -45,7 +45,9 @@ def client(cls) -> str: return "jumpstarter_driver_can.client.CanClient" def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() + self.bus = can.Bus(channel=self.channel, interface=self.interface) @export @@ -195,7 +197,9 @@ def client(cls) -> str: return "jumpstarter_driver_can.client.IsoTpClient" def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() + self.bus = can.Bus(channel=self.channel, interface=self.interface) self.notifier = can.Notifier(self.bus, []) self.stack = isotp.NotifierBasedCanStack( diff --git a/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver.py b/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver.py index 552e0ca5..36980e3a 100644 --- a/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver.py +++ b/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver.py @@ -29,6 +29,9 @@ class DutlinkConfig: tty: str | None = field(init=False, default=None) def __post_init__(self): + if hasattr(super(), "__post_init__"): + super().__post_init__() + for dev in usb.core.find(idVendor=0x2B23, idProduct=0x1012, find_all=True): serial = usb.util.get_string(dev, dev.iSerialNumber) if serial == self.serial or self.serial is None: @@ -84,12 +87,11 @@ class DutlinkSerial(DutlinkConfig, PySerial): url: str | None = field(init=False, default=None) def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() self.url = self.tty - super(PySerial, self).__post_init__() - @dataclass(kw_only=True) class DutlinkPower(DutlinkConfig, PowerInterface, Driver): @@ -247,7 +249,8 @@ class Dutlink(DutlinkConfig, CompositeInterface, Driver): """ def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() self.children["power"] = DutlinkPower(serial=self.serial, timeout_s=self.timeout_s) self.children["storage"] = DutlinkStorageMux( diff --git a/packages/jumpstarter-driver-http/jumpstarter_driver_http/driver.py b/packages/jumpstarter-driver-http/jumpstarter_driver_http/driver.py index 0b41f259..649e7bdb 100644 --- a/packages/jumpstarter-driver-http/jumpstarter_driver_http/driver.py +++ b/packages/jumpstarter-driver-http/jumpstarter_driver_http/driver.py @@ -29,7 +29,9 @@ class HttpServer(Driver): runner: Optional[web.AppRunner] = field(init=False, default=None) def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() + os.makedirs(self.root_dir, exist_ok=True) self.app.router.add_routes( [ diff --git a/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/driver.py b/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/driver.py index 3de0a3b5..44a0d757 100644 --- a/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/driver.py +++ b/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/driver.py @@ -33,7 +33,9 @@ class PySerial(Driver): baudrate: int = field(default=115200) def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() + self.device = serial_for_url(self.url, baudrate=self.baudrate) @classmethod diff --git a/packages/jumpstarter-driver-raspberrypi/jumpstarter_driver_raspberrypi/driver.py b/packages/jumpstarter-driver-raspberrypi/jumpstarter_driver_raspberrypi/driver.py index 20c4a0cb..f159d234 100644 --- a/packages/jumpstarter-driver-raspberrypi/jumpstarter_driver_raspberrypi/driver.py +++ b/packages/jumpstarter-driver-raspberrypi/jumpstarter_driver_raspberrypi/driver.py @@ -15,7 +15,8 @@ def client(cls) -> str: return "jumpstarter_driver_raspberrypi.client.DigitalOutputClient" def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() # Initialize as InputDevice first self.device = InputDevice(pin=self.pin) @@ -49,7 +50,8 @@ def client(cls) -> str: return "jumpstarter_driver_raspberrypi.client.DigitalInputClient" def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() self.device = DigitalInputDevice(pin=self.pin) @export diff --git a/packages/jumpstarter-driver-sdwire/jumpstarter_driver_sdwire/driver.py b/packages/jumpstarter-driver-sdwire/jumpstarter_driver_sdwire/driver.py index f7c6d826..bcbf9906 100644 --- a/packages/jumpstarter-driver-sdwire/jumpstarter_driver_sdwire/driver.py +++ b/packages/jumpstarter-driver-sdwire/jumpstarter_driver_sdwire/driver.py @@ -22,7 +22,9 @@ class SDWire(StorageMuxInterface, Driver): storage_device: str | None = field(default=None) def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() + for dev in usb.core.find(idVendor=0x04E8, idProduct=0x6001, find_all=True): if self.storage_device is None: context = pyudev.Context() diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py index f30d3a99..43013b3a 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py @@ -45,7 +45,9 @@ class Tftp(Driver): _loop: Optional[asyncio.AbstractEventLoop] = field(init=False, default=None) def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() + os.makedirs(self.root_dir, exist_ok=True) if self.host is None: self.host = self.get_default_ip() diff --git a/packages/jumpstarter-driver-ustreamer/jumpstarter_driver_ustreamer/driver.py b/packages/jumpstarter-driver-ustreamer/jumpstarter_driver_ustreamer/driver.py index a4de3571..4a6b3d11 100644 --- a/packages/jumpstarter-driver-ustreamer/jumpstarter_driver_ustreamer/driver.py +++ b/packages/jumpstarter-driver-ustreamer/jumpstarter_driver_ustreamer/driver.py @@ -34,7 +34,8 @@ def client(cls) -> str: return "jumpstarter_driver_ustreamer.client.UStreamerClient" def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() cmdline = [self.executable] diff --git a/packages/jumpstarter/jumpstarter/client/core.py b/packages/jumpstarter/jumpstarter/client/core.py index df413ce7..f8e81ab9 100644 --- a/packages/jumpstarter/jumpstarter/client/core.py +++ b/packages/jumpstarter/jumpstarter/client/core.py @@ -45,7 +45,8 @@ class AsyncDriverClient( logger: logging.Logger = field(init=False) def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() jumpstarter_pb2_grpc.ExporterServiceStub.__init__(self, self.channel) router_pb2_grpc.RouterServiceStub.__init__(self, self.channel) self.logger = logging.getLogger(self.__class__.__name__) diff --git a/packages/jumpstarter/jumpstarter/client/lease.py b/packages/jumpstarter/jumpstarter/client/lease.py index d7f66d49..d57aab5e 100644 --- a/packages/jumpstarter/jumpstarter/client/lease.py +++ b/packages/jumpstarter/jumpstarter/client/lease.py @@ -31,6 +31,9 @@ class Lease(AbstractContextManager, AbstractAsyncContextManager): tls_config: TLSConfigV1Alpha1 = field(default_factory=TLSConfigV1Alpha1) def __post_init__(self): + if hasattr(super(), "__post_init__"): + super().__post_init__() + self.controller = jumpstarter_pb2_grpc.ControllerServiceStub(self.channel) self.manager = self.portal.wrap_async_context_manager(self) diff --git a/packages/jumpstarter/jumpstarter/common/metadata.py b/packages/jumpstarter/jumpstarter/common/metadata.py index 4f8ec424..fbacc96e 100644 --- a/packages/jumpstarter/jumpstarter/common/metadata.py +++ b/packages/jumpstarter/jumpstarter/common/metadata.py @@ -9,9 +9,6 @@ class Metadata: uuid: UUID = field(default_factory=uuid4) labels: dict[str, str] = field(default_factory=dict) - def __post_init__(self): - pass - @property def name(self): return self.labels.get("jumpstarter.dev/name", "unknown") @@ -20,6 +17,3 @@ def name(self): @dataclass(kw_only=True, slots=True) class MetadataFilter: labels: dict[str, str] = field(default_factory=dict) - - def __post_init__(self): - pass diff --git a/packages/jumpstarter/jumpstarter/driver/base.py b/packages/jumpstarter/jumpstarter/driver/base.py index 6934bf89..d99e0e34 100644 --- a/packages/jumpstarter/jumpstarter/driver/base.py +++ b/packages/jumpstarter/jumpstarter/driver/base.py @@ -60,7 +60,9 @@ class Driver( logger: logging.Logger = field(init=False) def __post_init__(self): - super().__post_init__() + if hasattr(super(), "__post_init__"): + super().__post_init__() + self.logger = logging.getLogger(self.__class__.__name__) self.logger.setLevel(self.log_level) From 89bfcc17b435e9c99f7c1ab1b3681bbf071c9589 Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Wed, 29 Jan 2025 10:50:22 -0500 Subject: [PATCH 2/2] Fixup DutlinkSerial initialization --- .../jumpstarter_driver_dutlink/driver.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver.py b/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver.py index 36980e3a..39de02db 100644 --- a/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver.py +++ b/packages/jumpstarter-driver-dutlink/jumpstarter_driver_dutlink/driver.py @@ -83,9 +83,7 @@ def control(self, direction, ty, actions, action, value): @dataclass(kw_only=True) -class DutlinkSerial(DutlinkConfig, PySerial): - url: str | None = field(init=False, default=None) - +class DutlinkSerialConfig(DutlinkConfig, Driver): def __post_init__(self): if hasattr(super(), "__post_init__"): super().__post_init__() @@ -93,6 +91,11 @@ def __post_init__(self): self.url = self.tty +@dataclass(kw_only=True) +class DutlinkSerial(PySerial, DutlinkSerialConfig): + url: str | None = field(init=False, default=None) + + @dataclass(kw_only=True) class DutlinkPower(DutlinkConfig, PowerInterface, Driver): last_action: str | None = field(default=None)