1
+ #include < ATen/ScalarOps.h>
1
2
#include < torch/csrc/jit/mobile/promoted_prim_ops.h>
2
-
3
3
namespace torch {
4
4
namespace jit {
5
+
5
6
void tupleIndex (Stack& stack) {
6
7
int64_t index = pop (stack).toInt ();
7
8
auto tuple = pop (stack).toTuple ();
@@ -14,9 +15,23 @@ void tupleIndex(Stack& stack) {
14
15
}
15
16
16
17
void raiseException (Stack& stack) {
18
+ // this kernel supports RaiseException with only one argument: the error
19
+ // DEPRECATED from bytecode_version 8;
20
+ // Please do not make any changes to this to support BC
17
21
throw JITException (pop (stack).toStringRef ());
18
22
}
19
23
24
+ void raiseExceptionWithMessage (Stack& stack) {
25
+ // this kernel supports RaiseException with only two arguments: the error and
26
+ // the message Please make changes only to this kernel
27
+ c10::optional<std::string> qualified_class_name =
28
+ pop (stack).toOptional <std::string>();
29
+ std::string message;
30
+ pop (stack, message);
31
+
32
+ throw JITException (message, qualified_class_name);
33
+ }
34
+
20
35
void is (Stack& stack) {
21
36
IValue self, obj;
22
37
pop (stack, self, obj);
@@ -99,15 +114,15 @@ void toList(Stack& stack) {
99
114
100
115
// Rebuild the output type using elem_ty_val and dim_val. Start
101
116
// with the element type corresponding to elem_ty_val.
102
- TypePtr out_ty;
117
+ at:: TypePtr out_ty;
103
118
if (elem_ty_val == 0 ) {
104
- out_ty = IntType::get ();
119
+ out_ty = at:: IntType::get ();
105
120
} else if (elem_ty_val == 1 ) {
106
- out_ty = FloatType::get ();
121
+ out_ty = at:: FloatType::get ();
107
122
} else if (elem_ty_val == 2 ) {
108
- out_ty = BoolType::get ();
123
+ out_ty = at:: BoolType::get ();
109
124
} else if (elem_ty_val == 3 ) {
110
- out_ty = ComplexType::get ();
125
+ out_ty = at:: ComplexType::get ();
111
126
} else {
112
127
TORCH_CHECK (
113
128
false ,
@@ -120,8 +135,8 @@ void toList(Stack& stack) {
120
135
// the elements will be casted to double/c10::complex<double>
121
136
// later.
122
137
TORCH_CHECK (
123
- (out_ty == FloatType::get () && t.is_floating_point ()) ||
124
- (out_ty == ComplexType::get () && t.is_complex ()) ||
138
+ (out_ty == at:: FloatType::get () && t.is_floating_point ()) ||
139
+ (out_ty == at:: ComplexType::get () && t.is_complex ()) ||
125
140
tryScalarTypeFromJitType (*out_ty) == t.scalar_type (),
126
141
" Output annotation element type and runtime tensor element type must match for tolist()" );
127
142
@@ -134,7 +149,7 @@ void toList(Stack& stack) {
134
149
// Wrap out_ty in a ListType dim times.
135
150
for (const auto i : c10::irange (dim_val)) {
136
151
(void )i; // Suppress unused variable warning
137
- out_ty = ListType::create (out_ty);
152
+ out_ty = at:: ListType::create (out_ty);
138
153
}
139
154
140
155
int64_t dim = t.dim ();
@@ -150,7 +165,7 @@ void toList(Stack& stack) {
150
165
void numToTensorScalar (Stack& stack) {
151
166
at::Scalar s;
152
167
pop (stack, s);
153
- push (stack, at ::scalar_to_tensor (s));
168
+ push (stack, c10 ::scalar_to_tensor (s));
154
169
}
155
170
156
171
void isCuda (Stack& stack) {
@@ -163,7 +178,7 @@ void numToTensorBool(Stack& stack) {
163
178
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
164
179
bool b;
165
180
pop (stack, b);
166
- push (stack, at ::scalar_to_tensor (b));
181
+ push (stack, c10 ::scalar_to_tensor (b));
167
182
}
168
183
169
184
void dictIndex (Stack& stack) {
@@ -181,7 +196,9 @@ static const C10_UNUSED std::array<mobile::prim_op_fn_register, 15> op_reg = {
181
196
mobile::prim_op_fn_register (" aten::Bool.Tensor" , boolTensor),
182
197
mobile::prim_op_fn_register (" aten::format" , aten_format),
183
198
mobile::prim_op_fn_register (" prim::NumToTensor.Scalar" , numToTensorScalar),
184
- mobile::prim_op_fn_register (" prim::RaiseException" , raiseException),
199
+ mobile::prim_op_fn_register (
200
+ " prim::RaiseException" ,
201
+ raiseExceptionWithMessage),
185
202
mobile::prim_op_fn_register (" prim::device" , device),
186
203
mobile::prim_op_fn_register (" prim::dtype" , dtype),
187
204
mobile::prim_op_fn_register (" aten::__not__" , _not),
0 commit comments