10
10
from scipy .misc import imshow
11
11
from tqdm import tqdm
12
12
13
- from loadCOCO import loadCOCO
13
+ from loadCOCO import loadCOCO , Rescale , RandomCrop
14
14
15
15
16
16
class Net (nn .Module ):
@@ -66,7 +66,7 @@ def save_checkpoint(model, epoch, iteration, loss, vloss):
66
66
return
67
67
68
68
69
- def train ():
69
+ def train (resume_from = None ):
70
70
###########
71
71
# Load Dataset #
72
72
###########
@@ -102,7 +102,10 @@ def train():
102
102
103
103
criterion = nn .NLLLoss2d ()
104
104
# optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
105
- optimizer = optim .Adam (net .parameters (), lr = 0.001 )
105
+ optimizer = optim .Adam (net .parameters (), lr = 0.005 )
106
+
107
+ if resume_from is not None :
108
+ checkpoint = torch .load (resume_from )
106
109
107
110
checkpoint_rate = 500
108
111
for epoch in range (12 ): # loop over the dataset multiple times
@@ -135,6 +138,7 @@ def train():
135
138
136
139
# Validation test
137
140
running_valid_loss = 0.0
141
+ running_valid_acc = 0.0
138
142
for j , data in enumerate (validloader , 0 ):
139
143
inputs , labels = data
140
144
@@ -155,8 +159,16 @@ def train():
155
159
optimizer .step ()
156
160
# print statistics
157
161
running_valid_loss += loss .data [0 ]
162
+ running_valid_acc += \
163
+ ((outputs .max (1 )[1 ] == labels .long ()).sum ()).float () \
164
+ / (labels .size ()[1 ] * labels .size ()[2 ])
165
+
158
166
print ('[Validation loss]: %.3f' %
159
167
(running_valid_loss / len (imsValid )))
168
+
169
+ print ('[Validation accuracy]: %.3f' %
170
+ ((running_valid_acc / len (imsValid )) * 100.0 ).data [0 ])
171
+
160
172
save_checkpoint (
161
173
net .state_dict (),
162
174
epoch + 1 ,
@@ -169,6 +181,9 @@ def train():
169
181
170
182
171
183
def test_image (paramsPath , img , label = None , showim = False ):
184
+ resc = Rescale (500 )
185
+ crop = RandomCrop (480 )
186
+
172
187
im , lbl = resc (img , label )
173
188
im , lbl = crop (im , lbl )
174
189
im = np .transpose (im , (2 , 0 , 1 ))
@@ -184,14 +199,19 @@ def test_image(paramsPath, img, label=None, showim=False):
184
199
if torch .cuda .is_available ():
185
200
net .cuda ()
186
201
187
- par = torch .load ('model_paramms.dat' , map_location = lambda storage , loc : storage )
188
- net .load_state_dict (par )
202
+ par = torch .load (paramsPath , map_location = lambda storage , loc : storage )
203
+ net .load_state_dict (par [ "model" ] )
189
204
190
- out = net (imV )
191
- ouim = out .data
205
+ if torch .cuda .is_available ():
206
+ out = net (imV .cuda ())
207
+ ouim = out .data .cpu ()
208
+ else :
209
+ out = net (imV )
210
+ ouim = out .data
192
211
ouim = ouim .numpy ()
193
212
194
213
if showim :
195
214
imshow (ouim [0 ])
196
215
197
- return ouim
216
+ return ouim , lbl
217
+
0 commit comments