|
| 1 | +#! /usr/local/bin/python3 |
| 2 | + |
| 3 | +import sys |
| 4 | +import os |
| 5 | +import time |
| 6 | +import cv2 # install cv3, python3: http://seeb0h.github.io/howto/howto-install-homebrew-python-opencv-osx-el-capitan/ |
| 7 | +# add to profile: export PYTHONPATH=$PYTHONPATH:/usr/local/Cellar/opencv3/3.2.0/lib/python3.6/site-packages/ |
| 8 | +import numpy as np |
| 9 | +import argparse |
| 10 | +import torch |
| 11 | +import torchvision.transforms as transforms |
| 12 | +from torch.utils.serialization import load_lua |
| 13 | +from torch.legacy import nn # import torch.nn as nn |
| 14 | + |
| 15 | + |
| 16 | +def define_and_parse_args(): |
| 17 | + # argument Checking |
| 18 | + parser = argparse.ArgumentParser(description="Visor Demo") |
| 19 | + # parser.add_argument('network', help='CNN model file') |
| 20 | + parser.add_argument('categories', help='text file with categories') |
| 21 | + parser.add_argument('-i', '--input', type=int, default='0', help='camera device index or file name, default 0') |
| 22 | + parser.add_argument('-s', '--size', type=int, default=224, help='network input size') |
| 23 | + # parser.add_argument('-S', '--stat', help='stat.txt file') |
| 24 | + return parser.parse_args() |
| 25 | + |
| 26 | + |
| 27 | +def cat_file(): |
| 28 | + # load classes file |
| 29 | + categories = [] |
| 30 | + if hasattr(args, 'categories') and args.categories: |
| 31 | + try: |
| 32 | + f = open(args.categories, 'r') |
| 33 | + for line in f: |
| 34 | + cat = line.split(',')[0].split('\n')[0] |
| 35 | + if cat != 'classes': |
| 36 | + categories.append(cat) |
| 37 | + f.close() |
| 38 | + print('Number of categories:', len(categories)) |
| 39 | + except: |
| 40 | + print('Error opening file ' + args.categories) |
| 41 | + quit() |
| 42 | + return categories |
| 43 | + |
| 44 | + |
| 45 | +print("Visor demo e-Lab - older Torch7 networks") |
| 46 | +xres = 640 |
| 47 | +yres = 480 |
| 48 | +args = define_and_parse_args() |
| 49 | +categories = cat_file() # load category file |
| 50 | +print(categories) |
| 51 | + |
| 52 | + |
| 53 | +# setup camera input: |
| 54 | +cam = cv2.VideoCapture(args.input) |
| 55 | +cam.set(3, xres) |
| 56 | +cam.set(4, yres) |
| 57 | + |
| 58 | +# load old-pre-trained Torch7 CNN model: |
| 59 | + |
| 60 | +# https://www.dropbox.com/sh/l0rurgbx4k6j2a3/AAA223WOrRRjpe9bzO8ecpEpa?dl=0 |
| 61 | +model = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/elab-alexowt-46/model.net') |
| 62 | +stat = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/elab-alexowt-46/stat.t7') |
| 63 | +model.modules[13] = nn.View(1,9216) |
| 64 | + |
| 65 | +# https://www.dropbox.com/sh/xcm8xul3udwo72o/AAC8RChVSOmgN61nQ0cyfdava?dl=0 |
| 66 | +# model = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/elab-alextiny-46/model.net') |
| 67 | +# stat = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/elab-alextiny-46/stat.t7') |
| 68 | +# model.modules[13] = nn.View(1,64) |
| 69 | + |
| 70 | +# https://www.dropbox.com/sh/anklohs9g49z1o4/AAChA9rl0FEGixT75eT38Dqra?dl=0 |
| 71 | +# model = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/elab-enet-demo-46/model.net') |
| 72 | +# stat = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/elab-enet-demo-46/stat.t7') |
| 73 | +# model.modules[41] = nn.View(1,1024) |
| 74 | + |
| 75 | +# https://www.dropbox.com/sh/s0hwugnmhwkk9ow/AAD_abZ2LOav9GeMETt5VGvGa?dl=0 |
| 76 | +# model = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/enet128-demo-46/model.net') |
| 77 | +# stat = load_lua('/Users/eugenioculurciello/Dropbox/shared/models/enet128-demo-46/stat.t7') |
| 78 | +# model.modules[32] = nn.View(1,512) |
| 79 | + |
| 80 | +# print(model) |
| 81 | +# this now should work: |
| 82 | +# model.forward(torch.Tensor(1,3,224,224)) # test |
| 83 | + |
| 84 | +# image pre-processing functions: |
| 85 | +transformsImage = transforms.Compose([ |
| 86 | + # transforms.ToPILImage(), |
| 87 | + # transforms.Scale(256), |
| 88 | + # transforms.CenterCrop(224), |
| 89 | + transforms.ToTensor(), |
| 90 | + transforms.Normalize(stat.mean, stat.std) |
| 91 | + ]) |
| 92 | + |
| 93 | +while True: |
| 94 | + startt = time.time() |
| 95 | + ret, frame = cam.read() |
| 96 | + if not ret: |
| 97 | + break |
| 98 | + |
| 99 | + if xres > yres: |
| 100 | + frame = frame[:,int((xres - yres)/2):int((xres+yres)/2),:] |
| 101 | + else: |
| 102 | + frame = frame[int((yres - xres)/2):int((yres+xres)/2),:,:] |
| 103 | + |
| 104 | + pframe = cv2.resize(frame, dsize=(args.size, args.size)) |
| 105 | + |
| 106 | + # prepare and normalize frame for processing: |
| 107 | + pframe = np.swapaxes(pframe, 0, 2) |
| 108 | + pframe = np.expand_dims(pframe, axis=0) |
| 109 | + pframe = transformsImage(pframe) |
| 110 | + |
| 111 | + # process via CNN model: |
| 112 | + output = model.forward(pframe) |
| 113 | + if output is None: |
| 114 | + print('no output from CNN model file') |
| 115 | + break |
| 116 | + |
| 117 | + # print(output) |
| 118 | + output = output.numpy()[0] |
| 119 | + |
| 120 | + # process output and print results: |
| 121 | + order = output.argsort() |
| 122 | + last = len(categories)-1 |
| 123 | + text = '' |
| 124 | + for i in range(min(5, last+1)): |
| 125 | + text += categories[order[last-i]] + ' (' + '{0:.2f}'.format(output[order[last-i]]*100) + '%) ' |
| 126 | + |
| 127 | + # overlay on GUI frame |
| 128 | + # cv2.displayOverlay('win', text, 1000) # if Qt is present! |
| 129 | + font = cv2.FONT_HERSHEY_SIMPLEX |
| 130 | + cv2.putText(frame, text, (10, yres-20), font, 0.5, (255, 255, 255), 1) |
| 131 | + cv2.imshow('win', frame) |
| 132 | + |
| 133 | + endt = time.time() |
| 134 | + # sys.stdout.write("\r"+text+"fps: "+'{0:.2f}'.format(1/(endt-startt))) # text output |
| 135 | + sys.stdout.write("\rfps: "+'{0:.2f}'.format(1/(endt-startt))) |
| 136 | + sys.stdout.flush() |
| 137 | + |
| 138 | + if cv2.waitKey(1) == 27: # ESC to stop |
| 139 | + break |
| 140 | + |
| 141 | +# end program: |
| 142 | +cam.release() |
| 143 | +cv2.destroyAllWindows() |
0 commit comments