Skip to content

Commit b2135b2

Browse files
smessmerfacebook-github-bot
authored andcommitted
Define layer_norm schema in caffe2 (pytorch#16535)
Summary: Pull Request resolved: pytorch#16535 There is now no need anymore to define the layer norm schema in a central location. It can just be defined in caffe2 next to the kernel implementation. Reviewed By: ezyang Differential Revision: D13869503 fbshipit-source-id: c478153f8fd712ff6d507c794500286eb3583149
1 parent 16468a9 commit b2135b2

File tree

8 files changed

+33
-43
lines changed

8 files changed

+33
-43
lines changed

aten/src/ATen/core/dispatch/DispatchTable.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ class ThreadsafeOperatorTable_ final {
9999
}
100100

101101
static std::string dispatch_key_to_string(TensorTypeId id) {
102-
return std::string(toString(tensorTypeIdToBackend(id))) + "[" + guts::to_string(id) + "]";
102+
return std::string(toString(tensorTypeIdToBackend(id))) + "[" + toString(id) + "]";
103103
}
104104

105105
LeftRight<ska::flat_hash_map<TensorTypeId, DispatchTableEntry>> map_;

aten/src/ATen/core/opschema/layer_norm.cpp

-25
This file was deleted.

aten/src/ATen/core/opschema/layer_norm.h

-13
This file was deleted.

c10/core/TensorTypeId.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,8 @@ std::ostream& operator<<(std::ostream& str, c10::TensorTypeId rhs) {
77
return str << c10::to_string(rhs.underlyingId());
88
}
99

10+
std::string toString(TensorTypeId id) {
11+
return c10::to_string(id.underlyingId());
12+
}
13+
1014
} // namespace c10

c10/core/TensorTypeId.h

+3
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,11 @@ class C10_API TensorTypeId final
3232

3333
friend class TensorTypeIdCreator;
3434
friend C10_API std::ostream& operator<<(std::ostream&, TensorTypeId);
35+
friend C10_API std::string toString(TensorTypeId);
3536
};
3637

38+
std::string toString(TensorTypeId id);
39+
3740
C10_API std::ostream& operator<<(std::ostream&, c10::TensorTypeId);
3841

3942
} // namespace c10

caffe2/operators/experimental/c10/schemas/layer_norm.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include <ATen/core/opschema/layer_norm.h>
1+
#include "caffe2/operators/layer_norm_op.h"
22
#include "caffe2/core/operator_c10wrapper.h"
33

44
namespace {
@@ -26,7 +26,7 @@ struct EpsilonParameter final {
2626

2727
namespace caffe2 {
2828
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS(
29-
c10::core::opschema::LayerNorm,
29+
caffe2::LayerNorm,
3030
C10LayerNorm_DontUseThisOpYet,
3131
3,
3232
ParameterHelper<AxisParameter>,

caffe2/operators/layer_norm_op.cc

+20-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
#include "caffe2/operators/layer_norm_op.h"
22
#include "caffe2/utils/eigen_utils.h"
3-
#include <ATen/core/opschema/layer_norm.h>
3+
#include "caffe2/core/operator_c10wrapper.h"
44
#include <ATen/core/dispatch/KernelRegistration.h>
55
#include <c10/core/Tensor.h>
6+
#include <ATen/core/dispatch/OpSchemaRegistration.h>
67

78
namespace caffe2 {
89

@@ -181,6 +182,22 @@ to the end.)
181182
.Output(1, "mean", "Mean values for each feature vector")
182183
.Output(2, "stddev", "Standard deviations for each feature vector");
183184

185+
C10_DEFINE_OP_SCHEMA(LayerNorm, FunctionSchema(
186+
"caffe2::layer_norm_dont_use_this_op_yet",
187+
(std::vector<c10::Argument>{
188+
c10::Argument("input"),
189+
c10::Argument("axis", IntType::get()),
190+
c10::Argument("epsilon", FloatType::get()),
191+
c10::Argument("output", OptionalType::ofTensor(), nullopt, IValue()),
192+
c10::Argument("output_mean", OptionalType::ofTensor(), nullopt, IValue()),
193+
c10::Argument("output_stdev", OptionalType::ofTensor(), nullopt, IValue())
194+
}), (std::vector<c10::Argument>{
195+
c10::Argument("output"),
196+
c10::Argument("mean"),
197+
c10::Argument("stdev")
198+
})
199+
));
200+
184201
} // namespace caffe2
185202

186203

@@ -253,9 +270,10 @@ void layer_norm_c10(c10::Stack* stack, c10::KernelCache* cache_) { // TODO Pass
253270

254271
return;
255272
}
273+
256274
}
257275
namespace c10 {
258-
C10_REGISTER_KERNEL(c10::core::opschema::LayerNorm)
276+
C10_REGISTER_KERNEL(caffe2::LayerNorm)
259277
.withCache<Cache>()
260278
.kernel<&layer_norm_c10<float>>()
261279
.dispatchKey(CPUTensorId());

caffe2/operators/layer_norm_op.h

+3
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@
88
#include "caffe2/core/operator.h"
99
#include "caffe2/core/types.h"
1010
#include "caffe2/utils/math.h"
11+
#include <ATen/core/dispatch/OpSchemaRegistration.h>
1112

1213
namespace caffe2 {
1314

15+
C10_DECLARE_OP_SCHEMA(LayerNorm);
16+
1417
template <class Context>
1518
class LayerNormOp final : public Operator<Context> {
1619
public:

0 commit comments

Comments
 (0)