Skip to content

Commit 1fe1504

Browse files
committed
separated dataset into train/valid groups
1 parent eb126b5 commit 1fe1504

File tree

1 file changed

+116
-35
lines changed

1 file changed

+116
-35
lines changed

unet.py

+116-35
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
#!/usr/bin/env python
22

3-
3+
import numpy as np
44
import torch
55
from torch.autograd import Variable
66
import torch.nn as nn
77
import torch.nn.functional as F
88
import torch.optim as optim
99
from torch.utils.data import TensorDataset
10+
from scipy.misc import imshow
11+
from tqdm import tqdm
1012

1113
from loadCOCO import loadCOCO
1214

@@ -27,7 +29,7 @@ def __init__(self):
2729
self.dconv256 = nn.Conv2d(256, 128, 3, padding=1)
2830
self.upconv128 = nn.ConvTranspose2d(128, 64, 2, stride=2)
2931
self.dconv128 = nn.Conv2d(128, 64, 3, padding=1)
30-
self.conv1 = nn.Conv2d(64, 182, 1)
32+
self.conv1 = nn.Conv2d(64, 183, 1)
3133
self.pool = nn.MaxPool2d(2, 2)
3234

3335
def forward(self, x):
@@ -51,48 +53,127 @@ def forward(self, x):
5153
last = self.conv1(dx1)
5254
return F.log_softmax(last) # sigmoid if classes arent mutually exclusv
5355

54-
###########
55-
# Load Dataset #
56-
###########
57-
ims, labs = loadCOCO("/home/toni/Data/COCOstuff/")
58-
imsT = torch.Tensor(ims)
59-
labsT = torch.ByteTensor(labs)
60-
trainset = TensorDataset(imsT, labsT)
61-
trainloader = torch.utils.data.DataLoader(
56+
57+
def train():
58+
###########
59+
# Load Dataset #
60+
###########
61+
ims, labs = loadCOCO("/home/toni/Data/COCOstuff/")
62+
imsTrain = ims[0:int(0.95*len(ims))]
63+
labsTrain = labs[0:int(0.95*len(labs))]
64+
imsValid = ims[int(0.95*len(ims)):]
65+
labsValid = labs[int(0.95*len(labs)):]
66+
67+
imsTrainT = torch.Tensor(imsTrain)
68+
labsTrainT = torch.ByteTensor(labsTrain)
69+
imsValidT = torch.Tensor(imsValid)
70+
labsValidT = torch.ByteTensor(labsValid)
71+
trainset = TensorDataset(imsTrainT, labsTrainT)
72+
validset = TensorDataset(imsValidT, labsValidT)
73+
trainloader = torch.utils.data.DataLoader(
6274
trainset,
63-
batch_size=4,
75+
batch_size=1,
76+
shuffle=True,
77+
num_workers=2
78+
)
79+
80+
validloader = torch.utils.data.DataLoader(
81+
validset,
82+
batch_size=1,
6483
shuffle=True,
6584
num_workers=2
6685
)
6786

87+
net = Net()
88+
if torch.cuda.is_available():
89+
net.cuda()
6890

69-
net = Net()
70-
criterion = nn.NLLLoss2d()
71-
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
91+
criterion = nn.NLLLoss2d()
92+
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
7293

73-
for epoch in range(2): # loop over the dataset multiple times
74-
running_loss = 0.0
75-
for i, data in enumerate(trainloader, 0):
76-
# get the inputs
77-
inputs, labels = data
94+
for epoch in range(2): # loop over the dataset multiple times
95+
running_loss = 0.0
96+
steps = len(imsTrain) # Batch size = 1
97+
for i, data in enumerate(tqdm(trainloader, total=steps), start=0):
98+
# get the inputs
99+
inputs, labels = data
78100

79-
# wrap them in Variable
80-
inputs, labels = Variable(inputs), Variable(labels)
101+
# wrap them in Variable
102+
if torch.cuda.is_available():
103+
inputs, labels = Variable(inputs.cuda()),\
104+
Variable(labels.cuda())
105+
else:
106+
inputs, labels = Variable(inputs), Variable(labels)
81107

82-
# zero the parameter gradients
83-
optimizer.zero_grad()
108+
# zero the parameter gradients
109+
optimizer.zero_grad()
84110

85-
# forward + backward + optimize
86-
outputs = net(inputs)
87-
loss = criterion(outputs, labels)
88-
loss.backward()
89-
optimizer.step()
111+
# forward + backward + optimize
112+
outputs = net(inputs)
113+
loss = criterion(outputs, labels.long())
114+
loss.backward()
115+
optimizer.step()
90116

91-
# print statistics
92-
running_loss += loss.data[0]
93-
if i % 2000 == 1999: # print every 2000 mini-batches
94-
print('[%d, %5d] loss: %.3f' %
95-
(epoch + 1, i + 1, running_loss / 2000))
96-
running_loss = 0.0
117+
# print statistics
118+
running_loss += loss.data[0]
119+
checkpoint_rate = 500
120+
if i % checkpoint_rate == checkpoint_rate-1: # print every N mini-batches
121+
print('[%d, %5d] loss: %.3f' %
122+
(epoch + 1, i + 1, running_loss / checkpoint_rate))
123+
running_loss = 0.0
124+
125+
# Validation test
126+
running_valid_loss = 0.0
127+
for i, data in enumerate(tqdm(validloader, total=len(imsValid)), 0):
128+
inputs, labels = data
129+
130+
# wrap them in Variable
131+
if torch.cuda.is_available():
132+
inputs, labels = Variable(inputs.cuda()),\
133+
Variable(labels.cuda())
134+
else:
135+
inputs, labels = Variable(inputs), Variable(labels)
136+
137+
# zero the parameter gradients
138+
optimizer.zero_grad()
139+
140+
# forward + backward + optimize
141+
outputs = net(inputs)
142+
loss = criterion(outputs, labels.long())
143+
loss.backward()
144+
optimizer.step()
145+
# print statistics
146+
running_valid_loss += loss.data[0]
147+
print('[Validation loss: %.3f' %
148+
(running_valid_loss / len(imsValid)))
149+
150+
print('Finished Training')
151+
152+
153+
def test_image(paramsPath, img, label=None, showim=False):
154+
im, lbl = resc(img, label)
155+
im, lbl = crop(im, lbl)
156+
im = np.transpose(im, (2, 0, 1))
157+
im = np.array(im, dtype='float32')
158+
im /= 255.0
159+
im = (im*2)-1
160+
im = np.expand_dims(im, axis=0)
161+
imT = torch.Tensor(im)
162+
labT = torch.ByteTensor(lbl)
163+
imV, labV = Variable(imT), Variable(labT)
164+
165+
net = Net()
166+
if torch.cuda.is_available():
167+
net.cuda()
168+
169+
par = torch.load('model_paramms.dat', map_location=lambda storage, loc: storage)
170+
net.load_state_dict(par)
171+
172+
out = net(imV)
173+
ouim = out.data
174+
ouim = ouim.numpy()
175+
176+
if showim:
177+
imshow(ouim[0])
97178

98-
print('Finished Training')
179+
return ouim

0 commit comments

Comments
 (0)