Skip to content

Commit b46322b

Browse files
committed
stock classifier
1 parent 8d959d8 commit b46322b

File tree

3 files changed

+219
-68
lines changed

3 files changed

+219
-68
lines changed

SSBUBoundingBoxUtil.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
2+
def fighters_info_bbox(fighter_num):
3+
assert fighter_num in (2, 3, 4), 'Not implemented'
4+
5+
# FIXME: magic number
6+
if fighter_num == 2:
7+
return [
8+
# Fighter 1
9+
[ 245 ,560, 240,160 ],
10+
# Fighter 2
11+
[ 245+495,560, 240,160 ],
12+
]
13+
elif fighter_num == 3:
14+
return [
15+
# Fighter 1
16+
[ 75 ,560, 240,160 ],
17+
# Fighter 2
18+
[ 75+416,560, 240,160 ],
19+
# Fighter 3
20+
[ 75+832,560, 240,160 ],
21+
]
22+
elif fighter_num == 4:
23+
return [
24+
# Fighter 1
25+
[ 98 ,560, 240,160 ],
26+
# Fighter 2
27+
[ 98+272,560, 240,160 ],
28+
# Fighter 3
29+
[ 98+544,560, 240,160 ],
30+
# Fighter 4
31+
[ 98+816,560, 240,160 ],
32+
]
33+
34+
35+
def fighters_damage_bboxes(fighter_num):
36+
info_bboxes = fighters_info_bbox(fighter_num=fighter_num)
37+
38+
ret = []
39+
for fighter_idx, bbox in enumerate(info_bboxes):
40+
left = bbox[0] + 85
41+
top = bbox[1] + 50
42+
43+
ret.append([
44+
[ left ,top, 35,55 ],
45+
[ left+30,top, 35,55 ],
46+
[ left+60,top, 35,55 ],
47+
[ left+97,top+28, 18,25 ],
48+
])
49+
return ret
50+
51+
def fighters_name_bbox(fighter_num):
52+
info_bboxes = fighters_info_bbox(fighter_num=fighter_num)
53+
54+
ret = []
55+
for fighter_idx, bbox in enumerate(info_bboxes):
56+
ret.append([ bbox[0]+105,bbox[1]+110, 120,16 ])
57+
return ret
58+
59+
def fighters_chara_bbox(fighter_num):
60+
info_bboxes = fighters_info_bbox(fighter_num=fighter_num)
61+
62+
ret = []
63+
for fighter_idx, bbox in enumerate(info_bboxes):
64+
ret.append([ bbox[0]+10,bbox[1]+28, 110,110 ])
65+
return ret
66+
67+
def fighters_stock_bboxes(fighter_num, stock_num=3):
68+
info_bboxes = fighters_info_bbox(fighter_num=fighter_num)
69+
70+
ret = []
71+
for fighter_idx, bbox in enumerate(info_bboxes):
72+
left = bbox[0] + 73
73+
top = bbox[1] + 131
74+
75+
ret.append(
76+
[ [ left + 17*k, top, 16, 16, ] for k in range(stock_num) ]
77+
)
78+
return ret

SSBUFrameAnalyzer.py

Lines changed: 52 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,30 @@
55
from SSBUDigitClassifier import SSBUDigitClassifier
66
from SSBUNameRecognizer import SSBUNameRecognizer
77
from SSBUCharaClassifier import SSBUCharaClassifier
8+
from SSBUStockClassifier import SSBUStockClassifier
9+
import SSBUBoundingBoxUtil
10+
811

912
class SSBUFrameAnalyzer:
10-
def __init__(self, digit_classifier, name_recognizer, chara_classifier):
13+
def __init__(self, digit_classifier, name_recognizer, chara_classifier, stock_classifier):
1114
self.digit_classifier = digit_classifier
1215
self.name_recognizer = name_recognizer
1316
self.chara_classifier = chara_classifier
17+
self.stock_classifier = stock_classifier
1418

1519
def __call__(self, frame, fighter_num=2):
16-
1720
dmgs = self.analyze_damage(frame, fighter_num=fighter_num)
1821
names = self.analyze_name(frame, fighter_num=fighter_num)
1922
charas = self.analyze_chara(frame, fighter_num=fighter_num)
23+
stocks = self.analyze_stock(frame, fighter_num=fighter_num)
2024

2125
fighters = {}
2226
for fighter_idx in range(fighter_num):
2327
fighters[fighter_idx] = {
2428
'chara_name': charas[fighter_idx],
2529
'name': names[fighter_idx],
2630
'damage': dmgs[fighter_idx],
31+
'stocks': stocks[fighter_idx],
2732
}
2833

2934
result = {
@@ -32,69 +37,8 @@ def __call__(self, frame, fighter_num=2):
3237

3338
return result
3439

35-
def fighters_info_bbox(self, fighter_num):
36-
assert fighter_num in (2, 3, 4), 'Not implemented'
37-
38-
# FIXME: magic number
39-
if fighter_num == 2:
40-
return [
41-
# Fighter 1
42-
[ 245 ,560, 240,160 ],
43-
# Fighter 2
44-
[ 245+495,560, 240,160 ],
45-
]
46-
elif fighter_num == 3:
47-
return [
48-
# Fighter 1
49-
[ 75 ,560, 240,160 ],
50-
# Fighter 2
51-
[ 75+416,560, 240,160 ],
52-
# Fighter 3
53-
[ 75+832,560, 240,160 ],
54-
]
55-
elif fighter_num == 4:
56-
return [
57-
# Fighter 1
58-
[ 98 ,560, 240,160 ],
59-
# Fighter 2
60-
[ 98+272,560, 240,160 ],
61-
# Fighter 3
62-
[ 98+545,560, 240,160 ],
63-
# Fighter 4
64-
[ 98+817,560, 240,160 ],
65-
]
66-
def fighters_damage_bboxes(self, fighter_num):
67-
info_bboxes = self.fighters_info_bbox(fighter_num=fighter_num)
68-
69-
ret = []
70-
for fighter_idx, bbox in enumerate(info_bboxes):
71-
left = bbox[0] + 85
72-
top = bbox[1] + 50
73-
74-
ret.append([
75-
[ left ,top, 35,55 ],
76-
[ left+30,top, 35,55 ],
77-
[ left+60,top, 35,55 ],
78-
[ left+97,top+28, 18,25 ],
79-
])
80-
return ret
81-
def fighters_name_bbox(self, fighter_num):
82-
info_bboxes = self.fighters_info_bbox(fighter_num=fighter_num)
83-
84-
ret = []
85-
for fighter_idx, bbox in enumerate(info_bboxes):
86-
ret.append([ bbox[0]+105,bbox[1]+110, 120,16 ])
87-
return ret
88-
def fighters_chara_bbox(self, fighter_num):
89-
info_bboxes = self.fighters_info_bbox(fighter_num=fighter_num)
90-
91-
ret = []
92-
for fighter_idx, bbox in enumerate(info_bboxes):
93-
ret.append([ bbox[0]+10,bbox[1]+28, 110,110 ])
94-
return ret
95-
9640
def analyze_damage(self, frame, fighter_num):
97-
fighters_dmg_bboxes = self.fighters_damage_bboxes(fighter_num=fighter_num)
41+
fighters_dmg_bboxes = SSBUBoundingBoxUtil.fighters_damage_bboxes(fighter_num=fighter_num)
9842
assert fighter_num == len(fighters_dmg_bboxes)
9943

10044
dc = self.digit_classifier
@@ -136,7 +80,7 @@ def predict_digit(img):
13680
return result
13781

13882
def analyze_name(self, frame, fighter_num):
139-
fighters_name_bbox = self.fighters_name_bbox(fighter_num=fighter_num)
83+
fighters_name_bbox = SSBUBoundingBoxUtil.fighters_name_bbox(fighter_num=fighter_num)
14084
assert fighter_num == len(fighters_name_bbox)
14185

14286
nr = self.name_recognizer
@@ -156,15 +100,15 @@ def analyze_name(self, frame, fighter_num):
156100
return result
157101

158102
def analyze_chara(self, frame, fighter_num):
159-
fighters_chara_bbox = self.fighters_chara_bbox(fighter_num=fighter_num)
103+
fighters_chara_bbox = SSBUBoundingBoxUtil.fighters_chara_bbox(fighter_num=fighter_num)
160104
assert fighter_num == len(fighters_chara_bbox)
161105

162106
cc = self.chara_classifier
163107
def predict_chara(img):
164108
names, dists = cc(img, k=3)
165109
min_dist = dists[0]
166110

167-
print(names, dists)
111+
# print(names, dists)
168112
thresh_dist = 10.
169113

170114
name = names[0] if min_dist < thresh_dist else None
@@ -184,6 +128,41 @@ def predict_chara(img):
184128

185129
return result
186130

131+
def analyze_stock(self, frame, fighter_num):
132+
fighters_stock_bboxes = SSBUBoundingBoxUtil.fighters_stock_bboxes(fighter_num=fighter_num, stock_num=5)
133+
assert fighter_num == len(fighters_stock_bboxes)
134+
135+
sc = self.stock_classifier
136+
def predict_stock(img):
137+
stocks, dists = sc(img, k=3)
138+
min_dist = dists[0]
139+
140+
# print(stocks, dists)
141+
thresh_dist = 0.6
142+
143+
stock = stocks[0] if min_dist < thresh_dist else None
144+
return stock
145+
146+
result = {}
147+
for fighter_idx in range(fighter_num):
148+
bboxes = fighters_stock_bboxes[fighter_idx]
149+
150+
stocks = []
151+
for bbox_idx, bbox in enumerate(bboxes):
152+
simg = frame[bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]] # RGB
153+
simg = cv2.cvtColor(simg, cv2.COLOR_BGR2GRAY) # GRAY
154+
155+
# cv2.imwrite('fighter-stock-%d-%d.png' % (fighter_idx, bbox_idx, ), simg)
156+
157+
stock = predict_stock(simg)
158+
if stock is None:
159+
break
160+
stocks.append(stock)
161+
162+
result[fighter_idx] = stocks
163+
164+
return result
165+
187166

188167

189168
if __name__ == '__main__':
@@ -194,6 +173,7 @@ def predict_chara(img):
194173
parser = argparse.ArgumentParser()
195174
parser.add_argument('digit_dictionary', type=str)
196175
parser.add_argument('chara_dictionary', type=str)
176+
parser.add_argument('stock_dictionary', type=str)
197177
parser.add_argument('input', type=str)
198178
parser.add_argument('fighter_num', type=int) # TODO: predict
199179
args = parser.parse_args()
@@ -210,7 +190,11 @@ def predict_chara(img):
210190
chara_classifier = SSBUCharaClassifier(feature_json=args.chara_dictionary)
211191
print('loaded chara classifier')
212192

213-
analyzer = SSBUFrameAnalyzer(digit_classifier=digit_classifier, name_recognizer=name_recognizer, chara_classifier=chara_classifier)
193+
print('loading stock classifier...')
194+
stock_classifier = SSBUStockClassifier(feature_json=args.stock_dictionary)
195+
print('loaded stock classifier')
196+
197+
analyzer = SSBUFrameAnalyzer(digit_classifier=digit_classifier, name_recognizer=name_recognizer, chara_classifier=chara_classifier, stock_classifier=stock_classifier)
214198

215199
frame = cv2.imread(args.input, 1) # RGB
216200
frame = cv2.resize(frame, (1280, 720))

SSBUStockClassifier.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from skimage.feature import hog
2+
import numpy as np
3+
import json
4+
import cv2
5+
6+
class SSBUStockClassifier:
7+
def __init__(self, feature_json):
8+
self.feature_json = feature_json
9+
10+
with open(feature_json, 'r') as fp:
11+
data = json.load(fp)
12+
13+
chara_id2name = sorted(list(data.keys()))
14+
15+
charas = []
16+
features = []
17+
for chara_id, chara_name in enumerate(chara_id2name):
18+
fts = data[chara_name]
19+
20+
# n = 0
21+
for feature in fts:
22+
charas.append(int(chara_id))
23+
features.append(np.asarray(feature, dtype=np.float32))
24+
# n += 1
25+
# if n == 4:
26+
# break
27+
28+
self.categories = set(charas)
29+
self.datacount = len(charas)
30+
31+
self.chara_id2name = chara_id2name
32+
self.charas = np.asarray(charas, dtype=np.int32)
33+
self.features = np.asarray(features, dtype=np.float32)
34+
35+
36+
def __call__(self, img, k=3):
37+
# img isinstance of np.ndarray
38+
# print(img.shape)
39+
assert len(img.shape) == 2 # gray
40+
assert img.shape[1] == 16 and img.shape[0] == 16
41+
img = cv2.resize(img, (30, 30)) # hog requirement
42+
h0 = hog(img)
43+
44+
dists = np.linalg.norm(self.features - h0, axis=1)
45+
46+
sarg = np.argsort(dists) # sorted-arg
47+
topKarg = sarg[:k]
48+
49+
charas = self.charas[topKarg].tolist()
50+
names = []
51+
for i in range(len(charas)):
52+
chara_id = int(charas[i])
53+
name = self.chara_id2name[chara_id]
54+
names.append(name)
55+
56+
return names, dists[topKarg].tolist()
57+
58+
if __name__ == '__main__':
59+
import argparse
60+
import time
61+
import cv2
62+
import SSBUBoundingBoxUtil
63+
64+
parser = argparse.ArgumentParser()
65+
parser.add_argument('dictionary', type=str)
66+
parser.add_argument('input', type=str)
67+
parser.add_argument('fighter_num', type=int)
68+
args = parser.parse_args()
69+
70+
print('loading...')
71+
classifier = SSBUStockClassifier(feature_json=args.dictionary)
72+
print('loaded')
73+
74+
img = cv2.imread(args.input, 0)
75+
stock_bboxes = SSBUBoundingBoxUtil.fighters_stock_bboxes(fighter_num=args.fighter_num, stock_num=3)
76+
77+
t = time.time()
78+
# img = cv2.resize(img, (110, 110))
79+
80+
for fighter_idx, stock_bboxes in enumerate(stock_bboxes):
81+
for stock_idx, bbox in enumerate(stock_bboxes):
82+
simg = img[bbox[1]:bbox[1]+bbox[3], bbox[0]:bbox[0]+bbox[2]]
83+
84+
ret = classifier(simg)
85+
print(ret)
86+
87+
elapsed = time.time() - t
88+
89+
print('FPS: %f (%f s)' % (1/elapsed, elapsed, ))

0 commit comments

Comments
 (0)