Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Be careful about freeing callback trampolines #64

Merged
merged 3 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions src/systemd_ctypes/bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,7 @@ class Slot(libsystemd.sd_bus_slot):
def __init__(self, callback: Callable[[BusMessage], bool]):
def handler(message: WeakReference, _data: object, _err: object) -> int:
return 1 if callback(BusMessage.ref(message)) else 0
self.callback = libsystemd.sd_bus_message_handler_t(handler)
self.userdata = None

def cancel(self) -> None:
self._unref()
self.value = None
self.trampoline = libsystemd.sd_bus_message_handler_t(handler)


if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -363,7 +358,7 @@ async def call_async(
timeout: Optional[int] = None
) -> BusMessage:
pending = PendingCall()
self._call_async(byref(pending), message, pending.callback, pending.userdata, timeout or 0)
self._call_async(byref(pending), message, pending.trampoline, pending.userdata, timeout or 0)
return await pending.future

async def call_method_async(
Expand All @@ -384,12 +379,12 @@ async def call_method_async(

def add_match(self, rule: str, handler: Callable[[BusMessage], bool]) -> Slot:
slot = Slot(handler)
self._add_match(byref(slot), rule, slot.callback, slot.userdata)
self._add_match(byref(slot), rule, slot.trampoline, slot.userdata)
return slot

def add_object(self, path: str, obj: 'BaseObject') -> Slot:
slot = Slot(obj.message_received)
self._add_object(byref(slot), path, slot.callback, slot.userdata)
self._add_object(byref(slot), path, slot.trampoline, slot.userdata)
obj.registered_on_bus(self, path)
return slot

Expand Down
21 changes: 15 additions & 6 deletions src/systemd_ctypes/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,11 @@
from typing import Callable, ClassVar, Coroutine, List, Optional, Tuple

from . import inotify, libsystemd
from .librarywrapper import Callback, Reference, UserData, byref
from .librarywrapper import Reference, UserData, byref


class Event(libsystemd.sd_event):
class Source(libsystemd.sd_event_source):
callback: Callback
userdata: UserData = None

def cancel(self) -> None:
self._unref()
self.value = None
Expand All @@ -52,11 +49,11 @@ def callback(source: libsystemd.sd_event_source,
event = _event.contents
handler(inotify.Event(event.mask), event.cookie, event.name)
return 0
self.callback = libsystemd.sd_event_inotify_handler_t(callback)
self.trampoline = libsystemd.sd_event_inotify_handler_t(callback)

def add_inotify(self, path: str, mask: inotify.Event, handler: InotifyHandler) -> InotifySource:
source = Event.InotifySource(handler)
self._add_inotify(byref(source), path, mask, source.callback, source.userdata)
self._add_inotify(byref(source), path, mask, source.trampoline, source.userdata)
return source

def add_inotify_fd(self, fd: int, mask: inotify.Event, handler: InotifyHandler) -> InotifySource:
Expand All @@ -78,6 +75,14 @@ def __init__(self, event: Optional[Event] = None) -> None:
def select(
self, timeout: Optional[float] = None
) -> List[Tuple[selectors.SelectorKey, int]]:
# It's common to drop the last reference to a Source or Slot object on
# a dispatch of that same source/slot from the main loop. If we happen
# to garbage collect before returning, the trampoline could be
# destroyed before we're done using it. Provide a mechanism to defer
# the destruction of trampolines for as long as we might be
# dispatching. This gets cleared again at the bottom, before return.
libsystemd.Trampoline.deferred = []

while self.sd_event.prepare():
self.sd_event.dispatch()
ready = super().select(timeout)
Expand All @@ -87,6 +92,10 @@ def select(
self.sd_event.dispatch()
while self.sd_event.prepare():
self.sd_event.dispatch()

# We can be sure we're not dispatching callbacks anymore
libsystemd.Trampoline.deferred = None
Comment on lines +96 to +97
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, so that's the final cleanup of the deferred trampolines. Is there only ever one instance of Selector? This doesn't feel correct if there could be multiple ones

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. This is why I'm not super happy about this fix. If we tried to run independent mainloops in separate threads, this would indeed be incorrect. Trying to do something "more correct" here is hard, though, and we never have anything but the default loop running in the main thread, so ...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine. This could do with an assertion that there's only a single instance, but good enough now!


# This could return zero events with infinite timeout, but nobody seems to mind.
return [(key, events) for (key, events) in ready if key != self.key]

Expand Down
3 changes: 0 additions & 3 deletions src/systemd_ctypes/librarywrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,6 @@ class ReferenceType(ctypes.c_void_p):
def _install_cfuncs(cls, cdll: ctypes.CDLL) -> None:
logger.debug('Installing stubs for %s:', cls)
stubs = tuple(cls.__dict__.items())
if cls.__name__ == 'sd_bus':
assert True, stubs

for name, stub in stubs:
if name.startswith("__"):
continue
Expand Down
25 changes: 22 additions & 3 deletions src/systemd_ctypes/libsystemd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import ctypes
import os
import sys
from typing import List, Optional, Tuple, Union
from typing import ClassVar, List, Optional, Tuple, Union

from .inotify import inotify_event
from .librarywrapper import (
Expand All @@ -33,6 +33,25 @@
from .typing import Annotated


class Trampoline(ReferenceType):
deferred: 'ClassVar[list[Callback] | None]' = None
trampoline: Callback
userdata: UserData = None

def cancel(self) -> None:
self._unref()
self.value = None

def __del__(self) -> None:
# This might be the currently-dispatching callback — make sure we don't
# destroy the trampoline before we return. We drop the deferred list
# from the event loop when we're sure we're not doing any dispatches.
if Trampoline.deferred is not None:
Trampoline.deferred.append(self.trampoline)
if self.value is not None:
self._unref()


class sd_bus_error(ctypes.Structure):
# This is ABI, so we are safe to assume it doesn't change.
# Unfortunately, we lack anything like sd_bus_error_new().
Expand Down Expand Up @@ -65,7 +84,7 @@ class sd_id128(ctypes.Structure):
)


class sd_event_source(ReferenceType):
class sd_event_source(Trampoline):
...


Expand Down Expand Up @@ -105,7 +124,7 @@ def _default(ret: Reference['sd_event']) -> Union[None, Errno]:
...


class sd_bus_slot(ReferenceType):
class sd_bus_slot(Trampoline):
...


Expand Down
40 changes: 22 additions & 18 deletions test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import unittest

import dbusmock # type: ignore[import] # not typed
import pytest

import systemd_ctypes
from systemd_ctypes import bus, introspection
Expand Down Expand Up @@ -141,13 +142,13 @@ def test_int_async(self):
def test_int_error(self):
# int overflow
self.add_method('', 'Inc', 'i', 'i', 'ret = args[0] + 1')
with self.assertRaisesRegex(systemd_ctypes.BusError, 'OverflowError'):
with pytest.raises(systemd_ctypes.BusError, match='OverflowError'):
self.bus_user.call_method(*TEST_ADDR, 'Inc', 'i', 0x7FFFFFFF)

# uint underflow
self.add_method('', 'Dec', 'u', 'u', 'ret = args[0] - 1')
with self.assertRaisesRegex(systemd_ctypes.BusError,
"OverflowError: can't convert negative value to unsigned int"):
with pytest.raises(systemd_ctypes.BusError,
match="OverflowError: can't convert negative value to unsigned int"):
self.bus_user.call_method(*TEST_ADDR, 'Dec', 'u', 0)

def test_float(self):
Expand Down Expand Up @@ -213,35 +214,35 @@ def test_base64_binary_decode(self):
self.assertEqual(result, ['R8OkbnNlZsO8w59jaGVu'])

def test_unknown_method_sync(self):
with self.assertRaisesRegex(systemd_ctypes.BusError, '.*org.freedesktop.DBus.Error.UnknownMethod:.*'
'Do is not a valid method of interface org.freedesktop.Test.Main'):
with pytest.raises(systemd_ctypes.BusError, match='.*org.freedesktop.DBus.Error.UnknownMethod:.*'
'Do is not a valid method of interface org.freedesktop.Test.Main'):
self.bus_user.call_method(*TEST_ADDR, 'Do')

def test_unknown_method_async(self):
message = self.bus_user.message_new_method_call(*TEST_ADDR, 'Do')
with self.assertRaisesRegex(systemd_ctypes.BusError, '.*org.freedesktop.DBus.Error.UnknownMethod:.*'
'Do is not a valid method of interface org.freedesktop.Test.Main'):
with pytest.raises(systemd_ctypes.BusError, match='.*org.freedesktop.DBus.Error.UnknownMethod:.*'
'Do is not a valid method of interface org.freedesktop.Test.Main'):
self.async_call(message).get_body()

def test_call_signature_mismatch(self):
self.add_method('', 'Inc', 'i', 'i', 'ret = args[0] + 1')
# specified signature does not match server, but locally consistent args
with self.assertRaisesRegex(systemd_ctypes.BusError,
'(InvalidArgs|TypeError).*Fewer items.*signature.*arguments'):
with pytest.raises(systemd_ctypes.BusError,
match='(InvalidArgs|TypeError).*Fewer items.*signature.*arguments'):
self.bus_user.call_method(*TEST_ADDR, 'Inc', 'ii', 1, 2)
with self.assertRaisesRegex(systemd_ctypes.BusError, 'InvalidArgs|TypeError'):
with pytest.raises(systemd_ctypes.BusError, match='InvalidArgs|TypeError'):
self.bus_user.call_method(*TEST_ADDR, 'Inc', 's', 'hello.*dbus.String.*integer')

# specified signature does not match arguments
with self.assertRaisesRegex(AssertionError, r'call args \(1, 2\) have different length than signature.*'):
with pytest.raises(AssertionError, match=r'call args \(1, 2\) have different length than signature.*'):
self.bus_user.call_method(*TEST_ADDR, 'Inc', 'i', 1, 2)
with self.assertRaisesRegex(TypeError, r'.*str.* as.* integer|int.*str'):
with pytest.raises(TypeError, match=r'.*str.* as.* integer|int.*str'):
self.bus_user.call_method(*TEST_ADDR, 'Inc', 'i', 'hello')

def test_custom_error(self):
self.add_method('', 'Boom', '', '',
'raise dbus.exceptions.DBusException("no good", name="com.example.Error.NoGood")')
with self.assertRaisesRegex(systemd_ctypes.BusError, 'no good'):
with pytest.raises(systemd_ctypes.BusError, match='no good'):
self.bus_user.call_method(*TEST_ADDR, 'Boom')

def test_introspect(self):
Expand Down Expand Up @@ -283,17 +284,20 @@ def test_service_replace(self):

def test_request_name_errors(self):
# name already exists
self.assertRaises(FileExistsError, self.bus_user.request_name, TEST_ADDR[0], bus.Bus.NameFlags.DEFAULT)
with pytest.raises(FileExistsError):
self.bus_user.request_name(TEST_ADDR[0], bus.Bus.NameFlags.DEFAULT)

# invalid name
self.assertRaisesRegex(OSError, '.*Invalid argument',
self.bus_user.request_name, '', bus.Bus.NameFlags.DEFAULT)
with pytest.raises(OSError, match='.*Invalid argument'):
self.bus_user.request_name('', bus.Bus.NameFlags.DEFAULT)

# invalid flag
self.assertRaisesRegex(OSError, '.*Invalid argument', self.bus_user.request_name, TEST_ADDR[0], 0xFF)
with pytest.raises(OSError, match='.*Invalid argument'):
self.bus_user.request_name(TEST_ADDR[0], 0xFF)

# name not taken
self.assertRaises(ProcessLookupError, self.bus_user.release_name, 'com.example.NotThis')
with pytest.raises(ProcessLookupError):
self.bus_user.release_name('com.example.NotThis')


if __name__ == '__main__':
Expand Down
9 changes: 4 additions & 5 deletions test/test_p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,13 @@ async def test():

def test_method_throws(self):
async def test():
with self.assertRaisesRegex(BusError, 'cockpit.Error.ZeroDivisionError: Divide by zero'):
with pytest.raises(BusError, match='cockpit.Error.ZeroDivisionError: Divide by zero'):
await self.client.call_method_async(None, '/test', 'cockpit.Test', 'Divide', 'ii', 1554, 0)
run_async(test())

def test_method_throws_oserror(self):
async def test():
with self.assertRaisesRegex(BusError, 'org.freedesktop.DBus.Error.FileNotFound: .*notthere.*'):
with pytest.raises(BusError, match='org.freedesktop.DBus.Error.FileNotFound: .*notthere.*'):
await self.client.call_method_async(None, '/test', 'cockpit.Test', 'ReadFile', 's', 'notthere')
run_async(test())

Expand All @@ -206,7 +206,7 @@ async def test():

def test_async_method_throws(self):
async def test():
with self.assertRaisesRegex(BusError, 'cockpit.Error.ZeroDivisionError: Divide by zero'):
with pytest.raises(BusError, match='cockpit.Error.ZeroDivisionError: Divide by zero'):
await self.client.call_method_async(None, '/test', 'cockpit.Test', 'DivideSlowly', 'ii', 1554, 0)
run_async(test())

Expand Down Expand Up @@ -249,8 +249,7 @@ async def test():
# Make sure that dropping the slot results in the object being un-exported
self.test_object_slot = None

with self.assertRaisesRegex(
BusError, "org.freedesktop.DBus.Error.UnknownObject: Unknown object '/test'."):
with pytest.raises(BusError, match="org.freedesktop.DBus.Error.UnknownObject: Unknown object '/test'."):
await self.client.call_method_async(None, '/test', 'cockpit.Test', 'Divide', 'ii', 1554, 37)
run_async(test())

Expand Down
Loading