-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
63 lines (46 loc) · 2.15 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
from model import UNet
import os
def preprocess_image(image_path):
image = Image.open(image_path).convert('L')
original_size = image.size
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
processed_image = transform(image).unsqueeze(0) # Add batch dimension
return processed_image, original_size
def load_model(model_path):
model = UNet()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
return model
def make_prediction(model, input_tensor):
with torch.no_grad():
output = model(input_tensor)
return output
def save_prediction(input_image_path, model_path, output_image_path):
input_image, original_size = preprocess_image(input_image_path)
model = load_model(model_path)
prediction = make_prediction(model, input_image)
prediction = (prediction - prediction.min()) / (prediction.max() - prediction.min())
if isinstance(prediction, torch.Tensor):
prediction = prediction.cpu().detach().numpy()
prediction_for_binary = prediction.squeeze()
threshold = 0.55
binary_prediction = (prediction_for_binary > threshold)
binary_image = Image.fromarray(binary_prediction[0].astype(np.uint8) * 255)
binary_image = binary_image.resize(original_size, Image.NEAREST) # Use nearest neighbor for segmentation
os.makedirs(os.path.dirname(output_image_path), exist_ok=True)
binary_image.save(output_image_path)
print(f"Prediction saved at: {output_image_path}")
input_image_path = r"/teamspace/studios/this_studio/u_net_implementation/data/test/test/0a2637c772c5_03.jpg"
model_path = r"/teamspace/studios/this_studio/u_net_implementation/checkpoints/4.pth"
output_image_path = r"/teamspace/studios/this_studio/u_net_implementation/predictions/output.png"
save_prediction(input_image_path, model_path, output_image_path)