|
| 1 | +## OpenCV 调用tensorflow |
| 2 | + |
| 3 | +### 概述 |
| 4 | + |
| 5 | +✔️OpenCV在DNN模块中支持直接调用tensorflow object detection训练导出的模型使用,支持的模型包括 |
| 6 | +- SSD |
| 7 | +- Faster-RCNN |
| 8 | +- Mask-RCNN |
| 9 | + |
| 10 | +✔️ 利用这三种经典的对象检测网络,这样就可以实现从tensorflow模型训练、导出模型、在OpenCV DNN调用模型网络实现自定义对象检测的技术。 |
| 11 | + |
| 12 | +✔️ OpenCV3.4.1以上版本支持tensorflow1.11版本以上的对象检测框架(object detetion)模型导出使用,当前支持的模型包括以下: |
| 13 | + |
| 14 | +Model | Version | -| - |
| 15 | +---|---|---|--- |
| 16 | +MobileNet-SSD v1|2017_11_17| [weights](http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2017_11_17.tar.gz)| [config](https://github.com/opencv/opencv_extra/blob/master/testdata/dnn/ssd_mobilenet_v1_coco_2017_11_17.pbtxt)| |
| 17 | +MobileNet-SSD v1 PPN| 2018_07_03| [weights](http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_ppn_shared_box_predictor_300x300_coco14_sync_2018_07_03.tar.gz)| [config](https://github.com/opencv/opencv_extra/blob/master/testdata/dnn/ssd_mobilenet_v1_ppn_coco.pbtxt)| |
| 18 | +MobileNet-SSD v2| 2018_03_29| [weights](http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v2_coco_2018_03_29.tar.gz)| [config](https://github.com/opencv/opencv_extra/blob/master/testdata/dnn/ssd_mobilenet_v2_coco_2018_03_29.pbtxt)| |
| 19 | +Inception-SSD v2| 2017_11_17| [weights](http://download.tensorflow.org/models/object_detection/ssd_inception_v2_coco_2017_11_17.tar.gz)| [config](https://github.com/opencv/opencv_extra/blob/master/testdata/dnn/ssd_inception_v2_coco_2017_11_17.pbtxt)| |
| 20 | +Faster-RCNN Inception v2| 2018_01_28| [weights](http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_v2_coco_2018_01_28.tar.gz)| [config](https://github.com/opencv/opencv_extra/blob/master/testdata/dnn/faster_rcnn_inception_v2_coco_2018_01_28.pbtxt)| |
| 21 | +Faster-RCNN ResNet-50| 2018_01_28| [weights](http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet50_coco_2018_01_28.tar.gz)| [config](https://github.com/opencv/opencv_extra/blob/master/testdata/dnn/faster_rcnn_resnet50_coco_2018_01_28.pbtxt)| |
| 22 | +Mask-RCNN Inception v2| 2018_01_28| [weights](http://download.tensorflow.org/models/object_detection/mask_rcnn_inception_v2_coco_2018_01_28.tar.gz)| [config](https://github.com/opencv/opencv_extra/blob/master/testdata/dnn/mask_rcnn_inception_v2_coco_2018_01_28.pbtxt)| |
| 23 | + |
| 24 | +✏️ 使用tensorflow object detection API框架进行迁移学习训练模型,导出预测图之后,然后通过OpenCV3.4.1以上版本提供几个python脚本导出graph配置文件,这样就可以在OpenCV DNN模块中使用tensorflow相关的模型了。 |
| 25 | + |
| 26 | +### 使用tensorflow |
| 27 | + |
| 28 | +✔️使用tensorflow预测: |
| 29 | +``` |
| 30 | +import tensorflow as tf |
| 31 | +import cv2 |
| 32 | +
|
| 33 | +# Read the graph. |
| 34 | +model_dir = '../faster_rcnn_resnet50_coco_2018_01_28/frozen_inference_graph.pb' |
| 35 | +with tf.gfile.FastGFile(model_dir, 'rb') as f: |
| 36 | + graph_def = tf.GraphDef() |
| 37 | + graph_def.ParseFromString(f.read()) |
| 38 | +
|
| 39 | +with tf.Session() as sess: |
| 40 | + # Restore session |
| 41 | + sess.graph.as_default() |
| 42 | + tf.import_graph_def(graph_def, name='') |
| 43 | +
|
| 44 | + # Read and preprocess an image. |
| 45 | + img = cv2.imread('cat.jpg') |
| 46 | + rows = img.shape[0] |
| 47 | + cols = img.shape[1] |
| 48 | + inp = cv2.resize(img, (300, 300)) |
| 49 | + inp = inp[:, :, [2, 1, 0]] # BGR2RGB |
| 50 | +
|
| 51 | + # Run the model |
| 52 | + out = sess.run([sess.graph.get_tensor_by_name('num_detections:0'), |
| 53 | + sess.graph.get_tensor_by_name('detection_scores:0'), |
| 54 | + sess.graph.get_tensor_by_name('detection_boxes:0'), |
| 55 | + sess.graph.get_tensor_by_name('detection_classes:0')], |
| 56 | + feed_dict={'image_tensor:0': inp.reshape(1, inp.shape[0], inp.shape[1], 3)}) |
| 57 | +
|
| 58 | + # Visualize detected bounding boxes. |
| 59 | + num_detections = int(out[0][0]) |
| 60 | + for i in range(num_detections): |
| 61 | + classId = int(out[3][0][i]) |
| 62 | + score = float(out[1][0][i]) |
| 63 | + bbox = [float(v) for v in out[2][0][i]] |
| 64 | + if score > 0.8: |
| 65 | + x = bbox[1] * cols |
| 66 | + y = bbox[0] * rows |
| 67 | + right = bbox[3] * cols |
| 68 | + bottom = bbox[2] * rows |
| 69 | + cv2.rectangle(img, (int(x), int(y)), (int(right), int(bottom)), (125, 255, 51), thickness=2) |
| 70 | +``` |
| 71 | +<img src="./cat.jpg"> |
| 72 | +<img src="./cat_tensor.jpg"> |
| 73 | + |
| 74 | +### 调用tensorflow |
| 75 | + |
| 76 | +✔️根据tensorflow中迁移学习或者下载预训练模型不同,OpenCV DNN 模块提供如下可以使用脚本生成对应的模型配置文件 |
| 77 | + |
| 78 | +``` |
| 79 | +tf_text_graph_ssd.py |
| 80 | +
|
| 81 | +tf_text_graph_faster_rcnn.py |
| 82 | +
|
| 83 | +tf_text_graph_mask_rcnn.py |
| 84 | +``` |
| 85 | +✔️这是因为OpenCV DNN需要根据text版本的模型描述文件来解析tensorflow的pb文件,实现网络模型加载。 |
| 86 | + |
| 87 | +✔️对检测模型,生成模型描述文件运行以下命令行: |
| 88 | +``` |
| 89 | +python tf_text_graph_ssd.py |
| 90 | +
|
| 91 | +--input /path/to/model.pb |
| 92 | +
|
| 93 | +--config /path/to/example.config |
| 94 | +
|
| 95 | +--output /path/to/graph.pbtxt |
| 96 | +``` |
| 97 | + |
| 98 | +✔️采用faster_res50目标检测模型生成pbtxt的输出结果: |
| 99 | + |
| 100 | +``` |
| 101 | +python tf_text_graph_faster_rcnn.py \ |
| 102 | +--input faster_rcnn_resnet50_coco_2018_01_28/frozen_inference_graph.pb \ |
| 103 | +--output faster_rcnn_resnet50_coco_2018_01_28/graph.pbtxt \ |
| 104 | +--config faster_rcnn_resnet50_coco_2018_01_28/pipeline.config |
| 105 | +
|
| 106 | +Number of classes: 90 |
| 107 | +Scales: [0.25, 0.5, 1.0, 2.0] |
| 108 | +Aspect ratios: [0.5, 1.0, 2.0] |
| 109 | +Width stride: 16.000000 |
| 110 | +Height stride: 16.000000 |
| 111 | +Features stride: 16.000000 |
| 112 | +``` |
| 113 | +✔️opencv调用tensorflow预测目标: |
| 114 | +```python |
| 115 | +import cv2 |
| 116 | + |
| 117 | +inference_pb = "../faster_rcnn_resnet50_coco_2018_01_28/frozen_inference_graph.pb"; |
| 118 | +graph_text = "../faster_rcnn_resnet50_coco_2018_01_28/graph.pbtxt"; |
| 119 | + |
| 120 | +# load tensorflow model |
| 121 | +net = cv2.dnn.readNetFromTensorflow(inference_pb, graph_text) |
| 122 | +image = cv2.imread("cat.jpg") |
| 123 | +h = image.shape[0] |
| 124 | +w = image.shape[1] |
| 125 | + |
| 126 | +# 检测 |
| 127 | +net.setInput(cv2.dnn.blobFromImage(image, size=(300, 300), swapRB=True, crop=False)) |
| 128 | +cvOut = net.forward() |
| 129 | +for detection in cvOut[0,0,:,:]: |
| 130 | + score = float(detection[2]) |
| 131 | + if score > 0.5: |
| 132 | + left = detection[3]*w |
| 133 | + top = detection[4]*h |
| 134 | + right = detection[5]*w |
| 135 | + bottom = detection[6]*h |
| 136 | + |
| 137 | + # 绘制 |
| 138 | + cv2.rectangle(image, (int(left), int(top)), (int(right), int(bottom)), (0, 255, 0), thickness=2) |
| 139 | + cv2.putText(image, "score:%.2f"%score, (int(left), int(top)-2), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 1) |
| 140 | +``` |
| 141 | +<img src="./result_cat.jpg"> |
0 commit comments