Skip to content

Commit 16d6929

Browse files
jbschlosserpytorchmergebot
authored andcommitted
Use view name instead of view_copy name for functional inverses (pytorch#117056)
Ex: `unsqueeze_copy_inverse()` -> `unsqueeze_inverse()` Pull Request resolved: pytorch#117056 Approved by: https://github.com/bdhirsh
1 parent fdfdba7 commit 16d6929

File tree

7 files changed

+71
-67
lines changed

7 files changed

+71
-67
lines changed

aten/src/ATen/FunctionalInverses.cpp

Lines changed: 43 additions & 43 deletions
Large diffs are not rendered by default.

aten/src/ATen/native/TestOps.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,10 @@ Tensor _test_check_tensor(const Tensor& self) {
115115

116116
namespace at::functionalization {
117117

118-
// view_copy ops must have a functional inverse registered
119-
Tensor FunctionalInverses::_test_autograd_multiple_dispatch_view_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
118+
// view ops must have a functional inverse registered
119+
Tensor FunctionalInverses::_test_autograd_multiple_dispatch_view_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
120120
TORCH_INTERNAL_ASSERT(false,
121-
"Attempted to call _test_autograd_multiple_dispatch_view_copy_inverse() during the functionalization pass. ",
121+
"Attempted to call _test_autograd_multiple_dispatch_view_inverse() during the functionalization pass. ",
122122
"This function is for testing only and should never be called.");
123123
return Tensor();
124124
}

aten/src/ATen/templates/FunctionalInverses.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ struct FunctionalInverses {
2525
// NB: These are not generated! They're manually implemented in the template.
2626
// TODO: Change codegen to generate these. See the following link:
2727
// https://github.com/pytorch/pytorch/blob/main/torchgen/model.py#L2583-L2585
28-
static at::Tensor chunk_copy_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, int chunks, int dim);
29-
static at::Tensor narrow_copy_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, int dim, c10::SymInt start, c10::SymInt length);
28+
static at::Tensor chunk_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, int chunks, int dim);
29+
static at::Tensor narrow_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, int dim, c10::SymInt start, c10::SymInt length);
3030

3131
};
3232
}

tools/autograd/gen_inplace_or_view_type.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@
158158

159159
REVERSE_VIEW_DISPATCH = CodeTemplate(
160160
"""\
161-
at::functionalization::FunctionalInverses::${reverse_name}(${unpacked_args})"""
161+
${reverse_name}(${unpacked_args})"""
162162
)
163163

164164
MULTI_OUTPUT_VIEW_ITERATION = CodeTemplate(
@@ -429,8 +429,10 @@ def emit_view_lambda(
429429
*updated_args[1:],
430430
]
431431

432+
from torchgen.api.functionalization import reverse_name
433+
432434
reverse_replay_view_call = REVERSE_VIEW_DISPATCH.substitute(
433-
reverse_name=inverse_view_name(f),
435+
reverse_name=reverse_name(f, include_namespace=True),
434436
unpacked_args=reverse_unpacked_args,
435437
)
436438
reverse_replay_view_func = REVERSE_REPLAY_VIEW_LAMBDA_FUNC.substitute(

torchgen/api/functionalization.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
BaseTy,
1818
BaseType,
1919
FunctionSchema,
20+
NativeFunction,
2021
NativeFunctionsViewGroup,
2122
)
2223

@@ -99,16 +100,7 @@ def name(
99100
# since we always plumb the runtime "reapply_views" argument into the reverse function.
100101
assert is_reverse
101102
if is_reverse:
102-
# for the reverse: the name of the inverse function always involves "view_copy",
103-
# and we plumb the "reapply_views" flag into that function.
104-
# (We could avoid doing that, but that would require writing out twice as many view inverse functions).
105-
assert g.view_copy is not None
106-
api_name = g.view_copy.func.name.unambiguous_name()
107-
# in the reverse case, we codegen both the call-sites (which need the full namespace) and the declarations (which don't)
108-
if include_namespace:
109-
return f"at::functionalization::FunctionalInverses::{api_name}_inverse"
110-
else:
111-
return f"{api_name}_inverse"
103+
return reverse_name(g.view, include_namespace)
112104
# in the forward case, we just directly call into the at::_ops API (so we always need the namespace)
113105
assert include_namespace
114106
assert g.view_copy is not None
@@ -120,6 +112,18 @@ def name(
120112
return f"at::_ops::{api_name}::call"
121113

122114

115+
def reverse_name(f: NativeFunction, include_namespace: bool) -> str:
116+
# for the reverse: we plumb the "reapply_views" flag into that function and support
117+
# both copy and non-copy variants. (We could avoid doing that, but that would require
118+
# writing out twice as many view inverse functions).
119+
api_name = f.func.name.unambiguous_name()
120+
# in the reverse case, we codegen both the call-sites (which need the full namespace) and the declarations (which don't)
121+
if include_namespace:
122+
return f"at::functionalization::FunctionalInverses::{api_name}_inverse"
123+
else:
124+
return f"{api_name}_inverse"
125+
126+
123127
def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> List[Binding]:
124128
# capture arguments include all arguments except `self`.
125129
# Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture),

torchgen/api/types/signatures.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -289,16 +289,14 @@ class ViewInverseSignature:
289289
g: NativeFunctionsViewGroup
290290

291291
def name(self) -> str:
292-
assert self.g.view_copy is not None
293-
return functionalization.name(self.g, is_reverse=True, include_namespace=False)
292+
return functionalization.reverse_name(self.g.view, include_namespace=False)
294293

295294
def decl(self) -> str:
296-
assert self.g.view_copy is not None
297-
return_type = functionalization.returns_type(self.g.view_copy.func)
295+
return_type = functionalization.returns_type(self.g.view.func)
298296
decls = [
299297
a.decl()
300298
for a in functionalization.inner_arguments(
301-
self.g.view_copy.func, is_reverse=True
299+
self.g.view.func, is_reverse=True
302300
)
303301
]
304302
return f"static {return_type.cpp_type()} {self.name()}({', '.join(decls)});"

torchgen/gen_functionalization_type.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -686,8 +686,8 @@ def gen_functionalization_view_inverse_declaration(
686686
def emit_decl_helper(g: NativeFunctionsViewGroup) -> Optional[str]:
687687
if g.view.has_composite_implicit_autograd_kernel:
688688
return None
689-
view_copy_inverse_sig = ViewInverseSignature(g)
690-
return view_copy_inverse_sig.decl()
689+
view_inverse_sig = ViewInverseSignature(g)
690+
return view_inverse_sig.decl()
691691

692692
return emit_decl_helper(g)
693693

0 commit comments

Comments
 (0)