Skip to content

Commit 852d072

Browse files
authored
Refactor _SharedCache to handle context vs non-context ownership (#38620)
* Refactor logging in subprocess_server.py for improved observability * Introduce context-aware ownership in SubprocessServer cache Refactors `_SharedCache` to support an `is_context` flag for registered owners. This prevents independent non-context subprocesses (e.g., prism runner and expansion service) from sharing ownership of each other's keys, while allowing context wrappers to properly track all active subprocesses. * Address comments.
1 parent 45b79d2 commit 852d072

2 files changed

Lines changed: 78 additions & 19 deletions

File tree

sdks/python/apache_beam/utils/subprocess_server.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class _SharedCache:
7272
def __init__(self, constructor, destructor):
7373
self._constructor = constructor
7474
self._destructor = destructor
75-
self._live_owners = set()
75+
self._live_owners = {}
7676
self._cache = {}
7777
self._lock = threading.Lock()
7878
self._counter = 0
@@ -82,10 +82,10 @@ def _next_id(self):
8282
self._counter += 1
8383
return self._counter
8484

85-
def register(self):
85+
def register(self, is_context=False):
8686
with self._lock:
8787
owner = self._next_id()
88-
self._live_owners.add(owner)
88+
self._live_owners[owner] = is_context
8989
return owner
9090

9191
def purge(self, owner):
@@ -97,7 +97,7 @@ def purge(self, owner):
9797
"shutdown, the subprocess was already cleaned up earlier.",
9898
owner)
9999
return
100-
self._live_owners.remove(owner)
100+
del self._live_owners[owner]
101101
for key, entry in list(self._cache.items()):
102102
if owner in entry.owners:
103103
entry.owners.remove(owner)
@@ -108,14 +108,23 @@ def purge(self, owner):
108108
for value in to_delete:
109109
self._destructor(value)
110110

111-
def get(self, *key):
112-
if not self._live_owners:
113-
raise RuntimeError("At least one owner must be registered.")
111+
def get(self, *key, owner=None):
114112
with self._lock:
113+
if not self._live_owners:
114+
raise RuntimeError("At least one owner must be registered.")
115+
if owner is not None and owner not in self._live_owners:
116+
raise RuntimeError("The requesting owner must be registered.")
117+
115118
if key not in self._cache:
116119
self._cache[key] = _SharedCacheEntry(self._constructor(*key), set())
117-
for owner in self._live_owners:
120+
if owner is not None:
118121
self._cache[key].owners.add(owner)
122+
for live_owner, is_context in self._live_owners.items():
123+
if is_context:
124+
self._cache[key].owners.add(live_owner)
125+
else:
126+
for live_owner in self._live_owners:
127+
self._cache[key].owners.add(live_owner)
119128
return self._cache[key].obj
120129

121130
def force_remove(self, *key):
@@ -180,7 +189,7 @@ def cache_subprocesses(cls):
180189
These subprocesses may be shared with other contexts as well.
181190
"""
182191
try:
183-
unique_id = cls._cache.register()
192+
unique_id = cls._cache.register(is_context=True)
184193
yield
185194
finally:
186195
cls._cache.purge(unique_id)
@@ -214,7 +223,7 @@ def start(self):
214223
channel_ready = grpc.channel_ready_future(self._grpc_channel)
215224
while True:
216225
if process is not None and process.poll() is not None:
217-
_LOGGER.error("Started job service with %s", process.args)
226+
_LOGGER.error("Failed to start job service with %s", process.args)
218227
raise RuntimeError(
219228
'Service failed to start up with error %s' % process.poll())
220229
try:
@@ -235,15 +244,16 @@ def start(self):
235244
def start_process(self):
236245
if self._owner_id is not None:
237246
self._cache.purge(self._owner_id)
238-
self._owner_id = self._cache.register()
239-
return self._cache.get(tuple(self._cmd), self._port, self._logger)
247+
self._owner_id = self._cache.register(is_context=False)
248+
return self._cache.get(
249+
tuple(self._cmd), self._port, self._logger, owner=self._owner_id)
240250

241251
def _really_start_process(cmd, port, logger):
242252
if not port:
243253
port, = pick_port(None)
244254
cmd = [arg.replace('{{PORT}}', str(port)) for arg in cmd] # pylint: disable=not-an-iterable
245255
endpoint = 'localhost:%s' % port
246-
_LOGGER.info("Starting service with %s", str(cmd).replace("',", "'"))
256+
_LOGGER.warning("Really starting service at %s with cmd: %s", endpoint, cmd)
247257
process = subprocess.Popen(
248258
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
249259

@@ -295,9 +305,11 @@ def stop_force(self):
295305
self._grpc_channel = None
296306

297307
def _really_stop_process(process_and_endpoint):
298-
process, _ = process_and_endpoint # pylint: disable=unpacking-non-sequence
308+
process, endpoint = process_and_endpoint # pylint: disable=unpacking-non-sequence
299309
if not process:
300310
return
311+
_LOGGER.warning(
312+
"Really destroying service at %s with cmd: %s", endpoint, process.args)
301313
for _ in range(5):
302314
if process.poll() is not None:
303315
break

sdks/python/apache_beam/utils/subprocess_server_test.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -402,16 +402,16 @@ def mock_unregister(cb):
402402
self.assertEqual(len(registered_callbacks), 1)
403403

404404
def test_concurrent_purge_race_condition(self):
405-
# Concurrent threads attempting to check memebership and call purge for the same owner.
406-
# Here we explicitly define a synchronized set to mimic the behavior of _live_owners.
407-
# This set will block two threads on __contains__, allowing us to test the race condition.
405+
# Concurrent threads attempting to check membership and call purge for the same owner.
406+
# Here we explicitly define a synchronized dict to mimic the behavior of _live_owners.
407+
# This dict will block two threads on __contains__, allowing us to test the race condition.
408408
cache = subprocess_server._SharedCache(lambda x: "obj", lambda x: None)
409409
owner = cache.register()
410410

411411
barrier = threading.Barrier(2)
412412
exceptions = []
413413

414-
class SynchronizedSet(set):
414+
class SynchronizedDict(dict):
415415
def __contains__(self, item):
416416
res = super().__contains__(item)
417417
try:
@@ -421,7 +421,7 @@ def __contains__(self, item):
421421
pass
422422
return res
423423

424-
cache._live_owners = SynchronizedSet(cache._live_owners)
424+
cache._live_owners = SynchronizedDict(cache._live_owners)
425425

426426
def purge_worker():
427427
try:
@@ -551,6 +551,53 @@ def __init__(self):
551551
# Clean up the other owner
552552
cache.purge(other_owner)
553553

554+
def test_non_context_owners_do_not_share_keys(self):
555+
cache = subprocess_server._SharedCache(self.with_prefix, lambda x: None)
556+
# owner1 is a non-context owner (e.g., prism)
557+
owner1 = cache.register(is_context=False)
558+
a = cache.get('a', owner=owner1)
559+
560+
# owner2 is another non-context owner (e.g., short-lived expansion service)
561+
owner2 = cache.register(is_context=False)
562+
b = cache.get('b', owner=owner2)
563+
564+
# Verify that owner1 does not own 'b'
565+
self.assertNotIn(owner1, cache._cache[('b', )].owners)
566+
567+
# Verify that owner2 does not own 'a'
568+
self.assertNotIn(owner2, cache._cache[('a', )].owners)
569+
570+
# Purging owner2 should immediately destroy/remove 'b'
571+
cache.purge(owner2)
572+
self.assertNotIn(('b', ), cache._cache)
573+
574+
# 'a' is still alive because owner1 is still registered
575+
self.assertIn(('a', ), cache._cache)
576+
577+
# Purging owner1 should destroy/remove 'a'
578+
cache.purge(owner1)
579+
self.assertNotIn(('a', ), cache._cache)
580+
581+
def test_context_owner_owns_all_keys(self):
582+
cache = subprocess_server._SharedCache(self.with_prefix, lambda x: None)
583+
# owner1 is a non-context owner (e.g., prism)
584+
owner1 = cache.register(is_context=False)
585+
586+
# owner2 is a context owner (e.g., cache_subprocesses)
587+
owner2 = cache.register(is_context=True)
588+
589+
# owner3 is another non-context owner (e.g., short-lived service)
590+
owner3 = cache.register(is_context=False)
591+
592+
# owner3 requests 'b'
593+
b = cache.get('b', owner=owner3)
594+
595+
# owner2 (context) should own 'b'
596+
self.assertIn(owner2, cache._cache[('b', )].owners)
597+
598+
# owner1 (non-context) should NOT own 'b'
599+
self.assertNotIn(owner1, cache._cache[('b', )].owners)
600+
554601

555602
if __name__ == '__main__':
556603
unittest.main()

0 commit comments

Comments
 (0)