forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathkernel_spec.h
144 lines (129 loc) · 4.3 KB
/
kernel_spec.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#pragma once
#include <ATen/ATen.h>
#include <ATen/core/stack.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/codegen/fuser/arg_spec.h>
#include <torch/csrc/jit/codegen/fuser/fused_kernel.h>
#include <torch/csrc/jit/codegen/fuser/interface.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <optional>
#include <cstdint>
#include <memory>
#include <mutex>
#include <unordered_map>
#include <vector>
namespace torch::jit::fuser {
// Helper struct containing partition information: the number of tensors
// created and the dimension the partitioning is performed on.
// Note: created during upfront compilation, once the tensors are known
// at runtime the partition info is logically combined with the tensor
// descriptions to create PartitionDesc objects.
struct TORCH_API PartitionInfo {
PartitionInfo(const int64_t _nSubTensors, const int64_t _dim)
: nSubTensors_{_nSubTensors}, dim_{_dim} {}
int64_t nSubTensors() const {
return nSubTensors_;
}
int64_t dim() const {
return dim_;
}
private:
int64_t nSubTensors_;
int64_t dim_;
};
// "Kernel Specification." - Contains device-independent fusion information.
// Each kernel specification contains a map of instantiated generated functions
// that implement some or most of its functionality. Multiple generated
// functions are needed by each abstract specification because of different
// devices (cpu vs gpu, different gpus) and different inputs (int vs float,
// contiguous vs discontiguous).
// Note: uses a mutex to control access to its kernel store
// Note: unordered containers do not invalidate references/pointers on
// rehashing, which is critical for thread-safety.
// TODO: allow abstract kernels to use multiple generated kernels
// TODO: allow abstract kernels to reuse generated kernels from common pool
struct TORCH_API KernelSpec {
// Note: assumes the spec is a single block
// Note: This is the appropriate place to generalize if you want to add other
// passes to upfront compilation that walk the graph.
KernelSpec(const int64_t _key, const std::shared_ptr<Graph>& _graph)
: key_{_key},
graph_{_graph},
code_{_graph, "<fused code>"},
nInputs_{_graph->inputs().size()}
{
// No need to iterate over reference since n is pointer
for (const auto n : graph_->nodes()) {
static_assert(std::is_pointer_v<decltype(n)>, "n must be a pointer");
if (n->kind() == aten::rand_like) {
has_random_ = true;
break;
}
}
nTensorInputs_ = std::count_if(
graph_->inputs().begin(), graph_->inputs().end(), [](const Value* v) {
return v->type()->isSubtypeOf(*TensorType::get());
});
}
// Getters
int64_t key() const {
return key_;
}
std::shared_ptr<Graph> graph() const {
return graph_;
}
const Code& code() const {
return code_;
}
int64_t nInputs() const {
return nInputs_;
}
int64_t nTensorInputs() const {
return nTensorInputs_;
}
std::vector<std::vector<int64_t>>& inputBroadcastGroups() {
return inputBroadcastGroups_;
}
const std::vector<std::vector<int64_t>>& inputBroadcastGroups() const {
return inputBroadcastGroups_;
}
std::vector<PartitionInfo>& inputChunks() {
return inputChunks_;
}
const std::vector<PartitionInfo>& inputChunks() const {
return inputChunks_;
}
bool hasRandom() const {
return has_random_;
}
// Cache functions
std::optional<std::shared_ptr<FusedKernel>> findKernel(
const ArgSpec& arg_spec) const {
std::lock_guard<std::mutex> guard{mutex_};
const auto it = kernels_.find(arg_spec);
if (it == kernels_.end())
return std::nullopt;
return it->second;
}
void cacheKernel(
const ArgSpec& arg_spec,
const std::shared_ptr<FusedKernel>& kernel) const {
std::lock_guard<std::mutex> guard{mutex_};
kernels_.emplace(arg_spec, kernel);
}
private:
int64_t key_;
std::shared_ptr<Graph> graph_;
Code code_;
uint64_t nInputs_;
uint64_t nTensorInputs_{};
std::vector<std::vector<int64_t>> inputBroadcastGroups_;
std::vector<PartitionInfo> inputChunks_;
bool has_random_{false};
mutable std::mutex mutex_;
mutable std::
unordered_map<ArgSpec, std::shared_ptr<FusedKernel>, c10::hash<ArgSpec>>
kernels_;
};
} // namespace torch::jit::fuser