@@ -12,8 +12,9 @@ class ChannelShuffleOp final : public Operator<Context> {
12
12
public:
13
13
USE_OPERATOR_CONTEXT_FUNCTIONS;
14
14
15
- ChannelShuffleOp (const OperatorDef& operator_def, Workspace* ws)
16
- : Operator<Context>(operator_def, ws),
15
+ template <class ... Args>
16
+ explicit ChannelShuffleOp (Args&&... args)
17
+ : Operator<Context>(std::forward<Args>(args)...),
17
18
order_(StringToStorageOrder(
18
19
this ->template GetSingleArgument<std::string>(" order" , " NCHW" ))),
19
20
OP_SINGLE_ARG(int , " group" , group_, 1 ) {
@@ -39,8 +40,9 @@ class ChannelShuffleGradientOp final : public Operator<Context> {
39
40
public:
40
41
USE_OPERATOR_CONTEXT_FUNCTIONS;
41
42
42
- ChannelShuffleGradientOp (const OperatorDef& operator_def, Workspace* ws)
43
- : Operator<Context>(operator_def, ws),
43
+ template <class ... Args>
44
+ explicit ChannelShuffleGradientOp (Args&&... args)
45
+ : Operator<Context>(std::forward<Args>(args)...),
44
46
order_(StringToStorageOrder(
45
47
this ->template GetSingleArgument<std::string>(" order" , " NCHW" ))),
46
48
OP_SINGLE_ARG(int , " group" , group_, 1 ) {
0 commit comments