Skip to content

Commit 0a20e4c

Browse files
committed
type hints and favicons
1 parent 5f37537 commit 0a20e4c

File tree

14 files changed

+318
-2288
lines changed

14 files changed

+318
-2288
lines changed

pyproject.toml

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,34 +5,55 @@ description = "Add your description here"
55
readme = "README.md"
66
requires-python = ">=3.12"
77
dependencies = [
8-
"fastapi[standard]>=0.115.5",
9-
"geopy>=2.4.1",
10-
"httpx>=0.27.2",
11-
"numpy<2",
8+
"numpy==1.26.4",
129
"pillow==10.4.0",
13-
"sentence-transformers>=3.3.1",
14-
"sqlmodel>=0.0.22",
15-
"streamlit>=1.40.1",
10+
]
11+
12+
[project.optional-dependencies]
13+
frontend = [
14+
"geopy==2.4.1",
15+
"httpx==0.27.2",
16+
"streamlit==1.40.1",
17+
]
18+
19+
backend = [
20+
"fastapi[standard]==0.115.5",
21+
"sqlmodel==0.0.22",
22+
"sentence-transformers==3.3.1",
23+
"torch==2.5.1+cpu",
1624
]
1725

1826
[[tool.uv.index]]
1927
name = "pytorch-cpu"
2028
url = "https://download.pytorch.org/whl/cpu"
2129
explicit = true
2230

31+
[tool.uv.sources]
32+
torch = [{ index = "pytorch-cpu"}]
33+
2334
[project.scripts]
2435
moin-moin = "moin_moin:main"
2536

2637
[build-system]
2738
requires = ["hatchling"]
2839
build-backend = "hatchling.build"
2940

30-
[dependency-groups]
31-
dev = [
32-
"fastbook>=0.0.29",
33-
"ipdb>=0.13.13",
34-
"ipyplot>=1.1.2",
35-
"ipython>=8.29.0",
36-
"jupyterlab>=4.3.1",
37-
"jupyterlab-vim>=4.1.4",
41+
42+
[tool.ruff]
43+
line-length = 120
44+
target-version = "py312"
45+
fix = true
46+
47+
lint.select=["ALL"]
48+
49+
lint.ignore = [
50+
"COM812",
51+
"ISC001",
52+
"D203",
53+
"D211",
54+
"D213",
3855
]
56+
57+
[tool.ruff.lint.isort]
58+
force-single-line = true
59+
required-imports = ["from __future__ import annotations"]

src/moin_moin/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
1-
def main() -> None:
2-
print("Hello from moin-moin!")
1+
"""moin-moin app.
2+
3+
backend: contains fastapi endpoint as well as machine learning and database components.
4+
frontend: streamlit app to render webapp.
5+
"""

src/moin_moin/backend/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Backend: API and DB for app."""

src/moin_moin/backend/_db.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from __future__ import annotations
2+
3+
from sqlmodel import Field
4+
from sqlmodel import SQLModel
5+
6+
7+
class UserUploadData(SQLModel, table=True):
8+
id: int | None = Field(default=None, primary_key=True)
9+
image: bytes
10+
latitude: float
11+
longitude: float
12+
notes: str
13+
tags: str
14+
15+
16+
class Prediction(SQLModel, table=True):
17+
id: int | None = Field(default=None, primary_key=True)
18+
record_id: int | None = Field(default=None, foreign_key="useruploaddata.id")
19+
prediction: str
20+
21+
22+
class PublicRecord(SQLModel):
23+
latitude: float
24+
longitude: float
25+
notes: str
26+
tags: str
27+
prediction: str

src/moin_moin/backend/_ml.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
from typing import Self
5+
6+
from sentence_transformers import SentenceTransformer
7+
8+
if TYPE_CHECKING:
9+
from PIL import Image
10+
11+
12+
class ClipModel:
13+
"""Wrapper of SentenceTransformer, clip-ViT-B-32 model.
14+
15+
Process is to store embeddings of categories we can match images with, and then at
16+
prediction time to return the category closest to the given image.
17+
"""
18+
19+
def __init__(self: Self, text_options: dict[str, str]) -> None:
20+
self.model = SentenceTransformer("clip-ViT-B-32")
21+
self.model.eval()
22+
self.labels = list(text_options.keys())
23+
self.text_embedding = self.model.encode(list(text_options.values()))
24+
25+
def predict(self: Self, image: Image) -> str:
26+
"""Embed the image and return the closest text embedding."""
27+
img_emb = self.model.encode(image)
28+
similarity_array = self.model.similarity(img_emb, self.text_embedding)
29+
return self.labels[similarity_array.argmax().item()]

src/moin_moin/backend/api.py

Lines changed: 66 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,42 @@
1+
"""Module that builds FastAPI application with endpoints to save, predict and gather data."""
2+
3+
from __future__ import annotations
4+
15
from contextlib import asynccontextmanager
6+
from io import BytesIO
7+
from typing import TYPE_CHECKING
28
from typing import Annotated
9+
from typing import Final
310

4-
from fastapi import FastAPI, File, UploadFile, Form, Depends
11+
from fastapi import Depends
12+
from fastapi import FastAPI
13+
from fastapi import File
14+
from fastapi import Form
15+
from fastapi import UploadFile
516
from PIL import Image
6-
from io import BytesIO
7-
from sqlmodel import Session, select
8-
from moin_moin.backend.ml import ClipModel
9-
from moin_moin.backend.db import Record, Prediction, engine, PublicRecord
17+
from sqlmodel import Session
18+
from sqlmodel import SQLModel
19+
from sqlmodel import create_engine
20+
from sqlmodel import select
1021

22+
from moin_moin.backend._db import Prediction
23+
from moin_moin.backend._db import PublicRecord
24+
from moin_moin.backend._db import UserUploadData
25+
from moin_moin.backend._ml import ClipModel
1126

12-
ML_MODEL = {}
27+
if TYPE_CHECKING:
28+
from collections.abc import AsyncGenerator
29+
from collections.abc import Generator
1330

31+
from sqlmodel import Engine
1432

15-
def get_session():
16-
with Session(engine) as session:
17-
yield session
1833

34+
ML_MODEL: dict[str, ClipModel] = {}
1935

20-
SessionDep = Annotated[Session, Depends(get_session)]
36+
DB_NAME: Final[str] = "sqlite:///moin-moin.db"
37+
ENGINE: Final[Engine] = create_engine(DB_NAME)
2138

22-
institutions = {
39+
INSTITUTIONS = {
2340
"Police Department": "The police deals with crime and violence related topics.",
2441
"Fire Department": "The fire department deals with fire and other emergency situations.",
2542
"Hospital": "The hospital deals with health and medical related topics.",
@@ -28,42 +45,48 @@ def get_session():
2845
}
2946

3047

48+
def get_session() -> Generator[Session, None, None]:
49+
"""Yield SQLModel session."""
50+
with Session(ENGINE) as session:
51+
yield session
52+
53+
54+
SessionDep = Annotated[Session, Depends(get_session)]
55+
3156

3257
@asynccontextmanager
33-
async def lifespan(app: FastAPI):
34-
# Load the ML model
35-
ML_MODEL["similarity_model"] = ClipModel(text_options=institutions)
58+
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
59+
"""FastAPI Lifespan: loads and instantiate the model at startup, delete at shutdown."""
60+
SQLModel.metadata.create_all(ENGINE)
61+
ML_MODEL["similarity_model"] = ClipModel(text_options=INSTITUTIONS)
62+
3663
yield
37-
# Clean up the ML models and release the resources
64+
3865
ML_MODEL.clear()
3966

4067

4168
app = FastAPI(lifespan=lifespan)
4269

4370

44-
def _predict(img_array: Image):
45-
"""This function decoupled from the API exists for debugging purposes."""
46-
return ML_MODEL["similarity_model"].predict(img_array)
47-
48-
4971
@app.get("/")
50-
def root():
72+
def root() -> dict[str, str]:
5173
"""Root endpoint."""
5274
return {"message": "Health"}
5375

5476

5577
@app.post("/save")
56-
async def save(
78+
async def save( # noqa: PLR0913
5779
session: SessionDep,
58-
latitude: float = Form(...),
59-
longitude: float = Form(...),
60-
notes: str = Form(...),
61-
tags: str = Form(...),
62-
image_bytes: UploadFile = File(...),
63-
):
80+
latitude: Annotated[float, Form()],
81+
longitude: Annotated[float, Form()],
82+
notes: Annotated[str, Form()],
83+
tags: Annotated[str, Form()],
84+
image_bytes: Annotated[UploadFile, File()],
85+
) -> dict[str, int | None]:
86+
"""Save the input data into database records into user upload data table."""
6487
image_bytes = await image_bytes.read()
65-
record = Record(
66-
image=image_bytes,
88+
record = UserUploadData(
89+
image=image_bytes, # type: ignore[assignment]
6790
latitude=latitude,
6891
longitude=longitude,
6992
notes=notes,
@@ -79,12 +102,15 @@ async def save(
79102

80103
@app.post("/predict")
81104
async def predict(
82-
session: SessionDep, record_id: int = Form(), file: UploadFile = File(...)
83-
):
105+
session: SessionDep,
106+
record_id: Annotated[int, Form()],
107+
file: Annotated[UploadFile, File(...)],
108+
) -> dict[str, str]:
109+
"""Predict image category via ML model."""
84110
file_bytes = await file.read()
85111
buffer = BytesIO(file_bytes)
86112
image = Image.open(buffer)
87-
prediction = _predict(image)
113+
prediction = ML_MODEL["similarity_model"].predict(image)
88114

89115
pred_record = Prediction(record_id=record_id, prediction=prediction)
90116

@@ -95,13 +121,14 @@ async def predict(
95121

96122

97123
@app.get("/load-records", response_model=list[PublicRecord])
98-
async def load_records(session: SessionDep):
124+
async def load_records(session: SessionDep) -> list[PublicRecord]:
125+
"""Load all the records with their predictions, ignores associated image."""
99126
statement = select(
100-
Record.latitude,
101-
Record.longitude,
102-
Record.notes,
103-
Record.tags,
127+
UserUploadData.latitude,
128+
UserUploadData.longitude,
129+
UserUploadData.notes,
130+
UserUploadData.tags,
104131
Prediction.prediction,
105132
).join(Prediction)
106133
records = session.exec(statement).all()
107-
return [PublicRecord(**row._mapping) for row in records]
134+
return [PublicRecord(**row._mapping) for row in records] # noqa: SLF001

src/moin_moin/backend/db.py

Lines changed: 0 additions & 32 deletions
This file was deleted.

src/moin_moin/backend/ml.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

src/moin_moin/frontend/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Frontend: Streamlit app."""

src/moin_moin/frontend/app_conf.py renamed to src/moin_moin/frontend/_conf.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
HOST = "http://127.0.0.1"
2-
PORT = 8000
1+
from __future__ import annotations
32

4-
INSTITUTION_MAPPING = {
3+
from typing import Final
4+
5+
HOST: Final[str] = "http://127.0.0.1"
6+
PORT: Final[int] = 8081
7+
INSTITUTION_MAPPING: Final[dict[str, str]] = {
58
"Police Department": "#48b5a5",
69
"Fire Department": "#7b64ab",
710
"Hospital": "#20be64",

0 commit comments

Comments
 (0)