Skip to content

Commit 77a7285

Browse files
jspark1105facebook-github-bot
authored andcommitted
add more Python interface functions to make quantization simpler (pytorch#18246)
Summary: Pull Request resolved: pytorch#18246 Simplifies histogram collection and quantization process. Histogram collection before this diff was something like this ``` from caffe2.quantization.server import dnnlowp_pybind11 ... dnnlowp_pybind11.ObserveHistogramOfOutput(hist_file) for ... workspace.RunNet(predict_net) dnnlowp_pybind11.ClearNetObservers() # This is to trigger Stop function in the observer to dump out histogram file but this can have unintended consequence of also clearing all the other useful observers we attached ``` After this diff we can ``` workspace.CreateNet(predict_net) # Note we need to create net to have a net to attach observer histogram_observer = dnnlowp_pybind11.AddHistogramObserver(predic_net, hist_file) for ... workspace.RunNet(predict_net) predict_net.RemoveObserver(histogram_observer) ``` Choosing quantization parameters of weights before this diff was something like this ``` dnnlowp_pybind11.ObserveHistogramOfOutput(weight_hist_file) workspace.RunNetOnce(init_net) dnnlowp_pybind11.ClearNetObservers() # Has same issue as the histogram collection example above dnnlowp_pybind11.RegisterQuantizationParamsWithHistogram( weight_hist_file, is_weight=True, qparams_output_file_name=qparams_file ) workspace.CreateNet(init_net, overwrite=True) dnnlowp_pybind11.ClearNetObservers() logger.info("Loading quantization params from {}".format(qparams_file)) blobs_to_qparams = {} with open(qparams_file) as f: lines = f.readlines() for line in lines: op_id, op_type, output_id, tensor_name, mini, maxi, scale, zero_point, precision = ( line.split() ) op_id = int(op_id) output_id = int(output_id) op = net.Proto().op[op_id] if op_type != op.type or op.output[output_id] != tensor_name: print( "Corrupt qparams file {} {} {} {} {}".format( qparams_file, op_type, op.type, op.output[output_id], tensor_name ) ) blobs_to_qparams[tensor_name] = QuantizationParam(float(scale), int(zero_point)) ``` After this diff this can be simplified to ``` blobs_to_qparams = {} for op in init_net.Proto().op: for output in op.output: scale, zero_point = dnnlowp_pybind11.ChooseQuantizationParams(output) blobs_to_qparams[output] = QuantizationParam(scale, zero_point) ``` Reviewed By: dskhudia Differential Revision: D14544694 fbshipit-source-id: 4fd06cd63256201e2e9d15c39f503138d1be53c2
1 parent f3cf6ed commit 77a7285

File tree

1 file changed

+89
-0
lines changed

1 file changed

+89
-0
lines changed

caffe2/quantization/server/pybind.cc

+89
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22
#include "activation_distribution_observer.h"
33
#include "caffe2_dnnlowp_utils.h"
44

5+
namespace caffe2 {
6+
namespace python {
7+
// defined in caffe2/python/pybind_state.cc
8+
Workspace* GetCurrentWorkspace();
9+
} // namespace python
10+
} // namespace caffe2
11+
512
PYBIND11_MODULE(dnnlowp_pybind11, m) {
613
using namespace std;
714
using namespace caffe2;
@@ -33,6 +40,61 @@ PYBIND11_MODULE(dnnlowp_pybind11, m) {
3340
pybind11::arg("dump_freq") = -1,
3441
pybind11::arg("mul_nets") = false);
3542

43+
m.def(
44+
"AddHistogramObserver",
45+
[](const string& net_name,
46+
const string& out_file_name,
47+
int dump_freq,
48+
bool mul_nets) {
49+
Workspace* gWorkspace = caffe2::python::GetCurrentWorkspace();
50+
CAFFE_ENFORCE(gWorkspace);
51+
CAFFE_ENFORCE(
52+
gWorkspace->GetNet(net_name), "Can't find net ", net_name);
53+
pybind11::gil_scoped_release g;
54+
55+
NetBase* net = gWorkspace->GetNet(net_name);
56+
const Observable<NetBase>::Observer* observer = nullptr;
57+
58+
observer = net->AttachObserver(make_unique<HistogramNetObserver>(
59+
net, out_file_name, 2048, dump_freq, mul_nets));
60+
61+
CAFFE_ENFORCE(observer != nullptr);
62+
return pybind11::cast(observer);
63+
},
64+
pybind11::arg("net_name"),
65+
pybind11::arg("out_file_name"),
66+
pybind11::arg("dump_freq") = -1,
67+
pybind11::arg("mul_nets") = false);
68+
69+
m.def(
70+
"ChooseQuantizationParams",
71+
[](const std::string& blob_name) {
72+
Workspace* gWorkspace = caffe2::python::GetCurrentWorkspace();
73+
CAFFE_ENFORCE(gWorkspace);
74+
pybind11::gil_scoped_release g;
75+
76+
const auto* blob = gWorkspace->GetBlob(blob_name);
77+
if (blob == nullptr) {
78+
LOG(WARNING) << "Can't find blob " << blob_name;
79+
} else if (BlobIsTensorType(*blob, CPU)) {
80+
LOG(WARNING) << "Blob " << blob_name << " is not a tensor";
81+
} else {
82+
const auto& tensor = blob->template Get<Tensor>();
83+
if (tensor.IsType<float>()) {
84+
dnnlowp::QuantizationFactory* qfactory =
85+
dnnlowp::QuantizationFactory::GetDefaultInstance();
86+
dnnlowp::TensorQuantizationParams qparams =
87+
qfactory->ChooseQuantizationParams(
88+
tensor.data<float>(), tensor.size(), true /*weight*/);
89+
return std::tuple<float, int>(qparams.scale, qparams.zero_point);
90+
} else {
91+
LOG(WARNING) << "Blob " << blob_name << " is not a float tensor";
92+
}
93+
}
94+
return std::tuple<float, int>(1.0, 0);
95+
},
96+
pybind11::arg("blob_name"));
97+
3698
m.def(
3799
"RegisterQuantizationParams",
38100
[](const string& min_max_file_name,
@@ -66,6 +128,33 @@ PYBIND11_MODULE(dnnlowp_pybind11, m) {
66128
pybind11::arg("is_weight") = false,
67129
pybind11::arg("qparams_output_file_name") = "");
68130

131+
m.def(
132+
"AddRegisterQuantizationParamsWithHistogramObserver",
133+
[](const string& net_name,
134+
const string& histogram_file_name,
135+
int is_weight,
136+
const string& qparams_output_file_name) {
137+
Workspace* gWorkspace = caffe2::python::GetCurrentWorkspace();
138+
CAFFE_ENFORCE(gWorkspace);
139+
CAFFE_ENFORCE(
140+
gWorkspace->GetNet(net_name), "Can't find net ", net_name);
141+
pybind11::gil_scoped_release g;
142+
143+
NetBase* net = gWorkspace->GetNet(net_name);
144+
const Observable<NetBase>::Observer* observer = nullptr;
145+
146+
observer = net->AttachObserver(
147+
make_unique<RegisterQuantizationParamsWithHistogramNetObserver>(
148+
net, histogram_file_name, is_weight, qparams_output_file_name));
149+
150+
CAFFE_ENFORCE(observer != nullptr);
151+
return pybind11::cast(observer);
152+
},
153+
pybind11::arg("net_name"),
154+
pybind11::arg("histogram_file_name"),
155+
pybind11::arg("is_weight") = false,
156+
pybind11::arg("qparams_output_file_name") = "");
157+
69158
m.def(
70159
"AddScaleZeroOffsetArgumentsWithHistogram",
71160
[](const pybind11::bytes& net_def_bytes,

0 commit comments

Comments
 (0)