Skip to content

Commit 64339db

Browse files
smessmerfacebook-github-bot
authored andcommitted
Fix and re-enable test case (pytorch#16643)
Summary: Pull Request resolved: pytorch#16643 The test was disabled in D13908117 because it conflicted with another diff that was about to land. Now fixed the merge conflict and re-landing it. Reviewed By: ezyang Differential Revision: D13911775 fbshipit-source-id: b790f1c3a3f207916eea41ac93bc104d011f629b
1 parent 6750e1e commit 64339db

File tree

8 files changed

+54
-11
lines changed

8 files changed

+54
-11
lines changed

caffe2/core/blob.h

+8
Original file line numberDiff line numberDiff line change
@@ -116,5 +116,13 @@ inline const Tensor& BlobGetTensor(const Blob& blob, DeviceType device_type) {
116116
CAFFE_THROW("Blob didn't contain a Tensor or the device_type doesn't match");
117117
}
118118

119+
inline Tensor BlobGetTensorOrUndefined(const Blob& blob) {
120+
if (blob.IsType<Tensor>()) {
121+
return blob.Get<Tensor>().UnsafeSharedInstance();
122+
} else {
123+
return Tensor();
124+
}
125+
}
126+
119127
} // namespace caffe2
120128
#endif // CAFFE2_CORE_BLOB_H_

caffe2/core/c10_operator.h

+10-5
Original file line numberDiff line numberDiff line change
@@ -66,27 +66,32 @@ inline c10::FunctionSchema make_function_schema_for_c10(const char* OperatorName
6666
}
6767
}
6868

69+
#define C10_DECLARE_CAFFE2_OPERATOR(OperatorName) \
70+
namespace caffe2 { namespace _c10_ops { \
71+
C10_DECLARE_OP_SCHEMA(OperatorName); \
72+
}}
73+
6974
/**
7075
* Call this macro to register a caffe2 operator with the c10 dispatcher.
7176
*/
7277
// TODO This macro should take a JIT schema string instead of a vector of inputs and outputs.
73-
#define C10_REGISTER_CAFFE2_OPERATOR(OperatorName, Inputs, Outputs, OperatorClass) \
78+
#define C10_REGISTER_CAFFE2_OPERATOR_CPU(OperatorName, Inputs, Outputs, OperatorClass) \
7479
/* Register the op schema with the c10 dispatcher */ \
75-
namespace caffe2 { \
80+
namespace caffe2 { namespace _c10_ops { \
7681
C10_DEFINE_OP_SCHEMA(OperatorName, \
7782
caffe2::detail::make_function_schema_for_c10( \
7883
#OperatorName, Inputs, Outputs)); \
7984
} \
8085
/* Store the c10 operator handle so call_caffe2_op_from_c10 can access it */ \
81-
namespace caffe2 { namespace detail { \
86+
namespace detail { \
8287
template<> \
8388
const c10::OperatorHandle& c10_op_handle_for_c2_op<OperatorClass<caffe2::CPUContext>>() { \
84-
return caffe2::OperatorName(); \
89+
return caffe2::_c10_ops::OperatorName(); \
8590
} \
8691
}} \
8792
/* Register call_caffe2_op_from_c10 as a kernel with the c10 dispatcher */ \
8893
namespace c10 { \
89-
C10_REGISTER_KERNEL(caffe2::OperatorName) \
94+
C10_REGISTER_KERNEL(caffe2::_c10_ops::OperatorName) \
9095
/*.withCache<Cache>()*/ \
9196
.kernel<&caffe2::detail::call_caffe2_op_from_c10<OperatorClass<caffe2::CPUContext>>>() \
9297
.dispatchKey(CPUTensorId()); \

caffe2/core/operator.h

+4
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,10 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
213213
BlobSetTensor(outputs_.at(idx), std::move(tensor));
214214
}
215215

216+
Tensor OutputTensorOrUndefined(int idx) {
217+
return BlobGetTensorOrUndefined(*outputs_.at(idx));
218+
}
219+
216220
inline Tensor*
217221
OutputTensor(int idx, at::IntArrayRef dims, at::TensorOptions options) {
218222
if (isLegacyOperator()) {

caffe2/core/operator_c10wrapper.h

+6-1
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,12 @@ class C10OperatorWrapper final : public Operator<Context> {
117117

118118
void pushOutputParameters_() {
119119
for (size_t i = 0; i < num_output_parameters; ++i) {
120-
stack_.emplace_back(at::Tensor(C10Tensor(*Output(i))));
120+
caffe2::Tensor preallocated_output_tensor = OperatorBase::OutputTensorOrUndefined(i);
121+
if (preallocated_output_tensor.defined()) {
122+
stack_.emplace_back(at::Tensor(std::move(preallocated_output_tensor)));
123+
} else {
124+
stack_.emplace_back(IValue());
125+
}
121126
}
122127
}
123128

caffe2/operators/experimental/c10/schemas/layer_norm.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ struct EpsilonParameter final {
2626

2727
namespace caffe2 {
2828
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_WITH_PARAMETERS(
29-
caffe2::LayerNorm,
29+
caffe2::_c10_ops::LayerNorm,
3030
C10LayerNorm_DontUseThisOpYet,
3131
3,
3232
ParameterHelper<AxisParameter>,

caffe2/operators/layer_norm_op.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ to the end.)
184184

185185
} // namespace caffe2
186186

187-
C10_REGISTER_CAFFE2_OPERATOR(
187+
C10_REGISTER_CAFFE2_OPERATOR_CPU(
188188
LayerNorm,
189189
(std::vector<c10::Argument>{
190190
c10::Argument("input"),

caffe2/operators/layer_norm_op.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
#include "caffe2/utils/math.h"
1111
#include <ATen/core/dispatch/OpSchemaRegistration.h>
1212

13-
namespace caffe2 {
13+
C10_DECLARE_CAFFE2_OPERATOR(LayerNorm)
1414

15-
C10_DECLARE_OP_SCHEMA(LayerNorm);
15+
namespace caffe2 {
1616

1717
template <class Context>
1818
class LayerNormOp final : public Operator<Context> {

caffe2/python/operator_test/layer_norm_op_test.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ def test_layer_norm_op(self, X, gc, dc):
112112
)
113113

114114
@given(X=hu.tensor(min_dim=2), **hu.gcs_cpu_only)
115-
@unittest.skip("Tensor interop enforcement needs fixing")
116115
def test_layer_norm_op_c10(self, X, gc, dc):
117116
axis = np.random.randint(0, len(X.shape))
118117
epsilon = 1e-4
@@ -137,6 +136,28 @@ def test_layer_norm_op_c10(self, X, gc, dc):
137136
outputs_to_check=[0, 1, 2],
138137
)
139138

139+
@given(X=hu.tensor(min_dim=2), **hu.gcs_cpu_only)
140+
def test_layer_norm_op_c10_preallocated_outputs(self, X, gc, dc):
141+
# This test case ensures that it works correctly when output tensors are preallocated.
142+
axis = np.random.randint(0, len(X.shape))
143+
epsilon = 1e-4
144+
self.ws.create_blob('input').feed(X)
145+
m = ModelHelper(name="test")
146+
m.net.C10LayerNorm_DontUseThisOpYet(["input"], ["output", "mean", "stdev"], axis=axis, epsilon=epsilon)
147+
self.ws.create_net(m.param_init_net).run()
148+
net = self.ws.create_net(m.net)
149+
net.run()
150+
net.run() # run two times to be extra sure that the outputs are preallocated
151+
152+
expected_norm, expected_mean, expected_stdev = _layer_norm_ref(axis, epsilon, X)
153+
actual_norm = self.ws.fetch_blob('output')
154+
actual_mean = self.ws.fetch_blob('mean')
155+
actual_stdev = self.ws.fetch_blob('stdev')
156+
157+
torch.testing.assert_allclose(expected_norm, actual_norm)
158+
torch.testing.assert_allclose(expected_mean, actual_mean)
159+
torch.testing.assert_allclose(expected_stdev, actual_stdev)
160+
140161
@given(X=hu.tensor(min_dim=2), **hu.gcs)
141162
def test_layer_norm_op_pytorch(self, X, gc, dc):
142163
axis = np.random.randint(0, len(X.shape))

0 commit comments

Comments
 (0)