Skip to content

Commit

Permalink
Merge pull request #243 from jumpstarter-dev/post-init
Browse files Browse the repository at this point in the history
Treewide: harmonize __post_init__
  • Loading branch information
mangelajo authored Feb 3, 2025
2 parents ded89ae + 89bfcc1 commit 675b96c
Show file tree
Hide file tree
Showing 14 changed files with 49 additions and 26 deletions.
3 changes: 2 additions & 1 deletion __templates__/driver/jumpstarter_driver/driver.py.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ class ${DRIVER_CLASS}(Driver):
some_other_config: int = 69

def __post_init__(self):
super().__post_init__()
if hasattr(super(), "__post_init__"):
super().__post_init__()
# some initialization here.

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@ class CanClient(DriverClient, can.BusABC):
"""

def __post_init__(self):
if hasattr(super(), "__post_init__"):
super().__post_init__()

self._periodic_tasks: List[_SelfRemovingCyclicTask] = []
self._filters = None
self._is_shutdown: bool = False

super().__post_init__()

@property
@validate_call(validate_return=True)
def state(self) -> can.BusState:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def client(cls) -> str:
return "jumpstarter_driver_can.client.CanClient"

def __post_init__(self):
super().__post_init__()
if hasattr(super(), "__post_init__"):
super().__post_init__()

self.bus = can.Bus(channel=self.channel, interface=self.interface)

@export
Expand Down Expand Up @@ -195,7 +197,9 @@ def client(cls) -> str:
return "jumpstarter_driver_can.client.IsoTpClient"

def __post_init__(self):
super().__post_init__()
if hasattr(super(), "__post_init__"):
super().__post_init__()

self.bus = can.Bus(channel=self.channel, interface=self.interface)
self.notifier = can.Notifier(self.bus, [])
self.stack = isotp.NotifierBasedCanStack(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class DutlinkConfig:
tty: str | None = field(init=False, default=None)

def __post_init__(self):
if hasattr(super(), "__post_init__"):
super().__post_init__()

for dev in usb.core.find(idVendor=0x2B23, idProduct=0x1012, find_all=True):
serial = usb.util.get_string(dev, dev.iSerialNumber)
if serial == self.serial or self.serial is None:
Expand Down Expand Up @@ -80,15 +83,17 @@ def control(self, direction, ty, actions, action, value):


@dataclass(kw_only=True)
class DutlinkSerial(DutlinkConfig, PySerial):
url: str | None = field(init=False, default=None)

class DutlinkSerialConfig(DutlinkConfig, Driver):
def __post_init__(self):
super().__post_init__()
if hasattr(super(), "__post_init__"):
super().__post_init__()

self.url = self.tty

super(PySerial, self).__post_init__()

@dataclass(kw_only=True)
class DutlinkSerial(PySerial, DutlinkSerialConfig):
url: str | None = field(init=False, default=None)


@dataclass(kw_only=True)
Expand Down Expand Up @@ -247,7 +252,8 @@ class Dutlink(DutlinkConfig, CompositeInterface, Driver):
"""

def __post_init__(self):
super().__post_init__()
if hasattr(super(), "__post_init__"):
super().__post_init__()

self.children["power"] = DutlinkPower(serial=self.serial, timeout_s=self.timeout_s)
self.children["storage"] = DutlinkStorageMux(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ class HttpServer(Driver):
runner: Optional[web.AppRunner] = field(init=False, default=None)

def __post_init__(self):
super().__post_init__()
if hasattr(super(), "__post_init__"):
super().__post_init__()

os.makedirs(self.root_dir, exist_ok=True)
self.app.router.add_routes(
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ class PySerial(Driver):
baudrate: int = field(default=115200)

def __post_init__(self):
super().__post_init__()
if hasattr(super(), "__post_init__"):
super().__post_init__()

self.device = serial_for_url(self.url, baudrate=self.baudrate)

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def client(cls) -> str:
return "jumpstarter_driver_raspberrypi.client.DigitalOutputClient"

def __post_init__(self):
super().__post_init__()
if hasattr(super(), "__post_init__"):
super().__post_init__()
# Initialize as InputDevice first
self.device = InputDevice(pin=self.pin)

Expand Down Expand Up @@ -49,7 +50,8 @@ def client(cls) -> str:
return "jumpstarter_driver_raspberrypi.client.DigitalInputClient"

def __post_init__(self):
super().__post_init__()
if hasattr(super(), "__post_init__"):
super().__post_init__()
self.device = DigitalInputDevice(pin=self.pin)

@export
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ class SDWire(StorageMuxInterface, Driver):
storage_device: str | None = field(default=None)

def __post_init__(self):
super().__post_init__()
if hasattr(super(), "__post_init__"):
super().__post_init__()

for dev in usb.core.find(idVendor=0x04E8, idProduct=0x6001, find_all=True):
if self.storage_device is None:
context = pyudev.Context()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ class Tftp(Driver):
_loop: Optional[asyncio.AbstractEventLoop] = field(init=False, default=None)

def __post_init__(self):
super().__post_init__()
if hasattr(super(), "__post_init__"):
super().__post_init__()

os.makedirs(self.root_dir, exist_ok=True)
if self.host is None:
self.host = self.get_default_ip()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def client(cls) -> str:
return "jumpstarter_driver_ustreamer.client.UStreamerClient"

def __post_init__(self):
super().__post_init__()
if hasattr(super(), "__post_init__"):
super().__post_init__()

cmdline = [self.executable]

Expand Down
3 changes: 2 additions & 1 deletion packages/jumpstarter/jumpstarter/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class AsyncDriverClient(
logger: logging.Logger = field(init=False)

def __post_init__(self):
super().__post_init__()
if hasattr(super(), "__post_init__"):
super().__post_init__()
jumpstarter_pb2_grpc.ExporterServiceStub.__init__(self, self.channel)
router_pb2_grpc.RouterServiceStub.__init__(self, self.channel)
self.logger = logging.getLogger(self.__class__.__name__)
Expand Down
3 changes: 3 additions & 0 deletions packages/jumpstarter/jumpstarter/client/lease.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ class Lease(AbstractContextManager, AbstractAsyncContextManager):
tls_config: TLSConfigV1Alpha1 = field(default_factory=TLSConfigV1Alpha1)

def __post_init__(self):
if hasattr(super(), "__post_init__"):
super().__post_init__()

self.controller = jumpstarter_pb2_grpc.ControllerServiceStub(self.channel)
self.manager = self.portal.wrap_async_context_manager(self)

Expand Down
6 changes: 0 additions & 6 deletions packages/jumpstarter/jumpstarter/common/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@ class Metadata:
uuid: UUID = field(default_factory=uuid4)
labels: dict[str, str] = field(default_factory=dict)

def __post_init__(self):
pass

@property
def name(self):
return self.labels.get("jumpstarter.dev/name", "unknown")
Expand All @@ -20,6 +17,3 @@ def name(self):
@dataclass(kw_only=True, slots=True)
class MetadataFilter:
labels: dict[str, str] = field(default_factory=dict)

def __post_init__(self):
pass
4 changes: 3 additions & 1 deletion packages/jumpstarter/jumpstarter/driver/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ class Driver(
logger: logging.Logger = field(init=False)

def __post_init__(self):
super().__post_init__()
if hasattr(super(), "__post_init__"):
super().__post_init__()

self.logger = logging.getLogger(self.__class__.__name__)
self.logger.setLevel(self.log_level)

Expand Down

0 comments on commit 675b96c

Please sign in to comment.