Skip to content

Commit 4fccbc8

Browse files
committed
initial commit
0 parents  commit 4fccbc8

File tree

2 files changed

+204
-0
lines changed

2 files changed

+204
-0
lines changed

loadCOCO.py

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
#!/usr/bin/env python
2+
3+
import numpy as np
4+
import torch
5+
from os import path
6+
from scipy.misc import imread, imresize
7+
from scipy.io import loadmat
8+
9+
10+
class Rescale(object):
11+
"""Rescale the image in a sample to a given size.
12+
13+
Args:
14+
output_size (int or tuple): Desired output size. If tuple, output is
15+
matched to output_size. If int, smaller of image edges is matched
16+
to output_size keeping aspect ratio the same.
17+
"""
18+
19+
def __init__(self, output_size):
20+
assert isinstance(output_size, (int, tuple))
21+
self.output_size = output_size
22+
23+
def __call__(self, image, labels):
24+
assert image.shape[:2] == labels.shape
25+
26+
h, w = image.shape[:2]
27+
if isinstance(self.output_size, int):
28+
if h > w:
29+
new_h, new_w = self.output_size * h / w, self.output_size
30+
else:
31+
new_h, new_w = self.output_size, self.output_size * w / h
32+
else:
33+
new_h, new_w = self.output_size
34+
35+
new_h, new_w = int(new_h), int(new_w)
36+
37+
img = imresize(image, (new_h, new_w))
38+
lbls = imresize(labels, (new_h, new_w), interp="nearest")
39+
40+
return (img, lbls)
41+
42+
43+
class RandomCrop(object):
44+
"""Crop randomly the image in a sample.
45+
46+
Args:
47+
output_size (tuple or int): Desired output size. If int, square crop
48+
is made.
49+
"""
50+
51+
def __init__(self, output_size):
52+
assert isinstance(output_size, (int, tuple))
53+
if isinstance(output_size, int):
54+
self.output_size = (output_size, output_size)
55+
else:
56+
assert len(output_size) == 2
57+
self.output_size = output_size
58+
59+
def __call__(self, image, labels):
60+
assert image.shape[:2] == labels.shape
61+
62+
h, w = image.shape[:2]
63+
new_h, new_w = self.output_size
64+
65+
top = np.random.randint(0, h - new_h)
66+
left = np.random.randint(0, w - new_w)
67+
68+
image = image[top: top + new_h,
69+
left: left + new_w]
70+
71+
labels = labels[top: top + new_h,
72+
left: left + new_w]
73+
74+
return (image, labels)
75+
76+
77+
class ToTensor(object):
78+
"""Convert ndarrays in sample to Tensors."""
79+
80+
def __call__(self, image, labels):
81+
assert image.shape[:2] == labels.shape
82+
83+
# swap color axis because
84+
# numpy image: H x W x C
85+
# torch image: C X H X W
86+
image = image.transpose((2, 0, 1))
87+
return (torch.from_numpy(image),
88+
torch.from_numpy(labels))
89+
90+
91+
def loadCOCO(dataset_folder):
92+
resc = Rescale(650)
93+
crop = RandomCrop(640)
94+
95+
namespath = path.join(dataset_folder, "imageLists/train.txt")
96+
names = np.loadtxt(namespath, dtype=str, delimiter="\n")
97+
98+
images = []
99+
labels = []
100+
for imgName in names:
101+
im = imread(path.join(dataset_folder, "images/"+imgName+".jpg"), mode="RGB")
102+
mat = loadmat(path.join(dataset_folder, "annotations/"+imgName+".mat"))
103+
lbl = mat["S"]
104+
105+
im, lbl = resc(im, lbl)
106+
im, lbl = crop(im, lbl)
107+
images.append(im)
108+
labels.append(lbl)
109+
110+
images = np.array(images, dtype='float32')
111+
images /= 255.0 # Span 0 ~ 1
112+
images = (images*2) - 1 # Span -1 ~ 1
113+
114+
return (images, np.array(labels))
115+
116+
117+
if __name__ == '__main__':
118+
DATASET_FOLDER = "/home/toni/Data/ssegmentation/COCO"
119+
loadCOCO(DATASET_FOLDER)

unet.py

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#!/usr/bin/env python
2+
3+
4+
import torch
5+
from torch.autograd import Variable
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
import torch.optim as optim
9+
10+
11+
class Net(nn.Module):
12+
def __init__(self):
13+
super(Net, self).__init__()
14+
self.conv64 = nn.Conv2d(1, 64, 3, padding=1)
15+
self.conv128 = nn.Conv2d(64, 128, 3, padding=1)
16+
self.conv256 = nn.Conv2d(128, 256, 3, padding=1)
17+
self.conv512 = nn.Conv2d(256, 512, 3, padding=1)
18+
self.conv1024 = nn.Conv2d(512, 1024, 3, padding=1)
19+
self.upconv1024 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
20+
self.dconv1024 = nn.Conv2d(1024, 512, 3, padding=1)
21+
self.upconv512 = nn.ConvTranspose2d(512, 256, 2, stride=2)
22+
self.dconv512 = nn.Conv2d(512, 256, 3, padding=1)
23+
self.upconv256 = nn.ConvTranspose2d(256, 128, 2, stride=2)
24+
self.dconv256 = nn.Conv2d(256, 128, 3, padding=1)
25+
self.upconv128 = nn.ConvTranspose2d(128, 64, 2, stride=2)
26+
self.dconv128 = nn.Conv2d(128, 64, 3, padding=1)
27+
self.conv1 = nn.Conv2d(64, 2, 1)
28+
self.pool = nn.MaxPool2d(2, 2)
29+
30+
def forward(self, x):
31+
x1 = F.relu(self.conv64(x))
32+
x2 = F.relu(self.conv128(self.pool(x1)))
33+
x3 = F.relu(self.conv256(self.pool(x2)))
34+
x4 = F.relu(self.conv512(self.pool(x3)))
35+
x5 = F.relu(self.conv1024(self.pool(x4)))
36+
ux5 = self.upconv1024(x5)
37+
cc5 = torch.cat([ux5, x4], 1)
38+
dx4 = F.relu(self.dconv1024(cc5))
39+
ux4 = self.upconv512(dx4)
40+
cc4 = torch.cat([ux4, x3], 1)
41+
dx3 = F.relu(self.dconv512(cc4))
42+
ux3 = self.upconv256(dx3)
43+
cc3 = torch.cat([ux3, x2], 1)
44+
dx2 = F.relu(self.dconv256(cc3))
45+
ux2 = self.upconv128(dx2)
46+
cc2 = torch.cat([ux2, x1], 1)
47+
dx1 = F.relu(self.dconv128(cc2)) # no relu?
48+
last = self.conv1(dx1)
49+
return F.log_softmax(last) # sigmoid if classes arent mutually exclusv
50+
51+
###########
52+
# Load Dataset #
53+
###########
54+
55+
56+
net = Net()
57+
criterion = nn.NLLLoss2d()
58+
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
59+
60+
for epoch in range(2): # loop over the dataset multiple times
61+
running_loss = 0.0
62+
for i, data in enumerate(trainloader, 0):
63+
# get the inputs
64+
inputs, labels = data
65+
66+
# wrap them in Variable
67+
inputs, labels = Variable(inputs), Variable(labels)
68+
69+
# zero the parameter gradients
70+
optimizer.zero_grad()
71+
72+
# forward + backward + optimize
73+
outputs = net(inputs)
74+
loss = criterion(outputs, labels)
75+
loss.backward()
76+
optimizer.step()
77+
78+
# print statistics
79+
running_loss += loss.data[0]
80+
if i % 2000 == 1999: # print every 2000 mini-batches
81+
print('[%d, %5d] loss: %.3f' %
82+
(epoch + 1, i + 1, running_loss / 2000))
83+
running_loss = 0.0
84+
85+
print('Finished Training')

0 commit comments

Comments
 (0)