Skip to content

Commit 2adc0ff

Browse files
committed
load old nets
1 parent 48615cd commit 2adc0ff

File tree

3 files changed

+154
-8
lines changed

3 files changed

+154
-8
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11

22
*.ipynb
3+
4+
*.pyc

visor/visor.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,6 @@ def cat_file():
3939
quit()
4040
return categories
4141

42-
# image pre-processing functions:
43-
transformsImage = transforms.Compose([
44-
# transforms.ToPILImage(),
45-
# transforms.Scale(256),
46-
# transforms.CenterCrop(224),
47-
transforms.ToTensor(),
48-
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # needed for pythorch ZOO models on ImageNet (stats)
49-
])
5042

5143
print("Visor demo e-Lab")
5244
xres = 640
@@ -66,6 +58,15 @@ def cat_file():
6658
# print model
6759
softMax = nn.Softmax() # to get probabilities out of CNN
6860

61+
# image pre-processing functions:
62+
transformsImage = transforms.Compose([
63+
# transforms.ToPILImage(),
64+
# transforms.Scale(256),
65+
# transforms.CenterCrop(224),
66+
transforms.ToTensor(),
67+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # needed for pythorch ZOO models on ImageNet (stats)
68+
])
69+
6970
while True:
7071
startt = time.time()
7172
ret, frame = cam.read()

visor/visor_oldnets.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
#! /usr/local/bin/python3
2+
3+
import sys
4+
import os
5+
import time
6+
import cv2 # install cv3, python3: http://seeb0h.github.io/howto/howto-install-homebrew-python-opencv-osx-el-capitan/
7+
# add to profile: export PYTHONPATH=$PYTHONPATH:/usr/local/Cellar/opencv3/3.2.0/lib/python3.6/site-packages/
8+
import numpy as np
9+
import argparse
10+
import torch
11+
import torchvision.transforms as transforms
12+
from torch.utils.serialization import load_lua
13+
from torch.legacy import nn # import torch.nn as nn
14+
15+
16+
def define_and_parse_args():
17+
# argument Checking
18+
parser = argparse.ArgumentParser(description="Visor Demo")
19+
# parser.add_argument('network', help='CNN model file')
20+
parser.add_argument('categories', help='text file with categories')
21+
parser.add_argument('-i', '--input', type=int, default='0', help='camera device index or file name, default 0')
22+
parser.add_argument('-s', '--size', type=int, default=224, help='network input size')
23+
# parser.add_argument('-S', '--stat', help='stat.txt file')
24+
return parser.parse_args()
25+
26+
27+
def cat_file():
28+
# load classes file
29+
categories = []
30+
if hasattr(args, 'categories') and args.categories:
31+
try:
32+
f = open(args.categories, 'r')
33+
for line in f:
34+
cat = line.split(',')[0].split('\n')[0]
35+
if cat != 'classes':
36+
categories.append(cat)
37+
f.close()
38+
print('Number of categories:', len(categories))
39+
except:
40+
print('Error opening file ' + args.categories)
41+
quit()
42+
return categories
43+
44+
45+
print("Visor demo e-Lab - older Torch7 networks")
46+
xres = 640
47+
yres = 480
48+
args = define_and_parse_args()
49+
categories = cat_file() # load category file
50+
print(categories)
51+
52+
53+
# setup camera input:
54+
cam = cv2.VideoCapture(args.input)
55+
cam.set(3, xres)
56+
cam.set(4, yres)
57+
58+
# load old-pre-trained Torch7 CNN model:
59+
60+
# https://www.dropbox.com/sh/l0rurgbx4k6j2a3/AAA223WOrRRjpe9bzO8ecpEpa?dl=0
61+
model = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/elab-alexowt-46/model.net')
62+
stat = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/elab-alexowt-46/stat.t7')
63+
model.modules[13] = nn.View(1,9216)
64+
65+
# https://www.dropbox.com/sh/xcm8xul3udwo72o/AAC8RChVSOmgN61nQ0cyfdava?dl=0
66+
# model = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/elab-alextiny-46/model.net')
67+
# stat = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/elab-alextiny-46/stat.t7')
68+
# model.modules[13] = nn.View(1,64)
69+
70+
# https://www.dropbox.com/sh/anklohs9g49z1o4/AAChA9rl0FEGixT75eT38Dqra?dl=0
71+
# model = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/elab-enet-demo-46/model.net')
72+
# stat = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/elab-enet-demo-46/stat.t7')
73+
# model.modules[41] = nn.View(1,1024)
74+
75+
# https://www.dropbox.com/sh/s0hwugnmhwkk9ow/AAD_abZ2LOav9GeMETt5VGvGa?dl=0
76+
# model = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/enet128-demo-46/model.net')
77+
# stat = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/enet128-demo-46/stat.t7')
78+
# model.modules[32] = nn.View(1,512)
79+
80+
# print(model)
81+
# this now should work:
82+
# model.forward(torch.Tensor(1,3,224,224)) # test
83+
84+
# image pre-processing functions:
85+
transformsImage = transforms.Compose([
86+
# transforms.ToPILImage(),
87+
# transforms.Scale(256),
88+
# transforms.CenterCrop(224),
89+
transforms.ToTensor(),
90+
transforms.Normalize(stat.mean, stat.std)
91+
])
92+
93+
while True:
94+
startt = time.time()
95+
ret, frame = cam.read()
96+
if not ret:
97+
break
98+
99+
if xres > yres:
100+
frame = frame[:,int((xres - yres)/2):int((xres+yres)/2),:]
101+
else:
102+
frame = frame[int((yres - xres)/2):int((yres+xres)/2),:,:]
103+
104+
pframe = cv2.resize(frame, dsize=(args.size, args.size))
105+
106+
# prepare and normalize frame for processing:
107+
pframe = np.swapaxes(pframe, 0, 2)
108+
pframe = np.expand_dims(pframe, axis=0)
109+
pframe = transformsImage(pframe)
110+
111+
# process via CNN model:
112+
output = model.forward(pframe)
113+
if output is None:
114+
print('no output from CNN model file')
115+
break
116+
117+
# print(output)
118+
output = output.numpy()[0]
119+
120+
# process output and print results:
121+
order = output.argsort()
122+
last = len(categories)-1
123+
text = ''
124+
for i in range(min(5, last+1)):
125+
text += categories[order[last-i]] + ' (' + '{0:.2f}'.format(output[order[last-i]]*100) + '%) '
126+
127+
# overlay on GUI frame
128+
# cv2.displayOverlay('win', text, 1000) # if Qt is present!
129+
font = cv2.FONT_HERSHEY_SIMPLEX
130+
cv2.putText(frame, text, (10, yres-20), font, 0.5, (255, 255, 255), 1)
131+
cv2.imshow('win', frame)
132+
133+
endt = time.time()
134+
# sys.stdout.write("\r"+text+"fps: "+'{0:.2f}'.format(1/(endt-startt))) # text output
135+
sys.stdout.write("\rfps: "+'{0:.2f}'.format(1/(endt-startt)))
136+
sys.stdout.flush()
137+
138+
if cv2.waitKey(1) == 27: # ESC to stop
139+
break
140+
141+
# end program:
142+
cam.release()
143+
cv2.destroyAllWindows()

0 commit comments

Comments
 (0)