Skip to content

Commit 1c5f63d

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
[Pytorch Edge] Model Ops compatibility api (pytorch#57501)
Summary: Pull Request resolved: pytorch#57501 Add an api _get_model_ops_and_info to get root operators and versioning info of a model in both cxx and python, and the input can be from a file path or buffer. ghstack-source-id: 129620112 Test Plan: unit test. Reviewed By: xcheng16, raziel Differential Revision: D28162765 fbshipit-source-id: 4413c1e906b8a872e4a717d849da37347adbbea4
1 parent 2a456e4 commit 1c5f63d

File tree

9 files changed

+201
-9
lines changed

9 files changed

+201
-9
lines changed

test/cpp/jit/script_module_v6.ptl

3.64 KB
Binary file not shown.

test/mobile/test_bytecode.py

+45-5
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
from torch.jit.mobile import (
99
_load_for_lite_interpreter,
1010
_get_model_bytecode_version,
11+
_get_model_ops_and_info,
1112
_backport_for_mobile_to_buffer,
1213
_backport_for_mobile)
1314
from torch.testing._internal.common_utils import TestCase, run_tests
1415
from pathlib import Path
1516

16-
pytorch_test_dri = Path(__file__).resolve().parents[1]
17+
pytorch_test_dir = Path(__file__).resolve().parents[1]
1718

1819
# script_module_v4.ptl and script_module_v5.ptl source code
1920
# class TestModule(torch.nn.Module):
@@ -97,6 +98,38 @@
9798
((('name', ''), ('type', 'Tensor'), ('default_value', None)),)))))
9899
'''
99100

101+
SCRIPT_MODULE_V6_BYTECODE_PKL = '''
102+
(6,
103+
('__torch__.*.TestModule.forward',
104+
(('instructions',
105+
(('STOREN', 1, 2),
106+
('DROPR', 1, 0),
107+
('LOADC', 0, 0),
108+
('LOADC', 1, 0),
109+
('MOVE', 2, 0),
110+
('OP', 0, 0),
111+
('OP', 1, 0),
112+
('RET', 0, 0))),
113+
('operators', (('aten::add', 'int', 2), ('aten::add', 'Scalar', 2))),
114+
('constants',
115+
(torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage, '0', 'cpu', 8),),
116+
0,
117+
(2, 4),
118+
(4, 1),
119+
False,
120+
collections.OrderedDict()),
121+
1)),
122+
('types', ()),
123+
('register_size', 2)),
124+
(('arguments',
125+
((('name', 'self'),
126+
('type', '__torch__.*.TestModule'),
127+
('default_value', None)),
128+
(('name', 'y'), ('type', 'int'), ('default_value', None)))),
129+
('returns',
130+
((('name', ''), ('type', 'Tensor'), ('default_value', None)),)))))
131+
'''
132+
100133
SCRIPT_MODULE_BYTECODE_PKL = {
101134
4: {
102135
"bytecode_pkl": SCRIPT_MODULE_V4_BYTECODE_PKL,
@@ -113,7 +146,7 @@ def check_model_version(model_path, expect_version):
113146
actual_version = _get_model_bytecode_version(model_path)
114147
assert(actual_version == expect_version)
115148
for version, model_info in SCRIPT_MODULE_BYTECODE_PKL.items():
116-
model_path = pytorch_test_dri / "cpp" / "jit" / model_info["model_name"]
149+
model_path = pytorch_test_dir / "cpp" / "jit" / model_info["model_name"]
117150
check_model_version(model_path, version)
118151

119152
def test_bytecode_values_for_all_backport_functions(self):
@@ -130,7 +163,7 @@ def test_bytecode_values_for_all_backport_functions(self):
130163
while current_from_version > MINIMUM_TO_VERSION:
131164
# Load model v5 and run forward method
132165
model_name = SCRIPT_MODULE_BYTECODE_PKL[current_from_version]["model_name"]
133-
input_model_path = pytorch_test_dri / "cpp" / "jit" / model_name
166+
input_model_path = pytorch_test_dir / "cpp" / "jit" / model_name
134167

135168
# A temporary model file will be export to this path, and run through bytecode.pkl
136169
# content check.
@@ -205,7 +238,7 @@ def forward(self, y: int):
205238
# Check just the test_backport_bytecode_from_file_to_file mechanism but not the function implementations
206239
def test_backport_bytecode_from_file_to_file(self):
207240
maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys())
208-
script_module_v5_path = pytorch_test_dri / "cpp" / "jit" / SCRIPT_MODULE_BYTECODE_PKL[
241+
script_module_v5_path = pytorch_test_dir / "cpp" / "jit" / SCRIPT_MODULE_BYTECODE_PKL[
209242
maximum_checked_in_model_version]["model_name"]
210243

211244
if (maximum_checked_in_model_version > MINIMUM_TO_VERSION):
@@ -241,7 +274,7 @@ def test_backport_bytecode_from_file_to_file(self):
241274
# Check just the _backport_for_mobile_to_buffer mechanism but not the function implementations
242275
def test_backport_bytecode_from_file_to_buffer(self):
243276
maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys())
244-
script_module_v5_path = pytorch_test_dri / "cpp" / "jit" / SCRIPT_MODULE_BYTECODE_PKL[
277+
script_module_v5_path = pytorch_test_dir / "cpp" / "jit" / SCRIPT_MODULE_BYTECODE_PKL[
245278
maximum_checked_in_model_version]["model_name"]
246279

247280
if (maximum_checked_in_model_version > MINIMUM_TO_VERSION):
@@ -264,5 +297,12 @@ def test_backport_bytecode_from_file_to_buffer(self):
264297
torch.testing.assert_allclose(mobile_module_result, expected_mobile_module_result)
265298

266299

300+
def test_get_model_ops_and_info(self):
301+
# TODO update this to be more in the style of the above tests after a backport from 6 -> 5 exists
302+
script_module_v6 = pytorch_test_dir / "cpp" / "jit" / "script_module_v6.ptl"
303+
ops_v6 = _get_model_ops_and_info(script_module_v6)
304+
assert(ops_v6["aten::add.int"].num_schema_args == 2)
305+
assert(ops_v6["aten::add.Scalar"].num_schema_args == 2)
306+
267307
if __name__ == '__main__':
268308
run_tests()

torch/_C/__init__.pyi.in

+2
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,8 @@ def _backport_for_mobile(filename_input: Union[str, Path], filename_output: Unio
260260
def _backport_for_mobile_from_buffer(buffer: BinaryIO, filename_output: Union[str, Path], to_version: _int) -> None: ...
261261
def _backport_for_mobile_to_buffer(filename_input: Union[str, Path], to_version: _int) -> bytes:...
262262
def _backport_for_mobile_from_buffer_to_buffer(buffer: BinaryIO, to_version: _int) -> bytes:...
263+
def _get_model_ops_and_info(filename: Union[str, Path]): ...
264+
def _get_model_ops_and_info_from_buffer(buffer: BinaryIO): ...
263265
def _logging_set_logger(logger: LoggerBase) -> LoggerBase: ...
264266
def _get_graph_executor_optimize() -> _bool: ...
265267
def _set_graph_executor_optimize(optimize: _bool): ...

torch/csrc/jit/mobile/backport_manager.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ bool backport_v5_to_v4(
149149
PyTorchStreamReader& reader,
150150
PyTorchStreamWriter& writer) {
151151
// 1) read from archive `bytecode` archive
152-
std::vector<IValue> bytecode_values = get_bytecode_values(reader);
152+
std::vector<IValue> bytecode_values = get_bytecode_ivalues(reader);
153153
if (!check_bytecode_version(bytecode_values, kBytecodeVersionV5)) {
154154
TORCH_WARN("Incorrect bytecode version for input model.");
155155
return false;

torch/csrc/jit/mobile/model_compatibility.cpp

+86-2
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,14 @@ c10::IValue readArchive(
4949
return ivalues;
5050
}
5151

52-
std::vector<IValue> get_bytecode_values(PyTorchStreamReader& reader) {
52+
std::vector<IValue> get_bytecode_ivalues(PyTorchStreamReader& reader) {
5353
std::vector<IValue> bytecode_values;
5454
bytecode_values = readArchive("bytecode", reader).toTuple()->elements();
5555
return bytecode_values;
5656
}
5757

58+
/********************** Bytecode **********************/
59+
5860
// Forward declare
5961
int64_t _get_model_bytecode_version(
6062
const std::vector<IValue>& bytecode_ivalues);
@@ -76,7 +78,7 @@ int64_t _get_model_bytecode_version(std::shared_ptr<ReadAdapterInterface> rai) {
7678
return -1;
7779
}
7880
PyTorchStreamReader reader(std::move(rai));
79-
auto bytecode_values = get_bytecode_values(reader);
81+
auto bytecode_values = get_bytecode_ivalues(reader);
8082
return _get_model_bytecode_version(bytecode_values);
8183
}
8284

@@ -90,5 +92,87 @@ int64_t _get_model_bytecode_version(
9092
return -1;
9193
}
9294

95+
/********************** Operators and Info **********************/
96+
97+
// Forward declare
98+
std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
99+
std::vector<IValue> bytecode_ivalues);
100+
101+
std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
102+
std::istream& in) {
103+
std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
104+
return _get_model_ops_and_info(std::move(rai));
105+
}
106+
107+
std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
108+
const std::string& filename) {
109+
std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
110+
return _get_model_ops_and_info(std::move(rai));
111+
}
112+
113+
std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
114+
std::shared_ptr<ReadAdapterInterface> rai) {
115+
if (!check_zip_file(rai)) {
116+
TORCH_WARN("Failed to open zip file for model ops.");
117+
return std::unordered_map<std::string, OperatorInfo>{};
118+
}
119+
PyTorchStreamReader reader(std::move(rai));
120+
auto bytecode_values = get_bytecode_ivalues(reader);
121+
return _get_model_ops_and_info(bytecode_values);
122+
}
123+
124+
/* A function to retrieve the root (top level) operators of a model and their
125+
* corresponding compatibility info. These root operators can call other
126+
* operators within them (traced ops), and a root op can call many different
127+
* traced ops depending on internal code paths in the root op. These traced ops
128+
* are not returned by this function. Those operators are abstracted into the
129+
* runtime as an implementation detail (and the traced ops themselves can also
130+
* call other operators) making retrieving them difficult and their value from
131+
* this api negligible since they will differ between which runtime version the
132+
* model is run on. Because of this, there is a false positive this api can't
133+
* prevent in a compatibility usecase. All the root ops of a model are present
134+
* in a target runtime, but not all the traced ops are which prevents a model
135+
* from being able to run.
136+
**/
137+
std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
138+
std::vector<IValue> bytecode_ivalues) {
139+
constexpr uint64_t min_version_with_schema = 6;
140+
if (_get_model_bytecode_version(bytecode_ivalues) < min_version_with_schema) {
141+
TORCH_WARN(
142+
"Only models with bytecode version 6 and above contain operator schema information. Please re-export your model to generate it");
143+
}
144+
std::unordered_map<std::string, OperatorInfo> result;
145+
if (bytecode_ivalues.empty()) {
146+
TORCH_WARN("Failed to get model ops and info.");
147+
return result;
148+
}
149+
// loop over all the functions in the bytecode
150+
for (int i = 1; i < bytecode_ivalues.size(); i++) {
151+
// descend to the operators list
152+
auto method_tuple = bytecode_ivalues.at(i).toTuple()->elements();
153+
auto operators_tuple = method_tuple.at(1).toTuple()->elements()[1];
154+
auto operators = operators_tuple.toTuple()->elements()[1];
155+
for (auto& op_tuple : operators.toTuple()->elements()) {
156+
auto op = op_tuple.toTuple()->elements();
157+
158+
// grab name
159+
std::string op_name = op.at(0).toStringRef();
160+
std::string op_overload_name = op.at(1).toStringRef();
161+
if (op_overload_name != "") {
162+
op_name.append(".");
163+
op_name.append(op_overload_name);
164+
}
165+
166+
// grab schema size
167+
if (op.size() > 2) {
168+
result.emplace(op_name, OperatorInfo{(int)op.at(2).toInt()});
169+
} else { // no schema information use default
170+
result.emplace(op_name, OperatorInfo{});
171+
}
172+
}
173+
}
174+
return result;
175+
}
176+
93177
} // namespace jit
94178
} // namespace torch

torch/csrc/jit/mobile/model_compatibility.h

+14-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
#pragma once
22

3+
#include <torch/csrc/jit/mobile/runtime_compatibility.h>
4+
35
#include <istream>
46
#include <memory>
7+
#include <unordered_map>
58

69
namespace caffe2 {
710
namespace serialize {
@@ -24,7 +27,7 @@ TORCH_API int64_t _get_model_bytecode_version(
2427
int64_t _get_model_bytecode_version(
2528
const std::vector<c10::IValue>& bytecode_ivalues);
2629

27-
std::vector<c10::IValue> get_bytecode_values(
30+
std::vector<c10::IValue> get_bytecode_ivalues(
2831
caffe2::serialize::PyTorchStreamReader& reader);
2932

3033
c10::IValue readArchive(
@@ -34,5 +37,15 @@ c10::IValue readArchive(
3437
bool check_zip_file(
3538
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai);
3639

40+
// The family of methods below to get the root ops and information from a model
41+
TORCH_API std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
42+
std::istream& in);
43+
44+
TORCH_API std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
45+
const std::string& filename);
46+
47+
TORCH_API std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
48+
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai);
49+
3750
} // namespace jit
3851
} // namespace torch

torch/csrc/jit/mobile/runtime_compatibility.h

+4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
#pragma once
22

3+
#include <c10/util/Optional.h>
4+
35
#include <memory>
46
#include <unordered_map>
57

68
namespace torch {
79
namespace jit {
810

11+
// Struct storing metadata of an operator that can be useful for versioning
912
struct OperatorInfo {
13+
// The number of arguments within the schema of the op
1014
c10::optional<int> num_schema_args;
1115
};
1216

torch/csrc/jit/python/script_init.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -1770,6 +1770,15 @@ void initJitScriptBindings(PyObject* module) {
17701770
std::istringstream in(buffer);
17711771
return _get_model_bytecode_version(in);
17721772
});
1773+
py::class_<OperatorInfo>(m, "OperatorInfo")
1774+
.def_readonly("num_schema_args", &OperatorInfo::num_schema_args);
1775+
m.def("_get_model_ops_and_info", [](const std::string& filename) {
1776+
return _get_model_ops_and_info(filename);
1777+
});
1778+
m.def("_get_model_ops_and_info_from_buffer", [](const std::string& buffer) {
1779+
std::istringstream in(buffer);
1780+
return _get_model_ops_and_info(in);
1781+
});
17731782
m.def("_export_operator_list", [](torch::jit::mobile::Module& sm) {
17741783
return debugMakeSet(torch::jit::mobile::_export_operator_list(sm));
17751784
});

torch/jit/mobile/__init__.py

+40
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,43 @@ def _backport_for_mobile_to_buffer(f_input, to_version):
145145
return torch._C._backport_for_mobile_to_buffer(str(f_input), to_version) # type: ignore[attr-defined]
146146
else:
147147
return torch._C._backport_for_mobile_from_buffer_to_buffer(f_input.read(), to_version) # type: ignore[attr-defined]
148+
149+
def _get_model_ops_and_info(f_input):
150+
r"""
151+
A function to retrieve the root (top level) operators of a model and their corresponding
152+
compatibility info. These root operators can call other operators within them (traced ops), and
153+
a root op can call many different traced ops depending on internal code paths in the root op.
154+
These traced ops are not returned by this function. Those operators are abstracted into the
155+
runtime as an implementation detail (and the traced ops themselves can also call other operators)
156+
making retrieving them difficult and their value from this api negligible since they will differ
157+
between which runtime version the model is run on. Because of this, there is a false positive this
158+
api can't prevent in a compatibility usecase. All the root ops of a model are present in a
159+
target runtime, but not all the traced ops are which prevents a model from being able to run.
160+
Args:
161+
f_input: a file-like object (has to implement read, readline, tell, and seek),
162+
or a string containing a file name
163+
164+
Returns:
165+
Operators and info: A Dictionary mapping strings (the qualified names of the root operators)
166+
of the model to their OperatorInfo structs.
167+
168+
Example:
169+
170+
.. testcode::
171+
172+
from torch.jit.mobile import _get_model_ops_and_info
173+
174+
# Get bytecode version from a saved file path
175+
ops_and_info = _get_model_ops_and_info("path/to/model.ptl")
176+
177+
"""
178+
if isinstance(f_input, str):
179+
if not os.path.exists(f_input):
180+
raise ValueError(f"The provided filename {f_input} does not exist")
181+
if os.path.isdir(f_input):
182+
raise ValueError(f"The provided filename {f_input} is a directory")
183+
184+
if (isinstance(f_input, str) or isinstance(f_input, pathlib.Path)):
185+
return torch._C._get_model_ops_and_info(str(f_input)) # type: ignore[attr-defined]
186+
else:
187+
return torch._C._get_model_ops_and_info(f_input.read()) # type: ignore[attr-defined]

0 commit comments

Comments
 (0)