1
+ import argparse , time , os
2
+ import cv2 as cv
3
+ import mxnet as mx
4
+ import numpy as np
5
+ import gluoncv as gcv
6
+ from gluoncv import data
7
+ from gluoncv .model_zoo import get_model
8
+ from gluoncv .data .transforms .pose import detector_to_simple_pose , heatmap_to_coord
9
+ from gluoncv .utils .viz import cv_plot_image , cv_plot_keypoints
10
+ from mxnet .gluon .data .vision import transforms
11
+ from angle import CalAngle
12
+ from sklearn .metrics import r2_score
13
+
14
+ # 读取参数
15
+ parser = argparse .ArgumentParser ()
16
+ parser .add_argument ('--input' , default = 0 )
17
+ parser .add_argument ('--demo' , required = True )
18
+ parser .add_argument ('--data' , required = True )
19
+ args = parser .parse_args ()
20
+
21
+ fps_time = 0
22
+
23
+ # 设置模型
24
+ ctx = mx .gpu ()
25
+
26
+ detector_name = "ssd_512_mobilenet1.0_coco"
27
+ detector = get_model (detector_name , pretrained = True , ctx = ctx )
28
+
29
+ estimator_name = "simple_pose_resnet18_v1b"
30
+ estimator = get_model (estimator_name , pretrained = 'ccd24037' , ctx = ctx )
31
+
32
+ detector .reset_class (classes = ['person' ], reuse_weights = {'person' :'person' })
33
+
34
+ detector .hybridize ()
35
+ estimator .hybridize ()
36
+
37
+ # 视频读取
38
+ cap1 = cv .VideoCapture (args .input )
39
+ cap2 = cv .VideoCapture (args .demo )
40
+
41
+ # 标准特征
42
+ stdAngle = np .loadtxt (args .data , delimiter = '\t ' )
43
+ pos = 0
44
+
45
+ ret1 , frame1 = cap1 .read ()
46
+ ret2 , frame2 = cap2 .read ()
47
+ while ret1 and ret2 :
48
+
49
+ # 目标检测
50
+ frame = mx .nd .array (cv .cvtColor (frame1 , cv .COLOR_BGR2RGB )).astype ('uint8' )
51
+
52
+ x , img = gcv .data .transforms .presets .ssd .transform_test (frame , short = 512 )
53
+ x = x .as_in_context (ctx )
54
+ class_IDs , scores , bounding_boxs = detector (x )
55
+
56
+ pose_input , upscale_bbox = detector_to_simple_pose (img , class_IDs , scores , bounding_boxs , output_shape = (128 , 96 ), ctx = ctx )
57
+
58
+ # 姿态识别
59
+ if len (upscale_bbox ) > 0 :
60
+ predicted_heatmap = estimator (pose_input )
61
+ pred_coords , confidence = heatmap_to_coord (predicted_heatmap , upscale_bbox )
62
+ img = cv_plot_keypoints (img , pred_coords , confidence , class_IDs , bounding_boxs , scores )
63
+
64
+ # 动作对比
65
+ scores = []
66
+ # print(stdAngle[pos])
67
+ visibles = ~ np .isnan (stdAngle [pos ]) # 样本中没有缺失值的点
68
+ angles = CalAngle (pred_coords , confidence )
69
+ for angle in angles :
70
+ angle_v = angle [visibles ] # 过滤样本中也有缺失值的点
71
+ print (angle_v )
72
+ if np .isnan (angle_v ).any (): # 还有缺失值
73
+ scores .append ('NaN' )
74
+ else :
75
+ scores .append ('{:.4f}' .format (r2_score (angle_v , stdAngle [pos ][visibles ])))
76
+ pos += 1
77
+
78
+ cv_plot_image (img ,
79
+ upperleft_txt = f"FPS:{ (1.0 / (time .time () - fps_time )):.2f} " , upperleft_txt_corner = (10 ,25 ),
80
+ left_txt_list = scores , canvas_name = 'pose' )
81
+ fps_time = time .time ()
82
+ # cv.imshow('demo', frame2)
83
+
84
+ # ESC键退出
85
+ if cv .waitKey (1 ) == 27 :
86
+ break
87
+
88
+ ret1 , frame1 = cap1 .read ()
89
+ ret2 , frame2 = cap2 .read ()
90
+
91
+ cv .destroyAllWindows ()
92
+
93
+ cap1 .release ()
94
+ cap2 .release ()
0 commit comments