Skip to content

Commit 8812a25

Browse files
committed
Refactor: Update image upload handling to support threshold-based predictions and improve error handling
1 parent bf55b03 commit 8812a25

File tree

1 file changed

+68
-27
lines changed

1 file changed

+68
-27
lines changed

app/app.py

Lines changed: 68 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
"""
22
app/app.py
3-
----------
4-
Flask entrypoint for PneumoDetect (Week 3 - Day 1)
5-
Adds DICOM + PNG upload support, routes for home and predict,
6-
and renders results safely with cached ResNet-50 model.
3+
-----------
4+
Flask application for PneumoDetect (Week 3, Day 2)
5+
Extends Day 1 by adding probability-based predictions, risk thresholding,
6+
and inference timing. Supports .jpg/.png/.dcm uploads.
77
"""
88

99
from flask import Flask, render_template, request, redirect, url_for
1010
from werkzeug.utils import secure_filename
1111
from pathlib import Path
1212
import torch
13+
import torch.nn.functional as F
1314
from torchvision import models, transforms
1415
from PIL import Image
1516
from pydicom import dcmread
17+
import time
1618
import numpy as np
1719
import cv2
1820

@@ -39,30 +41,48 @@
3941
print(f"Model loaded: {MODEL_PATH.name} on {device}")
4042

4143
# -------------------------------------------------------------------
42-
# Image loader that handles DICOM and PNG/JPG
44+
# Image loader (handles .png/.jpg/.dcm)
4345
# -------------------------------------------------------------------
4446
def load_image(file_path: Path) -> Image.Image:
4547
ext = file_path.suffix.lower()
48+
if ext in [".png", ".jpg", ".jpeg"]:
49+
return Image.open(file_path).convert("RGB")
50+
4651
if ext == ".dcm":
4752
ds = dcmread(str(file_path))
48-
img = ds.pixel_array
49-
img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX)
50-
img = cv2.cvtColor(img.astype("uint8"), cv2.COLOR_GRAY2RGB)
51-
return Image.fromarray(img)
52-
else:
53-
return Image.open(file_path).convert("RGB")
53+
pixel_array = ds.pixel_array.astype(np.float32)
54+
55+
# Apply DICOM rescale parameters when present
56+
slope = float(getattr(ds, "RescaleSlope", 1) or 1)
57+
intercept = float(getattr(ds, "RescaleIntercept", 0) or 0)
58+
pixel_array = pixel_array * slope + intercept
59+
60+
# Normalize to 0–255 and convert to 3-channel RGB
61+
pixel_array -= pixel_array.min()
62+
max_val = pixel_array.max()
63+
if max_val > 0:
64+
pixel_array = pixel_array / max_val
65+
pixel_array = np.clip(pixel_array * 255.0, 0, 255).astype(np.uint8)
66+
if pixel_array.ndim == 2: # grayscale
67+
pixel_array = cv2.cvtColor(pixel_array, cv2.COLOR_GRAY2RGB)
68+
elif pixel_array.shape[-1] == 1: # single-channel with trailing dim
69+
pixel_array = cv2.cvtColor(pixel_array.squeeze(-1), cv2.COLOR_GRAY2RGB)
70+
71+
return Image.fromarray(pixel_array).convert("RGB")
72+
73+
raise ValueError("Unsupported image format. Please upload .jpg, .png, or .dcm.")
5474

5575
# -------------------------------------------------------------------
5676
# Routes
5777
# -------------------------------------------------------------------
5878
@app.route("/", methods=["GET"])
5979
def index():
60-
"""Home page with upload form."""
80+
"""Home page with upload form and threshold slider."""
6181
return render_template("index.html")
6282

6383
@app.route("/predict", methods=["POST"])
6484
def predict():
65-
"""Handle file upload and run inference."""
85+
"""Handle image upload, run inference, and return prediction."""
6686
if "file" not in request.files:
6787
return redirect(url_for("index"))
6888

@@ -74,23 +94,44 @@ def predict():
7494
file_path = UPLOAD_FOLDER / filename
7595
file.save(file_path)
7696

77-
# Load image
78-
img = load_image(file_path)
97+
try:
98+
# Get threshold (default 0.5)
99+
threshold = float(request.form.get("threshold", 0.5))
100+
101+
# Preprocess image
102+
img = load_image(file_path)
103+
transform = transforms.Compose([
104+
transforms.Resize((224, 224)),
105+
transforms.ToTensor(),
106+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
107+
std=[0.229, 0.224, 0.225])
108+
])
109+
tensor = transform(img).unsqueeze(0).to(device)
79110

80-
transform = transforms.Compose([
81-
transforms.Resize((224, 224)),
82-
transforms.ToTensor(),
83-
transforms.Normalize(mean=[0.485, 0.456, 0.406],
84-
std=[0.229, 0.224, 0.225])
85-
])
86-
tensor = transform(img).unsqueeze(0).to(device)
111+
# Run inference with timing
112+
start_time = time.time()
113+
with torch.no_grad():
114+
outputs = model(tensor)
115+
probs = F.softmax(outputs, dim=1)
116+
pneumonia_prob = probs[0, 1].item()
117+
elapsed = time.time() - start_time
87118

88-
with torch.no_grad():
89-
outputs = model(tensor)
90-
_, preds = torch.max(outputs, 1)
91-
label = "Pneumonia Detected" if preds.item() == 1 else "Normal"
119+
# Decision logic
120+
decision = "High Risk" if pneumonia_prob > threshold else "Low Risk"
121+
label = f"{decision} ({pneumonia_prob:.2f} probability)"
122+
print(f"Prediction: {label} | Time: {elapsed:.2f}s")
92123

93-
return render_template("result.html", prediction=label)
124+
return render_template(
125+
"result.html",
126+
prediction=label,
127+
prob=f"{pneumonia_prob:.3f}",
128+
threshold=threshold,
129+
elapsed=f"{elapsed:.2f}s"
130+
)
131+
132+
except Exception as e:
133+
print(f" Error during prediction: {e}")
134+
return redirect(url_for("index"))
94135

95136
# -------------------------------------------------------------------
96137
# Run Flask App

0 commit comments

Comments
 (0)