|
2 | 2 | #include "activation_distribution_observer.h"
|
3 | 3 | #include "caffe2_dnnlowp_utils.h"
|
4 | 4 |
|
| 5 | +namespace caffe2 { |
| 6 | +namespace python { |
| 7 | +// defined in caffe2/python/pybind_state.cc |
| 8 | +Workspace* GetCurrentWorkspace(); |
| 9 | +} // namespace python |
| 10 | +} // namespace caffe2 |
| 11 | + |
5 | 12 | PYBIND11_MODULE(dnnlowp_pybind11, m) {
|
6 | 13 | using namespace std;
|
7 | 14 | using namespace caffe2;
|
@@ -33,6 +40,61 @@ PYBIND11_MODULE(dnnlowp_pybind11, m) {
|
33 | 40 | pybind11::arg("dump_freq") = -1,
|
34 | 41 | pybind11::arg("mul_nets") = false);
|
35 | 42 |
|
| 43 | + m.def( |
| 44 | + "AddHistogramObserver", |
| 45 | + [](const string& net_name, |
| 46 | + const string& out_file_name, |
| 47 | + int dump_freq, |
| 48 | + bool mul_nets) { |
| 49 | + Workspace* gWorkspace = caffe2::python::GetCurrentWorkspace(); |
| 50 | + CAFFE_ENFORCE(gWorkspace); |
| 51 | + CAFFE_ENFORCE( |
| 52 | + gWorkspace->GetNet(net_name), "Can't find net ", net_name); |
| 53 | + pybind11::gil_scoped_release g; |
| 54 | + |
| 55 | + NetBase* net = gWorkspace->GetNet(net_name); |
| 56 | + const Observable<NetBase>::Observer* observer = nullptr; |
| 57 | + |
| 58 | + observer = net->AttachObserver(make_unique<HistogramNetObserver>( |
| 59 | + net, out_file_name, 2048, dump_freq, mul_nets)); |
| 60 | + |
| 61 | + CAFFE_ENFORCE(observer != nullptr); |
| 62 | + return pybind11::cast(observer); |
| 63 | + }, |
| 64 | + pybind11::arg("net_name"), |
| 65 | + pybind11::arg("out_file_name"), |
| 66 | + pybind11::arg("dump_freq") = -1, |
| 67 | + pybind11::arg("mul_nets") = false); |
| 68 | + |
| 69 | + m.def( |
| 70 | + "ChooseQuantizationParams", |
| 71 | + [](const std::string& blob_name) { |
| 72 | + Workspace* gWorkspace = caffe2::python::GetCurrentWorkspace(); |
| 73 | + CAFFE_ENFORCE(gWorkspace); |
| 74 | + pybind11::gil_scoped_release g; |
| 75 | + |
| 76 | + const auto* blob = gWorkspace->GetBlob(blob_name); |
| 77 | + if (blob == nullptr) { |
| 78 | + LOG(WARNING) << "Can't find blob " << blob_name; |
| 79 | + } else if (BlobIsTensorType(*blob, CPU)) { |
| 80 | + LOG(WARNING) << "Blob " << blob_name << " is not a tensor"; |
| 81 | + } else { |
| 82 | + const auto& tensor = blob->template Get<Tensor>(); |
| 83 | + if (tensor.IsType<float>()) { |
| 84 | + dnnlowp::QuantizationFactory* qfactory = |
| 85 | + dnnlowp::QuantizationFactory::GetDefaultInstance(); |
| 86 | + dnnlowp::TensorQuantizationParams qparams = |
| 87 | + qfactory->ChooseQuantizationParams( |
| 88 | + tensor.data<float>(), tensor.size(), true /*weight*/); |
| 89 | + return std::tuple<float, int>(qparams.scale, qparams.zero_point); |
| 90 | + } else { |
| 91 | + LOG(WARNING) << "Blob " << blob_name << " is not a float tensor"; |
| 92 | + } |
| 93 | + } |
| 94 | + return std::tuple<float, int>(1.0, 0); |
| 95 | + }, |
| 96 | + pybind11::arg("blob_name")); |
| 97 | + |
36 | 98 | m.def(
|
37 | 99 | "RegisterQuantizationParams",
|
38 | 100 | [](const string& min_max_file_name,
|
@@ -66,6 +128,33 @@ PYBIND11_MODULE(dnnlowp_pybind11, m) {
|
66 | 128 | pybind11::arg("is_weight") = false,
|
67 | 129 | pybind11::arg("qparams_output_file_name") = "");
|
68 | 130 |
|
| 131 | + m.def( |
| 132 | + "AddRegisterQuantizationParamsWithHistogramObserver", |
| 133 | + [](const string& net_name, |
| 134 | + const string& histogram_file_name, |
| 135 | + int is_weight, |
| 136 | + const string& qparams_output_file_name) { |
| 137 | + Workspace* gWorkspace = caffe2::python::GetCurrentWorkspace(); |
| 138 | + CAFFE_ENFORCE(gWorkspace); |
| 139 | + CAFFE_ENFORCE( |
| 140 | + gWorkspace->GetNet(net_name), "Can't find net ", net_name); |
| 141 | + pybind11::gil_scoped_release g; |
| 142 | + |
| 143 | + NetBase* net = gWorkspace->GetNet(net_name); |
| 144 | + const Observable<NetBase>::Observer* observer = nullptr; |
| 145 | + |
| 146 | + observer = net->AttachObserver( |
| 147 | + make_unique<RegisterQuantizationParamsWithHistogramNetObserver>( |
| 148 | + net, histogram_file_name, is_weight, qparams_output_file_name)); |
| 149 | + |
| 150 | + CAFFE_ENFORCE(observer != nullptr); |
| 151 | + return pybind11::cast(observer); |
| 152 | + }, |
| 153 | + pybind11::arg("net_name"), |
| 154 | + pybind11::arg("histogram_file_name"), |
| 155 | + pybind11::arg("is_weight") = false, |
| 156 | + pybind11::arg("qparams_output_file_name") = ""); |
| 157 | + |
69 | 158 | m.def(
|
70 | 159 | "AddScaleZeroOffsetArgumentsWithHistogram",
|
71 | 160 | [](const pybind11::bytes& net_def_bytes,
|
|
0 commit comments