Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Xgboost branch #501

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions exareme2/algorithms/flower/xgboost.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"name": "xgboost",
"desc": "xgboost",
"label": "XGBoost on Flower",
"enabled": true,
"type": "flower",
"inputdata": {
"y": {
"label": "Variable (dependent)",
"desc": "A unique nominal variable. The variable is converted to binary by assigning 1 to the positive class and 0 to all other classes. ",
"types": [
"int",
"text"
],
"stattypes": [
"nominal"
],
"notblank": true,
"multiple": false
},
"x": {
"label": "Covariates (independent)",
"desc": "One or more variables. Can be numerical or nominal. For nominal variables dummy encoding is used.",
"types": [
"real",
"int",
"text"
],
"stattypes": [
"numerical",
"nominal"
],
"notblank": true,
"multiple": true
},
"validation": true
}
}
Empty file.
170 changes: 170 additions & 0 deletions exareme2/algorithms/flower/xgboost/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import os
import time
import warnings
from math import log2

import flwr as fl
import xgboost as xgb
from flwr.common import Code
from flwr.common import EvaluateIns
from flwr.common import EvaluateRes
from flwr.common import FitIns
from flwr.common import FitRes
from flwr.common import GetParametersIns
from flwr.common import GetParametersRes
from flwr.common import Parameters
from flwr.common import Status
from flwr.common.logger import FLOWER_LOGGER

from exareme2.algorithms.flower.inputdata_preprocessing import fetch_data
from exareme2.algorithms.flower.inputdata_preprocessing import get_input
from exareme2.algorithms.flower.inputdata_preprocessing import preprocess_data

warnings.filterwarnings("ignore", category=UserWarning)


def transform_dataset_to_dmatrix(x, y) -> xgb.core.DMatrix:
new_data = xgb.DMatrix(x, label=y)
return new_data


# Hyper-parameters for xgboost training
num_local_round = 1
params = {
"objective": "binary:logistic",
"eta": 0.1, # Learning rate
"max_depth": 8,
"eval_metric": "auc",
"nthread": 16,
"num_parallel_tree": 1,
"subsample": 1,
"tree_method": "hist",
}


# Define Flower client
class XgbClient(fl.client.Client):
def __init__(self, train_dmatrix, valid_dmatrix, num_train, num_val):
self.bst = None
self.config = None

self.train_dmatrix = train_dmatrix
self.valid_dmatrix = valid_dmatrix

self.num_train = num_train
self.num_val = num_val

def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
_ = (self, ins)
return GetParametersRes(
status=Status(
code=Code.OK,
message="OK",
),
parameters=Parameters(tensor_type="", tensors=[]),
)

def _local_boost(self):
# Update trees based on local training data.
for i in range(num_local_round):
self.bst.update(train_dmatrix, self.bst.num_boosted_rounds())

# Extract the last N=num_local_round trees for sever aggregation
bst = self.bst[
self.bst.num_boosted_rounds()
- num_local_round : self.bst.num_boosted_rounds()
]

return bst

def fit(self, ins: FitIns) -> FitRes:
if not self.bst:
# First round local training
FLOWER_LOGGER.info("Start training at round 1")
bst = xgb.train(
params,
train_dmatrix,
num_boost_round=num_local_round,
evals=[(valid_dmatrix, "validate"), (train_dmatrix, "train")],
)
self.config = bst.save_config()
self.bst = bst
else:
for item in ins.parameters.tensors:
global_model = bytearray(item)

# Load global model into booster
self.bst.load_model(global_model)
self.bst.load_config(self.config)

bst = self._local_boost()

local_model = bst.save_raw("json")
local_model_bytes = bytes(local_model)

return FitRes(
status=Status(
code=Code.OK,
message="OK",
),
parameters=Parameters(tensor_type="", tensors=[local_model_bytes]),
num_examples=self.num_train,
metrics={},
)

def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
eval_results = self.bst.eval_set(
evals=[(valid_dmatrix, "valid")],
iteration=self.bst.num_boosted_rounds() - 1,
)
auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4)

return EvaluateRes(
status=Status(
code=Code.OK,
message="OK",
),
loss=0.0,
num_examples=self.num_val,
metrics={"AUC": auc},
)


# Start Flower client
# fl.client.start_client(server_address="127.0.0.1:8080", client=XgbClient().to_client())

if __name__ == "__main__":
inputdata = get_input()
full_data = fetch_data(inputdata)
X_train, y_train = preprocess_data(inputdata, full_data)
# hard coded for now, later we can split X_train and y_train
X_valid, y_valid = X_train, y_train

# Reformat data to DMatrix for xgboost
FLOWER_LOGGER.info("Reformatting data...")
train_dmatrix = transform_dataset_to_dmatrix(X_train, y=y_train)
valid_dmatrix = transform_dataset_to_dmatrix(X_valid, y=y_valid)

num_train = X_train.shape[0]
num_val = X_valid.shape[0]

client = XgbClient(train_dmatrix, valid_dmatrix, num_train, num_val)

attempts = 0
max_attempts = int(log2(int(os.environ["TIMEOUT"])))
while True:
try:
fl.client.start_client(
server_address=os.environ["SERVER_ADDRESS"], client=client.to_client()
)
FLOWER_LOGGER.debug("Connection successful on attempt")
break
except Exception as e:
FLOWER_LOGGER.warning(
f"Connection with the server failed. Attempt {attempts + 1} failed: {e}"
)
time.sleep(pow(2, attempts))
attempts += 1
if attempts >= max_attempts:
FLOWER_LOGGER.error("Could not establish connection to the server.")
raise e
80 changes: 80 additions & 0 deletions exareme2/algorithms/flower/xgboost/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import copy
import os

import flwr as fl
from flwr.common.logger import FLOWER_LOGGER
from flwr.server.strategy import FedXgbBagging

from exareme2.algorithms.flower.inputdata_preprocessing import post_result

# FL experimental settings
pool_size = 2
NUM_OF_ROUNDS = 5
num_clients_per_round = 2
num_evaluate_clients = 2


def evaluate_metrics_aggregation(eval_metrics):
"""Return an aggregated metric (AUC) for evaluation."""
total_num = sum([num for num, _ in eval_metrics])
auc_aggregated = (
sum([metrics["AUC"] * num for num, metrics in eval_metrics]) / total_num
)
metrics_aggregated = {"AUC": auc_aggregated}
return metrics_aggregated


class CustomFedXgbBagging(FedXgbBagging):
def __init__(self, num_rounds, **kwargs):
super().__init__(**kwargs)
self.num_rounds = num_rounds
self.initial_auc = 0.0

def aggregate_evaluate(self, rnd, results, failures):
aggregated_metrics = super().aggregate_evaluate(rnd, results, failures)
d2 = copy.deepcopy(aggregated_metrics)
curr_auc = d2[1]["AUC"]

if rnd == 1:
# print(aggregated_metrics)
d3 = copy.deepcopy(aggregated_metrics)
curr_auc = d3[1]["AUC"]
self.initial_auc = curr_auc

if rnd == self.num_rounds:
FLOWER_LOGGER.debug("aggregated metrics is " + str(aggregated_metrics))

auc_diff = curr_auc - self.initial_auc
auc_ascending = ""
if auc_diff >= -0.05:
auc_ascending = "correct"
else:
auc_ascending = "not_correct"

post_result(
{
"AUC": curr_auc,
"auc_ascending": auc_ascending,
"initial_auc": self.initial_auc,
}
)
return aggregated_metrics


if __name__ == "__main__":
# Define strategy
strategy = CustomFedXgbBagging(
num_rounds=NUM_OF_ROUNDS,
fraction_fit=(float(num_clients_per_round) / pool_size),
min_fit_clients=num_clients_per_round,
min_available_clients=pool_size,
min_evaluate_clients=num_evaluate_clients,
fraction_evaluate=1.0,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation,
)

fl.server.start_server(
server_address=os.environ["SERVER_ADDRESS"],
strategy=strategy,
config=fl.server.ServerConfig(num_rounds=NUM_OF_ROUNDS),
)
43 changes: 42 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ eventlet = "~0.33"
patsy = "~0.5.3"
flwr = "1.7.0"
psutil = "^5.9.8"
xgboost = "^2.1.1"

[tool.poetry.dev-dependencies]
pytest = "~7.4"
Expand Down
33 changes: 33 additions & 0 deletions tests/algorithm_validation_tests/flower/test_xgboost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
def test_xgboost(get_algorithm_result):
input = {
"inputdata": {
"y": ["gender"],
"x": ["lefthippocampus"],
"data_model": "dementia:0.1",
"datasets": [
"ppmi0",
"ppmi1",
"ppmi2",
"ppmi3",
"ppmi5",
"ppmi6",
"edsd6",
"ppmi7",
"ppmi8",
"ppmi9",
],
"validation_datasets": ["ppmi_test"],
"filters": None,
},
"parameters": None,
"test_case_num": 99,
}
input["type"] = "flower"
algorithm_result = get_algorithm_result("xgboost", input)
# {'metrics_aggregated': {'AUC': 0.7575790087463558}}
print(algorithm_result)
auc_aggregated = algorithm_result["AUC"]
auc_ascending = algorithm_result["auc_ascending"]

assert auc_aggregated > 0.0
assert auc_ascending == "correct"
Loading