Skip to content

Commit 7397eb7

Browse files
Dmytro Dzhulgakovfacebook-github-bot
Dmytro Dzhulgakov
authored andcommitted
End to end hack to call server side Caffe2 ops (pytorch#18267)
Summary: Pull Request resolved: pytorch#18267 Motivation: we don't actually want to use it for real under any circumstances. This is an idea to unblock our internal progress and parallelize workstreams. We can easily define schemas for all ops in question and implement forwarding to C2 ops which is NOT going to be performant. Then several things can be happening in parallel: * move code of ops outside of C2 ops that depend on protobuf into c10 * development of optimization/fusion passes * building python-level wrappers with clean API * improving perf This demonstrates, Relu, quant, dequant. It seems to cover all use cases necessary (maybe except weights prepacking). Ideally I'd demonstrate Conv, but will get to it later in a separate PR (contributions welcomed) Reviewed By: ezyang Differential Revision: D14531232 fbshipit-source-id: 4cd4a71ae0cb373c6c0e81f965c442b82a1b4069
1 parent f6df6ae commit 7397eb7

File tree

5 files changed

+209
-8
lines changed

5 files changed

+209
-8
lines changed

test/test_quantized.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
import torch.jit
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
import numpy as np
6+
import unittest
7+
from caffe2.python import core
8+
from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
9+
skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
10+
freeze_rng_state, set_rng_seed
11+
12+
13+
def canonical(graph):
14+
return str(torch._C._jit_pass_canonicalize(graph))
15+
16+
17+
@unittest.skipIf("Relu_ENGINE_DNNLOWP" not in core._REGISTERED_OPERATORS, "fbgemm-based Caffe2 ops are not linked")
18+
class TestQuantized(TestCase):
19+
def test_relu(self):
20+
a = (torch.tensor([4, 6, 1, 10], dtype=torch.uint8), 0.01, 5)
21+
r = torch.ops.c10.quantized_relu(a)
22+
np.testing.assert_equal(r[0].numpy(), torch.tensor([5, 6, 5, 10], dtype=torch.uint8).numpy())
23+
np.testing.assert_almost_equal(0.01, r[1])
24+
self.assertEqual(5, r[2])
25+
26+
def test_quantize(self):
27+
a = (torch.tensor([4, 6, 1, 10], dtype=torch.uint8), 0.01, 5)
28+
r = torch.ops.c10.dequantize(a)
29+
np.testing.assert_almost_equal(r.numpy(), [-0.01, 0.01, -0.04, 0.05])
30+
# default args
31+
q_def = torch.ops.c10.quantize(r)
32+
# specified
33+
q = torch.ops.c10.quantize(r, scale=0.01, zero_point=5)
34+
np.testing.assert_equal(q[0].numpy(), a[0].numpy())
35+
np.testing.assert_almost_equal(q[1], a[1])
36+
self.assertEqual(q[2], a[2])
37+
38+
def test_script(self):
39+
@torch.jit.script
40+
def foo(x):
41+
# type: (Tuple[Tensor, float, int]) -> Tuple[Tensor, float, int]
42+
return torch.ops.c10.quantized_relu(x)
43+
self.assertExpectedInline(canonical(foo.graph), '''\
44+
graph(%x : (Tensor, float, int)):
45+
%1 : (Tensor, float, int) = c10::quantized_relu(%x)
46+
return (%1)
47+
''')
48+
49+
50+
if __name__ == '__main__':
51+
run_tests()

tools/build_variables.py

+2
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
"torch/csrc/jit/passes/utils/memory_dag.cpp",
9090
"torch/csrc/jit/register_prim_ops.cpp",
9191
"torch/csrc/jit/register_special_ops.cpp",
92+
"torch/csrc/jit/register_quantized_ops.cpp",
9293
"torch/csrc/jit/scope.cpp",
9394
"torch/csrc/jit/script/compiler.cpp",
9495
"torch/csrc/jit/script/edit_distance.cpp",
@@ -199,6 +200,7 @@ def add_torch_libs():
199200
"//caffe2/aten:ATen-cpu",
200201
"//caffe2/caffe2:caffe2_cpu",
201202
"//caffe2/torch/lib/libshm:libshm",
203+
"//caffe2/caffe2/quantization/server:dnnlowp_ops",
202204
],
203205
external_deps=[
204206
("nanopb", None, "protobuf-nanopb"),

tools/run-clang-tidy-in-ci.sh

+9-8
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,13 @@ fi
3838
# Run Clang-Tidy
3939
# The negative filters below are to exclude files that include onnx_pb.h or
4040
# caffe2_pb.h, otherwise we'd have to build protos as part of this CI job.
41-
time python tools/clang_tidy.py \
42-
--verbose \
43-
--paths torch/csrc/ \
44-
--diff "$BASE_BRANCH" \
45-
-g"-torch/csrc/distributed/Module.cpp" \
46-
-g"-torch/csrc/jit/export.cpp" \
47-
-g"-torch/csrc/jit/import.cpp" \
48-
-g"-torch/csrc/jit/netdef_converter.cpp" \
41+
time python tools/clang_tidy.py \
42+
--verbose \
43+
--paths torch/csrc/ \
44+
--diff "$BASE_BRANCH" \
45+
-g"-torch/csrc/distributed/Module.cpp" \
46+
-g"-torch/csrc/jit/export.cpp" \
47+
-g"-torch/csrc/jit/import.cpp" \
48+
-g"-torch/csrc/jit/netdef_converter.cpp" \
49+
-g"-torch/csrc/jit/register_quantized_ops.cpp" \
4950
"$@"

torch/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ set(TORCH_SRCS
169169
${TORCH_SRC_DIR}/csrc/jit/fuser/interface.cpp
170170
${TORCH_SRC_DIR}/csrc/jit/register_prim_ops.cpp
171171
${TORCH_SRC_DIR}/csrc/jit/register_special_ops.cpp
172+
${TORCH_SRC_DIR}/csrc/jit/register_quantized_ops.cpp
172173
${TORCH_SRC_DIR}/csrc/jit/scope.cpp
173174
${TORCH_SRC_DIR}/csrc/jit/script/compiler.cpp
174175
${TORCH_SRC_DIR}/csrc/jit/testing/file_check.cpp
+146
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
// WARNING! WARNING! WARNING!
2+
// This file is a temporary hack to enable development of pytorch quantization
3+
//
4+
// It effectively wraps Caffe2 ops as is through custom jit ops mechanism
5+
// It obviously has terrible performance - caffe2 operator instance is created
6+
// on each invocation and also creation involves creating a protobuf (sigh...)
7+
//
8+
// Our plan is to implement quantized operators natively in c10 as operators and
9+
// also enforce some additional contracts on operator semantics:
10+
// - explicitly express weights prepacking as a separate operator to signify
11+
// reliance on weights being constant
12+
// - don't modify arguments of the op (OperatorDef) to store data
13+
// - explicitly model figuring out quantization params for dynamic quantization
14+
// instead of memorizing the first batch's params
15+
16+
#include <torch/csrc/jit/custom_operator.h>
17+
#include <torch/csrc/jit/operator.h>
18+
19+
#include <caffe2/core/operator.h>
20+
#include <caffe2/core/tensor_int8.h>
21+
#include <torch/csrc/autograd/variable.h>
22+
23+
namespace torch {
24+
namespace jit {
25+
26+
using caffe2::int8::Int8TensorCPU;
27+
28+
namespace {
29+
30+
caffe2::Tensor from_at_tensor(const c10::IValue& v) {
31+
return caffe2::Tensor(autograd::Variable(std::move(v).toTensor()).data());
32+
}
33+
34+
Int8TensorCPU from_proxy(const c10::IValue& proxy) {
35+
auto t = std::move(proxy).toTuple();
36+
Int8TensorCPU r;
37+
r.t = from_at_tensor(t->elements()[0]);
38+
r.scale = t->elements()[1].toDouble();
39+
r.zero_point = t->elements()[2].toInt();
40+
return r;
41+
}
42+
43+
at::Tensor to_proxy(const caffe2::Tensor& t) {
44+
return autograd::make_variable(at::Tensor(t.UnsafeSharedInstance()), false);
45+
}
46+
47+
c10::intrusive_ptr<c10::ivalue::Tuple> to_proxy(const Int8TensorCPU& t) {
48+
return c10::ivalue::Tuple::create({to_proxy(t.t), t.scale, t.zero_point});
49+
}
50+
51+
// TODO: replace this with c10 registration when it's ready
52+
RegisterOperators reg({
53+
Operator(
54+
// NOTE: we put outout in double parens because it's an output of type
55+
// tuple, not a tuple of multiple outputs
56+
"c10::quantized_relu((Tensor, float, int) self) -> ((Tensor, float, int))",
57+
// TODO: can't use C++ inference - doesn't work yet for tuple types
58+
[](Stack& stack) {
59+
AT_ASSERT(caffe2::GetRegisteredOperators().count(
60+
caffe2::OpRegistryKey("Relu", "DNNLOWP")))
61+
62+
// TODO: refactor the underlying op implementation and inline it in
63+
// c10 kernel
64+
caffe2::Workspace ws;
65+
ws.CreateBlob("X")->Reset(
66+
new Int8TensorCPU(from_proxy(std::move(peek(stack, 0, 1)))));
67+
68+
auto def = caffe2::CreateOperatorDef(
69+
"Relu", "proxy", {"X"}, {"Y"}, caffe2::DeviceOption(), "DNNLOWP");
70+
auto op = caffe2::CreateOperator(def, &ws);
71+
72+
op->Run();
73+
74+
drop(stack, 1);
75+
pack(stack, to_proxy(ws.GetBlob("Y")->Get<Int8TensorCPU>()));
76+
return 0;
77+
}),
78+
79+
Operator(
80+
"c10::quantize(Tensor X, float? scale = None, int? zero_point = None) -> ((Tensor, float, int))",
81+
[](Stack& stack) {
82+
AT_ASSERT(caffe2::GetRegisteredOperators().count(
83+
caffe2::OpRegistryKey("Quantize", "DNNLOWP")))
84+
85+
// TODO: refactor the underlying op implementation and inline it in
86+
// c10 kernel
87+
caffe2::Workspace ws;
88+
ws.CreateBlob("X")->Reset(
89+
new caffe2::Tensor(from_at_tensor(std::move(peek(stack, 0, 3)))));
90+
91+
auto def = caffe2::CreateOperatorDef(
92+
"Quantize",
93+
"proxy",
94+
{"X"},
95+
{"Y"},
96+
caffe2::DeviceOption(),
97+
"DNNLOWP");
98+
auto s = peek(stack, 1, 3).toOptional<float>();
99+
if (s.has_value()) {
100+
def.add_arg()->CopyFrom(caffe2::MakeArgument("Y_scale", *s));
101+
}
102+
auto zp = peek(stack, 2, 3).toOptional<int32_t>();
103+
if (zp.has_value()) {
104+
def.add_arg()->CopyFrom(caffe2::MakeArgument("Y_zero_point", *zp));
105+
}
106+
auto op = caffe2::CreateOperator(def, &ws);
107+
108+
op->Run();
109+
110+
drop(stack, 3);
111+
pack(stack, to_proxy(ws.GetBlob("Y")->Get<Int8TensorCPU>()));
112+
return 0;
113+
}),
114+
115+
Operator(
116+
"c10::dequantize((Tensor, float, int) x_q) -> Tensor",
117+
// TODO: can't use C++ inference - doesn't work yet for tuple types
118+
[](Stack& stack) {
119+
AT_ASSERT(caffe2::GetRegisteredOperators().count(
120+
caffe2::OpRegistryKey("Dequantize", "DNNLOWP")))
121+
122+
// TODO: refactor the underlying op implementation and inline it in
123+
// c10 kernel
124+
caffe2::Workspace ws;
125+
ws.CreateBlob("X")->Reset(
126+
new Int8TensorCPU(from_proxy(std::move(peek(stack, 0, 1)))));
127+
128+
auto def = caffe2::CreateOperatorDef(
129+
"Dequantize",
130+
"proxy",
131+
{"X"},
132+
{"Y"},
133+
caffe2::DeviceOption(),
134+
"DNNLOWP");
135+
auto op = caffe2::CreateOperator(def, &ws);
136+
137+
op->Run();
138+
139+
drop(stack, 1);
140+
pack(stack, to_proxy(ws.GetBlob("Y")->Get<caffe2::Tensor>()));
141+
return 0;
142+
}),
143+
});
144+
} // namespace
145+
} // namespace jit
146+
} // namespace torch

0 commit comments

Comments
 (0)