Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""In-memory LinkState implementation."""


import hashlib
import threading
from bisect import bisect_right
from collections import defaultdict
Expand All @@ -37,7 +38,7 @@
SubStatus,
)
from flwr.common.record import ConfigRecord
from flwr.common.typing import Run, RunStatus
from flwr.common.typing import Fab, Run, RunStatus
from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
from flwr.server.superlink.linkstate.linkstate import LinkState
from flwr.server.utils import validate_message
Expand Down Expand Up @@ -83,6 +84,7 @@ def __init__(
self.run_ids: dict[int, RunRecord] = {}
self.contexts: dict[int, Context] = {}
self.federation_options: dict[int, ConfigRecord] = {}
self.fab_store: dict[str, Fab] = {}
self.message_ins_store: dict[str, Message] = {}
self.message_res_store: dict[str, Message] = {}
self.message_ins_id_to_message_res_id: dict[str, str] = {}
Expand All @@ -101,6 +103,37 @@ def federation_manager(self) -> FederationManager:
"""Get the FederationManager instance."""
return self._federation_manager

def store_fab(self, fab: Fab) -> str:
"""Store a FAB."""
fab_hash = hashlib.sha256(fab.content).hexdigest()
if fab.hash_str and fab.hash_str != fab_hash:
raise ValueError(
f"FAB hash mismatch: provided {fab.hash_str}, computed {fab_hash}"
)
with self.lock:
# Keep launch behavior: last write wins for metadata under the same
# content hash.
self.fab_store[fab_hash] = Fab(
hash_str=fab_hash,
content=fab.content,
verifications=dict(fab.verifications),
)
return fab_hash

def get_fab(self, fab_hash: str) -> Fab | None:
"""Return a FAB by hash."""
with self.lock:
fab = self.fab_store.get(fab_hash)
if fab is None:
return None
# Launch tradeoff: do not recompute content hash on reads; rely on
# write-time validation and hash-addressed lookup.
return Fab(
hash_str=fab.hash_str,
content=fab.content,
verifications=dict(fab.verifications),
)

def store_message_ins(self, message: Message) -> str | None:
"""Store one Message."""
# Validate message
Expand Down
10 changes: 9 additions & 1 deletion framework/py/flwr/server/superlink/linkstate/linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from flwr.app.user_config import UserConfig
from flwr.common import Context, Message
from flwr.common.record import ConfigRecord
from flwr.common.typing import Run, RunStatus
from flwr.common.typing import Fab, Run, RunStatus
from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
from flwr.supercore.corestate import CoreState
from flwr.superlink.federation import FederationManager
Expand All @@ -36,6 +36,14 @@ class LinkState(CoreState): # pylint: disable=R0904
def federation_manager(self) -> FederationManager:
"""Return the FederationManager instance."""

@abc.abstractmethod
def store_fab(self, fab: Fab) -> str:
"""Store a FAB and return its canonical SHA-256 hash."""

@abc.abstractmethod
def get_fab(self, fab_hash: str) -> Fab | None:
"""Return the FAB for the given hash, if present."""

@abc.abstractmethod
def store_message_ins(self, message: Message) -> str | None:
"""Store one Message.
Expand Down
44 changes: 43 additions & 1 deletion framework/py/flwr/server/superlink/linkstate/linkstate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# pylint: disable=invalid-name, too-many-lines, R0904, R0913


import hashlib
import secrets
import tempfile
import time
Expand Down Expand Up @@ -46,7 +47,7 @@
SubStatus,
)
from flwr.common.serde import message_from_proto, message_to_proto
from flwr.common.typing import RunStatus
from flwr.common.typing import Fab, RunStatus

# pylint: disable=E0611
from flwr.proto.message_pb2 import Message as ProtoMessage
Expand Down Expand Up @@ -78,6 +79,47 @@ def create_public_key(self) -> bytes:
_, public_key = generate_key_pairs()
return public_key_to_bytes(public_key)

def test_store_and_get_fab(self) -> None:
"""Test storing and retrieving a FAB."""
state = self.state_factory()
content = b"fab-content"
fab = Fab(hashlib.sha256(content).hexdigest(), content, {"meta": "data"})

fab_hash = state.store_fab(fab)
retrieved = state.get_fab(fab_hash)

self.assertIsNotNone(retrieved)
assert retrieved is not None
self.assertEqual(retrieved.hash_str, fab_hash)
self.assertEqual(retrieved.content, fab.content)
self.assertEqual(retrieved.verifications, fab.verifications)

def test_store_fab_deduplicates_by_hash(self) -> None:
"""Test storing the same FAB content reuses the same hash."""
state = self.state_factory()
content = b"fab-content"
hash_str = hashlib.sha256(content).hexdigest()

fab_hash = state.store_fab(Fab(hash_str, content, {"meta": "data"}))
other_hash = state.store_fab(Fab(hash_str, content, {"meta": "next"}))
retrieved = state.get_fab(fab_hash)

self.assertEqual(fab_hash, other_hash)
self.assertIsNotNone(retrieved)
assert retrieved is not None
self.assertEqual(retrieved.verifications, {"meta": "next"})

def test_get_fab_missing_returns_none(self) -> None:
"""Test missing FAB retrieval."""
state = self.state_factory()
self.assertIsNone(state.get_fab("missing-fab-hash"))

def test_store_fab_rejects_hash_mismatch(self) -> None:
"""Test storing a FAB fails when provided hash doesn't match content."""
state = self.state_factory()
with self.assertRaisesRegex(ValueError, "FAB hash mismatch"):
state.store_fab(Fab("not-the-content-hash", b"fab-content", {}))

def test_create_and_get_run_info(self) -> None:
"""Test if create_run and get_run_info work correctly."""
# Prepare
Expand Down
46 changes: 45 additions & 1 deletion framework/py/flwr/server/superlink/linkstate/sql_linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

# pylint: disable=too-many-lines

import hashlib
import json
from collections.abc import Sequence
from datetime import datetime, timezone
Expand All @@ -39,7 +40,7 @@
SubStatus,
)
from flwr.common.record import ConfigRecord
from flwr.common.typing import Run, RunStatus
from flwr.common.typing import Fab, Run, RunStatus
from flwr.proto.node_pb2 import NodeInfo # pylint: disable=E0611
from flwr.server.utils.validator import validate_message
from flwr.supercore.constant import NodeStatus
Expand Down Expand Up @@ -99,6 +100,49 @@ def federation_manager(self) -> FederationManager:
"""Return the FederationManager instance."""
return self._federation_manager

def store_fab(self, fab: Fab) -> str:
"""Store a FAB."""
fab_hash = hashlib.sha256(fab.content).hexdigest()
if fab.hash_str and fab.hash_str != fab_hash:
raise ValueError(
f"FAB hash mismatch: provided {fab.hash_str}, computed {fab_hash}"
)
params = {
"fab_hash": fab_hash,
"content": fab.content,
"verifications": json.dumps(fab.verifications),
}
# Keep launch behavior: last write wins for metadata under the same
# content hash.
query = """
INSERT INTO fab (fab_hash, content, verifications)
VALUES (:fab_hash, :content, :verifications)
ON CONFLICT(fab_hash) DO UPDATE SET
content = excluded.content,
verifications = excluded.verifications
"""
self.query(query, params)
return fab_hash

def get_fab(self, fab_hash: str) -> Fab | None:
"""Return a FAB by hash."""
query = """
SELECT fab_hash, content, verifications
FROM fab
WHERE fab_hash = :fab_hash
"""
rows = self.query(query, {"fab_hash": fab_hash})
if not rows:
return None
row = rows[0]
# Launch tradeoff: do not recompute content hash on reads; rely on
# write-time validation and hash-addressed lookup.
return Fab(
hash_str=row["fab_hash"],
content=row["content"],
verifications=json.loads(row["verifications"]),
)

def store_message_ins(self, message: Message) -> str | None:
"""Store one Message."""
# Validate message
Expand Down
10 changes: 10 additions & 0 deletions framework/py/flwr/supercore/state/alembic/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,16 @@ def test_migrated_schema_matches_metadata(self) -> None:
finally:
engine.dispose()

def test_migrations_create_fab_table(self) -> None:
"""Ensure FAB state storage is present after migrations."""
engine = self.create_engine()
try:
run_migrations(engine)
inspector = inspect(engine)
self.assertTrue(inspector.has_table("fab"))
finally:
engine.dispose()

def test_legacy_database_is_stamped_and_upgraded_successfully(self) -> None:
"""Ensure legacy databases without alembic_version is stamped and upgraded."""
# Prepare
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2026 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Add fab table.

Revision ID: 33e2f70642b1
Revises: c8f4f6e2c1ad
Create Date: 2026-03-21 12:00:00.000000
"""

from collections.abc import Sequence

import sqlalchemy as sa
from alembic import op

# pylint: disable=no-member

# revision identifiers, used by Alembic.
revision: str = "33e2f70642b1"
down_revision: str | Sequence[str] | None = "c8f4f6e2c1ad"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def upgrade() -> None:
"""Upgrade schema."""
op.create_table(
"fab",
sa.Column("fab_hash", sa.String(), nullable=False),
sa.Column("content", sa.LargeBinary(), nullable=False),
sa.Column("verifications", sa.String(), nullable=False),
sa.PrimaryKeyConstraint("fab_hash"),
)


def downgrade() -> None:
"""Downgrade schema."""
op.drop_table("fab")
6 changes: 6 additions & 0 deletions framework/py/flwr/supercore/state/schema/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ erDiagram
BLOB context "nullable"
}

fab {
VARCHAR fab_hash PK
BLOB content
VARCHAR verifications
}

logs {
INTEGER run_id FK "nullable"
VARCHAR log "nullable"
Expand Down
11 changes: 11 additions & 0 deletions framework/py/flwr/supercore/state/schema/linkstate_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,17 @@ def create_linkstate_metadata() -> MetaData:
Column("clientapp_runtime", Float, server_default="0.0"),
)

# --------------------------------------------------------------------------
# Table: fab
# --------------------------------------------------------------------------
Table(
"fab",
metadata,
Column("fab_hash", String, primary_key=True),
Column("content", LargeBinary, nullable=False),
Column("verifications", String, nullable=False),
)

# --------------------------------------------------------------------------
# Table: logs
# --------------------------------------------------------------------------
Expand Down
37 changes: 36 additions & 1 deletion framework/py/flwr/supernode/nodestate/in_memory_nodestate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
"""In-memory NodeState implementation."""


import hashlib
from collections.abc import Sequence
from dataclasses import dataclass
from threading import Lock, RLock

from flwr.common import Context, Error, Message, now
from flwr.common.constant import ErrorCode
from flwr.common.typing import Run
from flwr.common.typing import Fab, Run
from flwr.supercore.constant import MESSAGE_TIME_ENTRY_MAX_AGE_SECONDS
from flwr.supercore.corestate.in_memory_corestate import InMemoryCoreState
from flwr.supercore.inflatable.inflatable_object import (
Expand Down Expand Up @@ -69,6 +70,9 @@ def __init__(self, object_store: ObjectStore) -> None:
# Store run ID to Run mapping
self.run_store: dict[int, Run] = {}
self.lock_run_store = Lock()
# Store hash to FAB mapping
self.fab_store: dict[str, Fab] = {}
self.lock_fab_store = Lock()
# Store run ID to Context mapping
self.ctx_store: dict[int, Context] = {}
self.lock_ctx_store = Lock()
Expand All @@ -86,6 +90,37 @@ def get_node_id(self) -> int:
raise ValueError("Node ID not set")
return self.node_id

def store_fab(self, fab: Fab) -> str:
"""Store a FAB."""
fab_hash = hashlib.sha256(fab.content).hexdigest()
if fab.hash_str and fab.hash_str != fab_hash:
raise ValueError(
f"FAB hash mismatch: provided {fab.hash_str}, computed {fab_hash}"
)
with self.lock_fab_store:
# Keep launch behavior: last write wins for metadata under the same
# content hash.
self.fab_store[fab_hash] = Fab(
hash_str=fab_hash,
content=fab.content,
verifications=dict(fab.verifications),
)
return fab_hash

def get_fab(self, fab_hash: str) -> Fab | None:
"""Return a FAB by hash."""
with self.lock_fab_store:
fab = self.fab_store.get(fab_hash)
if fab is None:
return None
# Launch tradeoff: do not recompute content hash on reads; rely on
# write-time validation and hash-addressed lookup.
return Fab(
hash_str=fab.hash_str,
content=fab.content,
verifications=dict(fab.verifications),
)

def store_message(self, message: Message) -> str | None:
"""Store a message."""
# No need to check for expired tokens here
Expand Down
Loading
Loading