Skip to content

Commit 6750e1e

Browse files
smessmerfacebook-github-bot
authored andcommitted
C10_REGISTER_CAFFE2_OPERATOR: Macro for registering c2 kernels (pytorch#16548)
Summary: Pull Request resolved: pytorch#16548 With this macro, a caffe2 operator can now directly be registered with c10. No need to write custom wrapper kernels anymore. Differential Revision: D13877076 fbshipit-source-id: e56846238c5bb4b1989b79855fd44d5ecf089c9c
1 parent ac4f66c commit 6750e1e

File tree

9 files changed

+127
-102
lines changed

9 files changed

+127
-102
lines changed

aten/src/ATen/core/dispatch/DispatchTable.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ class DispatchTable final {
174174
private:
175175
static size_t get_index_of_first_tensor_arg_(const FunctionSchema& schema) {
176176
for (size_t i = 0; i < schema.arguments().size(); ++i) {
177-
if (schema.arguments()[i].type()->isSubtypeOf(TensorType::get())) { // DynamicType means it's a tensor
177+
if (schema.arguments()[i].type()->isSubtypeOf(TensorType::get())) {
178178
return i;
179179
}
180180
}

aten/src/ATen/core/stack.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ static inline IValue pop(Stack& stack) {
5252
stack.pop_back();
5353
return r;
5454
}
55+
static inline std::vector<IValue> pop(Stack& stack, size_t n) {
56+
std::vector<IValue> result;
57+
result.reserve(n);
58+
for (size_t i = 0; i < n; ++i) {
59+
result.push_back(std::move(peek(stack, i, n)));
60+
}
61+
drop(stack, n);
62+
return result;
63+
}
5564

5665
// variadic pop:
5766
// int64_t a; at::Tensor b;

c10/util/flat_hash_map.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
// - https://github.com/skarupke/flat_hash_map/pull/26
66
// - replace size_t with uint64_t to fix it for 32bit
77
// - add "GCC diagnostic" pragma to ignore -Wshadow
8+
// - make sherwood_v3_table::convertible_to_iterator public because GCC5 seems to have issues with it otherwise
89

910
// Copyright Malte Skarupke 2017.
1011
// Distributed under the Boost Software License, Version 1.0.
@@ -293,9 +294,9 @@ class sherwood_v3_table : private EntryAlloc, private Hasher, private Equal
293294
using Entry = detailv3::sherwood_v3_entry<T>;
294295
using AllocatorTraits = std::allocator_traits<EntryAlloc>;
295296
using EntryPointer = typename AllocatorTraits::pointer;
296-
struct convertible_to_iterator;
297297

298298
public:
299+
struct convertible_to_iterator;
299300

300301
using value_type = T;
301302
using size_type = uint64_t;
@@ -924,7 +925,7 @@ class sherwood_v3_table : private EntryAlloc, private Hasher, private Equal
924925
return static_cast<Equal &>(*this)(lhs, rhs);
925926
}
926927

927-
private:
928+
public:
928929
struct convertible_to_iterator
929930
{
930931
EntryPointer it;

caffe2/core/c10_operator.h

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#pragma once
2+
3+
#include <vector>
4+
#include <ATen/core/dispatch/OpSchemaRegistration.h>
5+
#include <ATen/core/dispatch/KernelRegistration.h>
6+
#include <ATen/core/function_schema.h>
7+
8+
namespace caffe2 {
9+
namespace detail {
10+
11+
template<class Caffe2Operator> const c10::OperatorHandle& c10_op_handle_for_c2_op();
12+
template <class Caffe2Operator>
13+
void call_caffe2_op_from_c10(c10::Stack* stack, c10::KernelCache* cache) { // TODO Pass in correct cache type
14+
// precondition: on the stack, there's an IValue for each input and an IValue for each output.
15+
// The output ones could either be a preallocated tensor or ivalue::None.
16+
17+
const auto& schema = c10_op_handle_for_c2_op<Caffe2Operator>().schema();
18+
const size_t num_outputs = schema.returns().size();
19+
const size_t total_num_arguments = schema.arguments().size();
20+
const size_t num_inputs = total_num_arguments - num_outputs;
21+
22+
// TODO Avoid vector allocation. One idea would be to keep the std::vector instances in the cache.
23+
auto outputs = torch::jit::pop(*stack, num_outputs);
24+
auto inputs = torch::jit::pop(*stack, num_inputs);
25+
26+
const auto device = at::Device(at::DeviceType::CPU); // TODO Handle GPU devices
27+
28+
for (auto& output : outputs) {
29+
if (output.isNone() || (output.isTensor() && !output.toTensor().defined())) {
30+
output = at::Tensor(c10::C10Tensor(caffe2::empty({0}, device)));
31+
}
32+
}
33+
34+
std::vector<c10::IValue*> outputPtrs;
35+
outputPtrs.reserve(outputs.size());
36+
for (auto& output : outputs) {
37+
outputPtrs.push_back(&output);
38+
}
39+
40+
Caffe2Operator(schema, std::move(inputs), std::move(outputPtrs)).Run();
41+
42+
for (auto& output: outputs) {
43+
torch::jit::push(*stack, std::move(output));
44+
}
45+
46+
// postcondition: All inputs are cleared from the stack, there's now one
47+
// IValue for each output which holds the result. This
48+
// might reuse one of the preallocated tensors but doesn't have to.
49+
}
50+
51+
inline c10::FunctionSchema make_function_schema_for_c10(const char* OperatorName, std::vector<c10::Argument> inputs, std::vector<c10::Argument> outputs) {
52+
// actual_inputs is the real inputs plus an optional tensor argument for each output.
53+
// this can be used to pass in a preallocated output tensor.
54+
std::vector<c10::Argument> actual_inputs = std::move(inputs);
55+
actual_inputs.reserve(actual_inputs.size() + outputs.size());
56+
for (const auto& elem : outputs) {
57+
AT_ASSERT(elem.type()->isSubtypeOf(c10::TensorType::get()));
58+
actual_inputs.push_back(c10::Argument(elem.name(), c10::OptionalType::create(elem.type()), nullopt, IValue()));
59+
}
60+
61+
return c10::FunctionSchema(
62+
std::string("_caffe2::") + OperatorName,
63+
std::move(actual_inputs), std::move(outputs));
64+
}
65+
66+
}
67+
}
68+
69+
/**
70+
* Call this macro to register a caffe2 operator with the c10 dispatcher.
71+
*/
72+
// TODO This macro should take a JIT schema string instead of a vector of inputs and outputs.
73+
#define C10_REGISTER_CAFFE2_OPERATOR(OperatorName, Inputs, Outputs, OperatorClass) \
74+
/* Register the op schema with the c10 dispatcher */ \
75+
namespace caffe2 { \
76+
C10_DEFINE_OP_SCHEMA(OperatorName, \
77+
caffe2::detail::make_function_schema_for_c10( \
78+
#OperatorName, Inputs, Outputs)); \
79+
} \
80+
/* Store the c10 operator handle so call_caffe2_op_from_c10 can access it */ \
81+
namespace caffe2 { namespace detail { \
82+
template<> \
83+
const c10::OperatorHandle& c10_op_handle_for_c2_op<OperatorClass<caffe2::CPUContext>>() { \
84+
return caffe2::OperatorName(); \
85+
} \
86+
}} \
87+
/* Register call_caffe2_op_from_c10 as a kernel with the c10 dispatcher */ \
88+
namespace c10 { \
89+
C10_REGISTER_KERNEL(caffe2::OperatorName) \
90+
/*.withCache<Cache>()*/ \
91+
.kernel<&caffe2::detail::call_caffe2_op_from_c10<OperatorClass<caffe2::CPUContext>>>() \
92+
.dispatchKey(CPUTensorId()); \
93+
}

caffe2/core/operator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,4 +1356,6 @@ std::function<void(const OperatorDef&)> GetOperatorLogger();
13561356

13571357
} // namespace caffe2
13581358

1359+
#include "caffe2/core/c10_operator.h"
1360+
13591361
#endif // CAFFE2_CORE_OPERATOR_H_

caffe2/operators/layer_norm_op.cc

Lines changed: 13 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -182,99 +182,18 @@ to the end.)
182182
.Output(1, "mean", "Mean values for each feature vector")
183183
.Output(2, "stddev", "Standard deviations for each feature vector");
184184

185-
C10_DEFINE_OP_SCHEMA(LayerNorm, FunctionSchema(
186-
"caffe2::layer_norm_dont_use_this_op_yet",
187-
(std::vector<c10::Argument>{
188-
c10::Argument("input"),
189-
c10::Argument("axis", IntType::get()),
190-
c10::Argument("epsilon", FloatType::get()),
191-
c10::Argument("output", OptionalType::ofTensor(), nullopt, IValue()),
192-
c10::Argument("output_mean", OptionalType::ofTensor(), nullopt, IValue()),
193-
c10::Argument("output_stdev", OptionalType::ofTensor(), nullopt, IValue())
194-
}), (std::vector<c10::Argument>{
195-
c10::Argument("output"),
196-
c10::Argument("mean"),
197-
c10::Argument("stdev")
198-
})
199-
));
200-
201185
} // namespace caffe2
202186

203-
204-
// Register layer norm with c10
205-
namespace {
206-
struct Cache final : public c10::KernelCache {
207-
at::optional<at::Tensor> scale = at::nullopt;
208-
at::optional<at::Tensor> bias = at::nullopt;
209-
};
210-
211-
template <class DataType>
212-
void layer_norm_c10(c10::Stack* stack, c10::KernelCache* cache_) { // TODO Pass in correct cache type
213-
c10::ArrayRef<c10::IValue> inputs = torch::jit::peekSlice(*stack, 0, 3, 6);
214-
c10::ArrayRef<c10::IValue> outputs = torch::jit::peekSlice(*stack, 3, 3, 6);
215-
216-
217-
caffe2::Tensor X{inputs[0].toTensor()};
218-
int64_t axis = inputs[1].toInt();
219-
float epsilon = inputs[2].toDouble();
220-
221-
auto device = X.GetDevice();
222-
223-
caffe2::Tensor Y, mean, sig;
224-
if (outputs[0].isTensor()) {
225-
Y = caffe2::Tensor(std::move(torch::jit::peek(*stack, 0, 3)).toTensor());
226-
}
227-
if (outputs[1].isTensor()) {
228-
mean = caffe2::Tensor(std::move(torch::jit::peek(*stack, 1, 3)).toTensor());
229-
}
230-
if (outputs[2].isTensor()) {
231-
sig = caffe2::Tensor(std::move(torch::jit::peek(*stack, 2, 3)).toTensor());
232-
}
233-
if (!Y.defined()) {
234-
Y = caffe2::empty({0}, device);
235-
}
236-
if (!mean.defined()) {
237-
mean = caffe2::empty({0}, device);
238-
}
239-
if (!sig.defined()) {
240-
sig = caffe2::empty({0}, device);
241-
}
242-
243-
caffe2::CPUContext context;
244-
Cache* cache = static_cast<Cache*>(cache_);
245-
if (!cache->scale.has_value()) {
246-
cache->scale = at::Tensor(caffe2::empty({0}, at::dtype<float>()));
247-
}
248-
if (!cache->bias.has_value()) {
249-
cache->bias = at::Tensor(caffe2::empty({0}, at::dtype<float>()));
250-
}
251-
caffe2::Tensor scale(*cache->scale);
252-
caffe2::Tensor bias(*cache->bias);
253-
254-
const int canonical_axis = X.canonical_axis_index(axis);
255-
std::vector<int64_t> moments_dims(
256-
X.sizes().cbegin(), X.sizes().cbegin() + canonical_axis);
257-
moments_dims.push_back(1);
258-
mean.Resize(moments_dims);
259-
sig.Resize(moments_dims);
260-
caffe2::LayerNormOp<caffe2::CPUContext>::runLayerNorm<DataType>(
261-
X, &Y, &mean, &sig, canonical_axis, epsilon, &scale, &bias, static_cast<caffe2::CPUContext*>(&context)
262-
);
263-
264-
torch::jit::drop(*stack, 6);
265-
torch::jit::push(*stack,
266-
at::Tensor(std::move(Y)),
267-
at::Tensor(std::move(mean)),
268-
at::Tensor(std::move(sig))
269-
);
270-
271-
return;
272-
}
273-
274-
}
275-
namespace c10 {
276-
C10_REGISTER_KERNEL(caffe2::LayerNorm)
277-
.withCache<Cache>()
278-
.kernel<&layer_norm_c10<float>>()
279-
.dispatchKey(CPUTensorId());
280-
} // namespace c10
187+
C10_REGISTER_CAFFE2_OPERATOR(
188+
LayerNorm,
189+
(std::vector<c10::Argument>{
190+
c10::Argument("input"),
191+
c10::Argument("axis", c10::IntType::get()),
192+
c10::Argument("epsilon", c10::FloatType::get())
193+
}), (std::vector<c10::Argument>{
194+
c10::Argument("output"),
195+
c10::Argument("mean"),
196+
c10::Argument("stdev")
197+
}),
198+
caffe2::LayerNormOp
199+
)

caffe2/operators/layer_norm_op.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ class LayerNormOp final : public Operator<Context> {
1919
public:
2020
USE_OPERATOR_CONTEXT_FUNCTIONS;
2121

22-
LayerNormOp(const OperatorDef& operator_def, Workspace* ws)
23-
: Operator<Context>(operator_def, ws),
22+
template<class... Args>
23+
explicit LayerNormOp(Args&&... args)
24+
: Operator<Context>(std::forward<Args>(args)...),
2425
OP_SINGLE_ARG(int, "axis", axis_, 1),
2526
OP_SINGLE_ARG(float, "epsilon", epsilon_, 1e-5f) {}
2627

caffe2/python/operator_test/layer_norm_op_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def test_layer_norm_op_pytorch(self, X, gc, dc):
143143
epsilon = 1e-4
144144

145145
expected_norm, expected_mean, expected_stdev = _layer_norm_ref(axis, epsilon, X)
146-
actual_norm, actual_mean, actual_stdev = torch.ops.caffe2.layer_norm_dont_use_this_op_yet(torch.tensor(X), axis, epsilon)
146+
actual_norm, actual_mean, actual_stdev = torch.ops._caffe2.LayerNorm(torch.tensor(X), axis, epsilon)
147147

148148
torch.testing.assert_allclose(expected_norm, actual_norm)
149149
torch.testing.assert_allclose(expected_mean, actual_mean)
@@ -154,7 +154,7 @@ def test_layer_norm_op_jit(self, X, gc, dc):
154154
@torch.jit.script
155155
def jit_layer_norm(tensor, axis, epsilon):
156156
# type: (Tensor, int, float) -> Tuple[Tensor, Tensor, Tensor]
157-
norm, mean, stdev = torch.ops.caffe2.layer_norm_dont_use_this_op_yet(tensor, axis, epsilon)
157+
norm, mean, stdev = torch.ops._caffe2.LayerNorm(tensor, axis, epsilon)
158158
return norm, mean, stdev
159159

160160
axis = np.random.randint(0, len(X.shape))

test/test_torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10072,7 +10072,7 @@ def test_c10_layer_norm(self):
1007210072

1007310073
expected_norm = torch.nn.functional.layer_norm(X, X.size()[1:], eps=epsilon)
1007410074
actual_norm, actual_mean, actual_stdev = \
10075-
torch.ops.caffe2.layer_norm_dont_use_this_op_yet(torch.tensor(X), 1, epsilon)
10075+
torch.ops._caffe2.LayerNorm(torch.tensor(X), 1, epsilon)
1007610076
torch.testing.assert_allclose(expected_norm, actual_norm)
1007710077

1007810078
# Functions to test negative dimension wrapping

0 commit comments

Comments
 (0)