Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
865edec
Make RUNNING transition atomic across replicas
psfoley Mar 18, 2026
52345c8
Fix format and tests
psfoley Mar 18, 2026
8a9d066
Potential fix for pull request finding
psfoley Mar 18, 2026
02e9fc9
Potential fix for pull request finding
psfoley Mar 18, 2026
138f96a
Merge branch 'main' into atomic-run-claim-fix
psfoley Mar 18, 2026
aab7b1c
Address Github Copilot suggestions
psfoley Mar 18, 2026
221f062
Fix import order
psfoley Mar 18, 2026
6581d87
Fix failing tests
psfoley Mar 18, 2026
4adc658
Fix CI failure
psfoley Mar 19, 2026
bc023d4
Merge branch 'main' into atomic-run-claim-fix
msheller Mar 23, 2026
712d7d8
Merge branch 'main' into atomic-run-claim-fix
msheller Mar 25, 2026
08e7675
Removed message claiming race condition fixes that appear to have bee…
msheller Mar 25, 2026
a705542
Added race condition test for atomic run claiming
msheller Mar 25, 2026
33942a6
Merge branch 'main' into atomic-run-claim-fix
msheller Mar 26, 2026
f43cee2
Consolidate test files; Apply new 'test_ha_*' naming convention
psfoley Mar 26, 2026
6e566cc
Fix too many lines error
psfoley Mar 26, 2026
8cb865e
Merge branch 'main' into atomic-run-claim-fix
psfoley Mar 26, 2026
210d3e3
Merge branch 'main' into atomic-run-claim-fix
psfoley Mar 27, 2026
5ae1819
Merge branch 'main' into atomic-run-claim-fix
psfoley Mar 27, 2026
c80bf8f
Fix CI tests
psfoley Mar 27, 2026
80e3437
Fix CI tests
psfoley Mar 27, 2026
757c511
revert(e2e): drop heartbeat test changes from branch
psfoley Mar 27, 2026
82f9e42
Switch from multithreading to multiprocessing
psfoley Mar 28, 2026
42f519e
CI fix
psfoley Mar 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions framework/py/flwr/server/superlink/linkstate/linkstate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
# pylint: disable=invalid-name, too-many-lines, R0904, R0913


import multiprocessing
import os
import secrets
import tempfile
import time
import unittest
from abc import abstractmethod
from datetime import datetime, timedelta, timezone
from typing import Any
from unittest.mock import Mock, patch
from uuid import uuid4

Expand Down Expand Up @@ -1878,6 +1881,29 @@ def create_dummy_run( # pylint: disable=too-many-positional-arguments
)


def _claim_running_in_separate_process(
database_path: str,
run_id: int,
start_event: Any,
result_queue: Any,
) -> None:
"""Try to claim STARTING -> RUNNING in a dedicated process."""
state = SqlLinkState(
database_path=database_path,
federation_manager=NoOpFederationManager(),
object_store=ObjectStoreFactory().store(),
)
state.initialize()
if not start_event.wait(timeout=5.0):
result_queue.put((False, "start-event-timeout"))
return
try:
result = state.update_run_status(run_id, RunStatus(Status.RUNNING, "", ""))
result_queue.put((result, None))
except Exception as ex: # pylint: disable=broad-exception-caught
result_queue.put((False, repr(ex)))


class InMemoryStateTest(StateTest):
"""Test InMemoryState implementation."""

Expand Down Expand Up @@ -1933,6 +1959,78 @@ def state_factory(self) -> SqlLinkState:
state.initialize()
return state

@staticmethod
def _create_shared_sql_states(
database_path: str,
) -> tuple[SqlLinkState, SqlLinkState]:
"""Create two SqlLinkState replicas sharing the same SQLite file."""
state_0 = SqlLinkState(
database_path=database_path,
federation_manager=NoOpFederationManager(),
object_store=ObjectStoreFactory().store(),
)
state_1 = SqlLinkState(
database_path=database_path,
federation_manager=NoOpFederationManager(),
object_store=ObjectStoreFactory().store(),
)
state_0.initialize()
state_1.initialize()
return state_0, state_1

# pylint: disable-next=too-many-locals
def test_update_run_status_running_claim_is_atomic_across_replicas(self) -> None:
"""Ensure only one replica can claim STARTING -> RUNNING transition."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = os.path.join(tmpdir, "shared.db")
state_0, _ = self._create_shared_sql_states(db_path)
run_id = create_dummy_run(state_0)
assert state_0.update_run_status(run_id, RunStatus(Status.STARTING, "", ""))

ctx = multiprocessing.get_context("spawn")
start_event = ctx.Event()
result_queue = ctx.Queue()
timeout = 5.0

processes = [
ctx.Process(
target=_claim_running_in_separate_process,
args=(db_path, run_id, start_event, result_queue),
),
ctx.Process(
target=_claim_running_in_separate_process,
args=(db_path, run_id, start_event, result_queue),
),
]
for proc in processes:
proc.start()
# Release both processes to claim at (roughly) the same time.
start_event.set()
for proc in processes:
proc.join(timeout=timeout)

alive_processes = [proc for proc in processes if proc.is_alive()]
if alive_processes:
self.fail(
f"Concurrent run-claim test timed out; {len(alive_processes)} "
f"process(es) still alive after {timeout} seconds."
)
for proc in processes:
assert proc.exitcode == 0

results: list[bool] = []
errors: list[str] = []
for _ in processes:
result, error = result_queue.get(timeout=1.0)
results.append(result)
if error is not None:
errors.append(error)
if errors:
self.fail(f"Concurrent run-claim process failed: {errors[0]}")

assert results.count(True) == 1
assert results.count(False) == 1


if __name__ == "__main__":
unittest.main(verbosity=2)
81 changes: 80 additions & 1 deletion framework/py/flwr/server/superlink/linkstate/sql_linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,11 +995,90 @@ def get_federation_config(self, run_id: int) -> SimulationConfig | None:

return simulation_config_from_json(json.loads(fed_config_json))

# pylint: disable=too-many-return-statements,too-many-branches
def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
"""Update the status of the run with the specified `run_id`."""
# Clean up expired tokens; this will flag inactive runs as needed
self._cleanup_expired_tokens()

# Atomic claim path for STARTING -> RUNNING across replicas/processes.
if new_status.status == Status.RUNNING:
if not has_valid_sub_status(new_status):
log(
ERROR,
'Invalid run status payload: sub_status="%s" is not valid for '
'status="%s". For non-FINISHED statuses, sub_status must be '
"empty.",
new_status.sub_status,
new_status.status,
)
return False

sint64_run_id = uint64_to_int64(run_id)
query = """
UPDATE run
SET running_at = :timestamp,
sub_status = :sub_status, details = :details
WHERE run_id = :run_id
AND starting_at != ''
AND running_at = ''
AND finished_at = ''
RETURNING run_id
"""
params = {
"timestamp": now().isoformat(),
"sub_status": new_status.sub_status,
"details": new_status.details,
"run_id": sint64_run_id,
}
rows = self.query(query, params)
if rows:
# Successfully claimed STARTING -> RUNNING.
return True

# Claim failed: diagnose why the UPDATE affected zero rows.
diag_rows = self.query(
"SELECT * FROM run WHERE run_id = :run_id",
{"run_id": sint64_run_id},
)
if not diag_rows:
log(ERROR, "`run_id` is invalid")
return False

row = diag_rows[0]
current_status = RunStatus(
status=determine_run_status(row),
sub_status=row["sub_status"],
details=row["details"],
)
if row["finished_at"] != "":
log(
ERROR,
'Invalid status transition: from "%s" to "%s"',
current_status.status,
new_status.status,
)
elif row["starting_at"] == "":
log(
ERROR,
'Invalid status transition: run "%d" is not in STARTING state',
run_id,
)
elif row["running_at"] != "":
log(
ERROR,
'Invalid status transition: run "%d" is already in RUNNING state',
run_id,
)
else:
log(
ERROR,
'Invalid status transition: from "%s" to "%s"',
current_status.status,
new_status.status,
)
return False

with self.session():
# Convert the uint64 value to sint64 for SQLite
sint64_run_id = uint64_to_int64(run_id)
Expand Down Expand Up @@ -1031,7 +1110,7 @@ def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
if not has_valid_sub_status(current_status):
log(
ERROR,
'Invalid sub-status "%s" for status "%s"',
'Invalid sub-status: sub_status="%s" is not valid for status="%s".',
current_status.sub_status,
current_status.status,
)
Expand Down
Loading
Loading