Skip to content

Commit 12fcb2e

Browse files
committed
Added saving checkpoint and now using Adam optimizer
1 parent 1fe1504 commit 12fcb2e

File tree

1 file changed

+28
-10
lines changed

1 file changed

+28
-10
lines changed

unet.py

+28-10
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,18 @@ def forward(self, x):
5454
return F.log_softmax(last) # sigmoid if classes arent mutually exclusv
5555

5656

57+
def save_checkpoint(model, epoch, iteration, loss, vloss):
58+
checkpoint = {}
59+
checkpoint["model"] = model
60+
checkpoint["epoch"] = epoch
61+
checkpoint["iteration"] = iteration
62+
checkpoint["loss"] = loss
63+
checkpoint["vloss"] = vloss
64+
fname = "checkpoint_" + str(epoch) + "_" + str(iteration) + ".dat"
65+
torch.save(checkpoint, fname)
66+
return
67+
68+
5769
def train():
5870
###########
5971
# Load Dataset #
@@ -89,12 +101,13 @@ def train():
89101
net.cuda()
90102

91103
criterion = nn.NLLLoss2d()
92-
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
104+
# optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
105+
optimizer = optim.Adam(net.parameters(), lr=0.001)
93106

94-
for epoch in range(2): # loop over the dataset multiple times
107+
checkpoint_rate = 500
108+
for epoch in range(12): # loop over the dataset multiple times
95109
running_loss = 0.0
96-
steps = len(imsTrain) # Batch size = 1
97-
for i, data in enumerate(tqdm(trainloader, total=steps), start=0):
110+
for i, data in enumerate(trainloader, start=0):
98111
# get the inputs
99112
inputs, labels = data
100113

@@ -116,21 +129,19 @@ def train():
116129

117130
# print statistics
118131
running_loss += loss.data[0]
119-
checkpoint_rate = 500
120132
if i % checkpoint_rate == checkpoint_rate-1: # print every N mini-batches
121133
print('[%d, %5d] loss: %.3f' %
122134
(epoch + 1, i + 1, running_loss / checkpoint_rate))
123-
running_loss = 0.0
124135

125136
# Validation test
126137
running_valid_loss = 0.0
127-
for i, data in enumerate(tqdm(validloader, total=len(imsValid)), 0):
138+
for j, data in enumerate(validloader, 0):
128139
inputs, labels = data
129140

130141
# wrap them in Variable
131142
if torch.cuda.is_available():
132143
inputs, labels = Variable(inputs.cuda()),\
133-
Variable(labels.cuda())
144+
Variable(labels.cuda())
134145
else:
135146
inputs, labels = Variable(inputs), Variable(labels)
136147

@@ -144,8 +155,15 @@ def train():
144155
optimizer.step()
145156
# print statistics
146157
running_valid_loss += loss.data[0]
147-
print('[Validation loss: %.3f' %
148-
(running_valid_loss / len(imsValid)))
158+
print('[Validation loss]: %.3f' %
159+
(running_valid_loss / len(imsValid)))
160+
save_checkpoint(
161+
net.state_dict(),
162+
epoch+1,
163+
i + 1,
164+
running_loss / checkpoint_rate,
165+
running_valid_loss / len(imsValid))
166+
running_loss = 0.0
149167

150168
print('Finished Training')
151169

0 commit comments

Comments
 (0)