Skip to content

Commit ec22957

Browse files
committed
Fix visor_oldnet and make it work
1 parent 2adc0ff commit ec22957

File tree

1 file changed

+51
-48
lines changed

1 file changed

+51
-48
lines changed

Diff for: visor/visor_oldnets.py

+51-48
Original file line numberDiff line numberDiff line change
@@ -9,38 +9,46 @@
99
import argparse
1010
import torch
1111
import torchvision.transforms as transforms
12+
import torch.legacy.nn as nn
13+
#load_lua does not recognize SpatialConvolutionMM
14+
nn.SpatialConvolutionMM = nn.SpatialConvolution
1215
from torch.utils.serialization import load_lua
13-
from torch.legacy import nn # import torch.nn as nn
14-
1516

1617
def define_and_parse_args():
1718
# argument Checking
1819
parser = argparse.ArgumentParser(description="Visor Demo")
19-
# parser.add_argument('network', help='CNN model file')
20-
parser.add_argument('categories', help='text file with categories')
21-
parser.add_argument('-i', '--input', type=int, default='0', help='camera device index or file name, default 0')
20+
parser.add_argument('model', help='model directory')
21+
parser.add_argument('-i', '--input', default='0', help='camera device index or file name, default 0')
2222
parser.add_argument('-s', '--size', type=int, default=224, help='network input size')
23-
# parser.add_argument('-S', '--stat', help='stat.txt file')
2423
return parser.parse_args()
2524

2625

2726
def cat_file():
2827
# load classes file
2928
categories = []
30-
if hasattr(args, 'categories') and args.categories:
31-
try:
32-
f = open(args.categories, 'r')
33-
for line in f:
34-
cat = line.split(',')[0].split('\n')[0]
35-
if cat != 'classes':
36-
categories.append(cat)
37-
f.close()
38-
print('Number of categories:', len(categories))
39-
except:
40-
print('Error opening file ' + args.categories)
41-
quit()
29+
try:
30+
f = open(args.model + '/categories.txt', 'r')
31+
for line in f:
32+
cat = line.split(',')[0].split('\n')[0]
33+
if cat != 'classes':
34+
categories.append(cat)
35+
f.close()
36+
print('Number of categories:', len(categories))
37+
except:
38+
print('Error opening file ' + args.model + '/categories.txt')
39+
quit()
4240
return categories
4341

42+
def patch(m):
43+
s = str(type(m))
44+
s = s[str.rfind(s, '.')+1:-2]
45+
if s == 'Padding' and hasattr(m, 'nInputDim') and m.nInputDim == 3:
46+
m.dim = m.dim + 1
47+
if s == 'View' and len(m.size) == 1:
48+
m.size = torch.Size([1,m.size[0]])
49+
if hasattr(m, 'modules'):
50+
for m in m.modules:
51+
patch(m)
4452

4553
print("Visor demo e-Lab - older Torch7 networks")
4654
xres = 640
@@ -51,35 +59,25 @@ def cat_file():
5159

5260

5361
# setup camera input:
54-
cam = cv2.VideoCapture(args.input)
55-
cam.set(3, xres)
56-
cam.set(4, yres)
62+
if args.input[0] >= '0' and args.input[0] <= '9':
63+
cam = cv2.VideoCapture(int(args.input))
64+
cam.set(3, xres)
65+
cam.set(4, yres)
66+
usecam = True
67+
else:
68+
image = cv2.imread(args.input)
69+
xres = image.shape[1]
70+
yres = image.shape[0]
71+
usecam = False
5772

5873
# load old-pre-trained Torch7 CNN model:
5974

6075
# https://www.dropbox.com/sh/l0rurgbx4k6j2a3/AAA223WOrRRjpe9bzO8ecpEpa?dl=0
61-
model = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/elab-alexowt-46/model.net')
62-
stat = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/elab-alexowt-46/stat.t7')
63-
model.modules[13] = nn.View(1,9216)
64-
65-
# https://www.dropbox.com/sh/xcm8xul3udwo72o/AAC8RChVSOmgN61nQ0cyfdava?dl=0
66-
# model = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/elab-alextiny-46/model.net')
67-
# stat = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/elab-alextiny-46/stat.t7')
68-
# model.modules[13] = nn.View(1,64)
69-
70-
# https://www.dropbox.com/sh/anklohs9g49z1o4/AAChA9rl0FEGixT75eT38Dqra?dl=0
71-
# model = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/elab-enet-demo-46/model.net')
72-
# stat = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/elab-enet-demo-46/stat.t7')
73-
# model.modules[41] = nn.View(1,1024)
76+
model = load_lua(args.model + '/model.net')
77+
stat = load_lua(args.model + '/stat.t7')
7478

75-
# https://www.dropbox.com/sh/s0hwugnmhwkk9ow/AAD_abZ2LOav9GeMETt5VGvGa?dl=0
76-
# model = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/enet128-demo-46/model.net')
77-
# stat = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/enet128-demo-46/stat.t7')
78-
# model.modules[32] = nn.View(1,512)
79-
80-
# print(model)
81-
# this now should work:
82-
# model.forward(torch.Tensor(1,3,224,224)) # test
79+
#Patch Torch model to 4D
80+
patch(model)
8381

8482
# image pre-processing functions:
8583
transformsImage = transforms.Compose([
@@ -92,9 +90,12 @@ def cat_file():
9290

9391
while True:
9492
startt = time.time()
95-
ret, frame = cam.read()
96-
if not ret:
97-
break
93+
if usecam:
94+
ret, frame = cam.read()
95+
if not ret:
96+
break
97+
else:
98+
frame = image.copy()
9899

99100
if xres > yres:
100101
frame = frame[:,int((xres - yres)/2):int((xres+yres)/2),:]
@@ -104,9 +105,10 @@ def cat_file():
104105
pframe = cv2.resize(frame, dsize=(args.size, args.size))
105106

106107
# prepare and normalize frame for processing:
107-
pframe = np.swapaxes(pframe, 0, 2)
108-
pframe = np.expand_dims(pframe, axis=0)
108+
pframe = pframe[...,[2,1,0]]
109+
pframe = np.transpose(pframe, (2,0,1))
109110
pframe = transformsImage(pframe)
111+
pframe = pframe.view(1, pframe.size(0), pframe.size(1), pframe.size(2))
110112

111113
# process via CNN model:
112114
output = model.forward(pframe)
@@ -139,5 +141,6 @@ def cat_file():
139141
break
140142

141143
# end program:
142-
cam.release()
144+
if usecam:
145+
cam.release()
143146
cv2.destroyAllWindows()

0 commit comments

Comments
 (0)