Skip to content

Commit e86a5a3

Browse files
d1jangfacebook-github-bot
authored andcommitted
[Static Runtime] Add PyTorchPredictor::predict_managed_result to return managed output tensors (pytorch#65598)
Summary: Pull Request resolved: pytorch#65598 This change adds `PyTorchPredictor::predict_managed_result` to enable Static Runtime to return managed output tensors, allocated and owned by Static Runtime to accelerate inference workloads. - `PyTorchPredictor::predict_managed_result` does only meaningful work for the overridden `PyTorchStaticRuntimePredictor::predict_managed_result`. For other subclasses, it returns a simple object that just wraps the returned `Ivalue`. - When `manage_output_tensors` is enabled, a `StaticRuntime` cannot be reentered until its return value gets deallocated by calling `StaticRuntime::deallocateOutputTensors`. Currently an instance of `StaticRuntime` gets immediately pushed back to `static_runtime_pool` to be reentered again, and this cannot be done when `manage_output_tensors` is enabled. `PyTorchStaticRuntimePredictorManagedResult` makes sure to delay pushing a `StaticRuntime` instance back to the pool only after `StaticRuntime::deallocateOutputTensors` is called on the runtime instance. - When `manage_output_tensors` is enabled, `PyTorchStaticRuntimePredictor::predict_managed_result` returns the prediction result, whose backing memory is managed by an instance of `StaticRuntime`. The lifetime of any value reachable from `PyTorchStaticRuntimePredictorManagedResult.get()` is expected to end before `PyTorchStaticRuntimePredictorManagedResult` gets destructed. As explained above, `PyTorchPredictorManagedResult`'s destruction pushes the runtime instance that returned the result back to `static_runtime_pool` to be reused again. - The current API design of adding `predict_managed_result` instead of forcing `operator()` to return `PyTorchPredictorManagedResult` was motivated by the fact that `manage_output_tensors` will be selectively enabled just for a few models. In case `manage_output_tensors` becomes a commonly used feature we should revisit this API design to merge them together. Reviewed By: hlu1 Differential Revision: D31149323 fbshipit-source-id: 5ca026188077232d6a49a46759124a978439d7b2
1 parent 18955d3 commit e86a5a3

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

torch/csrc/jit/runtime/static/impl.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1145,17 +1145,24 @@ float StaticRuntime::benchmark_model(
11451145

11461146
const bool is_kwargs_empty = kwargs_list.size() == 0;
11471147
const std::unordered_map<std::string, c10::IValue> empty_kwargs;
1148+
bool manage_output_tensors = static_module_.opts().manage_output_tensors;
11481149
for (const auto i : c10::irange(warmup_runs)) {
11491150
(void)i; // Suppress unused variable warning
11501151
for (const auto j : c10::irange(args_list.size())) {
11511152
operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
1153+
if (manage_output_tensors) {
1154+
deallocateOutputTensors();
1155+
}
11521156
}
11531157
}
11541158
caffe2::Timer timer;
11551159
for (const auto i : c10::irange(main_runs)) {
11561160
(void)i; // Suppress unused variable warning
11571161
for (const auto j : c10::irange(args_list.size())) {
11581162
operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
1163+
if (manage_output_tensors) {
1164+
deallocateOutputTensors();
1165+
}
11591166
}
11601167
}
11611168
float millis = timer.MilliSeconds();
@@ -1253,7 +1260,7 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
12531260

12541261
const bool is_kwargs_empty = kwargs_list.size() == 0;
12551262
const std::unordered_map<std::string, c10::IValue> empty_kwargs;
1256-
1263+
bool manage_output_tensors = static_module_.opts().manage_output_tensors;
12571264
// See comment on above use of InferenceMode for
12581265
// explanation.
12591266
c10::InferenceMode mode;
@@ -1273,13 +1280,19 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
12731280
// iterations just use the already established memory planning.
12741281
timer.Start();
12751282
operator()(args_list[0], is_kwargs_empty ? empty_kwargs : kwargs_list[0]);
1283+
if (manage_output_tensors) {
1284+
deallocateOutputTensors();
1285+
}
12761286
results.first_iter_time = timer.MilliSeconds();
12771287

12781288
// warmup runs
12791289
for (const auto i : c10::irange(warmup_runs - 1)) {
12801290
(void)i; // Suppress unused variable warning
12811291
for (const auto j : c10::irange(args_list.size())) {
12821292
operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]);
1293+
if (manage_output_tensors) {
1294+
deallocateOutputTensors();
1295+
}
12831296
}
12841297
}
12851298

@@ -1310,6 +1323,9 @@ StaticRuntime::IndividualMetrics StaticRuntime::benchmark_individual_ops(
13101323
// clean up owning refs of input tensors
13111324
clean_up_input_ivalues();
13121325
}
1326+
if (manage_output_tensors) {
1327+
deallocateOutputTensors();
1328+
}
13131329
millis = timer.MilliSeconds();
13141330
results.memory_dealloc_time += millis;
13151331

0 commit comments

Comments
 (0)