Skip to content

Commit 16468a9

Browse files
smessmerfacebook-github-bot
authored andcommitted
Automatically register c10 ops with JIT (pytorch#16534)
Summary: Pull Request resolved: pytorch#16534 All c10 ops from the c10 dispatcher are now automatically registered with JIT Reviewed By: dzhulgakov Differential Revision: D13869275 fbshipit-source-id: 5ab5dec5b983fe661f977f9d29d8036768cdcab6
1 parent e5e0bf4 commit 16468a9

File tree

10 files changed

+238
-103
lines changed

10 files changed

+238
-103
lines changed
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,89 @@
11
#include <ATen/core/dispatch/Dispatcher.h>
22

33
namespace c10 {
4+
5+
namespace detail {
6+
class RegistrationListenerList final {
7+
public:
8+
void addListener(std::unique_ptr<OpRegistrationListener> listener) {
9+
listeners_.push_back(std::move(listener));
10+
}
11+
12+
void callOnOperatorRegistered(const OperatorHandle& op) {
13+
for (auto& listener : listeners_) {
14+
listener->onOperatorRegistered(op);
15+
}
16+
}
17+
18+
void callOnOperatorDeregistered(const OperatorHandle& op) {
19+
for (auto& listener : listeners_) {
20+
listener->onOperatorDeregistered(op);
21+
}
22+
}
23+
private:
24+
std::vector<std::unique_ptr<OpRegistrationListener>> listeners_;
25+
};
26+
}
27+
28+
OpRegistrationListener::~OpRegistrationListener() {}
29+
30+
Dispatcher::Dispatcher()
31+
: operators_()
32+
, listeners_(guts::make_unique<detail::RegistrationListenerList>())
33+
, mutex_() {}
34+
35+
Dispatcher::~Dispatcher() {}
36+
437
C10_EXPORT Dispatcher& Dispatcher::singleton() {
538
static Dispatcher _singleton;
639
return _singleton;
740
}
41+
42+
OperatorHandle Dispatcher::registerSchema(FunctionSchema schema) {
43+
// we need a lock to avoid concurrent writes
44+
std::lock_guard<std::mutex> lock(mutex_);
45+
46+
operators_.emplace_back(std::move(schema));
47+
auto op = OperatorHandle(--operators_.end());
48+
49+
// note: call listeners *after* operator is added, i.e. dispatcher is already valid for new op
50+
listeners_->callOnOperatorRegistered(op);
51+
52+
return op;
53+
}
54+
55+
void Dispatcher::deregisterSchema(const OperatorHandle& op) {
56+
// we need a lock to avoid concurrent writes
57+
std::lock_guard<std::mutex> lock(mutex_);
58+
59+
if (!op.operatorDefIterator_->dispatchTable.isEmpty()) {
60+
AT_ERROR("Tried to deregister op schema that still has kernels registered");
61+
}
62+
63+
// note: call listeners *before* operator is removed, i.e. dispatcher is still valid for removed op
64+
listeners_->callOnOperatorDeregistered(op);
65+
66+
operators_.erase(op.operatorDefIterator_);
67+
}
68+
69+
void Dispatcher::registerKernel(const OperatorHandle& op, TensorTypeId dispatch_key, KernelFunction* kernel_func, KernelCacheCreatorFunction* cache_creator_func) {
70+
// note: this doesn't need the mutex because write operations on the list keep iterators intact.
71+
op.operatorDefIterator_->dispatchTable.registerKernel(std::move(dispatch_key), DispatchTableEntry{kernel_func, cache_creator_func});
72+
}
73+
74+
void Dispatcher::deregisterKernel(const OperatorHandle& op, TensorTypeId dispatch_key) {
75+
// note: this doesn't need the mutex because write operations on the list keep iterators intact.
76+
op.operatorDefIterator_->dispatchTable.deregisterKernel(dispatch_key);
77+
}
78+
79+
void Dispatcher::addRegistrationListener(std::unique_ptr<OpRegistrationListener> listener) {
80+
std::lock_guard<std::mutex> lock(mutex_);
81+
82+
for (auto iter = operators_.begin(); iter != operators_.end(); ++iter) {
83+
listener->onOperatorRegistered(OperatorHandle(iter));
84+
}
85+
86+
listeners_->addListener(std::move(listener));
87+
}
88+
889
}

aten/src/ATen/core/dispatch/Dispatcher.h

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,23 @@ class CAFFE2_API OpKernel final {
5151
mutable std::unique_ptr<c10::KernelCache> cache_;
5252
};
5353

54+
/**
55+
* Implement this interface and register your instance with the dispatcher
56+
* to get notified when operators are registered or deregistered with
57+
* the dispatcher.
58+
*/
59+
class CAFFE2_API OpRegistrationListener {
60+
public:
61+
virtual ~OpRegistrationListener();
62+
63+
virtual void onOperatorRegistered(const OperatorHandle& op) = 0;
64+
virtual void onOperatorDeregistered(const OperatorHandle& op) = 0;
65+
};
66+
67+
namespace detail {
68+
class RegistrationListenerList;
69+
}
70+
5471
/**
5572
* Top-level dispatch interface for dispatching via the dynamic dispatcher.
5673
*/
@@ -67,6 +84,8 @@ class CAFFE2_API Dispatcher final {
6784
friend class OperatorHandle;
6885

6986
public:
87+
~Dispatcher();
88+
7089
// Implementation note: this class abstracts over the fact that we have per-operator
7190
// dispatch tables. This could be easily adjusted to have a single global hash
7291
// table.
@@ -100,8 +119,19 @@ class CAFFE2_API Dispatcher final {
100119
*/
101120
OpKernel lookup(const OperatorHandle& op, const Stack* stack) const;
102121

122+
/**
123+
* Add a listener that gets called whenever a new op is registered or an existing
124+
* op is deregistered. Immediately after registering, this listener gets called
125+
* for all previously registered ops, so it can be used to keep track of ops
126+
* registered with this dispatcher.
127+
*/
128+
void addRegistrationListener(std::unique_ptr<OpRegistrationListener> listener);
129+
103130
private:
131+
Dispatcher();
132+
104133
std::list<OperatorDef> operators_;
134+
std::unique_ptr<detail::RegistrationListenerList> listeners_;
105135
std::mutex mutex_;
106136
};
107137

@@ -130,35 +160,6 @@ class CAFFE2_API OperatorHandle final {
130160
};
131161

132162

133-
134-
inline OperatorHandle Dispatcher::registerSchema(FunctionSchema schema) {
135-
// we need a lock to avoid concurrent writes
136-
std::lock_guard<std::mutex> lock(mutex_);
137-
138-
operators_.emplace_back(std::move(schema));
139-
return OperatorHandle(--operators_.end());
140-
}
141-
142-
inline void Dispatcher::deregisterSchema(const OperatorHandle& op) {
143-
// we need a lock to avoid concurrent writes
144-
std::lock_guard<std::mutex> lock(mutex_);
145-
146-
if (!op.operatorDefIterator_->dispatchTable.isEmpty()) {
147-
AT_ERROR("Tried to deregister op schema that still has kernels registered");
148-
}
149-
operators_.erase(op.operatorDefIterator_);
150-
}
151-
152-
inline void Dispatcher::registerKernel(const OperatorHandle& op, TensorTypeId dispatch_key, KernelFunction* kernel_func, KernelCacheCreatorFunction* cache_creator_func) {
153-
// note: this doesn't need the mutex because write operations on the list keep iterators intact.
154-
op.operatorDefIterator_->dispatchTable.registerKernel(std::move(dispatch_key), DispatchTableEntry{kernel_func, cache_creator_func});
155-
}
156-
157-
inline void Dispatcher::deregisterKernel(const OperatorHandle& op, TensorTypeId dispatch_key) {
158-
// note: this doesn't need the mutex because write operations on the list keep iterators intact.
159-
op.operatorDefIterator_->dispatchTable.deregisterKernel(dispatch_key);
160-
}
161-
162163
inline OpKernel Dispatcher::lookup(const OperatorHandle& op, const Stack* stack) const {
163164
// note: this doesn't need the mutex because write operations on the list keep iterators intact.
164165
const DispatchTableEntry& kernel = op.operatorDefIterator_->dispatchTable.lookup(stack);

aten/src/ATen/core/jit_type.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ struct CAFFE2_API ListType : public SingleElementType<TypeKind::ListType, ListTy
514514

515515
struct DictType;
516516
using DictTypePtr = std::shared_ptr<DictType>;
517-
struct DictType : public Type {
517+
struct CAFFE2_API DictType : public Type {
518518
friend struct Type;
519519
static const TypeKind Kind = TypeKind::DictType;
520520

c10/util/flat_hash_map.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,7 @@ class sherwood_v3_table : private EntryAlloc, private Hasher, private Equal
924924
return static_cast<Equal &>(*this)(lhs, rhs);
925925
}
926926

927+
private:
927928
struct convertible_to_iterator
928929
{
929930
EntryPointer it;

caffe2/operators/layer_norm_op.cc

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -196,12 +196,32 @@ void layer_norm_c10(c10::Stack* stack, c10::KernelCache* cache_) { // TODO Pass
196196
c10::ArrayRef<c10::IValue> inputs = torch::jit::peekSlice(*stack, 0, 3, 6);
197197
c10::ArrayRef<c10::IValue> outputs = torch::jit::peekSlice(*stack, 3, 3, 6);
198198

199-
caffe2::Tensor X{c10::C10Tensor(inputs[0].toTensor())};
199+
200+
caffe2::Tensor X{inputs[0].toTensor()};
200201
int64_t axis = inputs[1].toInt();
201202
float epsilon = inputs[2].toDouble();
202-
caffe2::Tensor Y{c10::C10Tensor(outputs[0].toTensor())};
203-
caffe2::Tensor mean{c10::C10Tensor(outputs[1].toTensor())};
204-
caffe2::Tensor sig{c10::C10Tensor(outputs[2].toTensor())};
203+
204+
auto device = X.GetDevice();
205+
206+
caffe2::Tensor Y, mean, sig;
207+
if (outputs[0].isTensor()) {
208+
Y = caffe2::Tensor(std::move(torch::jit::peek(*stack, 0, 3)).toTensor());
209+
}
210+
if (outputs[1].isTensor()) {
211+
mean = caffe2::Tensor(std::move(torch::jit::peek(*stack, 1, 3)).toTensor());
212+
}
213+
if (outputs[2].isTensor()) {
214+
sig = caffe2::Tensor(std::move(torch::jit::peek(*stack, 2, 3)).toTensor());
215+
}
216+
if (!Y.defined()) {
217+
Y = caffe2::empty({0}, device);
218+
}
219+
if (!mean.defined()) {
220+
mean = caffe2::empty({0}, device);
221+
}
222+
if (!sig.defined()) {
223+
sig = caffe2::empty({0}, device);
224+
}
205225

206226
caffe2::CPUContext context;
207227
Cache* cache = static_cast<Cache*>(cache_);
@@ -226,9 +246,9 @@ void layer_norm_c10(c10::Stack* stack, c10::KernelCache* cache_) { // TODO Pass
226246

227247
torch::jit::drop(*stack, 6);
228248
torch::jit::push(*stack,
229-
at::Tensor(c10::C10Tensor(std::move(Y))),
230-
at::Tensor(c10::C10Tensor(std::move(mean))),
231-
at::Tensor(c10::C10Tensor(std::move(sig)))
249+
at::Tensor(std::move(Y)),
250+
at::Tensor(std::move(mean)),
251+
at::Tensor(std::move(sig))
232252
);
233253

234254
return;

tools/build_variables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
"torch/csrc/jit/ir.cpp",
6060
"torch/csrc/jit/caffe2_operator.cpp",
6161
"torch/csrc/jit/register_caffe2_ops.cpp",
62+
"torch/csrc/jit/register_c10_ops.cpp",
6263
"torch/csrc/jit/symbolic_script.cpp",
6364
"torch/csrc/jit/operator.cpp",
6465
"torch/csrc/jit/passes/alias_analysis.cpp",
@@ -101,7 +102,6 @@
101102
"torch/csrc/jit/script/lexer.cpp",
102103
"torch/csrc/jit/script/module.cpp",
103104
"torch/csrc/jit/tracer.cpp",
104-
"torch/csrc/jit/c10_ops/layer_norm.cpp",
105105
"torch/csrc/utils/tensor_flatten.cpp",
106106
"torch/csrc/utils/variadic.cpp",
107107
"torch/csrc/jit/fuser/kernel_cache.cpp",

torch/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ set(TORCH_SRCS
135135
${TORCH_SRC_DIR}/csrc/jit/ir.cpp
136136
${TORCH_SRC_DIR}/csrc/jit/operator.cpp
137137
${TORCH_SRC_DIR}/csrc/jit/caffe2_operator.cpp
138+
${TORCH_SRC_DIR}/csrc/jit/register_c10_ops.cpp
138139
${TORCH_SRC_DIR}/csrc/jit/symbolic_script.cpp
139140
${TORCH_SRC_DIR}/csrc/jit/passes/alias_analysis.cpp
140141
${TORCH_SRC_DIR}/csrc/jit/passes/batch_mm.cpp
@@ -179,7 +180,6 @@ set(TORCH_SRCS
179180
${TORCH_SRC_DIR}/csrc/jit/script/module.cpp
180181
${TORCH_SRC_DIR}/csrc/jit/tracer.cpp
181182
${TORCH_SRC_DIR}/csrc/jit/hooks_for_testing.cpp
182-
${TORCH_SRC_DIR}/csrc/jit/c10_ops/layer_norm.cpp
183183
${TORCH_SRC_DIR}/csrc/utils/tensor_flatten.cpp
184184
${TORCH_SRC_DIR}/csrc/utils/variadic.cpp
185185
${TORCH_SRC_DIR}/csrc/jit/fuser/kernel_cache.cpp

torch/csrc/jit/c10_ops/layer_norm.cpp

Lines changed: 0 additions & 64 deletions
This file was deleted.

torch/csrc/jit/operator.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,36 @@ TORCH_API FunctionSchema parseSchema(const std::string& schema);
2727

2828
using OperationCreator = std::function<Operation(const Node*)>;
2929

30+
/*
31+
* Note: JIT relies on Operator instances having static lifetime, because
32+
* it for example stores a non-owning FunctionSchema* pointer in the Node class,
33+
* which points to the function shema stored in the Operator instance.
34+
* Also, jit::Operator is meant to store more operator related information like
35+
* symbolic derivatives, which also requires them to have static lifetime
36+
* so that changes to symbolic derivatives are remembered.
37+
*
38+
* Now, currently, the c10 operator library doesn't store jit::Operator instances,
39+
* but we use a listener pattern that notifies JIT about changes in the
40+
* c10 operator library and then registers jit::Operator instances to the JIT
41+
* operator registry, acting as wrappers to the c10 operators.
42+
*
43+
* However, that results in code duplication as JIT and c10 will likely get
44+
* their own mechanisms for storing derivatives and other operator related
45+
* information, and all of this would have to be wrapped from c10 into JIT.
46+
*
47+
* We should consider merging the JIT and c10 registries, moving jit::Operator
48+
* to c10 and storing these jit::Operator instances in the c10 operator library
49+
* instead, allowing us to have these mechanisms only implemented once.
50+
* However, the current jit::Operator implementation has additional features
51+
* like OperationCreator that aren't needed in c10 (they're only used for
52+
* prim ops like If/Else or While which wouldn't be in the c10 operator library),
53+
* and which depend on other JIT features which we don't want to move to c10
54+
* (notably jit/ir.h). We might, however, be able, to split jit::Operator into
55+
* a c10::Operator with the core features and a jit::Operator that adds the
56+
* JIT-only features like OperationCreator, and then use c10::Operator in the
57+
* c10 operator library.
58+
*/
59+
3060
struct TORCH_API Operator {
3161
Operator(FunctionSchema schema, OperationCreator op_creator)
3262
: schema_(std::make_shared<FunctionSchema>(std::move(schema))),

0 commit comments

Comments
 (0)