Skip to content

Commit 54bce48

Browse files
committed
动作对比功能已完成
1 parent 26753a0 commit 54bce48

File tree

4 files changed

+211
-0
lines changed

4 files changed

+211
-0
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,6 @@ dmypy.json
127127

128128
# Pyre type checker
129129
.pyre/
130+
131+
.vscode
132+
data/*

angle.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import numpy as np
2+
from mxnet import ndarray
3+
4+
# 需要测量角度的部位,每个部位需要用它本身和与之连接的两个节点来计算角度
5+
# 第一点是关节点
6+
KeyPoints = [
7+
(5, 6, 7), # 左肩
8+
(6, 7, 8), # 右肩
9+
(7, 5, 9), # 左臂
10+
(8, 6, 10), # 右臂
11+
(11, 5, 13), # 左胯
12+
(12, 6, 14), # 右胯
13+
(13, 11, 15), # 左膝
14+
(14, 12, 16), # 右膝
15+
]
16+
17+
# 计算所有人关键部位的夹角的余弦值
18+
def CalAngle(coords, confidence, keypoint_thresh=0.2):
19+
joint_visible = confidence[:, :, 0] > keypoint_thresh
20+
angles = np.empty((coords.shape[0], len(KeyPoints)))
21+
22+
for i, pts in enumerate(coords):
23+
# 某个人
24+
for j, keyPoint in enumerate(KeyPoints):
25+
# 是否识别到这个关节
26+
if joint_visible[i, keyPoint[0]] and joint_visible[i, keyPoint[1]] and joint_visible[i, keyPoint[2]]:
27+
# 计算
28+
# print(pts)
29+
30+
p0x = pts[keyPoint[0], 0].asscalar()
31+
p0y = pts[keyPoint[0], 1].asscalar()
32+
p1x = pts[keyPoint[1], 0].asscalar()
33+
p1y = pts[keyPoint[1], 1].asscalar()
34+
p2x = pts[keyPoint[2], 0].asscalar()
35+
p2y = pts[keyPoint[2], 1].asscalar()
36+
37+
v1 = np.array([ p1x - p0x, p1y - p0y ])
38+
v2 = np.array([ p2x - p0x, p2y - p0y ])
39+
40+
angles[i][j] = np.dot(v1, v2) / np.linalg.norm(v1) / np.linalg.norm(v2)
41+
42+
else:
43+
angles[i][j] = np.nan
44+
45+
return angles

data.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import argparse, time, os
2+
import cv2 as cv
3+
import mxnet as mx
4+
import numpy as np
5+
import gluoncv as gcv
6+
from gluoncv import data
7+
from gluoncv.model_zoo import get_model
8+
from gluoncv.data.transforms.pose import detector_to_simple_pose, heatmap_to_coord
9+
from gluoncv.utils.viz import cv_plot_image, cv_plot_keypoints
10+
from mxnet.gluon.data.vision import transforms
11+
from angle import CalAngle
12+
13+
# 读取参数
14+
parser = argparse.ArgumentParser()
15+
parser.add_argument('--video')
16+
args = parser.parse_args()
17+
18+
fps_time = 0
19+
20+
# 设置模型
21+
ctx = mx.gpu()
22+
23+
detector_name = "ssd_512_mobilenet1.0_coco"
24+
detector = get_model(detector_name, pretrained=True, ctx=ctx)
25+
26+
estimator_name = "simple_pose_resnet18_v1b"
27+
estimator = get_model(estimator_name, pretrained='ccd24037', ctx=ctx)
28+
29+
detector.reset_class(classes=['person'], reuse_weights={'person':'person'})
30+
31+
detector.hybridize()
32+
estimator.hybridize()
33+
34+
# 视频读取
35+
cap = cv.VideoCapture(args.video)
36+
37+
ret, frame = cap.read()
38+
features = []
39+
while ret:
40+
41+
# 目标检测
42+
frame = mx.nd.array(cv.cvtColor(frame, cv.COLOR_BGR2RGB)).astype('uint8')
43+
44+
x, img = gcv.data.transforms.presets.ssd.transform_test(frame, short=512)
45+
x = x.as_in_context(ctx)
46+
class_IDs, scores, bounding_boxs = detector(x)
47+
48+
pose_input, upscale_bbox = detector_to_simple_pose(img, class_IDs, scores, bounding_boxs, output_shape=(128, 96), ctx=ctx)
49+
50+
# 只识别一个人的姿态
51+
if len(upscale_bbox) == 1:
52+
predicted_heatmap = estimator(pose_input)
53+
pred_coords, confidence = heatmap_to_coord(predicted_heatmap, upscale_bbox)
54+
img = cv_plot_keypoints(img, pred_coords, confidence, class_IDs, bounding_boxs, scores)
55+
56+
X = CalAngle(pred_coords, confidence)[0]
57+
print(X)
58+
features.append(X)
59+
else:
60+
# 人数不对就插入nan
61+
print(np.nan)
62+
features.append(np.nan)
63+
64+
ret, frame = cap.read()
65+
66+
cap.release()
67+
68+
# 将一个视频的特征保存到文件
69+
np.savetxt(os.path.join('data', 'demo.tsv'), np.array(features), delimiter='\t', fmt='%4f')

run.py

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import argparse, time, os
2+
import cv2 as cv
3+
import mxnet as mx
4+
import numpy as np
5+
import gluoncv as gcv
6+
from gluoncv import data
7+
from gluoncv.model_zoo import get_model
8+
from gluoncv.data.transforms.pose import detector_to_simple_pose, heatmap_to_coord
9+
from gluoncv.utils.viz import cv_plot_image, cv_plot_keypoints
10+
from mxnet.gluon.data.vision import transforms
11+
from angle import CalAngle
12+
from sklearn.metrics import r2_score
13+
14+
# 读取参数
15+
parser = argparse.ArgumentParser()
16+
parser.add_argument('--input', default=0)
17+
parser.add_argument('--demo', required=True)
18+
parser.add_argument('--data', required=True)
19+
args = parser.parse_args()
20+
21+
fps_time = 0
22+
23+
# 设置模型
24+
ctx = mx.gpu()
25+
26+
detector_name = "ssd_512_mobilenet1.0_coco"
27+
detector = get_model(detector_name, pretrained=True, ctx=ctx)
28+
29+
estimator_name = "simple_pose_resnet18_v1b"
30+
estimator = get_model(estimator_name, pretrained='ccd24037', ctx=ctx)
31+
32+
detector.reset_class(classes=['person'], reuse_weights={'person':'person'})
33+
34+
detector.hybridize()
35+
estimator.hybridize()
36+
37+
# 视频读取
38+
cap1 = cv.VideoCapture(args.input)
39+
cap2 = cv.VideoCapture(args.demo)
40+
41+
# 标准特征
42+
stdAngle = np.loadtxt(args.data, delimiter='\t')
43+
pos = 0
44+
45+
ret1, frame1 = cap1.read()
46+
ret2, frame2 = cap2.read()
47+
while ret1 and ret2:
48+
49+
# 目标检测
50+
frame = mx.nd.array(cv.cvtColor(frame1, cv.COLOR_BGR2RGB)).astype('uint8')
51+
52+
x, img = gcv.data.transforms.presets.ssd.transform_test(frame, short=512)
53+
x = x.as_in_context(ctx)
54+
class_IDs, scores, bounding_boxs = detector(x)
55+
56+
pose_input, upscale_bbox = detector_to_simple_pose(img, class_IDs, scores, bounding_boxs, output_shape=(128, 96), ctx=ctx)
57+
58+
# 姿态识别
59+
if len(upscale_bbox) > 0:
60+
predicted_heatmap = estimator(pose_input)
61+
pred_coords, confidence = heatmap_to_coord(predicted_heatmap, upscale_bbox)
62+
img = cv_plot_keypoints(img, pred_coords, confidence, class_IDs, bounding_boxs, scores)
63+
64+
# 动作对比
65+
scores = []
66+
# print(stdAngle[pos])
67+
visibles = ~np.isnan(stdAngle[pos]) # 样本中没有缺失值的点
68+
angles = CalAngle(pred_coords, confidence)
69+
for angle in angles:
70+
angle_v = angle[visibles] # 过滤样本中也有缺失值的点
71+
print(angle_v)
72+
if np.isnan(angle_v).any(): # 还有缺失值
73+
scores.append('NaN')
74+
else:
75+
scores.append('{:.4f}'.format(r2_score(angle_v, stdAngle[pos][visibles])))
76+
pos += 1
77+
78+
cv_plot_image(img,
79+
upperleft_txt=f"FPS:{(1.0 / (time.time() - fps_time)):.2f}", upperleft_txt_corner=(10,25),
80+
left_txt_list=scores, canvas_name='pose')
81+
fps_time = time.time()
82+
# cv.imshow('demo', frame2)
83+
84+
# ESC键退出
85+
if cv.waitKey(1) == 27:
86+
break
87+
88+
ret1, frame1 = cap1.read()
89+
ret2, frame2 = cap2.read()
90+
91+
cv.destroyAllWindows()
92+
93+
cap1.release()
94+
cap2.release()

0 commit comments

Comments
 (0)