Skip to content

Commit 28a27f9

Browse files
committed
Add inference code.
1 parent 982ced6 commit 28a27f9

File tree

6 files changed

+1116
-0
lines changed

6 files changed

+1116
-0
lines changed

Diff for: inference-cpp/cnn-classification/CMakeLists.txt

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
2+
project(cnn-classification)
3+
4+
find_package(Torch REQUIRED)
5+
find_package(OpenCV REQUIRED)
6+
7+
add_executable(cnn-inference infer.cc ../../utils/opencvutils.cc ../../utils/torchutils.cc)
8+
9+
target_link_libraries(cnn-inference "${TORCH_LIBRARIES}")
10+
target_link_libraries(cnn-inference "${OpenCV_LIBS}")
11+
12+
set_property(TARGET cnn-inference PROPERTY CXX_STANDARD 11)
13+
set_property(TARGET cnn-inference PROPERTY OUTPUT_NAME predict)

Diff for: inference-cpp/cnn-classification/build.sh

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
rm -rf build
2+
3+
mkdir -p build
4+
cd build
5+
cmake -DCMAKE_PREFIX_PATH=../../libtorch ..
6+
make -j4
7+
cd ..
8+
9+
mv build/predict .
10+
11+
rm -rf build

Diff for: inference-cpp/cnn-classification/image.jpeg

12.4 KB
Loading

Diff for: inference-cpp/cnn-classification/infer.cc

+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#include <iostream>
2+
#include <vector>
3+
#include <tuple>
4+
#include <chrono>
5+
#include <fstream>
6+
#include <random>
7+
#include <string>
8+
#include <memory>
9+
10+
#include <torch/script.h>
11+
#include <torch/tensor.h>
12+
#include <torch/serialize.h>
13+
14+
#include <opencv2/core/core.hpp>
15+
#include <opencv2/highgui/highgui.hpp>
16+
#include <opencv2/imgproc/imgproc.hpp>
17+
18+
#include "../../utils/torchutils.h"
19+
#include "../../utils/opencvutils.h"
20+
21+
std::tuple<std::string, std::string> infer(
22+
cv::Mat image,
23+
int image_height, int image_width,
24+
std::vector<double> mean, std::vector<double> std,
25+
std::vector<std::string> labels,
26+
std::shared_ptr<torch::jit::script::Module> model) {
27+
28+
if (image.empty()) {
29+
std::cout << "WARNING: Cannot read image!" << std::endl;
30+
}
31+
32+
std::string pred = "";
33+
std::string prob = "0.0";
34+
35+
// Predict if image is not empty
36+
if (!image.empty()) {
37+
38+
// Preprocess image
39+
image = preprocess(image, image_height, image_width,
40+
mean, std);
41+
42+
// Forward
43+
std::vector<float> probs = forward({image, }, model);
44+
45+
// Postprocess
46+
tie(pred, prob) = postprocess(probs, labels);
47+
}
48+
49+
return std::make_tuple(pred, prob);
50+
}
51+
52+
int main(int argc, char **argv) {
53+
54+
if (argc != 4) {
55+
std::cerr << "usage: predict <path-to-image> <path-to-exported-script-module> <path-to-labels-file> \n";
56+
return -1;
57+
}
58+
59+
std::string image_path = argv[1];
60+
std::string model_path = argv[2];
61+
std::string labels_path = argv[3];
62+
63+
int image_height = 224;
64+
int image_width = 224;
65+
66+
// Read labels
67+
std::vector<std::string> labels;
68+
std::string label;
69+
std::ifstream labelsfile (labels_path);
70+
if (labelsfile.is_open())
71+
{
72+
while (getline(labelsfile, label))
73+
{
74+
labels.push_back(label);
75+
}
76+
labelsfile.close();
77+
}
78+
79+
std::vector<double> mean = {0.485, 0.456, 0.406};
80+
std::vector<double> std = {0.229, 0.224, 0.225};
81+
82+
cv::Mat image = cv::imread(image_path);
83+
std::shared_ptr<torch::jit::script::Module> model = read_model(model_path);
84+
85+
std::string pred, prob;
86+
tie(pred, prob) = infer(image, image_height, image_width, mean, std, labels, model);
87+
88+
std::cout << "PREDICTION : " << pred << std::endl;
89+
std::cout << "PROBABILITY : " << prob << std::endl;
90+
91+
return 0;
92+
}

0 commit comments

Comments
 (0)