@@ -54,6 +54,18 @@ def forward(self, x):
54
54
return F .log_softmax (last ) # sigmoid if classes arent mutually exclusv
55
55
56
56
57
+ def save_checkpoint (model , epoch , iteration , loss , vloss ):
58
+ checkpoint = {}
59
+ checkpoint ["model" ] = model
60
+ checkpoint ["epoch" ] = epoch
61
+ checkpoint ["iteration" ] = iteration
62
+ checkpoint ["loss" ] = loss
63
+ checkpoint ["vloss" ] = vloss
64
+ fname = "checkpoint_" + str (epoch ) + "_" + str (iteration ) + ".dat"
65
+ torch .save (checkpoint , fname )
66
+ return
67
+
68
+
57
69
def train ():
58
70
###########
59
71
# Load Dataset #
@@ -89,12 +101,13 @@ def train():
89
101
net .cuda ()
90
102
91
103
criterion = nn .NLLLoss2d ()
92
- optimizer = optim .SGD (net .parameters (), lr = 0.001 , momentum = 0.9 )
104
+ # optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
105
+ optimizer = optim .Adam (net .parameters (), lr = 0.001 )
93
106
94
- for epoch in range (2 ): # loop over the dataset multiple times
107
+ checkpoint_rate = 500
108
+ for epoch in range (12 ): # loop over the dataset multiple times
95
109
running_loss = 0.0
96
- steps = len (imsTrain ) # Batch size = 1
97
- for i , data in enumerate (tqdm (trainloader , total = steps ), start = 0 ):
110
+ for i , data in enumerate (trainloader , start = 0 ):
98
111
# get the inputs
99
112
inputs , labels = data
100
113
@@ -116,21 +129,19 @@ def train():
116
129
117
130
# print statistics
118
131
running_loss += loss .data [0 ]
119
- checkpoint_rate = 500
120
132
if i % checkpoint_rate == checkpoint_rate - 1 : # print every N mini-batches
121
133
print ('[%d, %5d] loss: %.3f' %
122
134
(epoch + 1 , i + 1 , running_loss / checkpoint_rate ))
123
- running_loss = 0.0
124
135
125
136
# Validation test
126
137
running_valid_loss = 0.0
127
- for i , data in enumerate (tqdm ( validloader , total = len ( imsValid )) , 0 ):
138
+ for j , data in enumerate (validloader , 0 ):
128
139
inputs , labels = data
129
140
130
141
# wrap them in Variable
131
142
if torch .cuda .is_available ():
132
143
inputs , labels = Variable (inputs .cuda ()),\
133
- Variable (labels .cuda ())
144
+ Variable (labels .cuda ())
134
145
else :
135
146
inputs , labels = Variable (inputs ), Variable (labels )
136
147
@@ -144,8 +155,15 @@ def train():
144
155
optimizer .step ()
145
156
# print statistics
146
157
running_valid_loss += loss .data [0 ]
147
- print ('[Validation loss: %.3f' %
148
- (running_valid_loss / len (imsValid )))
158
+ print ('[Validation loss]: %.3f' %
159
+ (running_valid_loss / len (imsValid )))
160
+ save_checkpoint (
161
+ net .state_dict (),
162
+ epoch + 1 ,
163
+ i + 1 ,
164
+ running_loss / checkpoint_rate ,
165
+ running_valid_loss / len (imsValid ))
166
+ running_loss = 0.0
149
167
150
168
print ('Finished Training' )
151
169
0 commit comments