Skip to content

Commit

Permalink
feat: model training & selection
Browse files Browse the repository at this point in the history
  • Loading branch information
billsioros committed Jul 29, 2024
1 parent eff4822 commit 6b52c75
Show file tree
Hide file tree
Showing 24 changed files with 2,367 additions and 233 deletions.
3 changes: 3 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@ indent_size = unset

[*.txt]
indent_size = unset

[*.ipynb]
indent_size = unset
1 change: 1 addition & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ POSTGRES_USER='guest'
POSTGRES_PASSWORD=')M8z*yss$cRxw7(&'

BACKEND_DATABASE__URI = 'postgresql://guest:)M8z*yss$cRxw7(&@database:5432/heartbeat'
BACKEND_CHECKPOINT_PATH = './model.joblib'
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,7 @@ Thumbs.db #thumbnail cache on Windows
# End of https://www.toptal.com/developers/gitignore/api/python

*.csv
*.pkl
*.joblib

.vscode
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ repos:
- id: poetry-check
- id: poetry-lock
args: [--check]
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.8.5
hooks:
- id: nbqa-check-ast
- id: nbqa-ruff
- id: nbqa-mypy
- repo: meta
hooks:
- id: check-hooks-apply
Expand Down
12 changes: 10 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@
</a>
</p>

Cardiovascular diseases (CVDs) are the number 1 cause of death globally, taking an estimated 17.9 million lives each year, which accounts for 31% of all deaths worldwide. Four out of 5CVD deaths are due to heart attacks and strokes, and one-third of these deaths occur prematurely in people under 70 years of age. Heart failure is a common event caused by CVDs and this dataset contains 11 features that can be used to predict a possible heart disease.

People with cardiovascular disease or who are at high cardiovascular risk (due to the presence of one or more risk factors such as hypertension, diabetes, hyperlipidaemia or already established disease) need early detection and management wherein a machine learning model can be of great help.

> Source: [https://www.kaggle.com/datasets/fedesoriano/heart-failure-prediction](https://www.kaggle.com/datasets/fedesoriano/heart-failure-prediction)
In this project, we create a complete solution featuring a [`FastAPI`](https://fastapi.tiangolo.com/) backend and a [`React`](https://react.dev/) frontend. We perform Exploratory Data Analysis (EDA) and develop a machine learning model using [`scikit-learn`](https://scikit-learn.org).

## :rocket: Running the project

> [!IMPORTANT]
Expand All @@ -65,6 +73,7 @@
POSTGRES_PASSWORD=')M8z*yss$cRxw7(&'
BACKEND_DATABASE__URI = 'postgresql://guest:)M8z*yss$cRxw7(&@database:5432/heartbeat'
BACKEND_CHECKPOINT_PATH = './model.joblib'
```

3. **Build and start the services** using Docker Compose.
Expand Down Expand Up @@ -104,6 +113,7 @@ Running the backend:

```shell
export BACKEND_DATABASE__URI='postgresql://guest:)M8z*yss$cRxw7(&@localhost:5432/heartbeat'
export BACKEND_CHECKPOINT_PATH='./model.joblib'
poetry env use 3.11
poetry install
python src/api/app.py
Expand All @@ -126,8 +136,6 @@ npm run dev -- --host --port 80

## :computer: Deploying to production

## Deploying to Production

Deploying to production involves several crucial steps, including setting up a server, configuring DNS, and managing SSL certificates. [`Traefik`](https://github.com/traefik/traefik) simplifies many of these tasks, acting as a powerful reverse proxy and load balancer. For a comprehensive guide on deploying your FastAPI application with Traefik, read the full article [here](https://github.com/tiangolo/blog-posts/blob/master/deploying-fastapi-apps-with-https-powered-by-traefik/README.md).

## :bookmark_tabs: Contributing
Expand Down
2 changes: 2 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ services:
interval: 5s
timeout: 10s
retries: 10
volumes:
- ${PWD}/model.joblib:/home/app/app/model.joblib
database:
env_file: .env
container_name: database
Expand Down
1,283 changes: 1,270 additions & 13 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ pandas = "^2.2.2"
seaborn = "^0.13.2"
matplotlib = "^3.9.1"
scikit-learn = "^1.5.1"
xgboost = "^2.1.0"
tqdm = "^4.66.4"
jupyter = "^1.0.0"
ipywidgets = "^8.1.3"
joblib = "^1.4.2"

[tool.poetry.group.dev.dependencies]
mypy = "*"
Expand Down
6 changes: 0 additions & 6 deletions src/api/app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import uvicorn
from fastapi import APIRouter, FastAPI
from fastapi.middleware.cors import CORSMiddleware

from api.controllers.heartbeat_controller import api as heartbeat_controller
from api.controllers.monitor_controller import api as monitor_controller
from api.settings import Settings
Expand Down Expand Up @@ -49,14 +48,9 @@ def register_middlewares(app: FastAPI) -> FastAPI:
return app


def register_events(app: FastAPI) -> FastAPI:
return app


def create_app() -> FastAPI:
app = initialize_api()
app = register_configuration(app)
app = register_events(app)
app = register_controllers(app)
app = register_middlewares(app)

Expand Down
20 changes: 18 additions & 2 deletions src/api/bot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,22 @@
from api.schemas.heartbeat import HeartBeatCreateSchema
from sklearn.pipeline import Pipeline
import pandas as pd


class Bot:
class Bot(object):
def __init__(self, model: Pipeline) -> None:
self._model = model

def predict(self, heartbeat: HeartBeatCreateSchema) -> bool:
return True
payload = {
"Age": heartbeat.age,
"Sex": heartbeat.sex.value,
"ChestPain": heartbeat.chest_pain_type.value,
"FastingBS": int(heartbeat.fasting_blood_sugar),
"MaxHR": heartbeat.max_heart_rate,
"ExerciseAngina": int(heartbeat.exercise_angina),
"Oldpeak": heartbeat.old_peak,
"ST_Slope": heartbeat.st_slope.value,
}

return self._model.predict(pd.DataFrame([payload]))[0]
8 changes: 5 additions & 3 deletions src/api/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from fastapi import Depends
from fastapi.requests import Request

import joblib
from api.bot import Bot
from api.repositories.heartbeat_repository import HeartBeatRepository
from api.resources.database import Database
Expand All @@ -25,8 +25,10 @@ async def get_heartbeat_repository(
return HeartBeatRepository(database)


async def get_bot(request: Request):
return Bot()
async def get_bot(request: Request, settings: Settings = Depends(get_settings)):
model = joblib.load(settings.checkpoint_path)

return Bot(model)


async def get_heartbeat_service(
Expand Down
14 changes: 0 additions & 14 deletions src/api/models/heartbeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,6 @@ class ChestPain(IntEnum):
ASYMPTOMATIC = auto()


class RestingElectrocardiogram(IntEnum):
NORMAL = auto()
STT = auto() # having ST-T wave abnormality (T wave inversions and/or ST elevation or depression of > 0.05 mV)
LVH = (
auto()
) # showing probable or definite left ventricular hypertrophy by Estes' criteria


class StSlope(IntEnum):
UP = auto()
FLAT = auto()
Expand All @@ -38,13 +30,7 @@ class HeartBeatModel(PkModel):
age: Mapped[int] = Column(Integer, nullable=False)
sex: Mapped[Sex] = Column(Enum(Sex), nullable=False)
chest_pain_type: Mapped[ChestPain] = Column(Enum(ChestPain), nullable=False)
resting_blood_pressure: Mapped[int] = Column(Integer, nullable=False)
cholesterol: Mapped[int] = Column(Integer, nullable=False)
fasting_blood_sugar: Mapped[bool] = Column(Boolean, nullable=False)
resting_electrocardiogram: Mapped[RestingElectrocardiogram] = Column(
Enum(RestingElectrocardiogram),
nullable=False,
)
max_heart_rate: Mapped[int] = Column(Integer, nullable=False)
exercise_angina: Mapped[bool] = Column(Boolean, nullable=False)
old_peak: Mapped[float] = Column(Float, nullable=False)
Expand Down
8 changes: 7 additions & 1 deletion src/api/repositories/heartbeat_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from sqlalchemy import exc

from sqlalchemy import desc

from api.models.heartbeat import HeartBeatModel
from api.repositories import Repository
from api.resources.database import Database
Expand All @@ -22,7 +24,11 @@ def __init__(self, database: Database) -> None:

def get_all(self) -> Iterator[HeartBeatModel]:
with self.session_factory() as session:
return session.query(HeartBeatModel).all()
return (
session.query(HeartBeatModel)
.order_by(desc(HeartBeatModel.created_at))
.all()
)

def get_by_id(self, id: str) -> HeartBeatModel:
with self.session_factory() as session:
Expand Down
9 changes: 0 additions & 9 deletions src/api/schemas/heartbeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from api.models.heartbeat import (
ChestPain,
RestingElectrocardiogram,
Sex,
StSlope,
)
Expand All @@ -15,18 +14,10 @@ class Config:
age: int = Field(..., ge=0, le=130, description="Age of the patient [years]")
sex: Sex
chest_pain_type: ChestPain
resting_blood_pressure: int = Field(
...,
ge=0,
le=250,
description="Resting blood pressure [mm Hg]",
)
cholesterol: int = Field(..., ge=0, le=700, description="Serum cholesterol [mm/dl]")
fasting_blood_sugar: bool = Field(
...,
description="Fasting blood sugar [1: if FastingBS > 120 mg/dl, 0: otherwise]",
)
resting_electrocardiogram: RestingElectrocardiogram
max_heart_rate: int = Field(
...,
ge=60,
Expand Down
3 changes: 0 additions & 3 deletions src/api/services/heartbeat_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,7 @@ def create(
age=heartbeat_create.age,
sex=heartbeat_create.sex,
chest_pain_type=heartbeat_create.chest_pain_type,
resting_blood_pressure=heartbeat_create.resting_blood_pressure,
cholesterol=heartbeat_create.cholesterol,
fasting_blood_sugar=heartbeat_create.fasting_blood_sugar,
resting_electrocardiogram=heartbeat_create.resting_electrocardiogram,
max_heart_rate=heartbeat_create.max_heart_rate,
exercise_angina=heartbeat_create.exercise_angina,
old_peak=heartbeat_create.old_peak,
Expand Down
2 changes: 2 additions & 0 deletions src/api/settings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from pydantic import BaseModel, PostgresDsn
from pydantic_settings import BaseSettings, SettingsConfigDict

Expand All @@ -15,3 +16,4 @@ class Settings(BaseSettings):
)

database: DatabaseSettings = DatabaseSettings()
checkpoint_path: Path
147 changes: 140 additions & 7 deletions src/notebooks/eda.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 6b52c75

Please sign in to comment.