5
5
from SSBUDigitClassifier import SSBUDigitClassifier
6
6
from SSBUNameRecognizer import SSBUNameRecognizer
7
7
from SSBUCharaClassifier import SSBUCharaClassifier
8
+ from SSBUStockClassifier import SSBUStockClassifier
9
+ import SSBUBoundingBoxUtil
10
+
8
11
9
12
class SSBUFrameAnalyzer :
10
- def __init__ (self , digit_classifier , name_recognizer , chara_classifier ):
13
+ def __init__ (self , digit_classifier , name_recognizer , chara_classifier , stock_classifier ):
11
14
self .digit_classifier = digit_classifier
12
15
self .name_recognizer = name_recognizer
13
16
self .chara_classifier = chara_classifier
17
+ self .stock_classifier = stock_classifier
14
18
15
19
def __call__ (self , frame , fighter_num = 2 ):
16
-
17
20
dmgs = self .analyze_damage (frame , fighter_num = fighter_num )
18
21
names = self .analyze_name (frame , fighter_num = fighter_num )
19
22
charas = self .analyze_chara (frame , fighter_num = fighter_num )
23
+ stocks = self .analyze_stock (frame , fighter_num = fighter_num )
20
24
21
25
fighters = {}
22
26
for fighter_idx in range (fighter_num ):
23
27
fighters [fighter_idx ] = {
24
28
'chara_name' : charas [fighter_idx ],
25
29
'name' : names [fighter_idx ],
26
30
'damage' : dmgs [fighter_idx ],
31
+ 'stocks' : stocks [fighter_idx ],
27
32
}
28
33
29
34
result = {
@@ -32,69 +37,8 @@ def __call__(self, frame, fighter_num=2):
32
37
33
38
return result
34
39
35
- def fighters_info_bbox (self , fighter_num ):
36
- assert fighter_num in (2 , 3 , 4 ), 'Not implemented'
37
-
38
- # FIXME: magic number
39
- if fighter_num == 2 :
40
- return [
41
- # Fighter 1
42
- [ 245 ,560 , 240 ,160 ],
43
- # Fighter 2
44
- [ 245 + 495 ,560 , 240 ,160 ],
45
- ]
46
- elif fighter_num == 3 :
47
- return [
48
- # Fighter 1
49
- [ 75 ,560 , 240 ,160 ],
50
- # Fighter 2
51
- [ 75 + 416 ,560 , 240 ,160 ],
52
- # Fighter 3
53
- [ 75 + 832 ,560 , 240 ,160 ],
54
- ]
55
- elif fighter_num == 4 :
56
- return [
57
- # Fighter 1
58
- [ 98 ,560 , 240 ,160 ],
59
- # Fighter 2
60
- [ 98 + 272 ,560 , 240 ,160 ],
61
- # Fighter 3
62
- [ 98 + 545 ,560 , 240 ,160 ],
63
- # Fighter 4
64
- [ 98 + 817 ,560 , 240 ,160 ],
65
- ]
66
- def fighters_damage_bboxes (self , fighter_num ):
67
- info_bboxes = self .fighters_info_bbox (fighter_num = fighter_num )
68
-
69
- ret = []
70
- for fighter_idx , bbox in enumerate (info_bboxes ):
71
- left = bbox [0 ] + 85
72
- top = bbox [1 ] + 50
73
-
74
- ret .append ([
75
- [ left ,top , 35 ,55 ],
76
- [ left + 30 ,top , 35 ,55 ],
77
- [ left + 60 ,top , 35 ,55 ],
78
- [ left + 97 ,top + 28 , 18 ,25 ],
79
- ])
80
- return ret
81
- def fighters_name_bbox (self , fighter_num ):
82
- info_bboxes = self .fighters_info_bbox (fighter_num = fighter_num )
83
-
84
- ret = []
85
- for fighter_idx , bbox in enumerate (info_bboxes ):
86
- ret .append ([ bbox [0 ]+ 105 ,bbox [1 ]+ 110 , 120 ,16 ])
87
- return ret
88
- def fighters_chara_bbox (self , fighter_num ):
89
- info_bboxes = self .fighters_info_bbox (fighter_num = fighter_num )
90
-
91
- ret = []
92
- for fighter_idx , bbox in enumerate (info_bboxes ):
93
- ret .append ([ bbox [0 ]+ 10 ,bbox [1 ]+ 28 , 110 ,110 ])
94
- return ret
95
-
96
40
def analyze_damage (self , frame , fighter_num ):
97
- fighters_dmg_bboxes = self .fighters_damage_bboxes (fighter_num = fighter_num )
41
+ fighters_dmg_bboxes = SSBUBoundingBoxUtil .fighters_damage_bboxes (fighter_num = fighter_num )
98
42
assert fighter_num == len (fighters_dmg_bboxes )
99
43
100
44
dc = self .digit_classifier
@@ -136,7 +80,7 @@ def predict_digit(img):
136
80
return result
137
81
138
82
def analyze_name (self , frame , fighter_num ):
139
- fighters_name_bbox = self .fighters_name_bbox (fighter_num = fighter_num )
83
+ fighters_name_bbox = SSBUBoundingBoxUtil .fighters_name_bbox (fighter_num = fighter_num )
140
84
assert fighter_num == len (fighters_name_bbox )
141
85
142
86
nr = self .name_recognizer
@@ -156,15 +100,15 @@ def analyze_name(self, frame, fighter_num):
156
100
return result
157
101
158
102
def analyze_chara (self , frame , fighter_num ):
159
- fighters_chara_bbox = self .fighters_chara_bbox (fighter_num = fighter_num )
103
+ fighters_chara_bbox = SSBUBoundingBoxUtil .fighters_chara_bbox (fighter_num = fighter_num )
160
104
assert fighter_num == len (fighters_chara_bbox )
161
105
162
106
cc = self .chara_classifier
163
107
def predict_chara (img ):
164
108
names , dists = cc (img , k = 3 )
165
109
min_dist = dists [0 ]
166
110
167
- print (names , dists )
111
+ # print(names, dists)
168
112
thresh_dist = 10.
169
113
170
114
name = names [0 ] if min_dist < thresh_dist else None
@@ -184,6 +128,41 @@ def predict_chara(img):
184
128
185
129
return result
186
130
131
+ def analyze_stock (self , frame , fighter_num ):
132
+ fighters_stock_bboxes = SSBUBoundingBoxUtil .fighters_stock_bboxes (fighter_num = fighter_num , stock_num = 5 )
133
+ assert fighter_num == len (fighters_stock_bboxes )
134
+
135
+ sc = self .stock_classifier
136
+ def predict_stock (img ):
137
+ stocks , dists = sc (img , k = 3 )
138
+ min_dist = dists [0 ]
139
+
140
+ # print(stocks, dists)
141
+ thresh_dist = 0.6
142
+
143
+ stock = stocks [0 ] if min_dist < thresh_dist else None
144
+ return stock
145
+
146
+ result = {}
147
+ for fighter_idx in range (fighter_num ):
148
+ bboxes = fighters_stock_bboxes [fighter_idx ]
149
+
150
+ stocks = []
151
+ for bbox_idx , bbox in enumerate (bboxes ):
152
+ simg = frame [bbox [1 ]:bbox [1 ]+ bbox [3 ], bbox [0 ]:bbox [0 ]+ bbox [2 ]] # RGB
153
+ simg = cv2 .cvtColor (simg , cv2 .COLOR_BGR2GRAY ) # GRAY
154
+
155
+ # cv2.imwrite('fighter-stock-%d-%d.png' % (fighter_idx, bbox_idx, ), simg)
156
+
157
+ stock = predict_stock (simg )
158
+ if stock is None :
159
+ break
160
+ stocks .append (stock )
161
+
162
+ result [fighter_idx ] = stocks
163
+
164
+ return result
165
+
187
166
188
167
189
168
if __name__ == '__main__' :
@@ -194,6 +173,7 @@ def predict_chara(img):
194
173
parser = argparse .ArgumentParser ()
195
174
parser .add_argument ('digit_dictionary' , type = str )
196
175
parser .add_argument ('chara_dictionary' , type = str )
176
+ parser .add_argument ('stock_dictionary' , type = str )
197
177
parser .add_argument ('input' , type = str )
198
178
parser .add_argument ('fighter_num' , type = int ) # TODO: predict
199
179
args = parser .parse_args ()
@@ -210,7 +190,11 @@ def predict_chara(img):
210
190
chara_classifier = SSBUCharaClassifier (feature_json = args .chara_dictionary )
211
191
print ('loaded chara classifier' )
212
192
213
- analyzer = SSBUFrameAnalyzer (digit_classifier = digit_classifier , name_recognizer = name_recognizer , chara_classifier = chara_classifier )
193
+ print ('loading stock classifier...' )
194
+ stock_classifier = SSBUStockClassifier (feature_json = args .stock_dictionary )
195
+ print ('loaded stock classifier' )
196
+
197
+ analyzer = SSBUFrameAnalyzer (digit_classifier = digit_classifier , name_recognizer = name_recognizer , chara_classifier = chara_classifier , stock_classifier = stock_classifier )
214
198
215
199
frame = cv2 .imread (args .input , 1 ) # RGB
216
200
frame = cv2 .resize (frame , (1280 , 720 ))
0 commit comments