Skip to content

Commit f14b559

Browse files
committed
继续优化代码
1 parent 3b41b82 commit f14b559

File tree

5 files changed

+43
-51
lines changed

5 files changed

+43
-51
lines changed

data.py

+9-26
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,21 @@
22
import cv2 as cv
33
import mxnet as mx
44
import numpy as np
5-
import gluoncv as gcv
6-
from gluoncv import data
7-
from gluoncv.model_zoo import get_model
5+
from gluoncv.data.transforms.presets.ssd import transform_test
86
from gluoncv.data.transforms.pose import detector_to_simple_pose, heatmap_to_coord
97
from gluoncv.utils.viz import cv_plot_image, cv_plot_keypoints
108
from mxnet.gluon.data.vision import transforms
9+
from model import ctx, detector, estimator
1110
from angle import AngeleCal
1211

1312
# 读取参数
1413
parser = argparse.ArgumentParser()
15-
parser.add_argument('--video')
14+
parser.add_argument('--input')
15+
parser.add_argument('--output', required=True)
1616
args = parser.parse_args()
1717

18-
fps_time = 0
19-
20-
# 设置模型
21-
ctx = mx.gpu()
22-
23-
detector_name = "ssd_512_mobilenet1.0_coco"
24-
detector = get_model(detector_name, pretrained=True, ctx=ctx)
25-
26-
estimator_name = "simple_pose_resnet18_v1b"
27-
estimator = get_model(estimator_name, pretrained='ccd24037', ctx=ctx)
28-
29-
detector.reset_class(classes=['person'], reuse_weights={'person':'person'})
30-
31-
detector.hybridize()
32-
estimator.hybridize()
33-
3418
# 视频读取
35-
cap = cv.VideoCapture(args.video)
19+
cap = cv.VideoCapture(args.input)
3620

3721
ret, frame = cap.read()
3822
features = []
@@ -41,14 +25,13 @@
4125
# 目标检测
4226
frame = mx.nd.array(cv.cvtColor(frame, cv.COLOR_BGR2RGB)).astype('uint8')
4327

44-
x, img = gcv.data.transforms.presets.ssd.transform_test(frame, short=512)
28+
x, img = transform_test(frame, short=512)
4529
x = x.as_in_context(ctx)
4630
class_IDs, scores, bounding_boxs = detector(x)
4731

4832
pose_input, upscale_bbox = detector_to_simple_pose(img, class_IDs, scores, bounding_boxs, output_shape=(128, 96), ctx=ctx)
4933

50-
# 只识别一个人的姿态
51-
if len(upscale_bbox) == 1:
34+
if len(upscale_bbox) > 0:
5235
predicted_heatmap = estimator(pose_input)
5336
pred_coords, confidence = heatmap_to_coord(predicted_heatmap, upscale_bbox)
5437
img = cv_plot_keypoints(img, pred_coords, confidence, class_IDs, bounding_boxs, scores)
@@ -57,7 +40,7 @@
5740
print(X)
5841
features.append(X)
5942
else:
60-
# 人数不对就插入nan
43+
# 人数不够就插入nan
6144
print(np.nan)
6245
features.append(np.nan)
6346

@@ -66,4 +49,4 @@
6649
cap.release()
6750

6851
# 将一个视频的特征保存到文件
69-
np.savetxt(os.path.join('data', 'demo.tsv'), np.array(features), delimiter='\t', fmt='%4f')
52+
np.savetxt(args.output, np.array(features), delimiter='\t', fmt='%4f')

fps.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import time
2+
3+
class FPS():
4+
__fps_time = 0
5+
6+
@staticmethod
7+
def fps():
8+
fps = f"FPS:{(1.0 / (time.time() - FPS.__fps_time)):.2f}"
9+
FPS.__fps_time = time.time()
10+
11+
return fps

model.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import mxnet as mx
2+
from gluoncv.model_zoo import get_model
3+
4+
ctx = mx.gpu()
5+
6+
detector_name = "ssd_512_mobilenet1.0_coco"
7+
detector = get_model(detector_name, pretrained=True, ctx=ctx)
8+
9+
estimator_name = "simple_pose_resnet18_v1b"
10+
estimator = get_model(estimator_name, pretrained='ccd24037', ctx=ctx)
11+
12+
detector.reset_class(classes=['person'], reuse_weights={'person':'person'})
13+
14+
detector.hybridize()
15+
estimator.hybridize()

requirements.txt

1.12 KB
Binary file not shown.

run.py

+8-25
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
import argparse, time, os
1+
import argparse, os
22
import cv2 as cv
33
import mxnet as mx
44
import numpy as np
5-
import gluoncv as gcv
6-
from gluoncv import data
7-
from gluoncv.model_zoo import get_model
5+
from gluoncv.data.transforms.presets.ssd import transform_test
86
from gluoncv.data.transforms.pose import detector_to_simple_pose, heatmap_to_coord
97
from gluoncv.utils.viz import cv_plot_image, cv_plot_keypoints
108
from mxnet.gluon.data.vision import transforms
9+
from model import ctx, detector, estimator
10+
from fps import FPS
1111
from angle import AngeleCal
1212

1313
# 读取参数
@@ -17,38 +17,23 @@
1717
parser.add_argument('--data', required=True)
1818
args = parser.parse_args()
1919

20-
fps_time = 0
21-
22-
# 设置模型
23-
ctx = mx.gpu()
24-
25-
detector_name = "ssd_512_mobilenet1.0_coco"
26-
detector = get_model(detector_name, pretrained=True, ctx=ctx)
27-
28-
estimator_name = "simple_pose_resnet18_v1b"
29-
estimator = get_model(estimator_name, pretrained='ccd24037', ctx=ctx)
30-
31-
detector.reset_class(classes=['person'], reuse_weights={'person':'person'})
32-
33-
detector.hybridize()
34-
estimator.hybridize()
35-
3620
# 视频读取
21+
# 1是输入视频,2是示例视频
3722
cap1 = cv.VideoCapture(args.input)
3823
cap2 = cv.VideoCapture(args.demo)
3924

4025
# 标准特征
4126
angeleCal = AngeleCal(args.data)
42-
pos = 0
4327

4428
ret1, frame1 = cap1.read()
4529
ret2, frame2 = cap2.read()
30+
4631
while ret1 and ret2:
4732

4833
# 目标检测
4934
frame = mx.nd.array(cv.cvtColor(frame1, cv.COLOR_BGR2RGB)).astype('uint8')
5035

51-
x, img = gcv.data.transforms.presets.ssd.transform_test(frame, short=512)
36+
x, img = transform_test(frame, short=512)
5237
x = x.as_in_context(ctx)
5338
class_IDs, scores, bounding_boxs = detector(x)
5439

@@ -66,11 +51,9 @@
6651
else:
6752
results = ['NaN']
6853

69-
print('result', results)
7054
cv_plot_image(img,
71-
upperleft_txt=f"FPS:{(1.0 / (time.time() - fps_time)):.2f}", upperleft_txt_corner=(10,25),
55+
upperleft_txt=FPS.fps(), upperleft_txt_corner=(10,25),
7256
left_txt_list=results, canvas_name='pose')
73-
fps_time = time.time()
7457
cv.imshow('demo', frame2)
7558

7659
# ESC键退出

0 commit comments

Comments
 (0)