Skip to content

Commit 02c4d87

Browse files
antoniojkimpytorchmergebot
authored andcommitted
Codegen Non-Native IR Nodes (pytorch#76535)
Add codegen infrastructure to generate IR nodes for non-native ops. The proposed change is to add a `non_native` key to the `{backend}_native_functions.yaml` file that contains schema definitions similar to what is found in `native_functions.yaml`. e.g. ``` non_native: ... - func: expand(Tensor input, int[] size, bool is_scalar_expand) -> Tensor ... ``` these definitions are parsed into a `LazyIrSchema` that can be used for generating IR nodes using `GenLazyIR`. Fixes pytorch#74628 CC: @wconstab @desertfire @henrytwo Pull Request resolved: pytorch#76535 Approved by: https://github.com/wconstab
1 parent 13dcba8 commit 02c4d87

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+497
-1348
lines changed

BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1865,6 +1865,7 @@ test_suite(
18651865
"aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp",
18661866
"aten/src/ATen/templates/DispatchKeyNativeFunctions.h",
18671867
"aten/src/ATen/templates/LazyIr.h",
1868+
"aten/src/ATen/templates/LazyNonNativeIr.h",
18681869
"aten/src/ATen/templates/RegisterDispatchKey.cpp",
18691870
"aten/src/ATen/native/native_functions.yaml",
18701871
"aten/src/ATen/native/tags.yaml",

aten/src/ATen/native/ts_native_functions.yaml

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,41 @@ supported:
178178
- _unsafe_view
179179
autograd:
180180
- max_pool3d
181+
182+
# Ops that don't have a native schema definitions and are dispatched within Lazy Tensor Core
183+
non_native:
184+
- func: scalar(Scalar value, ScalarType type) -> Tensor
185+
opkind: at::prim::Constant
186+
properties:
187+
- ShapeCompute
188+
- TreatScalarsAsConstants
189+
- func: expand(Tensor input, int[] size, bool is_scalar_expand) -> Tensor
190+
- func: view(Tensor input, int[] output_size) -> Tensor
191+
properties:
192+
- ShapeCompute
193+
- func: cast(Tensor input, ScalarType dtype, ScalarType? stype) -> Tensor
194+
opkind: ltc_cast
195+
properties:
196+
- ShapeCompute
197+
198+
# View ops only required until proper functionalization pass is introduced into LTC
199+
- func: as_strided_view_update(Tensor target, Tensor input, int[] size, int[] stride, int storage_offset) -> Tensor
200+
opkind: ltc_as_strided_view_update
201+
- func: as_strided(Tensor input, int[] size, int[] stride, int storage_offset) -> Tensor
202+
- func: diagonal_view_update(Tensor target, Tensor input, int offset, int dim1, int dim2) -> Tensor
203+
opkind: ltc_diagonal_view_update
204+
properties:
205+
- ShapeCompute
206+
- func: diagonal(Tensor input, int offset, int dim1, int dim2) -> Tensor
207+
- func: narrow_view_update(Tensor input, Tensor source, int[] base_indices) -> Tensor
208+
opkind: ltc_narrow_view_update
209+
- func: narrow(Tensor input, int[] base_indices, int[] sizes) -> Tensor
210+
- func: permute(Tensor input, int[] dims) -> Tensor
211+
- func: resize(Tensor input, int[] size) -> Tensor
212+
- func: select_view_update(Tensor target, Tensor source, int dim, int start, int end, int stride) -> Tensor
213+
opkind: ltc_select_view_update
214+
properties:
215+
- ShapeCompute
216+
- func: select(Tensor input, int dim, int start, int end, int stride) -> Tensor
217+
- func: squeeze(Tensor input, int dim) -> Tensor
218+
- func: unsqueeze(Tensor input, int dim) -> Tensor
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#pragma once
2+
3+
${lazy_non_native_ir_inc}
4+
5+
// This file contains autogenerated LazyTensor Non Native IR nodes
6+
7+
${namespace_prologue}
8+
9+
${non_native_ir_nodes}
10+
11+
${namespace_epilogue}

build.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def define_targets(rules):
2828
":DispatchKeyNativeFunctions.cpp",
2929
":DispatchKeyNativeFunctions.h",
3030
":LazyIr.h",
31+
":LazyNonNativeIr.h",
3132
":RegisterDispatchKey.cpp",
3233
":native_functions.yaml",
3334
":shape_inference.h",
@@ -88,6 +89,7 @@ GENERATED_TESTING_PY = [
8889

8990
GENERATED_LAZY_H = [
9091
"torch/csrc/lazy/generated/LazyIr.h",
92+
"torch/csrc/lazy/generated/LazyNonNativeIr.h",
9193
"torch/csrc/lazy/generated/LazyNativeFunctions.h",
9294
]
9395

caffe2/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
380380
list(APPEND GENERATED_H_TORCH
381381
"${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType.h"
382382
"${TORCH_SRC_DIR}/csrc/lazy/generated/LazyIr.h"
383+
"${TORCH_SRC_DIR}/csrc/lazy/generated/LazyNonNativeIr.h"
383384
"${TORCH_SRC_DIR}/csrc/lazy/generated/LazyNativeFunctions.h"
384385
)
385386
endif()
@@ -444,6 +445,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
444445
"${TORCH_ROOT}/aten/src/ATen/templates/DispatchKeyNativeFunctions.h"
445446
"${TORCH_ROOT}/aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp"
446447
"${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.h"
448+
"${TORCH_ROOT}/aten/src/ATen/templates/LazyNonNativeIr.h"
447449
"${TORCH_ROOT}/aten/src/ATen/templates/RegisterDispatchKey.cpp"
448450
"${TOOLS_PATH}/autograd/templates/VariableType.h"
449451
"${TOOLS_PATH}/autograd/templates/VariableType.cpp"

tools/build_variables.bzl

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -417,35 +417,19 @@ lazy_tensor_core_sources = [
417417
# We can't build all of the ts backend under certain build configurations, e.g. mobile,
418418
# since it depends on things like autograd, meta functions, which may be disabled
419419
lazy_tensor_ts_sources = [
420-
"torch/csrc/lazy/ts_backend/config.cpp",
421420
"torch/csrc/lazy/ts_backend/dynamic_ir.cpp",
421+
"torch/csrc/lazy/ts_backend/config.cpp",
422422
"torch/csrc/lazy/ts_backend/ops/batch_norm_ops.cpp",
423-
"torch/csrc/lazy/ts_backend/ops/random_ops.cpp",
424-
"torch/csrc/lazy/ts_backend/ops/cast.cpp",
425423
"torch/csrc/lazy/ts_backend/ops/device_data.cpp",
426-
"torch/csrc/lazy/ts_backend/ops/expand.cpp",
424+
"torch/csrc/lazy/ts_backend/ops/random_ops.cpp",
427425
"torch/csrc/lazy/ts_backend/ops/generic.cpp",
428-
"torch/csrc/lazy/ts_backend/ops/scalar.cpp",
429-
"torch/csrc/lazy/ts_backend/view_ops/as_strided.cpp",
430-
"torch/csrc/lazy/ts_backend/view_ops/as_strided_view_update.cpp",
431-
"torch/csrc/lazy/ts_backend/view_ops/diagonal.cpp",
432-
"torch/csrc/lazy/ts_backend/view_ops/diagonal_view_update.cpp",
433-
"torch/csrc/lazy/ts_backend/view_ops/narrow.cpp",
434-
"torch/csrc/lazy/ts_backend/view_ops/narrow_view_update.cpp",
435-
"torch/csrc/lazy/ts_backend/view_ops/permute.cpp",
436-
"torch/csrc/lazy/ts_backend/view_ops/resize.cpp",
437-
"torch/csrc/lazy/ts_backend/view_ops/select.cpp",
438-
"torch/csrc/lazy/ts_backend/view_ops/squeeze.cpp",
439-
"torch/csrc/lazy/ts_backend/view_ops/unsqueeze.cpp",
440-
"torch/csrc/lazy/ts_backend/view_ops/select_view_update.cpp",
441-
"torch/csrc/lazy/ts_backend/view_ops/view.cpp",
442-
"torch/csrc/lazy/ts_backend/ts_node.cpp",
443426
"torch/csrc/lazy/ts_backend/tensor_aten_ops.cpp",
444427
"torch/csrc/lazy/ts_backend/ts_autograd_functions.cpp",
445428
"torch/csrc/lazy/ts_backend/ts_backend_impl.cpp",
446429
"torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp",
447430
"torch/csrc/lazy/ts_backend/ts_lowering_context.cpp",
448431
"torch/csrc/lazy/ts_backend/ts_native_functions.cpp",
432+
"torch/csrc/lazy/ts_backend/ts_node.cpp",
449433
"torch/csrc/lazy/ts_backend/ts_node_lowering.cpp",
450434
]
451435

tools/test/test_gen_backend_stubs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def test_unrecognized_key(self) -> None:
237237
output_error = self.get_errors_from_gen_backend_stubs(yaml_str)
238238
self.assertExpectedInline(
239239
output_error,
240-
""" contains unexpected keys: invalid_key. Only the following keys are supported: backend, class_name, cpp_namespace, extra_headers, supported, autograd, full_codegen""", # noqa: B950
240+
""" contains unexpected keys: invalid_key. Only the following keys are supported: backend, class_name, cpp_namespace, extra_headers, supported, autograd, full_codegen, non_native""", # noqa: B950
241241
)
242242

243243
# if use_out_as_primary is provided, it must be a bool

torch/csrc/lazy/core/ir.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ hash_t Output::hash() const {
1919
return HashCombine(node->hash(), Hash(index));
2020
}
2121

22+
hash_t Output::shapeHash() const {
23+
return HashCombine(node->shapeHash(), Hash(index));
24+
}
25+
2226
std::string Output::ToString() const {
2327
std::stringstream ss;
2428
ss << node->ToString() << ", index=" << index;
@@ -144,7 +148,7 @@ std::string Node::ToString() const {
144148

145149
void Node::AddOperand(NodePtr node, size_t index) {
146150
CHECK_LT(index, node->num_outputs());
147-
operands_.push_back(std::move(node));
151+
operands_.push_back(node);
148152
operands_as_outputs_.emplace_back(operands_.back().get(), index);
149153
}
150154

torch/csrc/lazy/core/ir.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ struct TORCH_API Output {
214214
: node(node), index(index) {}
215215

216216
hash_t hash() const;
217+
hash_t shapeHash() const;
217218

218219
bool operator==(const Output& rhs) const {
219220
return node == rhs.node && index == rhs.index;

torch/csrc/lazy/core/ops/utils.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,33 +6,33 @@
66
namespace torch {
77
namespace lazy {
88

9-
bool StrideIsSupported(c10::ArrayRef<int64_t> stride);
9+
TORCH_API bool StrideIsSupported(c10::ArrayRef<int64_t> stride);
1010

11-
std::vector<int64_t> GetArrayStridePermutation(c10::ArrayRef<int64_t> stride);
11+
TORCH_API std::vector<int64_t> GetArrayStridePermutation(c10::ArrayRef<int64_t> stride);
1212

13-
Shape MakeDiagonalShape(
13+
TORCH_API Shape MakeDiagonalShape(
1414
const Shape& shape,
1515
int64_t offset,
1616
int64_t dim1,
1717
int64_t dim2);
1818

19-
Shape MakePermuteShape(
19+
TORCH_API Shape MakePermuteShape(
2020
const Shape& source_shape,
2121
c10::ArrayRef<int64_t> permutation);
2222

23-
Shape MakeSelectShape(
23+
TORCH_API Shape MakeSelectShape(
2424
const Shape& shape,
2525
int64_t dim,
2626
int64_t start,
2727
int64_t end,
2828
int64_t stride);
2929

30-
int64_t GetStride(int64_t start, int64_t end, int64_t stride);
30+
TORCH_API int64_t GetStride(int64_t start, int64_t end, int64_t stride);
3131

32-
std::vector<int64_t> BuildSqueezedDimensions(c10::ArrayRef<int64_t> dimensions,
32+
TORCH_API std::vector<int64_t> BuildSqueezedDimensions(c10::ArrayRef<int64_t> dimensions,
3333
int64_t squeeze_dim);
3434

35-
std::vector<int64_t> BuildUnsqueezedDimensions(
35+
TORCH_API std::vector<int64_t> BuildUnsqueezedDimensions(
3636
c10::ArrayRef<int64_t> dimensions,
3737
int64_t squeeze_dim);
3838

0 commit comments

Comments
 (0)