Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 27 additions & 116 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

import openfl.callbacks as callbacks_module
from openfl.component.aggregator.straggler_handling import StragglerPolicy, WaitForAllPolicy
from openfl.component.assigner.assigner import Assigner
from openfl.databases import PersistentTensorDB, TensorDB
from openfl.interface.aggregation_functions import SecureWeightedAverage, WeightedAverage
from openfl.pipelines import NoCompressionPipeline, TensorCodec
from openfl.protocols import base_pb2, utils
from openfl.protocols import utils
from openfl.protocols.base_pb2 import NamedTensor
from openfl.utilities import TaskResultKey, TensorKey, change_tags

Expand Down Expand Up @@ -76,47 +77,21 @@ def __init__(
init_state_path,
best_state_path,
last_state_path,
assigner,
assigner: Assigner,
connector=None,
use_delta_updates=True,
straggler_handling_policy: StragglerPolicy = WaitForAllPolicy,
rounds_to_train=256,
single_col_cert_common_name=None,
compression_pipeline=None,
db_store_rounds=1,
initial_tensor_dict=None,
log_memory_usage=False,
write_logs=False,
callbacks: Optional[List] = [],
persist_checkpoint=True,
persistent_db_path=None,
secure_aggregation=False,
):
"""Initializes the Aggregator.

Args:
aggregator_uuid (int): Aggregation ID.
federation_uuid (str): Federation ID.
authorized_cols (list of str): The list of IDs of enrolled
collaborators.
init_state_path (str): The location of the initial weight file.
best_state_path (str): The file location to store the weight of
the best model.
last_state_path (str): The file location to store the latest
weight.
assigner: Assigner object.
straggler_handling_policy (optional): Straggler handling policy.
rounds_to_train (int, optional): Number of rounds to train.
Defaults to 256.
single_col_cert_common_name (str, optional): Common name for single
collaborator certificate. Defaults to None.
compression_pipeline (optional): Compression pipeline. Defaults to
NoCompressionPipeline.
db_store_rounds (int, optional): Rounds to store in TensorDB.
Defaults to 1.
initial_tensor_dict (dict, optional): Initial tensor dictionary.
callbacks: List of callbacks to be used during the experiment.
"""
self.round_number = 0
self.next_model_round_number = 0

Expand Down Expand Up @@ -203,21 +178,11 @@ def __init__(
origin="aggregator",
last_state_path=self.last_state_path,
)

if initial_tensor_dict:
self._load_initial_tensors_from_dict(initial_tensor_dict)
self.model = utils.construct_model_proto(
tensor_dict=initial_tensor_dict,
round_number=0,
tensor_pipe=self.compression_pipeline,
)
else:
if self.connector:
# The model definition will be handled by the respective framework
self.model = {}
else:
self.model: base_pb2.ModelProto = utils.load_proto(self.init_state_path)
self._load_initial_tensors() # keys are TensorKeys

self.model = None
if not self.connector:
self.model = utils.load_proto(self.init_state_path)
self._load_initial_tensors() # keys are TensorKeys

self._secure_aggregation_enabled = secure_aggregation
if self._secure_aggregation_enabled:
Expand Down Expand Up @@ -334,23 +299,6 @@ def _load_initial_tensors(self):
self.tensor_db.cache_tensor(tensor_key_dict)
logger.debug("This is the initial tensor_db: %s", self.tensor_db)

def _load_initial_tensors_from_dict(self, tensor_dict):
"""Load all of the tensors required to begin federated learning.
Comment on lines -337 to -338
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remnant of interactive API


Required tensors are: \
1. Initial model.

Returns:
None
"""
tensor_key_dict = {
TensorKey(k, self.uuid, self.round_number, False, ("model",)): v
for k, v in tensor_dict.items()
}
# all initial model tensors are loaded here
self.tensor_db.cache_tensor(tensor_key_dict)
logger.debug("This is the initial tensor_db: %s", self.tensor_db)

def _save_model(self, round_number, file_path):
"""Save the best or latest model.

Expand Down Expand Up @@ -476,74 +424,35 @@ def get_tasks(self, collaborator_name):
sleep_time (int): Sleep time.
time_to_quit (bool): Whether it's time to quit.
"""
logger.debug(
f"Aggregator GetTasks function reached from collaborator {collaborator_name}..."
)
time_to_quit = False
sleep_time = Aggregator._get_sleep_time()

# first, if it is time to quit, inform the collaborator
# If time to quit, inform collaborator.
if self._time_to_quit():
logger.info(
"Sending signal to collaborator %s to shutdown...",
collaborator_name,
)
logger.info("Sending signal to collaborator %s to shutdown...", collaborator_name)
self.quit_job_sent_to.append(collaborator_name)

tasks = None
tasks = []
sleep_time = 0
time_to_quit = True

return tasks, self.round_number, sleep_time, time_to_quit

time_to_quit = False
# otherwise, get the tasks from our task assigner
# Fetch tasks for the collaborator.
tasks = self.assigner.get_tasks_for_collaborator(collaborator_name, self.round_number)

# if no tasks, tell the collaborator to sleep
if len(tasks) == 0:
tasks = None
sleep_time = Aggregator._get_sleep_time()

return tasks, self.round_number, sleep_time, time_to_quit

# if we do have tasks, remove any that we already have results for
if isinstance(tasks[0], str):
# backward compatibility
tasks = [
t
for t in tasks
if not self._collaborator_task_completed(collaborator_name, t, self.round_number)
]
if collaborator_name in self.stragglers:
tasks = []

else:
tasks = [
t
for t in tasks
if not self._collaborator_task_completed(
collaborator_name, t.name, self.round_number
)
]
if collaborator_name in self.stragglers:
tasks = []

# Do the check again because it's possible that all tasks have
# been completed
if len(tasks) == 0:
tasks = None
sleep_time = Aggregator._get_sleep_time()
# Filter out tasks that have already been completed by the collaborator.
tasks = [
t
for t in tasks
if not self._collaborator_task_completed(collaborator_name, t, self.round_number)
]

return tasks, self.round_number, sleep_time, time_to_quit

logger.info(
f"Sending tasks to collaborator {collaborator_name} for round {self.round_number}"
)
sleep_time = 0
if collaborator_name in self.stragglers:
tasks = []

# Start straggler handling policy for timer based callback is required
# for %age based policy callback is not required
self.straggler_handling_policy.start_policy(callback=self._straggler_cutoff_time_elapsed)

return tasks, self.round_number, sleep_time, time_to_quit

def _straggler_cutoff_time_elapsed(self) -> None:
Expand Down Expand Up @@ -1087,7 +996,6 @@ def _compute_validation_related_task_metrics(self, task_name) -> dict:
)
# Leave out straggler for the round even if they've partially
# completed given tasks
collaborators_for_task = []
collaborators_for_task = [
c for c in all_collaborators_for_task if c in self.collaborators_done
]
Expand Down Expand Up @@ -1213,14 +1121,17 @@ def _end_of_round_check(self):
# todo handle case when aggregator restarted before callback was successful
self.callbacks.on_round_end(self.round_number, logs)

# Once all of the task results have been processed
self._end_of_round_check_done[self.round_number] = True
self.round_number += 1

# resetting stragglers for task for a new round
# Reset for next round
self.stragglers = []
# resetting collaborators_done for next round
self.collaborators_done = []
self.collaborator_tasks_results = {}
self.collaborator_task_weight = {}
self.tensor_db.clean_up(self.db_store_rounds)
self.straggler_handling_policy.reset_policy_for_round()

# TODO This needs to be fixed!
if self._time_to_quit():
Expand Down
4 changes: 2 additions & 2 deletions openfl/component/assigner/random_grouped_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def define_task_assignments(self):
col_idx += num_col_in_group
assert col_idx == col_list_size, "Task groups were not divided properly"

def get_tasks_for_collaborator(self, collaborator_name, round_number):
def get_tasks_for_collaborator(self, collaborator_name, round_number) -> list:
"""Get tasks for a specific collaborator in a specific round.

Args:
Expand All @@ -112,7 +112,7 @@ def get_tasks_for_collaborator(self, collaborator_name, round_number):
"""
return self.collaborator_tasks[collaborator_name][round_number]

def get_collaborators_for_task(self, task_name, round_number):
def get_collaborators_for_task(self, task_name, round_number) -> list:
"""Get collaborators for a specific task in a specific round.

Args:
Expand Down
11 changes: 3 additions & 8 deletions openfl/federated/plan/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,16 +364,11 @@ def get_tasks(self):
tasks[task]["aggregation_type"] = aggregation_type
return tasks

def get_aggregator(self, tensor_dict=None):
def get_aggregator(self):
"""Get federation aggregator.

This method retrieves the federation aggregator. If the aggregator
does not exist, it is built using the configuration settings and the
provided tensor dictionary.

Args:
tensor_dict (dict, optional): The initial tensor dictionary to use
when building the aggregator. Defaults to None.
does not exist, it is built using the configuration settings.

Returns:
self.aggregator_ (Aggregator): The federation aggregator.
Expand Down Expand Up @@ -401,7 +396,7 @@ def get_aggregator(self, tensor_dict=None):
# TODO: Load callbacks from plan.

if self.aggregator_ is None:
self.aggregator_ = Plan.build(**defaults, initial_tensor_dict=tensor_dict)
self.aggregator_ = Plan.build(**defaults)
Comment on lines -404 to +399
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remnant of interactive API


return self.aggregator_

Expand Down
22 changes: 1 addition & 21 deletions openfl/transport/grpc/aggregator_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,27 +175,7 @@ def GetTasks(self, request, context): # NOQA:N802
tasks, round_number, sleep_time, time_to_quit = self.aggregator.get_tasks(
request.header.sender
)
if tasks:
if isinstance(tasks[0], str):
# backward compatibility
tasks_proto = [
aggregator_pb2.Task(
name=task,
)
for task in tasks
]
else:
tasks_proto = [
aggregator_pb2.Task(
name=task.name,
function_name=task.function_name,
task_type=task.task_type,
apply_local=task.apply_local,
)
for task in tasks
]
else:
tasks_proto = []
tasks_proto = [aggregator_pb2.Task(name=task) for task in tasks]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these 20 lines are same as the one line, and covers all edge cases (note that removal of internal if-else is partly because interactive API is defunct)


header = create_header(
sender=self.aggregator.uuid,
Expand Down
4 changes: 2 additions & 2 deletions tests/openfl/component/aggregator/test_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ def test_time_to_quit(agg, round_number, rounds_to_train, expected):

@pytest.mark.parametrize(
'col_name,tasks,time_to_quit,exp_tasks,exp_sleep_time,exp_time_to_quit', [
('col1', ['task_name'], True, None, 0, True),
('col1', [], False, None, 10, False),
('col1', ['task_name'], True, [], 0, True),
('col1', [], False, [], 10, False),
Comment on lines -126 to +127
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tasks are now returned as lists. It is strange to expect a None for empty list, although to python both resolve to False in a boolean sense.

Now this is consistent across the entire component.

('col1', ['task_name'], False, ['task_name'], 0, False),
])
def test_get_tasks(agg, col_name, tasks, time_to_quit,
Expand Down
Loading