Skip to content

Commit d232928

Browse files
committed
server done
1 parent 6333ae4 commit d232928

File tree

12 files changed

+398
-205
lines changed

12 files changed

+398
-205
lines changed

Diff for: .env

+4-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ DB_NAME=iris
77
DB_USER=user
88
DB_PASSWORD=password
99
DB_PORT=5432
10-
POSTGRES_DB_URL=postgresql://${DB_USER}:${DB_PASSWORD}@localhost/${DB_NAME}
10+
POSTGRES_DB_URL=postgresql://${DB_USER}:${DB_PASSWORD}@localhost:5432/${DB_NAME}
1111

1212
# MetaBase
1313
METABASE_TAG=v0.48.0
@@ -20,10 +20,13 @@ CLOUDBEAVER_PORT=8978
2020
# Redis
2121
REDIS_TAG=7.2.2-bookworm
2222
REDIS_PORT=6379
23+
REDIS_INTERNAL_HOST=redis
24+
REDIS_INTERNAL_PORT=6379
2325

2426
# ELK
2527
ELASTICSEARCH_TAG=8.11.0
2628
ELASTICSEARCH_PORT=9200
29+
ELASTICSEARCH_HOST=http://elasticsearch:9200
2730

2831
# Adminer
2932
ADMINER_TAG=4.7.9-standalone

Diff for: server/gql/resolvers.py

+49-19
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,63 @@
1+
from decouple import config
2+
import json
3+
from ml import get_predictions
14
from model import models
25
from model.database import DBSession
3-
from strawberry import ID
46
from .schemas import Iris, PaginationInput
7+
from services import Cache
58
from typing import List
69

10+
def get_data(pagination: PaginationInput) -> List[Iris]:
11+
db = DBSession()
12+
13+
try:
14+
query = db.query(models.Iris)
15+
if pagination is not None:
16+
query = (
17+
query
18+
.offset(pagination.offset)
19+
.limit(pagination.limit)
20+
)
21+
tasks = query.all()
22+
23+
finally:
24+
db.close()
25+
26+
return tasks
27+
728
class QueryResolver:
829
@staticmethod
930
def get_name() -> str:
1031
return "Iris"
1132

1233
@staticmethod
1334
def get_data(pagination: PaginationInput) -> List[Iris]:
14-
db = DBSession()
15-
16-
try:
17-
query = db.query(models.Iris)
18-
if pagination is not None:
19-
query = (
20-
query
21-
.offset(pagination.offset)
22-
.limit(pagination.limit)
23-
)
24-
tasks = query.all()
25-
26-
finally:
27-
db.close()
28-
29-
return tasks
30-
35+
return get_data(pagination)
36+
3137
@staticmethod
3238
def get_predictions(pagination: PaginationInput) -> List[Iris]:
33-
return []
39+
tasks = get_data(pagination)
40+
41+
cache_host = config("REDIS_INTERNAL_HOST")
42+
cache_port = config("REDIS_INTERNAL_PORT")
43+
cache = Cache(
44+
host=cache_host, port=cache_port
45+
)
46+
47+
task_list = cache.get(cache.k)
48+
49+
if task_list is None:
50+
task_list = []
51+
for task in tasks:
52+
task_dict = task.__dict__
53+
del task_dict["_sa_instance_state"]
54+
task_list.append(task_dict)
55+
56+
cache.set(
57+
cache.k, json.dumps(task_list)
58+
)
59+
else:
60+
task_list = json.loads(task_list)
61+
62+
pred_tasks = get_predictions(task_list, Iris)
63+
return pred_tasks

Diff for: server/gql/schemas.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
@strawberry.type
44
class Iris:
5+
id: int
56
sepal_length: float | None
67
sepal_width: float | None
78
petal_length: float | None
@@ -10,5 +11,5 @@ class Iris:
1011

1112
@strawberry.input
1213
class PaginationInput:
13-
offset: int = 50
14-
limit: int = 100
14+
offset: int = 0
15+
limit: int = 250

Diff for: server/ml/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .logistic_regression import get_predictions

Diff for: server/ml/logistic_regression.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from decouple import config
2+
from sklearn.model_selection import train_test_split
3+
import pandas as pd
4+
from services import Search
5+
from sklearn.linear_model import LogisticRegression
6+
from sklearn.metrics import (
7+
accuracy_score, average_precision_score, f1_score,
8+
precision_score, recall_score
9+
)
10+
from sklearn.preprocessing import LabelEncoder
11+
12+
estimator = LogisticRegression(
13+
penalty = None, solver = "newton-cg", max_iter = 250, multi_class = "ovr"
14+
)
15+
16+
encoder = LabelEncoder()
17+
18+
def get_predictions(tasks, data_model):
19+
df = pd.DataFrame(tasks).drop(columns=["id"])
20+
resp = "species"
21+
X = df.drop(columns=resp)
22+
y = encoder.fit_transform(df[resp])
23+
X_train, X_test, y_train, y_test = train_test_split(
24+
X, y, test_size=0.2
25+
)
26+
27+
estimator.fit(X_train, y_train)
28+
preds = estimator.predict(X_test).reshape(-1, 1)
29+
document = {
30+
"accuracy": accuracy_score(y_test, preds),
31+
"average_precision": average_precision_score(y_test, preds),
32+
"f1": f1_score(y_test, preds, average="weighted"),
33+
"precision": precision_score(y_test, preds, average="weighted"),
34+
"recall": recall_score(y_test, preds, average="weighted"),
35+
}
36+
37+
search_host = config("ELASTICSEARCH_HOST")
38+
search = Search(search_host)
39+
search.index(index=search.indx, document=document)
40+
search.close()
41+
42+
preds = estimator.predict(X).reshape(-1, 1)
43+
pred_tasks = []
44+
for i, task_item in enumerate(tasks):
45+
task_item[resp] = preds[i]
46+
pred_task = data_model(**task_item)
47+
pred_tasks.append(pred_task)
48+
return pred_tasks

Diff for: server/model/models.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
1-
from sqlalchemy import Column, String, Float
1+
from sqlalchemy import Column, String, Float, Integer
22
from .database import Base
33

44
class Iris(Base):
55
__tablename__ = "iris"
6+
id = Column(Integer, primary_key = True)
67
sepal_length = Column(Float)
78
sepal_width = Column(Float)
89
petal_length = Column(Float)
910
petal_width = Column(Float)
1011
species = Column(String)
11-
12-
def __repr__(self):
13-
return
14-
f"""Iris(sepal_length={self.sepal_length},
15-
sepal_width={self.sepal_width}, petal_length={self.petal_length},
16-
petal_width={self.petal_width}, species={self.species})"""

0 commit comments

Comments
 (0)