diff --git a/app/database/models.py b/app/database/models.py index 390896c4..7cdab467 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -79,8 +79,7 @@ class User(Base): back_populates="user", ) comments = relationship("Comment", back_populates="user") - tasks = relationship( - "Task", cascade="all, delete", back_populates="owner") + tasks = relationship("Task", cascade="all, delete", back_populates="owner") features = relationship("Feature", secondary=UserFeature.__tablename__) oauth_credentials = relationship( @@ -535,6 +534,19 @@ def __repr__(self): ) +class UserMenstrualPeriodLength(Base): + __tablename__ = "user_menstrual_period_length" + + id = Column(Integer, primary_key=True, index=True) + user_id = Column( + Integer, + ForeignKey("users.id"), + nullable=False, + unique=True, + ) + period_length = Column(Integer, nullable=False) + + class SharedListItem(Base): __tablename__ = "shared_list_item" diff --git a/app/internal/menstrual_predictor_utils.py b/app/internal/menstrual_predictor_utils.py new file mode 100644 index 00000000..06298c80 --- /dev/null +++ b/app/internal/menstrual_predictor_utils.py @@ -0,0 +1,147 @@ +import datetime +from datetime import timedelta +from typing import List, Union + +from fastapi import Depends +from loguru import logger +from sqlalchemy import asc +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session + +from app.database.models import Event, UserMenstrualPeriodLength +from app.dependencies import get_db +from app.internal.security.dependencies import current_user +from app.internal.security.schema import CurrentUser +from app.routers.event import create_event + +MENSTRUAL_PERIOD_CATEGORY_ID = 111 +N_MONTHS_GENERATED = 3 +GAP_IN_CASE_NO_PERIODS = 30 + + +def get_avg_period_gap(db: Session, user_id: int) -> int: + period_days = get_all_period_days(db, user_id) + gaps_list = [] + + if len(period_days) <= 1: + return GAP_IN_CASE_NO_PERIODS + + for i in range(len(period_days) - 1): + gap = get_date_diff(period_days[i].start, period_days[i + 1].start) + gaps_list.append(gap.days) + return get_list_avg(gaps_list) + + +def get_date_diff(date_1: datetime, date_2: datetime) -> timedelta: + return date_2 - date_1 + + +def get_list_avg(received_list: List) -> int: + return sum(received_list) // len(received_list) + + +def remove_existing_period_dates(db: Session, user_id: int) -> None: + ( + db.query(Event) + .filter(Event.owner_id == user_id) + .filter(Event.category_id == MENSTRUAL_PERIOD_CATEGORY_ID) + .filter(Event.start > datetime.datetime.now()) + .delete() + ) + db.commit() + logger.info("Removed all period predictions to create new ones") + + +def generate_predicted_period_dates( + db: Session, + period_length: str, + period_start_date: datetime, + user_id: int, +) -> Event: + delta = datetime.timedelta(int(period_length)) + period_end_date = period_start_date + delta + period_event = create_event( + db, + "period", + period_start_date, + period_end_date, + user_id, + category_id=MENSTRUAL_PERIOD_CATEGORY_ID, + ) + return period_event + + +def add_n_month_predictions( + db: Session, + period_length: str, + period_start_date: datetime, + user_id: int, +) -> List[Event]: + avg_gap = get_avg_period_gap(db, user_id) + avg_gap_delta = datetime.timedelta(avg_gap) + generated_months = [] + for _ in range(N_MONTHS_GENERATED + 1): + generated_period = generate_predicted_period_dates( + db, + period_length, + period_start_date, + user_id, + ) + generated_months.append(generated_period) + period_start_date += avg_gap_delta + logger.info(f"Generated predictions: {generated_months}") + return generated_months + + +def add_prediction_events_if_valid( + period_start_date: datetime, + db: Session = Depends(get_db), + user: CurrentUser = Depends(current_user), +) -> None: + current_user_id = user.user_id + user_period_length = is_user_signed_up_to_menstrual_predictor( + db, + current_user_id, + ) + + remove_existing_period_dates(db, current_user_id) + if user_period_length: + add_n_month_predictions( + db, + user_period_length, + period_start_date, + current_user_id, + ) + + +def get_all_period_days(session: Session, user_id: int) -> List[Event]: + """Returns all period days filtered by user id.""" + + try: + period_days = ( + session.query(Event) + .filter(Event.owner_id == user_id) + .filter(Event.category_id == MENSTRUAL_PERIOD_CATEGORY_ID) + .order_by(asc(Event.start)) + .all() + ) + + except SQLAlchemyError as err: + logger.exception(err) + return [] + + return period_days + + +def is_user_signed_up_to_menstrual_predictor( + session: Session, + user_id: int, +) -> Union[bool, int]: + user_menstrual_period_length = ( + session.query(UserMenstrualPeriodLength) + .filter(user_id == user_id) + .first() + ) + if user_menstrual_period_length: + return user_menstrual_period_length.period_length + return False diff --git a/app/main.py b/app/main.py index 67136680..90320c85 100644 --- a/app/main.py +++ b/app/main.py @@ -79,6 +79,7 @@ def create_tables(engine, psql_environment): login, logout, meds, + menstrual_predictor, notification, profile, register, @@ -133,6 +134,7 @@ async def swagger_ui_redirect(): login.router, logout.router, meds.router, + menstrual_predictor.router, notes.router, notification.router, profile.router, diff --git a/app/routers/menstrual_predictor.py b/app/routers/menstrual_predictor.py new file mode 100644 index 00000000..c351b805 --- /dev/null +++ b/app/routers/menstrual_predictor.py @@ -0,0 +1,102 @@ +import datetime + +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import RedirectResponse, Response +from loguru import logger +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session +from starlette.status import HTTP_302_FOUND, HTTP_400_BAD_REQUEST + +from app.database.models import UserMenstrualPeriodLength +from app.dependencies import get_db, templates +from app.internal.menstrual_predictor_utils import ( + add_prediction_events_if_valid, + generate_predicted_period_dates, + is_user_signed_up_to_menstrual_predictor, +) +from app.internal.security.dependencies import current_user +from app.internal.security.schema import CurrentUser +from app.internal.utils import create_model + +router = APIRouter( + prefix="/menstrual-predictor", + tags=["menstrual-predictor"], + dependencies=[Depends(get_db)], +) + + +@router.get("/") +def join_menstrual_predictor( + request: Request, + db: Session = Depends(get_db), + user: CurrentUser = Depends(current_user), +) -> Response: + current_user_id = user.user_id + + if is_user_signed_up_to_menstrual_predictor(db, current_user_id): + return RedirectResponse(url="/", status_code=HTTP_302_FOUND) + + return templates.TemplateResponse( + "join_menstrual_predictor.html", + { + "request": request, + }, + ) + + +@router.get("/add/{start_date}") +def add_period_start( + request: Request, + start_date: str, + db: Session = Depends(get_db), + user: CurrentUser = Depends(current_user), +) -> RedirectResponse: + try: + period_start_date = datetime.datetime.strptime(start_date, "%Y-%m-%d") + except ValueError as err: + logger.exception(err) + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="The given date doesn't match a date format YYYY-MM-DD", + ) + else: + add_prediction_events_if_valid(period_start_date, db, user) + logger.info("Adding menstrual start date") + return RedirectResponse("/", status_code=HTTP_302_FOUND) + + +@router.post("/") +async def submit_join_form( + request: Request, + db: Session = Depends(get_db), + user: CurrentUser = Depends(current_user), +) -> RedirectResponse: + + data = await request.form() + print(data) + user_menstrual_period_length = { + "user_id": user.user_id, + "period_length": data["avg-period-length"], + } + last_period_date = datetime.datetime.strptime( + data["last-period-date"], + "%Y-%m-%d", + ) + try: + create_model( + session=db, + model_class=UserMenstrualPeriodLength, + **user_menstrual_period_length, + ) + except SQLAlchemyError: + logger.info("Current user already signed up to the service, hurray!!") + db.rollback() + url = "/" + generate_predicted_period_dates( + db, + data["avg-period-length"], + last_period_date, + user.user_id, + ) + + return RedirectResponse(url=url, status_code=HTTP_302_FOUND) diff --git a/app/static/grid_style.css b/app/static/grid_style.css index 026474a0..94bb43fd 100644 --- a/app/static/grid_style.css +++ b/app/static/grid_style.css @@ -2,8 +2,6 @@ --backgroundcol: #F7F7F7; --textcolor: #222831; --start-of-month: #E9ECEf; - --primary-variant: #FFDE4D; - --secondary: #EF5454; --borders: #E7E7E7; --borders-variant: #F7F7F7; } @@ -12,7 +10,6 @@ --backgroundcol: #000000; --textcolor: #EEEEEE; --start-of-month: #8C28BF; - --secondary: #EF5454; --borders: #E7E7E7; --borders-variant: #F7F7F7; } diff --git a/app/static/js/menstrual_predictor.js b/app/static/js/menstrual_predictor.js new file mode 100644 index 00000000..e657e77c --- /dev/null +++ b/app/static/js/menstrual_predictor.js @@ -0,0 +1,13 @@ +function change_max_to_today_date(el) { + const today = new Date(); + const today_str = today.toISOString().substring(0, 10); + el.max = el.dataset.maxDate = today_str; +} +function validate_date_older_than_today(received_date) { + return received_date < new Date(); +} + +document.addEventListener("DOMContentLoaded", () => { + const last_period_date_element = document.getElementById("last-period-date"); + change_max_to_today_date(last_period_date_element); +}); diff --git a/app/static/js/settings.js b/app/static/js/settings.js index 77ee2d13..6311790e 100644 --- a/app/static/js/settings.js +++ b/app/static/js/settings.js @@ -1,23 +1,39 @@ -document.addEventListener('DOMContentLoaded', () => { - const tabBtn = document.getElementsByClassName("tab"); - for (let i = 0; i < tabBtn.length; i++) { - const btn = document.getElementById("tab" + i); - btn.addEventListener('click', () => { - tabClick(btn.id, tabBtn); - }); - } -}); +document.addEventListener("DOMContentLoaded", () => { + const tabBtn = document.getElementsByClassName("tab"); + for (let i = 0; i < tabBtn.length; i++) { + const btn = document.getElementById("tab" + i); + btn.addEventListener("click", () => { + tabClick(btn.id, tabBtn); + }); + } + const menstrualSubscriptionSwitch = document.getElementById("switch3"); + menstrualSubscriptionSwitch.addEventListener("click", () => { + const btnState = menstrualSubscriptionSwitch.checked; + if (btnState) { + fetch('/menstrual-predictor/') + .then(response => { + let subscriptionContainer = document.getElementById('menstrual-prediction-container'); + subscriptionContainer.innerHTML = response; + }) -function tabClick(tab_id, tabBtn) { - let shownTab = document.querySelector(".tab-show"); - let selectedTabContent = document.querySelector(`#${tab_id}-content`); - shownTab.classList.remove("tab-show"); - shownTab.classList.add("tab-hide"); - for (btn of tabBtn) { - btn.children[0].classList.remove("active"); + console.log(menstrualSubscriptionSwitch.checked); } - document.getElementById(tab_id).classList.add("active"); - selectedTabContent.classList.remove("tab-hide"); - selectedTabContent.classList.add("tab-show"); + }); +}); +async function loadSubscriptionPage(response){ + const data = await response.text(); + return data; +} +function tabClick(tab_id, tabBtn) { + let shownTab = document.querySelector(".tab-show"); + let selectedTabContent = document.querySelector(`#${tab_id}-content`); + shownTab.classList.remove("tab-show"); + shownTab.classList.add("tab-hide"); + for (btn of tabBtn) { + btn.children[0].classList.remove("active"); + } + document.getElementById(tab_id).classList.add("active"); + selectedTabContent.classList.remove("tab-hide"); + selectedTabContent.classList.add("tab-show"); } diff --git a/app/templates/join_menstrual_predictor.html b/app/templates/join_menstrual_predictor.html new file mode 100644 index 00000000..064a9742 --- /dev/null +++ b/app/templates/join_menstrual_predictor.html @@ -0,0 +1,28 @@ +{% extends "base.html" %} +{% block content %} +
+

Please fill in your details

+
+
+ +
+
+
+
+ + Must be above 1 day. + +
+
+
+ +
+
+
+
+ + +
+
+ +{% endblock %} diff --git a/app/templates/partials/calendar/navigation.html b/app/templates/partials/calendar/navigation.html index 7042e37f..41a2e3e5 100644 --- a/app/templates/partials/calendar/navigation.html +++ b/app/templates/partials/calendar/navigation.html @@ -39,6 +39,9 @@
+
+ +
@@ -49,4 +52,4 @@
- \ No newline at end of file + diff --git a/app/templates/settings.html b/app/templates/settings.html index ee741658..2579ef14 100644 --- a/app/templates/settings.html +++ b/app/templates/settings.html @@ -49,6 +49,9 @@
  • Event settings
  • +
  • + Menstrual Predictor +
  • @@ -115,8 +118,23 @@

    View options

    + +
    +

    Menstrual Predictor settings

    +
    + + +
    +
    + + +
    +
    +
    +
    +
    {% endblock content %} diff --git a/tests/test_menstrual_predictor.py b/tests/test_menstrual_predictor.py new file mode 100644 index 00000000..6e8cce3b --- /dev/null +++ b/tests/test_menstrual_predictor.py @@ -0,0 +1,30 @@ +from app.routers.menstrual_predictor import router +from tests.test_login import test_login_successfull + + +class TestMenstrualPredictor: + @staticmethod + def test_menstrual_predictor_page_not_signed_up(client, session): + resp = client.get(router.url_path_for("join_menstrual_predictor")) + assert resp.ok + + @staticmethod + def test_menstrual_predictor_sign_up(security_test_client, session): + test_login_successfull(session, security_test_client) + resp = security_test_client.post( + router.url_path_for("submit_join_form"), + data={"avg-period-length": 8, "last-period-date": "2020-11-07"}, + ) + assert resp.ok + + url = router.url_path_for("add_period_start", start_date="2020-12-11") + resp = security_test_client.get(url) + + assert resp.ok + + @staticmethod + def test_add_period_date(security_test_client, session): + test_login_successfull(session, security_test_client) + url = router.url_path_for("add_period_start", start_date="2020-12-11") + resp = security_test_client.get(url) + assert resp.ok