11"""
22app/app.py
3- ----------
4- Flask entrypoint for PneumoDetect (Week 3 - Day 1 )
5- Adds DICOM + PNG upload support, routes for home and predict ,
6- and renders results safely with cached ResNet-50 model .
3+ -----------
4+ Flask application for PneumoDetect (Week 3, Day 2 )
5+ Extends Day 1 by adding probability-based predictions, risk thresholding ,
6+ and inference timing. Supports .jpg/.png/.dcm uploads .
77"""
88
99from flask import Flask , render_template , request , redirect , url_for
1010from werkzeug .utils import secure_filename
1111from pathlib import Path
1212import torch
13+ import torch .nn .functional as F
1314from torchvision import models , transforms
1415from PIL import Image
1516from pydicom import dcmread
17+ import time
1618import numpy as np
1719import cv2
1820
3941print (f"Model loaded: { MODEL_PATH .name } on { device } " )
4042
4143# -------------------------------------------------------------------
42- # Image loader that handles DICOM and PNG/JPG
44+ # Image loader ( handles .png/.jpg/.dcm)
4345# -------------------------------------------------------------------
4446def load_image (file_path : Path ) -> Image .Image :
4547 ext = file_path .suffix .lower ()
48+ if ext in [".png" , ".jpg" , ".jpeg" ]:
49+ return Image .open (file_path ).convert ("RGB" )
50+
4651 if ext == ".dcm" :
4752 ds = dcmread (str (file_path ))
48- img = ds .pixel_array
49- img = cv2 .normalize (img , None , 0 , 255 , cv2 .NORM_MINMAX )
50- img = cv2 .cvtColor (img .astype ("uint8" ), cv2 .COLOR_GRAY2RGB )
51- return Image .fromarray (img )
52- else :
53- return Image .open (file_path ).convert ("RGB" )
53+ pixel_array = ds .pixel_array .astype (np .float32 )
54+
55+ # Apply DICOM rescale parameters when present
56+ slope = float (getattr (ds , "RescaleSlope" , 1 ) or 1 )
57+ intercept = float (getattr (ds , "RescaleIntercept" , 0 ) or 0 )
58+ pixel_array = pixel_array * slope + intercept
59+
60+ # Normalize to 0–255 and convert to 3-channel RGB
61+ pixel_array -= pixel_array .min ()
62+ max_val = pixel_array .max ()
63+ if max_val > 0 :
64+ pixel_array = pixel_array / max_val
65+ pixel_array = np .clip (pixel_array * 255.0 , 0 , 255 ).astype (np .uint8 )
66+ if pixel_array .ndim == 2 : # grayscale
67+ pixel_array = cv2 .cvtColor (pixel_array , cv2 .COLOR_GRAY2RGB )
68+ elif pixel_array .shape [- 1 ] == 1 : # single-channel with trailing dim
69+ pixel_array = cv2 .cvtColor (pixel_array .squeeze (- 1 ), cv2 .COLOR_GRAY2RGB )
70+
71+ return Image .fromarray (pixel_array ).convert ("RGB" )
72+
73+ raise ValueError ("Unsupported image format. Please upload .jpg, .png, or .dcm." )
5474
5575# -------------------------------------------------------------------
5676# Routes
5777# -------------------------------------------------------------------
5878@app .route ("/" , methods = ["GET" ])
5979def index ():
60- """Home page with upload form."""
80+ """Home page with upload form and threshold slider ."""
6181 return render_template ("index.html" )
6282
6383@app .route ("/predict" , methods = ["POST" ])
6484def predict ():
65- """Handle file upload and run inference."""
85+ """Handle image upload, run inference, and return prediction ."""
6686 if "file" not in request .files :
6787 return redirect (url_for ("index" ))
6888
@@ -74,23 +94,44 @@ def predict():
7494 file_path = UPLOAD_FOLDER / filename
7595 file .save (file_path )
7696
77- # Load image
78- img = load_image (file_path )
97+ try :
98+ # Get threshold (default 0.5)
99+ threshold = float (request .form .get ("threshold" , 0.5 ))
100+
101+ # Preprocess image
102+ img = load_image (file_path )
103+ transform = transforms .Compose ([
104+ transforms .Resize ((224 , 224 )),
105+ transforms .ToTensor (),
106+ transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ],
107+ std = [0.229 , 0.224 , 0.225 ])
108+ ])
109+ tensor = transform (img ).unsqueeze (0 ).to (device )
79110
80- transform = transforms . Compose ([
81- transforms . Resize (( 224 , 224 )),
82- transforms . ToTensor (),
83- transforms . Normalize ( mean = [ 0.485 , 0.456 , 0.406 ],
84- std = [ 0.229 , 0.224 , 0.225 ] )
85- ] )
86- tensor = transform ( img ). unsqueeze ( 0 ). to ( device )
111+ # Run inference with timing
112+ start_time = time . time ()
113+ with torch . no_grad ():
114+ outputs = model ( tensor )
115+ probs = F . softmax ( outputs , dim = 1 )
116+ pneumonia_prob = probs [ 0 , 1 ]. item ( )
117+ elapsed = time . time () - start_time
87118
88- with torch . no_grad ():
89- outputs = model ( tensor )
90- _ , preds = torch . max ( outputs , 1 )
91- label = "Pneumonia Detected" if preds . item () == 1 else "Normal"
119+ # Decision logic
120+ decision = "High Risk" if pneumonia_prob > threshold else "Low Risk"
121+ label = f" { decision } ( { pneumonia_prob :.2f } probability)"
122+ print ( f"Prediction: { label } | Time: { elapsed :.2f } s" )
92123
93- return render_template ("result.html" , prediction = label )
124+ return render_template (
125+ "result.html" ,
126+ prediction = label ,
127+ prob = f"{ pneumonia_prob :.3f} " ,
128+ threshold = threshold ,
129+ elapsed = f"{ elapsed :.2f} s"
130+ )
131+
132+ except Exception as e :
133+ print (f" Error during prediction: { e } " )
134+ return redirect (url_for ("index" ))
94135
95136# -------------------------------------------------------------------
96137# Run Flask App
0 commit comments