diff --git a/exareme2/algorithms/flower/xgboost.json b/exareme2/algorithms/flower/xgboost.json new file mode 100644 index 000000000..ef3eee99d --- /dev/null +++ b/exareme2/algorithms/flower/xgboost.json @@ -0,0 +1,38 @@ +{ + "name": "xgboost", + "desc": "xgboost", + "label": "XGBoost on Flower", + "enabled": true, + "type": "flower", + "inputdata": { + "y": { + "label": "Variable (dependent)", + "desc": "A unique nominal variable. The variable is converted to binary by assigning 1 to the positive class and 0 to all other classes. ", + "types": [ + "int", + "text" + ], + "stattypes": [ + "nominal" + ], + "notblank": true, + "multiple": false + }, + "x": { + "label": "Covariates (independent)", + "desc": "One or more variables. Can be numerical or nominal. For nominal variables dummy encoding is used.", + "types": [ + "real", + "int", + "text" + ], + "stattypes": [ + "numerical", + "nominal" + ], + "notblank": true, + "multiple": true + }, + "validation": true + } +} diff --git a/exareme2/algorithms/flower/xgboost/__init__.py b/exareme2/algorithms/flower/xgboost/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/exareme2/algorithms/flower/xgboost/client.py b/exareme2/algorithms/flower/xgboost/client.py new file mode 100644 index 000000000..e32426c71 --- /dev/null +++ b/exareme2/algorithms/flower/xgboost/client.py @@ -0,0 +1,170 @@ +import os +import time +import warnings +from math import log2 + +import flwr as fl +import xgboost as xgb +from flwr.common import Code +from flwr.common import EvaluateIns +from flwr.common import EvaluateRes +from flwr.common import FitIns +from flwr.common import FitRes +from flwr.common import GetParametersIns +from flwr.common import GetParametersRes +from flwr.common import Parameters +from flwr.common import Status +from flwr.common.logger import FLOWER_LOGGER + +from exareme2.algorithms.flower.inputdata_preprocessing import fetch_data +from exareme2.algorithms.flower.inputdata_preprocessing import get_input +from exareme2.algorithms.flower.inputdata_preprocessing import preprocess_data + +warnings.filterwarnings("ignore", category=UserWarning) + + +def transform_dataset_to_dmatrix(x, y) -> xgb.core.DMatrix: + new_data = xgb.DMatrix(x, label=y) + return new_data + + +# Hyper-parameters for xgboost training +num_local_round = 1 +params = { + "objective": "binary:logistic", + "eta": 0.1, # Learning rate + "max_depth": 8, + "eval_metric": "auc", + "nthread": 16, + "num_parallel_tree": 1, + "subsample": 1, + "tree_method": "hist", +} + + +# Define Flower client +class XgbClient(fl.client.Client): + def __init__(self, train_dmatrix, valid_dmatrix, num_train, num_val): + self.bst = None + self.config = None + + self.train_dmatrix = train_dmatrix + self.valid_dmatrix = valid_dmatrix + + self.num_train = num_train + self.num_val = num_val + + def get_parameters(self, ins: GetParametersIns) -> GetParametersRes: + _ = (self, ins) + return GetParametersRes( + status=Status( + code=Code.OK, + message="OK", + ), + parameters=Parameters(tensor_type="", tensors=[]), + ) + + def _local_boost(self): + # Update trees based on local training data. + for i in range(num_local_round): + self.bst.update(train_dmatrix, self.bst.num_boosted_rounds()) + + # Extract the last N=num_local_round trees for sever aggregation + bst = self.bst[ + self.bst.num_boosted_rounds() + - num_local_round : self.bst.num_boosted_rounds() + ] + + return bst + + def fit(self, ins: FitIns) -> FitRes: + if not self.bst: + # First round local training + FLOWER_LOGGER.info("Start training at round 1") + bst = xgb.train( + params, + train_dmatrix, + num_boost_round=num_local_round, + evals=[(valid_dmatrix, "validate"), (train_dmatrix, "train")], + ) + self.config = bst.save_config() + self.bst = bst + else: + for item in ins.parameters.tensors: + global_model = bytearray(item) + + # Load global model into booster + self.bst.load_model(global_model) + self.bst.load_config(self.config) + + bst = self._local_boost() + + local_model = bst.save_raw("json") + local_model_bytes = bytes(local_model) + + return FitRes( + status=Status( + code=Code.OK, + message="OK", + ), + parameters=Parameters(tensor_type="", tensors=[local_model_bytes]), + num_examples=self.num_train, + metrics={}, + ) + + def evaluate(self, ins: EvaluateIns) -> EvaluateRes: + eval_results = self.bst.eval_set( + evals=[(valid_dmatrix, "valid")], + iteration=self.bst.num_boosted_rounds() - 1, + ) + auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4) + + return EvaluateRes( + status=Status( + code=Code.OK, + message="OK", + ), + loss=0.0, + num_examples=self.num_val, + metrics={"AUC": auc}, + ) + + +# Start Flower client +# fl.client.start_client(server_address="127.0.0.1:8080", client=XgbClient().to_client()) + +if __name__ == "__main__": + inputdata = get_input() + full_data = fetch_data(inputdata) + X_train, y_train = preprocess_data(inputdata, full_data) + # hard coded for now, later we can split X_train and y_train + X_valid, y_valid = X_train, y_train + + # Reformat data to DMatrix for xgboost + FLOWER_LOGGER.info("Reformatting data...") + train_dmatrix = transform_dataset_to_dmatrix(X_train, y=y_train) + valid_dmatrix = transform_dataset_to_dmatrix(X_valid, y=y_valid) + + num_train = X_train.shape[0] + num_val = X_valid.shape[0] + + client = XgbClient(train_dmatrix, valid_dmatrix, num_train, num_val) + + attempts = 0 + max_attempts = int(log2(int(os.environ["TIMEOUT"]))) + while True: + try: + fl.client.start_client( + server_address=os.environ["SERVER_ADDRESS"], client=client.to_client() + ) + FLOWER_LOGGER.debug("Connection successful on attempt") + break + except Exception as e: + FLOWER_LOGGER.warning( + f"Connection with the server failed. Attempt {attempts + 1} failed: {e}" + ) + time.sleep(pow(2, attempts)) + attempts += 1 + if attempts >= max_attempts: + FLOWER_LOGGER.error("Could not establish connection to the server.") + raise e diff --git a/exareme2/algorithms/flower/xgboost/server.py b/exareme2/algorithms/flower/xgboost/server.py new file mode 100644 index 000000000..13e415927 --- /dev/null +++ b/exareme2/algorithms/flower/xgboost/server.py @@ -0,0 +1,80 @@ +import copy +import os + +import flwr as fl +from flwr.common.logger import FLOWER_LOGGER +from flwr.server.strategy import FedXgbBagging + +from exareme2.algorithms.flower.inputdata_preprocessing import post_result + +# FL experimental settings +pool_size = 2 +NUM_OF_ROUNDS = 5 +num_clients_per_round = 2 +num_evaluate_clients = 2 + + +def evaluate_metrics_aggregation(eval_metrics): + """Return an aggregated metric (AUC) for evaluation.""" + total_num = sum([num for num, _ in eval_metrics]) + auc_aggregated = ( + sum([metrics["AUC"] * num for num, metrics in eval_metrics]) / total_num + ) + metrics_aggregated = {"AUC": auc_aggregated} + return metrics_aggregated + + +class CustomFedXgbBagging(FedXgbBagging): + def __init__(self, num_rounds, **kwargs): + super().__init__(**kwargs) + self.num_rounds = num_rounds + self.initial_auc = 0.0 + + def aggregate_evaluate(self, rnd, results, failures): + aggregated_metrics = super().aggregate_evaluate(rnd, results, failures) + d2 = copy.deepcopy(aggregated_metrics) + curr_auc = d2[1]["AUC"] + + if rnd == 1: + # print(aggregated_metrics) + d3 = copy.deepcopy(aggregated_metrics) + curr_auc = d3[1]["AUC"] + self.initial_auc = curr_auc + + if rnd == self.num_rounds: + FLOWER_LOGGER.debug("aggregated metrics is " + str(aggregated_metrics)) + + auc_diff = curr_auc - self.initial_auc + auc_ascending = "" + if auc_diff >= -0.05: + auc_ascending = "correct" + else: + auc_ascending = "not_correct" + + post_result( + { + "AUC": curr_auc, + "auc_ascending": auc_ascending, + "initial_auc": self.initial_auc, + } + ) + return aggregated_metrics + + +if __name__ == "__main__": + # Define strategy + strategy = CustomFedXgbBagging( + num_rounds=NUM_OF_ROUNDS, + fraction_fit=(float(num_clients_per_round) / pool_size), + min_fit_clients=num_clients_per_round, + min_available_clients=pool_size, + min_evaluate_clients=num_evaluate_clients, + fraction_evaluate=1.0, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation, + ) + + fl.server.start_server( + server_address=os.environ["SERVER_ADDRESS"], + strategy=strategy, + config=fl.server.ServerConfig(num_rounds=NUM_OF_ROUNDS), + ) diff --git a/poetry.lock b/poetry.lock index 39f71279b..7d10230c9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1575,6 +1575,17 @@ files = [ {file = "numpy-1.24.1.tar.gz", hash = "sha256:2386da9a471cc00a1f47845e27d916d5ec5346ae9696e01a8a34760858fe9dd2"}, ] +[[package]] +name = "nvidia-nccl-cu12" +version = "2.23.4" +description = "NVIDIA Collective Communication Library (NCCL) Runtime" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_nccl_cu12-2.23.4-py3-none-manylinux2014_aarch64.whl", hash = "sha256:aa946c8327e22ced28e7cef508a334673abc42064ec85f02d005ba1785ea4cec"}, + {file = "nvidia_nccl_cu12-2.23.4-py3-none-manylinux2014_x86_64.whl", hash = "sha256:b097258d9aab2fa9f686e33c6fe40ae57b27df60cedbd15d139701bb5509e0c1"}, +] + [[package]] name = "packaging" version = "24.1" @@ -2990,6 +3001,36 @@ files = [ [package.dependencies] h11 = ">=0.9.0,<1" +[[package]] +name = "xgboost" +version = "2.1.1" +description = "XGBoost Python Package" +optional = false +python-versions = ">=3.8" +files = [ + {file = "xgboost-2.1.1-py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.macosx_12_0_x86_64.whl", hash = "sha256:4163ab55118628f605cfccf950e2d667150640f6fc746bb5a173bddfd935950f"}, + {file = "xgboost-2.1.1-py3-none-macosx_12_0_arm64.whl", hash = "sha256:40d1f647022f497c1b0f69073765baf50ff5802ca77c6bb1aca55a6bc65df00d"}, + {file = "xgboost-2.1.1-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4c534818aa08ab327ac2239ef211ef78db65a8573d069bc9898f824830fa2308"}, + {file = "xgboost-2.1.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:deef471e8d353afa99e5cc0e2af7d99ace7013f40684fcf3eed9124de033265d"}, + {file = "xgboost-2.1.1-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:8f3246a6d839dceb4553d3e5ea64ed718f9c692f072ee8275eeb895b58e283e6"}, + {file = "xgboost-2.1.1-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:6475ca35dede1f87d1dc485b362caba08f69f6020f4440e97b167676a533850e"}, + {file = "xgboost-2.1.1-py3-none-win_amd64.whl", hash = "sha256:fcf8413f3c621e97fdaaa45abb7ae808319c88eff5447328eff14c419c7c6ae0"}, + {file = "xgboost-2.1.1.tar.gz", hash = "sha256:4b1729837f9f1ba88a32ef1be3f8efb860fee6454a68719b196dc88032c23d97"}, +] + +[package.dependencies] +numpy = "*" +nvidia-nccl-cu12 = {version = "*", markers = "platform_system == \"Linux\" and platform_machine != \"aarch64\""} +scipy = "*" + +[package.extras] +dask = ["dask", "distributed", "pandas"] +datatable = ["datatable"] +pandas = ["pandas (>=1.2)"] +plotting = ["graphviz", "matplotlib"] +pyspark = ["cloudpickle", "pyspark", "scikit-learn"] +scikit-learn = ["scikit-learn"] + [[package]] name = "zipp" version = "3.19.2" @@ -3008,4 +3049,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = "~3.8" -content-hash = "72b48e2af97e04691afbe14e1ed311b8e4709902fbf04f420bdaf0a5b82db9d3" +content-hash = "f6000f15d29fd5dd53fe6e5b4e90f61cf1f29b7511ba0519b00437f2d14aad8c" diff --git a/pyproject.toml b/pyproject.toml index 5f307b39e..0afb4df97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ eventlet = "~0.33" patsy = "~0.5.3" flwr = "1.7.0" psutil = "^5.9.8" +xgboost = "^2.1.1" [tool.poetry.dev-dependencies] pytest = "~7.4" diff --git a/tests/algorithm_validation_tests/flower/test_xgboost.py b/tests/algorithm_validation_tests/flower/test_xgboost.py new file mode 100644 index 000000000..b806b03ac --- /dev/null +++ b/tests/algorithm_validation_tests/flower/test_xgboost.py @@ -0,0 +1,33 @@ +def test_xgboost(get_algorithm_result): + input = { + "inputdata": { + "y": ["gender"], + "x": ["lefthippocampus"], + "data_model": "dementia:0.1", + "datasets": [ + "ppmi0", + "ppmi1", + "ppmi2", + "ppmi3", + "ppmi5", + "ppmi6", + "edsd6", + "ppmi7", + "ppmi8", + "ppmi9", + ], + "validation_datasets": ["ppmi_test"], + "filters": None, + }, + "parameters": None, + "test_case_num": 99, + } + input["type"] = "flower" + algorithm_result = get_algorithm_result("xgboost", input) + # {'metrics_aggregated': {'AUC': 0.7575790087463558}} + print(algorithm_result) + auc_aggregated = algorithm_result["AUC"] + auc_ascending = algorithm_result["auc_ascending"] + + assert auc_aggregated > 0.0 + assert auc_ascending == "correct"