Skip to content

Commit 7a3488e

Browse files
smessmerfacebook-github-bot
authored andcommitted
Expose c10 cuda ops to caffe2 (pytorch#18036)
Summary: Pull Request resolved: pytorch#18036 - Add macros to export c10 cuda operators to caffe2 frontend - Instead of having a separate caffe2 registry for the c10 operator wrappers, use the existing caffe2 registries Reviewed By: ezyang Differential Revision: D14467495 fbshipit-source-id: 7715ed2e38d2bbe16f1446ae82c17193a3fabcb9
1 parent cb2ea17 commit 7a3488e

24 files changed

+62
-71
lines changed

caffe2/core/operator.cc

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ GlobalEnginePrefType& g_global_engine_pref() {
119119
return *g_global_engine_pref_;
120120
}
121121

122-
unique_ptr<OperatorBase> TryCreateC2Operator(
122+
unique_ptr<OperatorBase> TryCreateOperator(
123123
const string& key,
124124
const OperatorDef& operator_def,
125125
Workspace* ws) {
@@ -143,24 +143,6 @@ unique_ptr<OperatorBase> TryCreateC2Operator(
143143
}
144144
}
145145

146-
unique_ptr<OperatorBase> TryCreateC10Operator(
147-
const string& key,
148-
const OperatorDef& operator_def,
149-
Workspace* ws) {
150-
return C10OperatorRegistry()->Create(key, operator_def, ws);
151-
}
152-
153-
unique_ptr<OperatorBase> TryCreateOperator(
154-
const string& key,
155-
const OperatorDef& operator_def,
156-
Workspace* ws) {
157-
if (auto op = TryCreateC10Operator(key, operator_def, ws)) {
158-
return op;
159-
} else {
160-
return TryCreateC2Operator(key, operator_def, ws);
161-
}
162-
}
163-
164146
unique_ptr<OperatorBase> _CreateOperator(
165147
const OperatorDef& operator_def,
166148
Workspace* ws) {
@@ -726,11 +708,6 @@ std::set<std::string> GetRegisteredOperators() {
726708
all_keys.emplace(name);
727709
}
728710

729-
// C10 operators
730-
for (const auto& name : C10OperatorRegistry()->Keys()) {
731-
all_keys.emplace(name);
732-
}
733-
734711
return all_keys;
735712
}
736713

caffe2/core/operator_c10wrapper.h

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -172,14 +172,14 @@ class C10OperatorWrapper final : public Operator<Context> {
172172
const std::string& name,
173173
const c10::optional<IValue>& default_value) {
174174
if (default_value.has_value()) {
175-
return OperatorBase::GetSingleArgument<T>(name, default_value->to<T>());
175+
return this->template GetSingleArgument<T>(name, default_value->to<T>());
176176
} else {
177177
AT_CHECK(
178-
OperatorBase::HasSingleArgumentOfType<T>(name),
178+
this->template HasSingleArgumentOfType<T>(name),
179179
"Error in caffe2->c10 wrapper: Expected argument '",
180180
name,
181181
"' missing or wrong type.");
182-
return OperatorBase::GetSingleArgument<T>(name, 0);
182+
return this->template GetSingleArgument<T>(name, 0);
183183
}
184184
}
185185

@@ -211,22 +211,22 @@ createC10OperatorWrapper(const c10::OperatorHandle& op_handle) {
211211

212212
} // namespace detail
213213

214-
C10_DECLARE_REGISTRY(
215-
C10OperatorRegistry,
216-
OperatorBase,
217-
const OperatorDef&,
218-
Workspace*);
219-
220214
// TODO Also register c10 operators on mobile
221215
#ifndef C10_MOBILE
222216
// TODO Currently we only register the CPU variant. This is going to be fixed
223217
// once the tensor detemplatization lands.
224-
#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(OperatorHandle, Name) \
225-
C10_REGISTER_CREATOR( \
226-
C10OperatorRegistry, \
227-
Name, \
228-
detail::createC10OperatorWrapper<CPUContext>(OperatorHandle))
218+
#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(OperatorHandle, Name) \
219+
REGISTER_CPU_OPERATOR_CREATOR( \
220+
Name, detail::createC10OperatorWrapper<CPUContext>(OperatorHandle))
221+
#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CUDA(OperatorHandle, Name) \
222+
REGISTER_CUDA_OPERATOR_CREATOR( \
223+
Name, detail::createC10OperatorWrapper<CUDAContext>(OperatorHandle))
224+
#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_HIP(OperatorHandle, Name) \
225+
REGISTER_HIP_OPERATOR_CREATOR( \
226+
Name, detail::createC10OperatorWrapper<HIPContext>(OperatorHandle))
229227
#else
230-
#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(OperatorHandle, Name)
228+
#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(OperatorHandle, Name)
229+
#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CUDA(OperatorHandle, Name)
230+
#define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_HIP(OperatorHandle, Name)
231231
#endif
232232
} // namespace caffe2

caffe2/operators/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c1
6060
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/enforce_finite_cpu.cc)
6161
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/cpu/add_cpu.cc)
6262
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/schemas/sigmoid.cc)
63-
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/schemas/layer_norm.cc)
6463
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/schemas/filler.cc)
6564
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/schemas/expand_dims.cc)
6665
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${CMAKE_CURRENT_LIST_DIR}/experimental/c10/schemas/mul.cc)

caffe2/operators/experimental/c10/schemas/add.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,7 @@ C10_DEFINE_OP_SCHEMA(Add, FunctionSchema(
2121
}
2222

2323
namespace caffe2 {
24-
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(ops::Add(), C10Add_DontUseThisOpYet)
24+
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
25+
ops::Add(),
26+
C10Add_DontUseThisOpYet)
2527
}

caffe2/operators/experimental/c10/schemas/averaged_loss.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ C10_DEFINE_OP_SCHEMA(AveragedLoss, FunctionSchema(
1919
}
2020

2121
namespace caffe2 {
22-
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(
22+
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
2323
ops::AveragedLoss(),
2424
C10AveragedLoss_DontUseThisOpYet)
2525
}

caffe2/operators/experimental/c10/schemas/batch_gather.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ C10_DEFINE_OP_SCHEMA(BatchGather, FunctionSchema(
2020
}
2121

2222
namespace caffe2 {
23-
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(
23+
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
2424
ops::BatchGather(),
2525
C10BatchGather_DontUseThisOpYet)
2626
}

caffe2/operators/experimental/c10/schemas/batch_matmul.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ C10_DEFINE_OP_SCHEMA(BatchMatmul, FunctionSchema(
2424

2525
namespace caffe2 {
2626

27-
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(
27+
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
2828
ops::BatchMatmul(),
2929
C10BatchMatMul_DontUseThisOpYet)
3030
}

caffe2/operators/experimental/c10/schemas/cast.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,7 @@ C10_DEFINE_OP_SCHEMA(Cast, FunctionSchema(
2121
}
2222

2323
namespace caffe2 {
24-
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(ops::Cast(), C10Cast_DontUseThisOpYet)
24+
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
25+
ops::Cast(),
26+
C10Cast_DontUseThisOpYet)
2527
}

caffe2/operators/experimental/c10/schemas/concat.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ C10_DEFINE_OP_SCHEMA(Concat, FunctionSchema(
2222
}
2323

2424
namespace caffe2 {
25-
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(
25+
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
2626
ops::Concat(),
2727
C10Concat_DontUseThisOpYet)
2828
}

caffe2/operators/experimental/c10/schemas/enforce_finite.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ C10_DEFINE_OP_SCHEMA(EnforceFinite, FunctionSchema(
1818
}
1919

2020
namespace caffe2 {
21-
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH(
21+
REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
2222
ops::EnforceFinite(),
2323
C10EnforceFinite_DontUseThisOpYet)
2424
}

0 commit comments

Comments
 (0)