@@ -488,6 +488,7 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed(
488
488
cpp_kernel_overload_name = "" ,
489
489
op_overload = None ,
490
490
raw_args = None ,
491
+ outputs = None ,
491
492
):
492
493
self .writeline (f"{ name } = { kernel } ({ ', ' .join (codegen_args )} )" )
493
494
@@ -836,7 +837,7 @@ def make_buffer_free(self, buffer):
836
837
return f"del { buffer .get_name ()} "
837
838
838
839
def codegen_exact_buffer_reuse (self , old_name : str , new_name : str , del_line : str ):
839
- return f"{ self .declare } { new_name } = { old_name } { del_line } { self .comment } reuse"
840
+ return f"{ self .declare } { new_name } = { old_name } { del_line } { self . ending } { self .comment } reuse"
840
841
841
842
def make_buffer_reuse (self , old , new ):
842
843
assert old .get_dtype () == new .get_dtype ()
@@ -1664,10 +1665,16 @@ def fill_args(arg, arg_type):
1664
1665
torch .Type ,
1665
1666
torch .DeviceObjType ,
1666
1667
)
1668
+ inductor_tensor_buffers = (
1669
+ ir .InputBuffer ,
1670
+ ir .ComputedBuffer ,
1671
+ ir .ConcatKernel ,
1672
+ ir .ExternKernelOut ,
1673
+ )
1667
1674
1668
1675
if isinstance (arg_type , torch .TensorType ):
1669
- assert isinstance (arg , ( ir . InputBuffer , ir . ComputedBuffer ) )
1670
- new_tensor_args .append (f"& { arg .name } " )
1676
+ assert isinstance (arg , inductor_tensor_buffers )
1677
+ new_tensor_args .append (f"{ arg .name } .get() " )
1671
1678
elif isinstance (arg_type , (torch .IntType , torch .SymIntType )):
1672
1679
# int or SymInt
1673
1680
assert isinstance (arg , int )
@@ -1683,14 +1690,16 @@ def fill_args(arg, arg_type):
1683
1690
1684
1691
# List[Tensor]
1685
1692
if isinstance (arg_type .getElementType (), torch .TensorType ):
1686
- new_tensor_args .extend ([f"& { a .name } " for a in arg ])
1693
+ new_tensor_args .extend ([f"{ a .name } .get() " for a in arg ])
1687
1694
# List[Optional[Tensor]]
1688
1695
elif isinstance (
1689
1696
arg_type .getElementType (), torch .OptionalType
1690
1697
) and isinstance (
1691
1698
arg_type .getElementType ().getElementType (), torch .TensorType
1692
1699
):
1693
- new_tensor_args .extend ([f"&{ a .name } " for a in arg if a is not None ])
1700
+ new_tensor_args .extend (
1701
+ [f"{ a .name } .get()" for a in arg if a is not None ]
1702
+ )
1694
1703
# List [int] or List[SymInt]
1695
1704
elif isinstance (
1696
1705
arg_type .getElementType (), (torch .IntType , torch .SymIntType )
@@ -1723,8 +1732,12 @@ def fill_args(arg, arg_type):
1723
1732
1724
1733
def fill_output_arg (arg , return_type ):
1725
1734
if isinstance (return_type , torch .TensorType ):
1726
- self .writeline (f"at::Tensor { arg } ; // output buffer" )
1727
- new_tensor_args .append (f"&{ output_arg } " )
1735
+ self .writeline (f"AtenTensorHandle { arg } _handle; // output buffer" )
1736
+ self .writeline (
1737
+ f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{ arg } _handle));"
1738
+ )
1739
+ self .writeline (f"RAIIAtenTensorHandle { arg } ({ arg } _handle);" )
1740
+ new_tensor_args .append (f"{ arg } .get()" )
1728
1741
elif isinstance (return_type , torch .ListType ) and isinstance (
1729
1742
return_type .getElementType (), torch .TensorType
1730
1743
):
@@ -1763,16 +1776,19 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed(
1763
1776
cpp_kernel_overload_name = "" ,
1764
1777
op_overload = None ,
1765
1778
raw_args = None ,
1779
+ outputs = None ,
1766
1780
):
1767
1781
if config .is_fbcode ():
1768
1782
assert op_overload is not None
1769
1783
assert raw_args is not None
1784
+ assert outputs is not None
1770
1785
1771
1786
return self .generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode (
1772
1787
name ,
1773
1788
cpp_kernel_key ,
1774
1789
op_overload ,
1775
1790
raw_args ,
1791
+ outputs ,
1776
1792
)
1777
1793
else :
1778
1794
return self .generate_extern_kernel_alloc_and_find_schema_if_needed_oss (
@@ -1813,8 +1829,12 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode(
1813
1829
cpp_kernel_key ,
1814
1830
op_overload ,
1815
1831
raw_args , # contains both args and flatten kwargs
1832
+ outputs ,
1816
1833
):
1817
- output_args = [name ]
1834
+ if isinstance (outputs , (list , tuple )):
1835
+ output_args = [output .get_name () for output in outputs ]
1836
+ else :
1837
+ output_args = [outputs .get_name ()]
1818
1838
1819
1839
(
1820
1840
tensor_call_args ,
@@ -1825,7 +1845,9 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode(
1825
1845
1826
1846
tensor_args_var = f"tensor_args_var_{ next (self .kernel_callsite_id )} "
1827
1847
tensor_call_args_str = ", " .join (tensor_call_args )
1828
- self .writeline (f"void* { tensor_args_var } [] = {{{ tensor_call_args_str } }};" )
1848
+ self .writeline (
1849
+ f"AtenTensorHandle { tensor_args_var } [] = {{{ tensor_call_args_str } }};"
1850
+ )
1829
1851
1830
1852
int_args_var = f"int_args_var_{ next (self .kernel_callsite_id )} "
1831
1853
int_call_args_str = ", " .join (int_call_args )
0 commit comments