Skip to content

Commit

Permalink
Refactor Flower integration to support dynamic algorithm folder loadi…
Browse files Browse the repository at this point in the history
…ng (#497)

* Refactor Flower to support dynamic algorithm folder loading

- Updated the system to dynamically find algorithms for Flower.
- Instead of a static file, algorithms can now reside in any folder specified by the environment variable FLOWER_ALGORITHM_FOLDERS.
- This allows for greater flexibility in managing and executing algorithms in dev and production mode.
- Introduced a `connect_with_retries` function to handle the retry logic for client connections.
  • Loading branch information
KFilippopolitis authored Oct 7, 2024
1 parent 39a249d commit dbf1078
Show file tree
Hide file tree
Showing 51 changed files with 313 additions and 176 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@
monetdb_nclients = 128
monetdb_memory_limit = 2048 # MB
algorithm_folders = "./exareme2/algorithms/exareme2,./exareme2/algorithms/flower,./tests/algorithms"
exareme2_algorithm_folders = "./exareme2/algorithms/exareme2,./tests/algorithms/exareme2"
flower_algorithm_folders = "./exareme2/algorithms/flower,./tests/algorithms/flower"
worker_landscape_aggregator_update_interval = 30
flower_execution_timeout = 30
Expand Down
70 changes: 55 additions & 15 deletions exareme2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,20 @@
__all__ = [
"DType",
"AttrDict",
"ALGORITHM_FOLDERS_ENV_VARIABLE",
"ALGORITHM_FOLDERS",
"algorithm_classes",
"EXAREME2_ALGORITHM_FOLDERS_ENV_VARIABLE",
"EXAREME2_ALGORITHM_FOLDERS",
"exareme2_algorithm_classes",
"DATA_TABLE_PRIMARY_KEY",
"FLOWER_ALGORITHM_FOLDERS_ENV_VARIABLE",
"FLOWER_ALGORITHM_FOLDERS",
]

DATA_TABLE_PRIMARY_KEY = "row_id"

ALGORITHM_FOLDERS_ENV_VARIABLE = "ALGORITHM_FOLDERS"
ALGORITHM_FOLDERS = "./exareme2/algorithms/exareme2,./exareme2/algorithms/flower"
if algorithm_folders := os.getenv(ALGORITHM_FOLDERS_ENV_VARIABLE):
ALGORITHM_FOLDERS = algorithm_folders
EXAREME2_ALGORITHM_FOLDERS_ENV_VARIABLE = "EXAREME2_ALGORITHM_FOLDERS"
EXAREME2_ALGORITHM_FOLDERS = "./exareme2/algorithms/exareme2"
if exareme2_algorithm_folders := os.getenv(EXAREME2_ALGORITHM_FOLDERS_ENV_VARIABLE):
EXAREME2_ALGORITHM_FOLDERS = exareme2_algorithm_folders


class AlgorithmNamesMismatchError(Exception):
Expand All @@ -46,13 +48,13 @@ def __init__(self, mismatches, algorithm_classes, algorithm_data_loaders):
self.message = message


def import_algorithm_modules() -> Dict[str, ModuleType]:
def import_exareme2_algorithm_modules() -> Dict[str, ModuleType]:
# Import all algorithm modules
# Import all .py modules in the algorithm folder paths
# https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path?page=1&tab=votes#tab-top

all_modules = {}
for algorithm_folder in ALGORITHM_FOLDERS.split(","):
for algorithm_folder in EXAREME2_ALGORITHM_FOLDERS.split(","):
all_module_paths = glob.glob(f"{algorithm_folder}/*.py")
algorithm_module_paths = [
module
Expand Down Expand Up @@ -84,14 +86,14 @@ def import_algorithm_modules() -> Dict[str, ModuleType]:
return all_modules


import_algorithm_modules()
import_exareme2_algorithm_modules()


def get_algorithm_classes() -> Dict[str, type]:
def get_exareme2_algorithm_classes() -> Dict[str, type]:
return {cls.algname: cls for cls in Algorithm.__subclasses__()}


def get_algorithm_data_loaders() -> Dict[str, type]:
def get_exareme2_algorithm_data_loaders() -> Dict[str, type]:
return {cls.algname: cls for cls in AlgorithmDataLoader.__subclasses__()}


Expand All @@ -103,8 +105,46 @@ def _check_algo_naming_matching(algo_classes: dict, algo_data_loaders: dict):
raise AlgorithmNamesMismatchError(sym_diff, algo_classes, algo_data_loaders)


algorithm_classes = get_algorithm_classes()
algorithm_data_loaders = get_algorithm_data_loaders()
exareme2_algorithm_classes = get_exareme2_algorithm_classes()
exareme2_algorithm_data_loaders = get_exareme2_algorithm_data_loaders()
_check_algo_naming_matching(
algo_classes=algorithm_classes, algo_data_loaders=algorithm_data_loaders
algo_classes=exareme2_algorithm_classes,
algo_data_loaders=exareme2_algorithm_data_loaders,
)


def find_flower_algorithm_folder_paths(algorithm_folders):
# Split the input string into a list of folder paths
folder_paths = algorithm_folders.split(",")

# Initialize an empty dictionary to store the result
algorithm_folder_paths = {}

# Iterate over each folder path
for folder_path in folder_paths:
if not os.path.isdir(folder_path):
continue # Skip if the path is not a valid directory

# List all files and folders in the current folder path
items = os.listdir(folder_path)

# Filter for .json files and corresponding folders
for item in items:
if item.endswith(".json"):
algorithm_name = item[:-5] # Remove '.json' to get the algorithm name
algorithm_folder = os.path.join(folder_path, algorithm_name)
if os.path.isdir(algorithm_folder):
# Store the algorithm name and the complete folder path in the dictionary
algorithm_folder_paths[algorithm_name] = algorithm_folder

return algorithm_folder_paths


FLOWER_ALGORITHM_FOLDERS_ENV_VARIABLE = "FLOWER_ALGORITHM_FOLDERS"
FLOWER_ALGORITHM_FOLDERS = "./exareme2/algorithms/flower"
if flower_algorithm_folders := os.getenv(FLOWER_ALGORITHM_FOLDERS_ENV_VARIABLE):
FLOWER_ALGORITHM_FOLDERS = flower_algorithm_folders

flower_algorithm_folder_paths = find_flower_algorithm_folder_paths(
FLOWER_ALGORITHM_FOLDERS
)
38 changes: 37 additions & 1 deletion exareme2/algorithms/flower/inputdata_preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import json
import os
from pathlib import Path
import time
from math import log2
from math import pow
from typing import List
from typing import Optional

import flwr as fl
import pandas as pd
import requests
from flwr.common.logger import FLOWER_LOGGER
Expand Down Expand Up @@ -110,3 +113,36 @@ def get_enumerations(data_model: str, variable_name: str) -> list:
raise KeyError(f"'enumerations' key not found in {variable_name}")
except (requests.RequestException, KeyError, json.JSONDecodeError) as e:
error_handling(str(e))


def connect_with_retries(client, client_name):
"""
Attempts to connect the client to the Flower server with retries.
Args:
client: The client instance to connect.
client_name: The name of the client (for logging purposes).
"""
attempts = 0
max_attempts = int(log2(int(os.environ["TIMEOUT"])))

while True:
try:
fl.client.start_client(
server_address=os.environ["SERVER_ADDRESS"], client=client.to_client()
)
FLOWER_LOGGER.debug(
f"{client_name} - Connection successful on attempt: {attempts + 1}"
)
break
except Exception as e:
FLOWER_LOGGER.warning(
f"{client_name} - Connection with the server failed. Attempt {attempts + 1} failed: {e}"
)
time.sleep(pow(2, attempts)) # Exponential backoff
attempts += 1
if attempts >= max_attempts:
FLOWER_LOGGER.error(
f"{client_name} - Could not establish connection to the server."
)
raise e
25 changes: 3 additions & 22 deletions exareme2/algorithms/flower/logistic_regression/client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import os
import time
import warnings
from math import log2

import flwr as fl
from flwr.common.logger import FLOWER_LOGGER
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
from utils import get_model_parameters
from utils import set_initial_params
from utils import set_model_params

from exareme2.algorithms.flower.inputdata_preprocessing import connect_with_retries
from exareme2.algorithms.flower.inputdata_preprocessing import fetch_data
from exareme2.algorithms.flower.inputdata_preprocessing import get_input
from exareme2.algorithms.flower.inputdata_preprocessing import preprocess_data
Expand Down Expand Up @@ -48,21 +45,5 @@ def evaluate(self, parameters, config):

client = LogisticRegressionClient(model, X_train, y_train)

attempts = 0
max_attempts = int(log2(int(os.environ["TIMEOUT"])))
while True:
try:
fl.client.start_client(
server_address=os.environ["SERVER_ADDRESS"], client=client.to_client()
)
FLOWER_LOGGER.debug(f"Connection successful on attempt: {attempts + 1}")
break
except Exception as e:
FLOWER_LOGGER.warning(
f"Connection with the server failed. Attempt {attempts + 1} failed: {e}"
)
time.sleep(pow(2, attempts))
attempts += 1
if attempts >= max_attempts:
FLOWER_LOGGER.error("Could not establish connection to the server.")
raise e
logistic_regression_client = LogisticRegressionClient(model, X_train, y_train)
connect_with_retries(logistic_regression_client, "LogisticRegressionClient")
5 changes: 1 addition & 4 deletions exareme2/algorithms/flower/process_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

import psutil

ALGORITHMS_ROOT = Path(__file__).parent


def process_status(proc):
"""Check the status of a process."""
Expand Down Expand Up @@ -108,9 +106,8 @@ def start(self, logger):
if self.proc is not None:
logger.error("Process already started!")
raise RuntimeError("Process already started!")
flower_executable = ALGORITHMS_ROOT / self.file
env = {**os.environ, **{k: str(v) for k, v in self.env_vars.items()}}
command = ["poetry", "run", "python", str(flower_executable), *self.parameters]
command = ["poetry", "run", "python", str(self.file), *self.parameters]
logger.info(f"Executing command: {command}")
self.proc = subprocess.Popen(
command, env=env, stdout=self.stdout, stderr=self.stderr
Expand Down
18 changes: 14 additions & 4 deletions exareme2/controller/celery/tasks_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,24 +298,34 @@ def queue_healthcheck_task(
)

def start_flower_client(
self, request_id, algorithm_name, server_address, csv_paths, execution_timeout
self,
request_id,
algorithm_folder_path,
server_address,
csv_paths,
execution_timeout,
) -> WorkerTaskResult:
return self._queue_task(
task_signature=TASK_SIGNATURES["start_flower_client"],
request_id=request_id,
algorithm_name=algorithm_name,
algorithm_folder_path=algorithm_folder_path,
server_address=server_address,
csv_paths=csv_paths,
execution_timeout=execution_timeout,
)

def start_flower_server(
self, request_id, algorithm_name, number_of_clients, server_address, csv_paths
self,
request_id,
algorithm_folder_path,
number_of_clients,
server_address,
csv_paths,
) -> WorkerTaskResult:
return self._queue_task(
task_signature=TASK_SIGNATURES["start_flower_server"],
request_id=request_id,
algorithm_name=algorithm_name,
algorithm_folder_path=algorithm_folder_path,
number_of_clients=number_of_clients,
server_address=server_address,
csv_paths=csv_paths,
Expand Down
9 changes: 7 additions & 2 deletions exareme2/controller/services/api/algorithm_spec_dtos.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

from pydantic import BaseModel

from exareme2 import ALGORITHM_FOLDERS
from exareme2 import EXAREME2_ALGORITHM_FOLDERS
from exareme2 import FLOWER_ALGORITHM_FOLDERS
from exareme2.algorithms.specifications import AlgorithmSpecification
from exareme2.algorithms.specifications import AlgorithmType
from exareme2.algorithms.specifications import InputDataSpecification
Expand Down Expand Up @@ -299,7 +300,11 @@ def load_and_parse_specifications(self):

@staticmethod
def get_specs_paths():
return [Path(specs_path.strip()) for specs_path in ALGORITHM_FOLDERS.split(",")]
return [
Path(specs_path.strip())
for specs_path in EXAREME2_ALGORITHM_FOLDERS.split(",")
+ FLOWER_ALGORITHM_FOLDERS.split(",")
]

def parse_specifications(self, specs_path, all_algorithms, all_transformers):
for spec_property_path in specs_path.glob("*.json"):
Expand Down
10 changes: 5 additions & 5 deletions exareme2/controller/services/exareme2/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from typing import List
from typing import Optional

from exareme2 import algorithm_classes
from exareme2 import algorithm_data_loaders
from exareme2 import exareme2_algorithm_classes
from exareme2 import exareme2_algorithm_data_loaders
from exareme2.algorithms.exareme2.algorithm import AlgorithmDataLoader
from exareme2.algorithms.exareme2.algorithm import (
InitializationParams as AlgorithmInitParams,
Expand Down Expand Up @@ -531,7 +531,7 @@ def __init__(
):
self._algorithm_name = algorithm_name
self._variables = variables
self._algorithm_data_loader = algorithm_data_loaders[algorithm_name](
self._algorithm_data_loader = exareme2_algorithm_data_loaders[algorithm_name](
variables=variables
)
self._algorithm_request_dto = algorithm_request_dto
Expand Down Expand Up @@ -598,7 +598,7 @@ async def run(self, data, metadata):
X = data_transformed[0]
y = data_transformed[1]
alg_vars = Variables(x=X.columns, y=y.columns)
algorithm_data_loader = algorithm_data_loaders[self._algorithm_name](
algorithm_data_loader = exareme2_algorithm_data_loaders[self._algorithm_name](
variables=alg_vars
)

Expand Down Expand Up @@ -692,7 +692,7 @@ async def run(self, data, metadata):
algorithm_parameters=self._params,
datasets=self._datasets,
)
algorithm = algorithm_classes[self._algorithm_name](
algorithm = exareme2_algorithm_classes[self._algorithm_name](
initialization_params=init_params,
data_loader=self._algorithm_data_loader,
engine=self._engine,
Expand Down
7 changes: 4 additions & 3 deletions exareme2/controller/services/flower/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Dict
from typing import List

from exareme2 import flower_algorithm_folder_paths
from exareme2.controller import config as ctrl_config
from exareme2.controller import logger as ctrl_logger
from exareme2.controller.federation_info_logs import log_experiment_execution
Expand Down Expand Up @@ -91,10 +92,10 @@ async def exec_algorithm(self, algorithm_name, algorithm_request_dto):
server_pid = None
clients_pids = {}
server_address = f"{server_ip}:{FLOWER_SERVER_PORT}"

algorithm_folder_path = flower_algorithm_folder_paths[algorithm_name]
try:
server_pid = server_task_handler.start_flower_server(
algorithm_name,
algorithm_folder_path,
len(task_handlers),
str(server_address),
csv_paths_per_worker_id[server_id]
Expand All @@ -103,7 +104,7 @@ async def exec_algorithm(self, algorithm_name, algorithm_request_dto):
)
clients_pids = {
handler.start_flower_client(
algorithm_name,
algorithm_folder_path,
str(server_address),
csv_paths_per_worker_id[handler.worker_id],
ctrl_config.flower_execution_timeout,
Expand Down
12 changes: 8 additions & 4 deletions exareme2/controller/services/flower/tasks_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,26 @@ def worker_data_address(self) -> str:
return self._db_address

def start_flower_client(
self, algorithm_name, server_address, csv_paths, execution_timeout
self, algorithm_folder_path, server_address, csv_paths, execution_timeout
) -> int:
return self._worker_tasks_handler.start_flower_client(
self._request_id,
algorithm_name,
algorithm_folder_path,
server_address,
csv_paths,
execution_timeout,
).get(timeout=self._tasks_timeout)

def start_flower_server(
self, algorithm_name: str, number_of_clients: int, server_address, csv_paths
self,
algorithm_folder_path: str,
number_of_clients: int,
server_address,
csv_paths,
) -> int:
return self._worker_tasks_handler.start_flower_server(
self._request_id,
algorithm_name,
algorithm_folder_path,
number_of_clients,
server_address,
csv_paths,
Expand Down
Loading

0 comments on commit dbf1078

Please sign in to comment.