forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathio_manager.h
More file actions
345 lines (329 loc) · 11.4 KB
/
io_manager.h
File metadata and controls
345 lines (329 loc) · 11.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
/*
* Copyright (c) Qualcomm Innovation Center, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <cstddef>
#include <cstdint>
#include <future>
#include <limits>
#include <memory>
#include <thread>
#include <vector>
#include <executorch/backends/qualcomm/runtime/QnnExecuTorch.h>
#include <executorch/extension/module/module.h>
#include <executorch/runtime/executor/method_meta.h>
namespace example {
enum EvalMode {
kKVCached = 0,
kHybrid,
kUnsupported,
};
class IoMgrBase {
public:
IoMgrBase(
std::vector<std::shared_ptr<executorch::extension::Module>>& modules);
virtual ~IoMgrBase();
virtual void init_io() = 0;
virtual void reset_io(
const std::vector<executorch::runtime::Result<
executorch::runtime::MethodMeta>>& prefill_methods_meta,
const std::vector<
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
kv_methods_meta) = 0;
virtual void prepare_prefill_io(
const std::vector<
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
methods_meta) = 0;
virtual void prepare_kv_io(
const std::vector<
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
methods_meta) = 0;
virtual void fill_prefill_toks(
int64_t start_pos,
std::vector<uint64_t>& prompt_tokens) = 0;
virtual void fill_kv_tok_mask(int64_t pos, int64_t cur_token) = 0;
virtual void update_prefill_to_kv_io(
int64_t cur_token,
int64_t pos,
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) = 0;
virtual void update_kv_to_prefill_io(
int64_t pos,
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) = 0;
virtual void update_kv_io(
int64_t cur_token,
int64_t pos,
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) = 0;
virtual void update_prefill_io(
int64_t cur_token,
int64_t pos,
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) = 0;
void* get_mutable_ptr();
std::vector<executorch::aten::Tensor> get_input_tensors(
int shard_index,
const std::string& method_name);
std::vector<executorch::aten::Tensor> get_output_tensors(
int shard_index,
const std::string& method_name);
protected:
std::unique_ptr<void, void (*)(void*)> data_ptr_;
std::unordered_map<
std::string,
std::vector<std::vector<executorch::aten::TensorImpl*>>>
input_tensors_;
std::unordered_map<
std::string,
std::vector<std::vector<executorch::aten::TensorImpl*>>>
output_tensors_;
std::vector<std::shared_ptr<executorch::extension::Module>> modules_;
};
class ShiftPointerIoMgr : public IoMgrBase {
public:
ShiftPointerIoMgr(
std::vector<std::shared_ptr<executorch::extension::Module>>& modules,
int32_t context_len,
int32_t prefill_ar_len,
int32_t prefill_cache_len,
int32_t kv_ar_len,
int32_t kv_cache_len,
int32_t vocab_size,
int32_t num_layers,
int32_t head_dim,
int32_t num_heads,
EvalMode eval_mode,
const std::string& prefill_forward_name,
const std::string& kv_forward_name,
const bool use_int64_token);
void init_io() override;
void reset_io(
const std::vector<executorch::runtime::Result<
executorch::runtime::MethodMeta>>& prefill_methods_meta,
const std::vector<
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
kv_methods_meta) override;
void prepare_prefill_io(
const std::vector<
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
methods_meta) override;
void prepare_kv_io(
const std::vector<
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
methods_meta) override;
void fill_prefill_toks(
int64_t start_pos,
std::vector<uint64_t>& prompt_tokens) override;
void fill_kv_tok_mask(int64_t pos, int64_t cur_token) override;
void update_prefill_to_kv_io(
int64_t cur_token,
int64_t pos,
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors)
override;
void update_kv_to_prefill_io(
int64_t pos,
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors)
override;
void update_kv_io(
int64_t cur_token,
int64_t pos,
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors)
override;
void update_prefill_io(
int64_t cur_token,
int64_t pos,
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors)
override;
struct IO {
int64_t kv_input_toks;
int32_t kv_input_pos;
std::vector<std::vector<std::vector<uint8_t>>> k_cache;
std::vector<std::vector<uint8_t>> v_cache;
std::vector<std::vector<uint8_t>> k_cache_out;
std::vector<uint16_t> kv_attention_mask;
std::vector<uint16_t> kv_logits;
std::vector<int64_t> prefill_input_toks;
std::vector<int32_t> prefill_input_pos;
std::vector<uint16_t> prefill_attention_mask;
std::vector<uint16_t> prefill_logits;
};
private:
std::unique_ptr<executorch::aten::TensorImpl> kv_input_toks_;
std::unique_ptr<executorch::aten::TensorImpl> kv_input_pos_;
std::unique_ptr<executorch::aten::TensorImpl> kv_attention_mask_;
std::unique_ptr<executorch::aten::TensorImpl> prefill_input_toks_;
std::unique_ptr<executorch::aten::TensorImpl> prefill_input_pos_;
std::unique_ptr<executorch::aten::TensorImpl> prefill_attention_mask_;
std::unique_ptr<executorch::aten::TensorImpl> prefill_logits_;
std::unordered_map<
std::string,
std::vector<std::unique_ptr<executorch::aten::TensorImpl>>>
k_cache_in_;
std::unordered_map<
std::string,
std::vector<std::unique_ptr<executorch::aten::TensorImpl>>>
v_cache_in_;
std::unordered_map<
std::string,
std::vector<std::unique_ptr<executorch::aten::TensorImpl>>>
k_cache_out_;
std::unordered_map<
std::string,
std::vector<std::unique_ptr<executorch::aten::TensorImpl>>>
v_cache_out_;
std::unique_ptr<executorch::aten::TensorImpl> kv_logits_;
std::vector<int> shard_layers_;
int32_t context_len_{0};
int32_t kv_ar_len_{0};
int32_t kv_cache_len_{0};
int32_t prefill_ar_len_{0};
int32_t prefill_cache_len_{0};
int32_t vocab_size_;
int32_t num_layers_;
int32_t head_dim_;
int32_t num_heads_;
EvalMode eval_mode_;
std::string prefill_forward_name_;
std::string kv_forward_name_;
const bool use_int64_token_{false};
const bool is_bert_{false};
};
class SmartMaskIoMgr : public IoMgrBase {
public:
SmartMaskIoMgr(
std::vector<std::shared_ptr<executorch::extension::Module>>& modules,
int32_t context_len,
int32_t prefill_ar_len,
int32_t prefill_cache_len,
int32_t kv_ar_len,
int32_t kv_cache_len,
int32_t vocab_size,
int32_t num_layers,
int32_t head_dim,
int32_t num_heads,
EvalMode eval_mode,
const std::string& prefill_forward_name,
const std::string& kv_forward_name,
const bool use_int64_token);
void init_io() override;
void reset_io(
const std::vector<executorch::runtime::Result<
executorch::runtime::MethodMeta>>& prefill_methods_meta,
const std::vector<
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
kv_methods_meta) override;
void prepare_prefill_io(
const std::vector<
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
methods_meta) override;
void prepare_kv_io(
const std::vector<
executorch::runtime::Result<executorch::runtime::MethodMeta>>&
methods_meta) override;
void fill_prefill_toks(
int64_t start_pos,
std::vector<uint64_t>& prompt_tokens) override;
void fill_kv_tok_mask(int64_t pos, int64_t cur_token) override;
void update_prefill_to_kv_io(
int64_t cur_token,
int64_t pos,
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors)
override;
void update_kv_to_prefill_io(
int64_t pos,
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors)
override;
void update_kv_io(
int64_t cur_token,
int64_t pos,
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors)
override;
void update_prefill_io(
int64_t cur_token,
int64_t pos,
std::vector<std::vector<executorch::aten::Tensor>>& output_tensors)
override;
std::unordered_map<std::string, size_t> get_io_elements();
std::unordered_map<std::string, size_t> get_io_bytes();
struct IO {
void* shared_buffer_base;
int64_t* kv_input_toks;
int32_t* kv_input_pos;
// layer -> head -> head_dim * seq_len
std::vector<std::vector<uint8_t*>> k_cache;
std::vector<std::vector<uint8_t*>> v_cache;
// layer -> head -> head_dim
std::vector<std::vector<uint8_t*>> k_cache_out;
std::vector<std::vector<uint8_t*>> v_cache_out;
// kv_ar_len_ * context_len_
uint16_t* kv_attention_mask;
// kv_ar_len_ * vocab_size
uint16_t* kv_logits;
// prefill_ar_len_
int64_t* prefill_input_toks;
int32_t* prefill_input_pos;
// prefill_ar_len_ * context_len_
uint16_t* prefill_attention_mask;
// vocab_size * prefill_ar_len_
uint16_t* prefill_logits;
size_t num_layers_;
size_t num_heads_;
size_t head_dim_;
std::unordered_map<std::byte*, size_t> io_pos_map;
~IO() {
QnnExecuTorchFreeCustomMem(shared_buffer_base);
}
void init_io_ptrs(
void* shared_buffer_ptr,
std::unordered_map<std::string, size_t>& io_bytes_map);
void add_custom_mem_info(
void* ptr,
size_t nbytes,
executorch::aten::ScalarType scalar_type,
executorch::runtime::TensorInfo& tensor_info);
};
private:
std::unique_ptr<executorch::aten::TensorImpl> kv_input_toks_;
std::unique_ptr<executorch::aten::TensorImpl> kv_input_pos_;
std::unique_ptr<executorch::aten::TensorImpl> kv_attention_mask_;
std::unique_ptr<executorch::aten::TensorImpl> prefill_input_toks_;
std::unique_ptr<executorch::aten::TensorImpl> prefill_input_pos_;
std::unique_ptr<executorch::aten::TensorImpl> prefill_attention_mask_;
std::unique_ptr<executorch::aten::TensorImpl> prefill_logits_;
std::unordered_map<
std::string,
std::vector<std::unique_ptr<executorch::aten::TensorImpl>>>
k_cache_in_;
std::unordered_map<
std::string,
std::vector<std::unique_ptr<executorch::aten::TensorImpl>>>
v_cache_in_;
std::unordered_map<
std::string,
std::vector<std::unique_ptr<executorch::aten::TensorImpl>>>
k_cache_out_;
std::unordered_map<
std::string,
std::vector<std::unique_ptr<executorch::aten::TensorImpl>>>
v_cache_out_;
std::unique_ptr<executorch::aten::TensorImpl> kv_logits_;
std::vector<int> shard_layers_;
int32_t context_len_{0};
int32_t kv_ar_len_{0};
int32_t kv_cache_len_{0};
int32_t prefill_ar_len_{0};
int32_t prefill_cache_len_{0};
int32_t vocab_size_;
int32_t num_layers_;
int32_t head_dim_;
int32_t num_heads_;
EvalMode eval_mode_;
std::string prefill_forward_name_;
std::string kv_forward_name_;
const bool use_int64_token_{false};
// If the cache length is zero, it indicates a BERT model, which does not use
// position ids or KV cache inputs.
const bool is_bert_{false};
};
} // namespace example