Skip to content

Commit ec5bbef

Browse files
SherlockNoMadpytorchmergebot
authored andcommitted
[AOTInductor] Switch ProxyExecutor to use AtenTensorHandle (pytorch#109748)
Summary: Switch ProxyExecutor to use AtenTensorHandle. Test Plan: E2E Test Differential Revision: D49471659 Pull Request resolved: pytorch#109748 Approved by: https://github.com/yifuwang, https://github.com/desertfire, https://github.com/chenyang78
1 parent 633bd07 commit ec5bbef

File tree

6 files changed

+55
-15
lines changed

6 files changed

+55
-15
lines changed

torch/_export/serde/serialize.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,12 @@ def is_sym_bool_arg(self, arg) -> bool:
526526

527527
def serialize_input(self, arg) -> Argument:
528528
import torch._inductor.ir as inductor_ir
529-
inductor_tensor_buffers = (inductor_ir.InputBuffer, inductor_ir.ComputedBuffer, inductor_ir.ConcatKernel)
529+
inductor_tensor_buffers = (
530+
inductor_ir.InputBuffer,
531+
inductor_ir.ComputedBuffer,
532+
inductor_ir.ConcatKernel,
533+
inductor_ir.ExternKernelOut,
534+
)
530535

531536
if isinstance(arg, torch.fx.Node):
532537
if arg.op == "get_attr":

torch/_inductor/codegen/wrapper.py

+31-9
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,7 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed(
488488
cpp_kernel_overload_name="",
489489
op_overload=None,
490490
raw_args=None,
491+
outputs=None,
491492
):
492493
self.writeline(f"{name} = {kernel}({', '.join(codegen_args)})")
493494

@@ -836,7 +837,7 @@ def make_buffer_free(self, buffer):
836837
return f"del {buffer.get_name()}"
837838

838839
def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str):
839-
return f"{self.declare}{new_name} = {old_name}{del_line} {self.comment} reuse"
840+
return f"{self.declare}{new_name} = {old_name}{del_line}{self.ending} {self.comment} reuse"
840841

841842
def make_buffer_reuse(self, old, new):
842843
assert old.get_dtype() == new.get_dtype()
@@ -1664,10 +1665,16 @@ def fill_args(arg, arg_type):
16641665
torch.Type,
16651666
torch.DeviceObjType,
16661667
)
1668+
inductor_tensor_buffers = (
1669+
ir.InputBuffer,
1670+
ir.ComputedBuffer,
1671+
ir.ConcatKernel,
1672+
ir.ExternKernelOut,
1673+
)
16671674

16681675
if isinstance(arg_type, torch.TensorType):
1669-
assert isinstance(arg, (ir.InputBuffer, ir.ComputedBuffer))
1670-
new_tensor_args.append(f"&{arg.name}")
1676+
assert isinstance(arg, inductor_tensor_buffers)
1677+
new_tensor_args.append(f"{arg.name}.get()")
16711678
elif isinstance(arg_type, (torch.IntType, torch.SymIntType)):
16721679
# int or SymInt
16731680
assert isinstance(arg, int)
@@ -1683,14 +1690,16 @@ def fill_args(arg, arg_type):
16831690

16841691
# List[Tensor]
16851692
if isinstance(arg_type.getElementType(), torch.TensorType):
1686-
new_tensor_args.extend([f"&{a.name}" for a in arg])
1693+
new_tensor_args.extend([f"{a.name}.get()" for a in arg])
16871694
# List[Optional[Tensor]]
16881695
elif isinstance(
16891696
arg_type.getElementType(), torch.OptionalType
16901697
) and isinstance(
16911698
arg_type.getElementType().getElementType(), torch.TensorType
16921699
):
1693-
new_tensor_args.extend([f"&{a.name}" for a in arg if a is not None])
1700+
new_tensor_args.extend(
1701+
[f"{a.name}.get()" for a in arg if a is not None]
1702+
)
16941703
# List [int] or List[SymInt]
16951704
elif isinstance(
16961705
arg_type.getElementType(), (torch.IntType, torch.SymIntType)
@@ -1723,8 +1732,12 @@ def fill_args(arg, arg_type):
17231732

17241733
def fill_output_arg(arg, return_type):
17251734
if isinstance(return_type, torch.TensorType):
1726-
self.writeline(f"at::Tensor {arg}; // output buffer")
1727-
new_tensor_args.append(f"&{output_arg}")
1735+
self.writeline(f"AtenTensorHandle {arg}_handle; // output buffer")
1736+
self.writeline(
1737+
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{arg}_handle));"
1738+
)
1739+
self.writeline(f"RAIIAtenTensorHandle {arg}({arg}_handle);")
1740+
new_tensor_args.append(f"{arg}.get()")
17281741
elif isinstance(return_type, torch.ListType) and isinstance(
17291742
return_type.getElementType(), torch.TensorType
17301743
):
@@ -1763,16 +1776,19 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed(
17631776
cpp_kernel_overload_name="",
17641777
op_overload=None,
17651778
raw_args=None,
1779+
outputs=None,
17661780
):
17671781
if config.is_fbcode():
17681782
assert op_overload is not None
17691783
assert raw_args is not None
1784+
assert outputs is not None
17701785

17711786
return self.generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode(
17721787
name,
17731788
cpp_kernel_key,
17741789
op_overload,
17751790
raw_args,
1791+
outputs,
17761792
)
17771793
else:
17781794
return self.generate_extern_kernel_alloc_and_find_schema_if_needed_oss(
@@ -1813,8 +1829,12 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode(
18131829
cpp_kernel_key,
18141830
op_overload,
18151831
raw_args, # contains both args and flatten kwargs
1832+
outputs,
18161833
):
1817-
output_args = [name]
1834+
if isinstance(outputs, (list, tuple)):
1835+
output_args = [output.get_name() for output in outputs]
1836+
else:
1837+
output_args = [outputs.get_name()]
18181838

18191839
(
18201840
tensor_call_args,
@@ -1825,7 +1845,9 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode(
18251845

18261846
tensor_args_var = f"tensor_args_var_{next(self.kernel_callsite_id)}"
18271847
tensor_call_args_str = ", ".join(tensor_call_args)
1828-
self.writeline(f"void* {tensor_args_var}[] = {{{tensor_call_args_str}}};")
1848+
self.writeline(
1849+
f"AtenTensorHandle {tensor_args_var}[] = {{{tensor_call_args_str}}};"
1850+
)
18291851

18301852
int_args_var = f"int_args_var_{next(self.kernel_callsite_id)}"
18311853
int_call_args_str = ", ".join(int_call_args)

torch/_inductor/ir.py

+1
Original file line numberDiff line numberDiff line change
@@ -3829,6 +3829,7 @@ def codegen(self, wrapper):
38293829
self.cpp_kernel_overlad_name,
38303830
self.op_overload,
38313831
exported_args,
3832+
self.outputs,
38323833
)
38333834
else:
38343835
super().codegen(wrapper)

torch/csrc/inductor/aoti_torch/c/shim.h

+6-1
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,11 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
171171
AtenTensorHandle* ret8 // returns new reference
172172
);
173173

174+
// This function will create a new uninitialized tensor object
175+
// and its pointer is returned through *ret.
176+
AOTI_TORCH_EXPORT AOTITorchError
177+
aoti_torch_new_uninitialized_tensor(AtenTensorHandle* ret);
178+
174179
AOTI_TORCH_EXPORT AOTITorchError
175180
aoti_torch_tensor_copy_(AtenTensorHandle src, AtenTensorHandle dst);
176181

@@ -214,7 +219,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_proxy_executor_call_function(
214219
int num_ints,
215220
int64_t* flatten_int_args,
216221
int num_tensors,
217-
void** flatten_tensor_args);
222+
AtenTensorHandle* flatten_tensor_args);
218223

219224
#ifdef __cplusplus
220225
} // extern "C"

torch/csrc/inductor/aoti_torch/proxy_executor.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
#include <ATen/core/ivalue.h>
44
#include <c10/macros/Export.h>
5-
#include <string>
5+
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
66

77
namespace torch {
88
namespace aot_inductor {
99

10-
class TORCH_API ProxyExecutor : public torch::CustomClassHolder {
10+
class ProxyExecutor {
1111
public:
1212
ProxyExecutor() {}
1313
virtual ~ProxyExecutor() {}
@@ -17,7 +17,7 @@ class TORCH_API ProxyExecutor : public torch::CustomClassHolder {
1717
int num_ints,
1818
int64_t* flatten_int_args,
1919
int num_tensors,
20-
void** flatten_tensor_args) = 0;
20+
AtenTensorHandle* flatten_tensor_args) = 0;
2121
};
2222

2323
} // namespace aot_inductor

torch/csrc/inductor/aoti_torch/shim_common.cpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,13 @@ AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
239239
});
240240
}
241241

242+
AOTITorchError aoti_torch_new_uninitialized_tensor(AtenTensorHandle* ret) {
243+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
244+
at::Tensor* out_tensor = new at::Tensor();
245+
*ret = tensor_pointer_to_tensor_handle(out_tensor);
246+
});
247+
}
248+
242249
// TODO: implement a more efficient version instead of calling into aten
243250
AOTITorchError aoti_torch_tensor_copy_(
244251
AtenTensorHandle src,
@@ -301,7 +308,7 @@ AOTITorchError aoti_torch_proxy_executor_call_function(
301308
int num_ints,
302309
int64_t* flatten_int_args,
303310
int num_tensors,
304-
void** flatten_tensor_args) {
311+
AtenTensorHandle* flatten_tensor_args) {
305312
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
306313
ProxyExecutor* executor = reinterpret_cast<ProxyExecutor*>(proxy_executor);
307314
executor->call_function(

0 commit comments

Comments
 (0)