9
9
import argparse
10
10
import torch
11
11
import torchvision .transforms as transforms
12
+ import torch .legacy .nn as nn
13
+ #load_lua does not recognize SpatialConvolutionMM
14
+ nn .SpatialConvolutionMM = nn .SpatialConvolution
12
15
from torch .utils .serialization import load_lua
13
- from torch .legacy import nn # import torch.nn as nn
14
-
15
16
16
17
def define_and_parse_args ():
17
18
# argument Checking
18
19
parser = argparse .ArgumentParser (description = "Visor Demo" )
19
- # parser.add_argument('network', help='CNN model file')
20
- parser .add_argument ('categories' , help = 'text file with categories' )
21
- parser .add_argument ('-i' , '--input' , type = int , default = '0' , help = 'camera device index or file name, default 0' )
20
+ parser .add_argument ('model' , help = 'model directory' )
21
+ parser .add_argument ('-i' , '--input' , default = '0' , help = 'camera device index or file name, default 0' )
22
22
parser .add_argument ('-s' , '--size' , type = int , default = 224 , help = 'network input size' )
23
- # parser.add_argument('-S', '--stat', help='stat.txt file')
24
23
return parser .parse_args ()
25
24
26
25
27
26
def cat_file ():
28
27
# load classes file
29
28
categories = []
30
- if hasattr (args , 'categories' ) and args .categories :
31
- try :
32
- f = open (args .categories , 'r' )
33
- for line in f :
34
- cat = line .split (',' )[0 ].split ('\n ' )[0 ]
35
- if cat != 'classes' :
36
- categories .append (cat )
37
- f .close ()
38
- print ('Number of categories:' , len (categories ))
39
- except :
40
- print ('Error opening file ' + args .categories )
41
- quit ()
29
+ try :
30
+ f = open (args .model + '/categories.txt' , 'r' )
31
+ for line in f :
32
+ cat = line .split (',' )[0 ].split ('\n ' )[0 ]
33
+ if cat != 'classes' :
34
+ categories .append (cat )
35
+ f .close ()
36
+ print ('Number of categories:' , len (categories ))
37
+ except :
38
+ print ('Error opening file ' + args .model + '/categories.txt' )
39
+ quit ()
42
40
return categories
43
41
42
+ def patch (m ):
43
+ s = str (type (m ))
44
+ s = s [str .rfind (s , '.' )+ 1 :- 2 ]
45
+ if s == 'Padding' and hasattr (m , 'nInputDim' ) and m .nInputDim == 3 :
46
+ m .dim = m .dim + 1
47
+ if s == 'View' and len (m .size ) == 1 :
48
+ m .size = torch .Size ([1 ,m .size [0 ]])
49
+ if hasattr (m , 'modules' ):
50
+ for m in m .modules :
51
+ patch (m )
44
52
45
53
print ("Visor demo e-Lab - older Torch7 networks" )
46
54
xres = 640
@@ -51,35 +59,25 @@ def cat_file():
51
59
52
60
53
61
# setup camera input:
54
- cam = cv2 .VideoCapture (args .input )
55
- cam .set (3 , xres )
56
- cam .set (4 , yres )
62
+ if args .input [0 ] >= '0' and args .input [0 ] <= '9' :
63
+ cam = cv2 .VideoCapture (int (args .input ))
64
+ cam .set (3 , xres )
65
+ cam .set (4 , yres )
66
+ usecam = True
67
+ else :
68
+ image = cv2 .imread (args .input )
69
+ xres = image .shape [1 ]
70
+ yres = image .shape [0 ]
71
+ usecam = False
57
72
58
73
# load old-pre-trained Torch7 CNN model:
59
74
60
75
# https://www.dropbox.com/sh/l0rurgbx4k6j2a3/AAA223WOrRRjpe9bzO8ecpEpa?dl=0
61
- model = load_lua ('/Users/eugenioculurciello/Dropbox/shared/models/elab-alexowt-46/model.net' )
62
- stat = load_lua ('/Users/eugenioculurciello/Dropbox/shared/models/elab-alexowt-46/stat.t7' )
63
- model .modules [13 ] = nn .View (1 ,9216 )
64
-
65
- # https://www.dropbox.com/sh/xcm8xul3udwo72o/AAC8RChVSOmgN61nQ0cyfdava?dl=0
66
- # model = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/elab-alextiny-46/model.net')
67
- # stat = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/elab-alextiny-46/stat.t7')
68
- # model.modules[13] = nn.View(1,64)
69
-
70
- # https://www.dropbox.com/sh/anklohs9g49z1o4/AAChA9rl0FEGixT75eT38Dqra?dl=0
71
- # model = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/elab-enet-demo-46/model.net')
72
- # stat = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/elab-enet-demo-46/stat.t7')
73
- # model.modules[41] = nn.View(1,1024)
76
+ model = load_lua (args .model + '/model.net' )
77
+ stat = load_lua (args .model + '/stat.t7' )
74
78
75
- # https://www.dropbox.com/sh/s0hwugnmhwkk9ow/AAD_abZ2LOav9GeMETt5VGvGa?dl=0
76
- # model = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/enet128-demo-46/model.net')
77
- # stat = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/enet128-demo-46/stat.t7')
78
- # model.modules[32] = nn.View(1,512)
79
-
80
- # print(model)
81
- # this now should work:
82
- # model.forward(torch.Tensor(1,3,224,224)) # test
79
+ #Patch Torch model to 4D
80
+ patch (model )
83
81
84
82
# image pre-processing functions:
85
83
transformsImage = transforms .Compose ([
@@ -92,9 +90,12 @@ def cat_file():
92
90
93
91
while True :
94
92
startt = time .time ()
95
- ret , frame = cam .read ()
96
- if not ret :
97
- break
93
+ if usecam :
94
+ ret , frame = cam .read ()
95
+ if not ret :
96
+ break
97
+ else :
98
+ frame = image .copy ()
98
99
99
100
if xres > yres :
100
101
frame = frame [:,int ((xres - yres )/ 2 ):int ((xres + yres )/ 2 ),:]
@@ -104,9 +105,10 @@ def cat_file():
104
105
pframe = cv2 .resize (frame , dsize = (args .size , args .size ))
105
106
106
107
# prepare and normalize frame for processing:
107
- pframe = np . swapaxes ( pframe , 0 , 2 )
108
- pframe = np .expand_dims (pframe , axis = 0 )
108
+ pframe = pframe [...,[ 2 , 1 , 0 ]]
109
+ pframe = np .transpose (pframe , ( 2 , 0 , 1 ) )
109
110
pframe = transformsImage (pframe )
111
+ pframe = pframe .view (1 , pframe .size (0 ), pframe .size (1 ), pframe .size (2 ))
110
112
111
113
# process via CNN model:
112
114
output = model .forward (pframe )
@@ -139,5 +141,6 @@ def cat_file():
139
141
break
140
142
141
143
# end program:
142
- cam .release ()
144
+ if usecam :
145
+ cam .release ()
143
146
cv2 .destroyAllWindows ()
0 commit comments