Skip to content

Commit 4c467bf

Browse files
committed
added accuracy measurement
1 parent 12fcb2e commit 4c467bf

File tree

1 file changed

+28
-8
lines changed

1 file changed

+28
-8
lines changed

unet.py

+28-8
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from scipy.misc import imshow
1111
from tqdm import tqdm
1212

13-
from loadCOCO import loadCOCO
13+
from loadCOCO import loadCOCO, Rescale, RandomCrop
1414

1515

1616
class Net(nn.Module):
@@ -66,7 +66,7 @@ def save_checkpoint(model, epoch, iteration, loss, vloss):
6666
return
6767

6868

69-
def train():
69+
def train(resume_from=None):
7070
###########
7171
# Load Dataset #
7272
###########
@@ -102,7 +102,10 @@ def train():
102102

103103
criterion = nn.NLLLoss2d()
104104
# optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
105-
optimizer = optim.Adam(net.parameters(), lr=0.001)
105+
optimizer = optim.Adam(net.parameters(), lr=0.005)
106+
107+
if resume_from is not None:
108+
checkpoint = torch.load(resume_from)
106109

107110
checkpoint_rate = 500
108111
for epoch in range(12): # loop over the dataset multiple times
@@ -135,6 +138,7 @@ def train():
135138

136139
# Validation test
137140
running_valid_loss = 0.0
141+
running_valid_acc = 0.0
138142
for j, data in enumerate(validloader, 0):
139143
inputs, labels = data
140144

@@ -155,8 +159,16 @@ def train():
155159
optimizer.step()
156160
# print statistics
157161
running_valid_loss += loss.data[0]
162+
running_valid_acc += \
163+
((outputs.max(1)[1] == labels.long()).sum()).float() \
164+
/ (labels.size()[1] * labels.size()[2])
165+
158166
print('[Validation loss]: %.3f' %
159167
(running_valid_loss / len(imsValid)))
168+
169+
print('[Validation accuracy]: %.3f' %
170+
((running_valid_acc / len(imsValid)) * 100.0).data[0])
171+
160172
save_checkpoint(
161173
net.state_dict(),
162174
epoch+1,
@@ -169,6 +181,9 @@ def train():
169181

170182

171183
def test_image(paramsPath, img, label=None, showim=False):
184+
resc = Rescale(500)
185+
crop = RandomCrop(480)
186+
172187
im, lbl = resc(img, label)
173188
im, lbl = crop(im, lbl)
174189
im = np.transpose(im, (2, 0, 1))
@@ -184,14 +199,19 @@ def test_image(paramsPath, img, label=None, showim=False):
184199
if torch.cuda.is_available():
185200
net.cuda()
186201

187-
par = torch.load('model_paramms.dat', map_location=lambda storage, loc: storage)
188-
net.load_state_dict(par)
202+
par = torch.load(paramsPath, map_location=lambda storage, loc: storage)
203+
net.load_state_dict(par["model"])
189204

190-
out = net(imV)
191-
ouim = out.data
205+
if torch.cuda.is_available():
206+
out = net(imV.cuda())
207+
ouim = out.data.cpu()
208+
else:
209+
out = net(imV)
210+
ouim = out.data
192211
ouim = ouim.numpy()
193212

194213
if showim:
195214
imshow(ouim[0])
196215

197-
return ouim
216+
return ouim, lbl
217+

0 commit comments

Comments
 (0)