diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index af457c00a7..13e9182b19 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -14,10 +14,11 @@ import openfl.callbacks as callbacks_module from openfl.component.aggregator.straggler_handling import StragglerPolicy, WaitForAllPolicy +from openfl.component.assigner.assigner import Assigner from openfl.databases import PersistentTensorDB, TensorDB from openfl.interface.aggregation_functions import SecureWeightedAverage, WeightedAverage from openfl.pipelines import NoCompressionPipeline, TensorCodec -from openfl.protocols import base_pb2, utils +from openfl.protocols import utils from openfl.protocols.base_pb2 import NamedTensor from openfl.utilities import TaskResultKey, TensorKey, change_tags @@ -76,7 +77,7 @@ def __init__( init_state_path, best_state_path, last_state_path, - assigner, + assigner: Assigner, connector=None, use_delta_updates=True, straggler_handling_policy: StragglerPolicy = WaitForAllPolicy, @@ -84,7 +85,6 @@ def __init__( single_col_cert_common_name=None, compression_pipeline=None, db_store_rounds=1, - initial_tensor_dict=None, log_memory_usage=False, write_logs=False, callbacks: Optional[List] = [], @@ -92,31 +92,6 @@ def __init__( persistent_db_path=None, secure_aggregation=False, ): - """Initializes the Aggregator. - - Args: - aggregator_uuid (int): Aggregation ID. - federation_uuid (str): Federation ID. - authorized_cols (list of str): The list of IDs of enrolled - collaborators. - init_state_path (str): The location of the initial weight file. - best_state_path (str): The file location to store the weight of - the best model. - last_state_path (str): The file location to store the latest - weight. - assigner: Assigner object. - straggler_handling_policy (optional): Straggler handling policy. - rounds_to_train (int, optional): Number of rounds to train. - Defaults to 256. - single_col_cert_common_name (str, optional): Common name for single - collaborator certificate. Defaults to None. - compression_pipeline (optional): Compression pipeline. Defaults to - NoCompressionPipeline. - db_store_rounds (int, optional): Rounds to store in TensorDB. - Defaults to 1. - initial_tensor_dict (dict, optional): Initial tensor dictionary. - callbacks: List of callbacks to be used during the experiment. - """ self.round_number = 0 self.next_model_round_number = 0 @@ -203,21 +178,11 @@ def __init__( origin="aggregator", last_state_path=self.last_state_path, ) - - if initial_tensor_dict: - self._load_initial_tensors_from_dict(initial_tensor_dict) - self.model = utils.construct_model_proto( - tensor_dict=initial_tensor_dict, - round_number=0, - tensor_pipe=self.compression_pipeline, - ) - else: - if self.connector: - # The model definition will be handled by the respective framework - self.model = {} - else: - self.model: base_pb2.ModelProto = utils.load_proto(self.init_state_path) - self._load_initial_tensors() # keys are TensorKeys + + self.model = None + if not self.connector: + self.model = utils.load_proto(self.init_state_path) + self._load_initial_tensors() # keys are TensorKeys self._secure_aggregation_enabled = secure_aggregation if self._secure_aggregation_enabled: @@ -334,23 +299,6 @@ def _load_initial_tensors(self): self.tensor_db.cache_tensor(tensor_key_dict) logger.debug("This is the initial tensor_db: %s", self.tensor_db) - def _load_initial_tensors_from_dict(self, tensor_dict): - """Load all of the tensors required to begin federated learning. - - Required tensors are: \ - 1. Initial model. - - Returns: - None - """ - tensor_key_dict = { - TensorKey(k, self.uuid, self.round_number, False, ("model",)): v - for k, v in tensor_dict.items() - } - # all initial model tensors are loaded here - self.tensor_db.cache_tensor(tensor_key_dict) - logger.debug("This is the initial tensor_db: %s", self.tensor_db) - def _save_model(self, round_number, file_path): """Save the best or latest model. @@ -476,74 +424,35 @@ def get_tasks(self, collaborator_name): sleep_time (int): Sleep time. time_to_quit (bool): Whether it's time to quit. """ - logger.debug( - f"Aggregator GetTasks function reached from collaborator {collaborator_name}..." - ) + time_to_quit = False + sleep_time = Aggregator._get_sleep_time() - # first, if it is time to quit, inform the collaborator + # If time to quit, inform collaborator. if self._time_to_quit(): - logger.info( - "Sending signal to collaborator %s to shutdown...", - collaborator_name, - ) + logger.info("Sending signal to collaborator %s to shutdown...", collaborator_name) self.quit_job_sent_to.append(collaborator_name) - tasks = None + tasks = [] sleep_time = 0 time_to_quit = True - return tasks, self.round_number, sleep_time, time_to_quit - time_to_quit = False - # otherwise, get the tasks from our task assigner + # Fetch tasks for the collaborator. tasks = self.assigner.get_tasks_for_collaborator(collaborator_name, self.round_number) - # if no tasks, tell the collaborator to sleep - if len(tasks) == 0: - tasks = None - sleep_time = Aggregator._get_sleep_time() - - return tasks, self.round_number, sleep_time, time_to_quit - - # if we do have tasks, remove any that we already have results for - if isinstance(tasks[0], str): - # backward compatibility - tasks = [ - t - for t in tasks - if not self._collaborator_task_completed(collaborator_name, t, self.round_number) - ] - if collaborator_name in self.stragglers: - tasks = [] - - else: - tasks = [ - t - for t in tasks - if not self._collaborator_task_completed( - collaborator_name, t.name, self.round_number - ) - ] - if collaborator_name in self.stragglers: - tasks = [] - - # Do the check again because it's possible that all tasks have - # been completed - if len(tasks) == 0: - tasks = None - sleep_time = Aggregator._get_sleep_time() + # Filter out tasks that have already been completed by the collaborator. + tasks = [ + t + for t in tasks + if not self._collaborator_task_completed(collaborator_name, t, self.round_number) + ] - return tasks, self.round_number, sleep_time, time_to_quit - - logger.info( - f"Sending tasks to collaborator {collaborator_name} for round {self.round_number}" - ) - sleep_time = 0 + if collaborator_name in self.stragglers: + tasks = [] # Start straggler handling policy for timer based callback is required # for %age based policy callback is not required self.straggler_handling_policy.start_policy(callback=self._straggler_cutoff_time_elapsed) - return tasks, self.round_number, sleep_time, time_to_quit def _straggler_cutoff_time_elapsed(self) -> None: @@ -1087,7 +996,6 @@ def _compute_validation_related_task_metrics(self, task_name) -> dict: ) # Leave out straggler for the round even if they've partially # completed given tasks - collaborators_for_task = [] collaborators_for_task = [ c for c in all_collaborators_for_task if c in self.collaborators_done ] @@ -1213,14 +1121,17 @@ def _end_of_round_check(self): # todo handle case when aggregator restarted before callback was successful self.callbacks.on_round_end(self.round_number, logs) + # Once all of the task results have been processed + self._end_of_round_check_done[self.round_number] = True self.round_number += 1 - # resetting stragglers for task for a new round + # Reset for next round self.stragglers = [] - # resetting collaborators_done for next round self.collaborators_done = [] self.collaborator_tasks_results = {} self.collaborator_task_weight = {} + self.tensor_db.clean_up(self.db_store_rounds) + self.straggler_handling_policy.reset_policy_for_round() # TODO This needs to be fixed! if self._time_to_quit(): diff --git a/openfl/component/assigner/random_grouped_assigner.py b/openfl/component/assigner/random_grouped_assigner.py index 5c7635eac9..6d73dd6444 100644 --- a/openfl/component/assigner/random_grouped_assigner.py +++ b/openfl/component/assigner/random_grouped_assigner.py @@ -100,7 +100,7 @@ def define_task_assignments(self): col_idx += num_col_in_group assert col_idx == col_list_size, "Task groups were not divided properly" - def get_tasks_for_collaborator(self, collaborator_name, round_number): + def get_tasks_for_collaborator(self, collaborator_name, round_number) -> list: """Get tasks for a specific collaborator in a specific round. Args: @@ -112,7 +112,7 @@ def get_tasks_for_collaborator(self, collaborator_name, round_number): """ return self.collaborator_tasks[collaborator_name][round_number] - def get_collaborators_for_task(self, task_name, round_number): + def get_collaborators_for_task(self, task_name, round_number) -> list: """Get collaborators for a specific task in a specific round. Args: diff --git a/openfl/federated/plan/plan.py b/openfl/federated/plan/plan.py index a324934e76..9e0e0a96d3 100644 --- a/openfl/federated/plan/plan.py +++ b/openfl/federated/plan/plan.py @@ -364,16 +364,11 @@ def get_tasks(self): tasks[task]["aggregation_type"] = aggregation_type return tasks - def get_aggregator(self, tensor_dict=None): + def get_aggregator(self): """Get federation aggregator. This method retrieves the federation aggregator. If the aggregator - does not exist, it is built using the configuration settings and the - provided tensor dictionary. - - Args: - tensor_dict (dict, optional): The initial tensor dictionary to use - when building the aggregator. Defaults to None. + does not exist, it is built using the configuration settings. Returns: self.aggregator_ (Aggregator): The federation aggregator. @@ -401,7 +396,7 @@ def get_aggregator(self, tensor_dict=None): # TODO: Load callbacks from plan. if self.aggregator_ is None: - self.aggregator_ = Plan.build(**defaults, initial_tensor_dict=tensor_dict) + self.aggregator_ = Plan.build(**defaults) return self.aggregator_ diff --git a/openfl/transport/grpc/aggregator_server.py b/openfl/transport/grpc/aggregator_server.py index 0d2f82ac98..377bdd92b6 100644 --- a/openfl/transport/grpc/aggregator_server.py +++ b/openfl/transport/grpc/aggregator_server.py @@ -175,27 +175,7 @@ def GetTasks(self, request, context): # NOQA:N802 tasks, round_number, sleep_time, time_to_quit = self.aggregator.get_tasks( request.header.sender ) - if tasks: - if isinstance(tasks[0], str): - # backward compatibility - tasks_proto = [ - aggregator_pb2.Task( - name=task, - ) - for task in tasks - ] - else: - tasks_proto = [ - aggregator_pb2.Task( - name=task.name, - function_name=task.function_name, - task_type=task.task_type, - apply_local=task.apply_local, - ) - for task in tasks - ] - else: - tasks_proto = [] + tasks_proto = [aggregator_pb2.Task(name=task) for task in tasks] header = create_header( sender=self.aggregator.uuid, diff --git a/tests/openfl/component/aggregator/test_aggregator.py b/tests/openfl/component/aggregator/test_aggregator.py index ca103d04ec..ab56ee53bc 100644 --- a/tests/openfl/component/aggregator/test_aggregator.py +++ b/tests/openfl/component/aggregator/test_aggregator.py @@ -123,8 +123,8 @@ def test_time_to_quit(agg, round_number, rounds_to_train, expected): @pytest.mark.parametrize( 'col_name,tasks,time_to_quit,exp_tasks,exp_sleep_time,exp_time_to_quit', [ - ('col1', ['task_name'], True, None, 0, True), - ('col1', [], False, None, 10, False), + ('col1', ['task_name'], True, [], 0, True), + ('col1', [], False, [], 10, False), ('col1', ['task_name'], False, ['task_name'], 0, False), ]) def test_get_tasks(agg, col_name, tasks, time_to_quit,