forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfused_kernel.h
98 lines (85 loc) · 3.23 KB
/
fused_kernel.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#pragma once
#include <ATen/ATen.h>
#include <ATen/Utils.h>
#include <torch/csrc/jit/codegen/fuser/partition_desc.h>
#include <torch/csrc/jit/codegen/fuser/tensor_desc.h>
#include <cstdint>
#include <string>
#include <vector>
namespace torch::jit::fuser {
struct FusedKernel {
AT_DISALLOW_COPY_AND_ASSIGN(FusedKernel);
FusedKernel(
std::string name,
std::string code,
std::vector<TensorDesc> input_desc,
std::vector<TensorDesc> output_desc,
std::vector<PartitionDesc> chunk_desc,
std::vector<PartitionDesc> concat_desc,
bool has_random)
: name_(std::move(name)),
code_(std::move(code)),
input_desc_(std::move(input_desc)),
output_desc_(std::move(output_desc)),
chunk_desc_(std::move(chunk_desc)),
concat_desc_(std::move(concat_desc)),
has_random_(has_random) {}
virtual ~FusedKernel() = default;
// arguments is a list of pointers to the arguments for the compiled CUDA/CPU
// code.
// The format of arguments is suitable for directly passing to a call to
// cuLaunchKernel as the kernel arguments.
// Currently the first argument is a pointer to numel (for passing to
// CUDA code), and the remainder are pointers to the TensorInfo<T> structs
// that compiled code uses to load Tensor data.
// launch_with_tensors handles packing at::Tensors into this arguments array.
// CPU code uses the same convension so that launch_with_tensors can be
// shared.
virtual void launch_raw(const uint32_t numel, std::vector<void*>& arguments)
const = 0;
virtual at::Backend backend() const = 0;
// Getters
const std::string& name() const {
return name_;
}
const std::string& code() const {
return code_;
}
const std::vector<TensorDesc>& inputDesc() const {
return input_desc_;
}
const std::vector<TensorDesc>& outputDesc() const {
return output_desc_;
}
const std::vector<PartitionDesc>& chunkDesc() const {
return chunk_desc_;
}
const std::vector<PartitionDesc>& concatDesc() const {
return concat_desc_;
}
bool hasRandom() const {
return has_random_;
}
protected:
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
const std::string name_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
const std::string code_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
const std::vector<TensorDesc> input_desc_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
const std::vector<TensorDesc> output_desc_;
// same size as input_desc, describes whether an
// input should be broken into subtensors (chunks)
// to be consumed by the fusion group
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
const std::vector<PartitionDesc> chunk_desc_;
// same size as output_desc, describes whether
// an output is actually a concatenation of
// many subtensors that the fusion group produces
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
const std::vector<PartitionDesc> concat_desc_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
const bool has_random_;
};
} // namespace torch::jit::fuser