-
Notifications
You must be signed in to change notification settings - Fork 52
/
Copy pathmenstrual_predictor_utils.py
147 lines (120 loc) · 3.96 KB
/
menstrual_predictor_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import datetime
from datetime import timedelta
from typing import List, Union
from fastapi import Depends
from loguru import logger
from sqlalchemy import asc
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from app.database.models import Event, UserMenstrualPeriodLength
from app.dependencies import get_db
from app.internal.security.dependencies import current_user
from app.internal.security.schema import CurrentUser
from app.routers.event import create_event
MENSTRUAL_PERIOD_CATEGORY_ID = 111
N_MONTHS_GENERATED = 3
GAP_IN_CASE_NO_PERIODS = 30
def get_avg_period_gap(db: Session, user_id: int) -> int:
period_days = get_all_period_days(db, user_id)
gaps_list = []
if len(period_days) <= 1:
return GAP_IN_CASE_NO_PERIODS
for i in range(len(period_days) - 1):
gap = get_date_diff(period_days[i].start, period_days[i + 1].start)
gaps_list.append(gap.days)
return get_list_avg(gaps_list)
def get_date_diff(date_1: datetime, date_2: datetime) -> timedelta:
return date_2 - date_1
def get_list_avg(received_list: List) -> int:
return sum(received_list) // len(received_list)
def remove_existing_period_dates(db: Session, user_id: int) -> None:
(
db.query(Event)
.filter(Event.owner_id == user_id)
.filter(Event.category_id == MENSTRUAL_PERIOD_CATEGORY_ID)
.filter(Event.start > datetime.datetime.now())
.delete()
)
db.commit()
logger.info("Removed all period predictions to create new ones")
def generate_predicted_period_dates(
db: Session,
period_length: str,
period_start_date: datetime,
user_id: int,
) -> Event:
delta = datetime.timedelta(int(period_length))
period_end_date = period_start_date + delta
period_event = create_event(
db,
"period",
period_start_date,
period_end_date,
user_id,
category_id=MENSTRUAL_PERIOD_CATEGORY_ID,
)
return period_event
def add_n_month_predictions(
db: Session,
period_length: str,
period_start_date: datetime,
user_id: int,
) -> List[Event]:
avg_gap = get_avg_period_gap(db, user_id)
avg_gap_delta = datetime.timedelta(avg_gap)
generated_months = []
for _ in range(N_MONTHS_GENERATED + 1):
generated_period = generate_predicted_period_dates(
db,
period_length,
period_start_date,
user_id,
)
generated_months.append(generated_period)
period_start_date += avg_gap_delta
logger.info(f"Generated predictions: {generated_months}")
return generated_months
def add_prediction_events_if_valid(
period_start_date: datetime,
db: Session = Depends(get_db),
user: CurrentUser = Depends(current_user),
) -> None:
current_user_id = user.user_id
user_period_length = is_user_signed_up_to_menstrual_predictor(
db,
current_user_id,
)
remove_existing_period_dates(db, current_user_id)
if user_period_length:
add_n_month_predictions(
db,
user_period_length,
period_start_date,
current_user_id,
)
def get_all_period_days(session: Session, user_id: int) -> List[Event]:
"""Returns all period days filtered by user id."""
try:
period_days = (
session.query(Event)
.filter(Event.owner_id == user_id)
.filter(Event.category_id == MENSTRUAL_PERIOD_CATEGORY_ID)
.order_by(asc(Event.start))
.all()
)
except SQLAlchemyError as err:
logger.exception(err)
return []
return period_days
def is_user_signed_up_to_menstrual_predictor(
session: Session,
user_id: int,
) -> Union[bool, int]:
user_menstrual_period_length = (
session.query(UserMenstrualPeriodLength)
.filter(user_id == user_id)
.first()
)
if user_menstrual_period_length:
return user_menstrual_period_length.period_length
return False