Skip to content

Commit

Permalink
Client load data directly from csvs. (#484)
Browse files Browse the repository at this point in the history
* Added a retry mechanism for the connection of flower clients to the flower server.

* Update data processing for client, so they load data from csv and not from the database
  • Loading branch information
KFilippopolitis authored Jul 9, 2024
1 parent 1385732 commit 325476c
Show file tree
Hide file tree
Showing 28 changed files with 1,796 additions and 1,250 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/algorithm_validation_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ jobs:
with:
run: cat /tmp/exareme2/localworker1.out

- name: Run Flower algorithm validation tests
run: poetry run pytest tests/algorithm_validation_tests/flower/test_logistic_regression.py -n 2 --verbosity=4 --reruns 6 --reruns-delay 5

- name: Run Exareme2 algorithm validation tests
run: poetry run pytest tests/algorithm_validation_tests/exareme2/ --verbosity=4 -n 16 -k "input1 and not input1-" # run tests 10-19

- name: Run Flower algorithm validation tests
run: poetry run pytest tests/algorithm_validation_tests/flower/test_logistic_regression.py -n 2 --verbosity=4 --reruns 6 --reruns-delay 5
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Optional

import pandas as pd
import pymonetdb
import requests
from flwr.common.logger import FLOWER_LOGGER
from pydantic import BaseModel
Expand All @@ -29,37 +28,30 @@ class Inputdata(BaseModel):
x: Optional[List[str]]


def fetch_data(data_model, datasets, from_db=False) -> pd.DataFrame:
return (
_fetch_data_from_db(data_model, datasets)
if from_db
else _fetch_data_from_csv(data_model, datasets)
)
def fetch_client_data(inputdata) -> pd.DataFrame:
FLOWER_LOGGER.error(f"BROOO {os.getenv('CSV_PATHS')}")
dataframes = [
pd.read_csv(f"{os.getenv('DATA_PATH')}{csv_path}")
for csv_path in os.getenv("CSV_PATHS").split(",")
]
df = pd.concat(dataframes, ignore_index=True)
df = df[df["dataset"].isin(inputdata.datasets)]
return df[inputdata.x + inputdata.y]


def _fetch_data_from_db(data_model, datasets) -> pd.DataFrame:
query = f'SELECT * FROM "{data_model}"."primary_data"'
conn = pymonetdb.connect(
hostname=os.getenv("MONETDB_IP"),
port=int(os.getenv("MONETDB_PORT")),
username=os.getenv("MONETDB_USERNAME"),
password=os.getenv("MONETDB_PASSWORD"),
database=os.getenv("MONETDB_DB"),
def fetch_server_data(inputdata) -> pd.DataFrame:
data_folder = Path(
f"{os.getenv('DATA_PATH')}/{inputdata.data_model.split(':')[0]}_v_0_1"
)
df = pd.read_sql(query, conn)
conn.close()
df = df[df["dataset"].isin(datasets)]
return df


def _fetch_data_from_csv(data_model, datasets) -> pd.DataFrame:
data_folder = Path(f"{os.getenv('DATA_PATH')}/{data_model.split(':')[0]}_v_0_1")
print(f"Loading data from folder: {data_folder}")
dataframes = [
pd.read_csv(data_folder / f"{dataset}.csv")
for dataset in datasets
for dataset in inputdata.datasets
if (data_folder / f"{dataset}.csv").exists()
]
return pd.concat(dataframes, ignore_index=True)
df = pd.concat(dataframes, ignore_index=True)
df = df[df["dataset"].isin(inputdata.datasets)]
return df[inputdata.x + inputdata.y]


def preprocess_data(inputdata, full_data):
Expand Down
33 changes: 26 additions & 7 deletions exareme2/algorithms/flower/logistic_regression/client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import os
import time
import warnings
from math import log2

import flwr as fl
from flwr.common.logger import FLOWER_LOGGER
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
from utils import get_model_parameters
from utils import set_initial_params
from utils import set_model_params

from exareme2.algorithms.flower.flower_data_processing import fetch_data
from exareme2.algorithms.flower.flower_data_processing import get_input
from exareme2.algorithms.flower.flower_data_processing import preprocess_data
from exareme2.algorithms.flower.inputdata_preprocessing import fetch_client_data
from exareme2.algorithms.flower.inputdata_preprocessing import get_input
from exareme2.algorithms.flower.inputdata_preprocessing import preprocess_data


class LogisticRegressionClient(fl.client.NumPyClient):
Expand Down Expand Up @@ -39,11 +42,27 @@ def evaluate(self, parameters, config):
if __name__ == "__main__":
model = LogisticRegression(penalty="l2", max_iter=1, warm_start=True)
inputdata = get_input()
full_data = fetch_data(inputdata.data_model, inputdata.datasets, from_db=True)
full_data = fetch_client_data(inputdata)
X_train, y_train = preprocess_data(inputdata, full_data)
set_initial_params(model, X_train, full_data, inputdata)

client = LogisticRegressionClient(model, X_train, y_train)
fl.client.start_client(
server_address=os.environ["SERVER_ADDRESS"], client=client.to_client()
)

attempts = 0
max_attempts = int(log2(int(os.environ["TIMEOUT"])))
while True:
try:
fl.client.start_client(
server_address=os.environ["SERVER_ADDRESS"], client=client.to_client()
)
FLOWER_LOGGER.debug("Connection successful on attempt", attempts + 1)
break
except Exception as e:
FLOWER_LOGGER.warning(
f"Connection with the server failed. Attempt {attempts + 1} failed: {e}"
)
time.sleep(pow(2, attempts))
attempts += 1
if attempts >= max_attempts:
FLOWER_LOGGER.error("Could not establish connection to the server.")
raise e
10 changes: 5 additions & 5 deletions exareme2/algorithms/flower/logistic_regression/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from utils import set_initial_params
from utils import set_model_params

from exareme2.algorithms.flower.flower_data_processing import fetch_data
from exareme2.algorithms.flower.flower_data_processing import get_input
from exareme2.algorithms.flower.flower_data_processing import post_result
from exareme2.algorithms.flower.flower_data_processing import preprocess_data
from exareme2.algorithms.flower.inputdata_preprocessing import fetch_server_data
from exareme2.algorithms.flower.inputdata_preprocessing import get_input
from exareme2.algorithms.flower.inputdata_preprocessing import post_result
from exareme2.algorithms.flower.inputdata_preprocessing import preprocess_data

# TODO: NUM_OF_ROUNDS should become a parameter of the algorithm and be set on the AlgorithmRequestDTO
NUM_OF_ROUNDS = 5
Expand All @@ -35,7 +35,7 @@ def evaluate(server_round, parameters, config):
if __name__ == "__main__":
model = LogisticRegression()
inputdata = get_input()
full_data = fetch_data(inputdata.data_model, inputdata.datasets)
full_data = fetch_server_data(inputdata)
X_train, y_train = preprocess_data(inputdata, full_data)
set_initial_params(model, X_train, full_data, inputdata)
strategy = fl.server.strategy.FedAvg(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss

from exareme2.algorithms.flower.flower_data_processing import post_result
from exareme2.algorithms.flower.inputdata_preprocessing import post_result
from exareme2.algorithms.flower.mnist_logistic_regression import utils

NUM_OF_ROUNDS = 5
Expand Down
4 changes: 3 additions & 1 deletion exareme2/controller/celery/tasks_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,15 @@ def queue_healthcheck_task(
)

def start_flower_client(
self, request_id, algorithm_name, server_address
self, request_id, algorithm_name, server_address, csv_paths, execution_timeout
) -> WorkerTaskResult:
return self._queue_task(
task_signature=TASK_SIGNATURES["start_flower_client"],
request_id=request_id,
algorithm_name=algorithm_name,
server_address=server_address,
csv_paths=csv_paths,
execution_timeout=execution_timeout,
)

def start_flower_server(
Expand Down
9 changes: 8 additions & 1 deletion exareme2/controller/quart/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,14 @@ async def get_datasets() -> dict:

@algorithms.route("/datasets_locations", methods=["GET"])
async def get_datasets_locations() -> dict:
return get_worker_landscape_aggregator().get_datasets_locations().datasets_locations
return {
data_model: {
dataset: info.worker_id for dataset, info in datasets_location.items()
}
for data_model, datasets_location in get_worker_landscape_aggregator()
.get_datasets_locations()
.datasets_locations.items()
}


@algorithms.route("/cdes_metadata", methods=["GET"])
Expand Down
28 changes: 14 additions & 14 deletions exareme2/controller/services/flower/controller.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import asyncio
import warnings
from typing import Dict
from typing import List

from exareme2.controller import config as ctrl_config
from exareme2.controller import logger as ctrl_logger
from exareme2.controller.federation_info_logs import log_experiment_execution
from exareme2.controller.services.flower.tasks_handler import FlowerTasksHandler
Expand Down Expand Up @@ -52,10 +55,16 @@ async def exec_algorithm(self, algorithm_name, algorithm_request_dto):
request_id = algorithm_request_dto.request_id
context_id = UIDGenerator().get_a_uid()
logger = ctrl_logger.get_request_logger(request_id)
workers_info = self._get_workers_info_by_dataset(
csv_paths_per_worker_id: Dict[
str, List[str]
] = self.worker_landscape_aggregator.get_csv_paths_per_worker_id(
algorithm_request_dto.inputdata.data_model,
algorithm_request_dto.inputdata.datasets,
)
workers_info = [
self.worker_landscape_aggregator.get_worker_info(worker_id)
for worker_id in csv_paths_per_worker_id
]
task_handlers = [
self._create_worker_tasks_handler(request_id, worker)
for worker in workers_info
Expand Down Expand Up @@ -87,7 +96,10 @@ async def exec_algorithm(self, algorithm_name, algorithm_request_dto):
)
clients_pids = {
handler.start_flower_client(
algorithm_name, str(server_address)
algorithm_name,
str(server_address),
csv_paths_per_worker_id[handler.worker_id],
ctrl_config.flower_execution_timeout,
): handler
for handler in task_handlers
}
Expand Down Expand Up @@ -127,15 +139,3 @@ async def _cleanup(
server_task_handler.stop_flower_server(server_pid, algorithm_name)
for pid, handler in clients_pids.items():
handler.stop_flower_client(pid, algorithm_name)

def _get_workers_info_by_dataset(self, data_model, datasets) -> List[WorkerInfo]:
"""Retrieves worker information for those handling the specified datasets."""
worker_ids = (
self.worker_landscape_aggregator.get_worker_ids_with_any_of_datasets(
data_model, datasets
)
)
return [
self.worker_landscape_aggregator.get_worker_info(worker_id)
for worker_id in worker_ids
]
10 changes: 8 additions & 2 deletions exareme2/controller/services/flower/tasks_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,15 @@ def worker_id(self) -> str:
def worker_data_address(self) -> str:
return self._db_address

def start_flower_client(self, algorithm_name, server_address) -> int:
def start_flower_client(
self, algorithm_name, server_address, csv_paths, execution_timeout
) -> int:
return self._worker_tasks_handler.start_flower_client(
self._request_id, algorithm_name, server_address
self._request_id,
algorithm_name,
server_address,
csv_paths,
execution_timeout,
).get(timeout=self._tasks_timeout)

def start_flower_server(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from exareme2.controller.celery.tasks_handler import WorkerTasksHandler
from exareme2.worker_communication import CommonDataElements
from exareme2.worker_communication import DataModelAttributes
from exareme2.worker_communication import DatasetsInfoPerDataModel
from exareme2.worker_communication import WorkerInfo


Expand All @@ -23,10 +24,11 @@ def get_worker_info_task(self) -> WorkerInfo:
).get(self._tasks_timeout)
return WorkerInfo.parse_raw(result)

def get_worker_datasets_per_data_model_task(self) -> Dict[str, Dict[str, str]]:
return self._worker_tasks_handler.queue_worker_datasets_per_data_model_task(
def get_worker_datasets_per_data_model_task(self) -> DatasetsInfoPerDataModel:
result = self._worker_tasks_handler.queue_worker_datasets_per_data_model_task(
self._request_id
).get(self._tasks_timeout)
return DatasetsInfoPerDataModel.parse_raw(result)

def get_data_model_cdes_task(self, data_model: str) -> CommonDataElements:
result = self._worker_tasks_handler.queue_data_model_cdes_task(
Expand Down
Loading

0 comments on commit 325476c

Please sign in to comment.