Skip to content

Commit bf69a61

Browse files
pavithranraopytorchmergebot
authored andcommitted
(1/2) Make TorchScript Preserve Fully Qualified Class Name for Python Exceptions: backend change
Summary: Reland for D33282878 (pytorch@911d527) . Land backend change first to maintain FC. Will wait for 2 weeks after this diff is in. And than land the front-end change in next diff. Test Plan: test in next diff time buck test mode/dev-nosan fblearner/flow/projects/langtech/translation:tests -- test_e2e_base_training Reviewed By: gmagogsfm Differential Revision: D33342547 fbshipit-source-id: b3dee9a4bdfd78103848c12629e5fccafdd621e3 (cherry picked from commit ae1935f)
1 parent 027c0d7 commit bf69a61

18 files changed

+527
-163
lines changed

caffe2/serialize/versions.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ constexpr uint64_t kProducedBytecodeVersion = 0x7L;
108108
// we should support this model_version. For example, we provide a wrapper to
109109
// handle an updated operator.
110110
constexpr uint64_t kMinSupportedBytecodeVersion = 0x3L;
111-
constexpr uint64_t kMaxSupportedBytecodeVersion = 0x7L;
111+
constexpr uint64_t kMaxSupportedBytecodeVersion = 0x8L;
112112

113113
} // namespace serialize
114114
} // namespace caffe2

test/cpp/jit/test_backend.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,27 @@ TEST(BackendTest, TestComposite) {
187187
AT_ASSERT(res_jit.toTensor().equal(res_mobile.toTensor()));
188188
}
189189

190+
TEST(BackendTest, TestPrimDtype) {
191+
Module c("name");
192+
c.define(R"(
193+
def forward(self, x, y):
194+
c = y.dtype
195+
return c
196+
)");
197+
198+
std::vector<IValue> inputs;
199+
inputs.emplace_back(3.0 * torch::ones({}));
200+
inputs.emplace_back(1.0 * torch::ones({}));
201+
auto res_jit = c.forward(inputs);
202+
203+
std::stringstream ss;
204+
c._save_for_mobile(ss);
205+
auto mc = _load_for_mobile(ss);
206+
auto res_mobile = mc.forward(inputs);
207+
208+
ASSERT_EQ(res_jit.toInt(), res_mobile.toInt());
209+
}
210+
190211
Module getCompositeModuleWithSameNameSubModules() {
191212
// Two submodules with same module name but different forward and other
192213
// functions should be serialized and loaded correctly.

test/cpp/jit/test_lite_interpreter.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,42 @@ TEST(LiteInterpreterTest, Tuple) {
186186
AT_ASSERT(output.toTupleRef().elements()[1].toInt() == 2);
187187
}
188188

189+
TEST(LiteInterpreterTest, AtenFormat) {
190+
Module m("m");
191+
m.define(R"""(
192+
def forward(self, fmt:str="first {} {}", num:str="abc"):
193+
x = 2
194+
x = x * x
195+
return fmt.format(num, x)
196+
)""");
197+
std::stringstream ss;
198+
m._save_for_mobile(ss);
199+
mobile::Module bc = _load_for_mobile(ss);
200+
std::vector<torch::jit::IValue> inputs;
201+
auto output_bc = bc.get_method("forward")(inputs);
202+
auto output_m = m.get_method("forward")(inputs);
203+
// std::cout << output_m.toStringRef() << "\n"
204+
// << output_bc.toStringRef() << std::endl;
205+
AT_ASSERT(output_m.toStringRef() == output_bc.toStringRef());
206+
}
207+
208+
TEST(LiteInterpreterTest, PrimDevice) {
209+
Module m("m");
210+
m.define(R"""(
211+
def forward(self, x:torch.Tensor):
212+
return x.device
213+
)""");
214+
std::stringstream ss;
215+
m._save_for_mobile(ss);
216+
mobile::Module bc = _load_for_mobile(ss);
217+
std::vector<torch::jit::IValue> inputs;
218+
auto minput = 3.5 * torch::ones({});
219+
inputs.emplace_back(minput);
220+
auto output_bc = bc.get_method("forward")(inputs);
221+
auto output_m = m.get_method("forward")(inputs);
222+
AT_ASSERT(output_bc.toDevice().str() == output_m.toDevice().str());
223+
}
224+
189225
TEST(LiteInterpreterTest, Dict) {
190226
Module m("m");
191227
m.define(R"JIT(

tools/build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ core_sources_common = [
115115
"torch/csrc/jit/runtime/instruction.cpp",
116116
"torch/csrc/jit/runtime/jit_exception.cpp",
117117
"torch/csrc/jit/runtime/operator.cpp",
118+
"torch/csrc/jit/mobile/register_ops_common_utils.cpp",
118119
"torch/csrc/jit/runtime/print_handler.cpp",
119120
"torch/csrc/jit/runtime/slice_indices_adjust.cpp",
120121
"torch/csrc/jit/runtime/register_ops_utils.cpp",

torch/csrc/jit/mobile/interpreter.cpp

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
#include <ATen/core/function.h>
66
#include <ATen/core/jit_type.h>
77
#include <ATen/core/operator_name.h>
8-
#include <torch/csrc/jit/mobile/function.h>
9-
#include <torch/csrc/jit/runtime/jit_exception.h>
10-
#include <torch/csrc/jit/runtime/vararg_functions.h>
11-
128
#include <ATen/record_function.h>
139
#include <c10/util/Exception.h>
1410
#include <c10/util/irange.h>
1511
#include <torch/csrc/jit/backends/backend_exception.h>
12+
#include <torch/csrc/jit/mobile/function.h>
1613
#include <torch/csrc/jit/mobile/observer.h>
14+
#include <torch/csrc/jit/mobile/promoted_prim_ops.h>
15+
#include <torch/csrc/jit/runtime/jit_exception.h>
16+
#include <torch/csrc/jit/runtime/vararg_functions.h>
1717

1818
namespace torch {
1919
namespace jit {
@@ -248,6 +248,62 @@ bool InterpreterState::run(Stack& stack) {
248248
tupleSlice(stack, inst.X, inst.X + inst.N);
249249
frame.step();
250250
} break;
251+
case TUPLE_INDEX: {
252+
tupleIndex(stack);
253+
frame.step();
254+
} break;
255+
case RAISE_EXCEPTION: {
256+
raiseExceptionWithMessage(stack);
257+
frame.step();
258+
} break;
259+
case __IS__: {
260+
is(stack);
261+
frame.step();
262+
} break;
263+
case UN_INITIALIZED: {
264+
unInitialized(stack);
265+
frame.step();
266+
} break;
267+
case __ISNOT__: {
268+
isNot(stack);
269+
frame.step();
270+
} break;
271+
case FORMAT: {
272+
format(stack, inst.X);
273+
frame.step();
274+
} break;
275+
case DEVICE: {
276+
device(stack);
277+
frame.step();
278+
} break;
279+
case DTYPE: {
280+
dtype(stack);
281+
frame.step();
282+
} break;
283+
case DIM: {
284+
dim(stack);
285+
frame.step();
286+
} break;
287+
case __NOT__: {
288+
_not(stack);
289+
frame.step();
290+
} break;
291+
case DICT_INDEX: {
292+
dictIndex(stack);
293+
frame.step();
294+
} break;
295+
case TO_LIST: {
296+
toList(stack);
297+
frame.step();
298+
} break;
299+
case NUM_TO_TENSOR: {
300+
numToTensorScalar(stack);
301+
frame.step();
302+
} break;
303+
case IS_CUDA: {
304+
isCuda(stack);
305+
frame.step();
306+
} break;
251307
case DICT_CONSTRUCT: {
252308
dictConstruct(stack, *code.types_[inst.X], inst.N);
253309
frame.step();

torch/csrc/jit/mobile/promoted_prim_ops.cpp

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
#include <ATen/ScalarOps.h>
12
#include <torch/csrc/jit/mobile/promoted_prim_ops.h>
2-
33
namespace torch {
44
namespace jit {
5+
56
void tupleIndex(Stack& stack) {
67
int64_t index = pop(stack).toInt();
78
auto tuple = pop(stack).toTuple();
@@ -14,9 +15,23 @@ void tupleIndex(Stack& stack) {
1415
}
1516

1617
void raiseException(Stack& stack) {
18+
// this kernel supports RaiseException with only one argument: the error
19+
// DEPRECATED from bytecode_version 8;
20+
// Please do not make any changes to this to support BC
1721
throw JITException(pop(stack).toStringRef());
1822
}
1923

24+
void raiseExceptionWithMessage(Stack& stack) {
25+
// this kernel supports RaiseException with only two arguments: the error and
26+
// the message Please make changes only to this kernel
27+
c10::optional<std::string> qualified_class_name =
28+
pop(stack).toOptional<std::string>();
29+
std::string message;
30+
pop(stack, message);
31+
32+
throw JITException(message, qualified_class_name);
33+
}
34+
2035
void is(Stack& stack) {
2136
IValue self, obj;
2237
pop(stack, self, obj);
@@ -99,15 +114,15 @@ void toList(Stack& stack) {
99114

100115
// Rebuild the output type using elem_ty_val and dim_val. Start
101116
// with the element type corresponding to elem_ty_val.
102-
TypePtr out_ty;
117+
at::TypePtr out_ty;
103118
if (elem_ty_val == 0) {
104-
out_ty = IntType::get();
119+
out_ty = at::IntType::get();
105120
} else if (elem_ty_val == 1) {
106-
out_ty = FloatType::get();
121+
out_ty = at::FloatType::get();
107122
} else if (elem_ty_val == 2) {
108-
out_ty = BoolType::get();
123+
out_ty = at::BoolType::get();
109124
} else if (elem_ty_val == 3) {
110-
out_ty = ComplexType::get();
125+
out_ty = at::ComplexType::get();
111126
} else {
112127
TORCH_CHECK(
113128
false,
@@ -120,8 +135,8 @@ void toList(Stack& stack) {
120135
// the elements will be casted to double/c10::complex<double>
121136
// later.
122137
TORCH_CHECK(
123-
(out_ty == FloatType::get() && t.is_floating_point()) ||
124-
(out_ty == ComplexType::get() && t.is_complex()) ||
138+
(out_ty == at::FloatType::get() && t.is_floating_point()) ||
139+
(out_ty == at::ComplexType::get() && t.is_complex()) ||
125140
tryScalarTypeFromJitType(*out_ty) == t.scalar_type(),
126141
"Output annotation element type and runtime tensor element type must match for tolist()");
127142

@@ -134,7 +149,7 @@ void toList(Stack& stack) {
134149
// Wrap out_ty in a ListType dim times.
135150
for (const auto i : c10::irange(dim_val)) {
136151
(void)i; // Suppress unused variable warning
137-
out_ty = ListType::create(out_ty);
152+
out_ty = at::ListType::create(out_ty);
138153
}
139154

140155
int64_t dim = t.dim();
@@ -150,7 +165,7 @@ void toList(Stack& stack) {
150165
void numToTensorScalar(Stack& stack) {
151166
at::Scalar s;
152167
pop(stack, s);
153-
push(stack, at::scalar_to_tensor(s));
168+
push(stack, c10::scalar_to_tensor(s));
154169
}
155170

156171
void isCuda(Stack& stack) {
@@ -163,7 +178,7 @@ void numToTensorBool(Stack& stack) {
163178
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
164179
bool b;
165180
pop(stack, b);
166-
push(stack, at::scalar_to_tensor(b));
181+
push(stack, c10::scalar_to_tensor(b));
167182
}
168183

169184
void dictIndex(Stack& stack) {
@@ -181,7 +196,9 @@ static const C10_UNUSED std::array<mobile::prim_op_fn_register, 15> op_reg = {
181196
mobile::prim_op_fn_register("aten::Bool.Tensor", boolTensor),
182197
mobile::prim_op_fn_register("aten::format", aten_format),
183198
mobile::prim_op_fn_register("prim::NumToTensor.Scalar", numToTensorScalar),
184-
mobile::prim_op_fn_register("prim::RaiseException", raiseException),
199+
mobile::prim_op_fn_register(
200+
"prim::RaiseException",
201+
raiseExceptionWithMessage),
185202
mobile::prim_op_fn_register("prim::device", device),
186203
mobile::prim_op_fn_register("prim::dtype", dtype),
187204
mobile::prim_op_fn_register("aten::__not__", _not),

torch/csrc/jit/mobile/promoted_prim_ops.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22
#include <torch/csrc/jit/mobile/prim_ops_registery.h>
3-
#include <torch/csrc/jit/runtime/register_ops_utils.h>
3+
#include <torch/csrc/jit/mobile/register_ops_common_utils.h>
44

55
namespace torch {
66
namespace jit {
@@ -41,5 +41,7 @@ void numToTensorBool(Stack& stack);
4141

4242
void dictIndex(Stack& stack);
4343

44+
void raiseExceptionWithMessage(Stack& stack);
45+
4446
} // namespace jit
4547
} // namespace torch
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#include <ATen/core/dynamic_type.h>
2+
#include <ATen/core/type_factory.h>
3+
#include <torch/csrc/jit/mobile/register_ops_common_utils.h>
4+
5+
namespace torch {
6+
namespace jit {
7+
8+
int64_t normalizeIndex(int64_t idx, int64_t list_size) {
9+
if (idx < 0) {
10+
// Handle negative indexing
11+
idx = list_size + idx;
12+
}
13+
return idx;
14+
}
15+
16+
IValue tensorToListRecursive(
17+
char* data,
18+
int64_t cur_dim,
19+
int64_t num_tensor_dims,
20+
at::TypePtr ty,
21+
at::ScalarType scalar_ty,
22+
at::IntArrayRef sizes,
23+
at::IntArrayRef strides,
24+
size_t element_size) {
25+
// If ty is a ListType, get the element type.
26+
if (auto list_type = ty->cast<at::ListType>()) {
27+
ty = list_type->getElementType();
28+
} else {
29+
// If the output type is a scalar, read and push one scalar of
30+
// the right type onto the stack.
31+
if (ty == at::IntType::get()) {
32+
int64_t scalar = *(int64_t*)data;
33+
return IValue(scalar);
34+
} else if (ty == at::FloatType::get()) {
35+
TORCH_INTERNAL_ASSERT(
36+
scalar_ty == at::ScalarType::Float ||
37+
scalar_ty == at::ScalarType::Double,
38+
"Unexpected scalar type for Tensor");
39+
double scalar =
40+
scalar_ty == at::ScalarType::Float ? *(float*)data : *(double*)data;
41+
return IValue(scalar);
42+
} else if (ty == at::ComplexType::get()) {
43+
TORCH_INTERNAL_ASSERT(
44+
scalar_ty == at::ScalarType::ComplexFloat ||
45+
scalar_ty == at::ScalarType::ComplexDouble,
46+
"Unexpected scalar type for Tensor");
47+
c10::complex<double> scalar = scalar_ty == at::ScalarType::ComplexFloat
48+
? *(c10::complex<float>*)data
49+
: *(c10::complex<double>*)data;
50+
return IValue(scalar);
51+
} else if (ty == at::BoolType::get()) {
52+
bool scalar = *(bool*)data;
53+
return IValue(scalar);
54+
} else {
55+
TORCH_CHECK(
56+
false,
57+
ty->repr_str(),
58+
" is not one of the supported types for tolist: int, float, bool");
59+
}
60+
}
61+
62+
// Make the result list consisting of elements of type ty. Since this
63+
// invocation is processing dimension cur_dim, there will be sizes[cur_dim]
64+
// output elements.
65+
auto result = c10::impl::GenericList(ty);
66+
result.reserve(sizes[cur_dim]);
67+
68+
// Since ty was a list type, tensorToListRecursive needs to be called
69+
// recursively on each slice of the tensor in the current dimension.
70+
for (int64_t i = 0, e = sizes[cur_dim]; i < e; ++i) {
71+
auto inner_result = tensorToListRecursive(
72+
data,
73+
cur_dim + 1,
74+
num_tensor_dims,
75+
ty,
76+
scalar_ty,
77+
sizes,
78+
strides,
79+
element_size);
80+
81+
if (inner_result.isList()) {
82+
result.emplace_back(inner_result.toList());
83+
} else if (inner_result.isComplexDouble()) {
84+
result.emplace_back(inner_result.toComplexDouble());
85+
} else if (inner_result.isDouble()) {
86+
result.emplace_back(inner_result.toDouble());
87+
} else if (inner_result.isInt()) {
88+
result.emplace_back(inner_result.toInt());
89+
} else if (inner_result.isBool()) {
90+
result.emplace_back(inner_result.toBool());
91+
} else {
92+
TORCH_INTERNAL_ASSERT("Unknown return type for tensorToListRecursive");
93+
}
94+
95+
data += strides[cur_dim] * element_size;
96+
}
97+
98+
return result;
99+
}
100+
101+
} // namespace jit
102+
} // namespace torch

0 commit comments

Comments
 (0)