-
Notifications
You must be signed in to change notification settings - Fork 97
Expand file tree
/
Copy pathdecoder_wrappers.py
More file actions
276 lines (225 loc) · 12.2 KB
/
decoder_wrappers.py
File metadata and controls
276 lines (225 loc) · 12.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Wrappers used at runtime to prepare inputs for decoder models.
import logging
import torch
import torch.nn.functional as F
from transformers import PretrainedConfig
from ...config import NxDNeuronConfig
from ...model_wrapper import NxDModelWrapper
CONTEXT_ENCODING_MODEL_TAG = "context_encoding_model"
TOKEN_GENERATION_MODEL_TAG = "token_generation_model"
SPECULATION_MODEL_TAG = "speculation_model"
CHUNKED_PREFILL_MODEL_TAG = "chunked_prefill_model"
class NxDDecoderWrapperForCausalLM(NxDModelWrapper):
"""A decoder wrapper for decoder models used in causal language modeling.
It prepares the inputs tensors to match the compiled model static input shapes.
"""
def __init__(
self, config: PretrainedConfig, neuron_config: NxDNeuronConfig, model: torch.jit.ScriptModule, tag: str
) -> None:
super().__init__()
self.config = config
self.neuron_config = neuron_config
self.model = model
self.tag = tag
if not self.neuron_config.torch_dtype:
self.neuron_config.torch_dtype = torch.float32
if config.pad_token_id is None:
config.pad_token_id = 0
def _forward_with_pad(self, input_ids, position_ids, seq_ids, sampling_params):
# pad the inputs up to the compiled batch size in the end
def pad_to_batch_size(tensor):
if tensor is None or tensor.shape[0] == self.neuron_config.batch_size:
return tensor
padded_shape = list(tensor.shape)
padded_shape[0] = self.neuron_config.batch_size
# pad with first batch line values instead of zeros, to reduce chances of NaN
padded_tensor = tensor[0].unsqueeze(0).repeat(padded_shape[0], 1).to(tensor.dtype)
padded_tensor[: tensor.shape[0]] = tensor
return padded_tensor
padded_args = []
padded_args.append(pad_to_batch_size(input_ids))
padded_args.append(pad_to_batch_size(position_ids))
# need to handle seq_ids separately, when compiled batch is 4, if we pad seq_ids from [0,2,1] to [0,2,1,
# 0]. then the kv cache of padded input could be written into the first cache line, so we need to pad as [0,
# 2, 1, 3] instead
seq_ids_list = seq_ids.tolist()
padded_seq_ids = torch.tensor(
seq_ids_list + [x for x in range(self.neuron_config.max_batch_size) if x not in seq_ids_list],
dtype=seq_ids.dtype,
)
padded_args.append(padded_seq_ids)
padded_sampling_params = pad_to_batch_size(sampling_params)
padded_args.append(padded_sampling_params)
outputs = self._forward(*padded_args)
# note that we don't do index select here as it should already be handled, simply sliced out padding here
logits = outputs
return logits[: seq_ids.shape[0]]
def _forward(self, input_ids, position_ids, seq_ids, sampling_params):
needs_reordering = False
if self.tag == TOKEN_GENERATION_MODEL_TAG and self.neuron_config.continuous_batching:
# if continuous batching is enabled, we need to ensure that the inputs are at the expected positions
orig_seq_ids = seq_ids.clone()
needs_reordering = not torch.equal(seq_ids, torch.arange(seq_ids.shape[0]))
if needs_reordering:
sorting_index = torch.argsort(seq_ids)
seq_ids = torch.index_select(seq_ids, 0, sorting_index)
input_ids = torch.index_select(input_ids, 0, sorting_index)
position_ids = torch.index_select(position_ids, 0, sorting_index)
sampling_params = torch.index_select(sampling_params, 0, sorting_index)
outputs = self.model(input_ids, position_ids, seq_ids, sampling_params)
if needs_reordering:
# if we reordered the inputs, we need to reorder the outputs as well
outputs = torch.index_select(outputs, 0, orig_seq_ids)
return outputs
def convert_int64_to_int32(self, *args):
"""
Convert int64 args to int32 to match compiled input types.
Neuron compiler handles int32 better than int64. Context: P165494809
"""
return [t.to(torch.int32) if t.dtype == torch.int64 else t for t in args]
def _pad_to_max_context_length(self, x: torch.Tensor, padding_value) -> torch.Tensor:
"""Pad input along dim=1 to max_context_length using a constant value."""
pad_length = self.neuron_config.max_context_length - x.shape[1]
return F.pad(x, (0, pad_length), "constant", padding_value)
def _pad_to_chunk_size(
self, input_ids: torch.Tensor, position_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Pad inputs along dim=1 to chunk_size by repeating the last token.
Repeating the last real token ensures:
1. torch.max(position_ids) stays at the last REAL token's index, so
logit gathering picks the correct hidden state.
2. Padded tokens scatter their KV to cache[last_real_pos] with the
identical (input_id, position_id) as the real last token, so the
scatter is a no-op overwrite -- no corruption.
"""
chunk_size = self.neuron_config.prefill_chunk_size
pad_length = chunk_size - input_ids.shape[1]
if pad_length <= 0:
return input_ids, position_ids
last_input_id = input_ids[:, -1:]
last_pos = position_ids[:, -1:]
input_ids = torch.cat([input_ids, last_input_id.expand(-1, pad_length)], dim=-1)
position_ids = torch.cat([position_ids, last_pos.expand(-1, pad_length)], dim=-1)
return input_ids, position_ids
def forward(self, input_ids, position_ids, seq_ids, sampling_params):
input_ids, position_ids, seq_ids = self.convert_int64_to_int32(input_ids, position_ids, seq_ids)
if self.tag == CONTEXT_ENCODING_MODEL_TAG:
input_ids = self._pad_to_max_context_length(input_ids, self.config.pad_token_id)
position_ids = self._pad_to_max_context_length(position_ids, 1)
elif self.tag == CHUNKED_PREFILL_MODEL_TAG:
input_ids, position_ids = self._pad_to_chunk_size(input_ids, position_ids)
input_batch_size = seq_ids.shape[0]
if input_batch_size > self.neuron_config.max_batch_size:
raise ValueError(
f"Input batch size {input_batch_size} exceeds the maximum batch size {self.neuron_config.max_batch_size}."
)
elif input_batch_size == self.neuron_config.batch_size:
return self._forward(input_ids, position_ids, seq_ids, sampling_params)
cur_batch = 0
output_logits = []
logging.debug(
f"get input_batch_size as {input_batch_size} but compiled batch_size as {self.neuron_config.batch_size}"
)
args = (input_ids, position_ids, seq_ids, sampling_params)
while cur_batch < input_batch_size:
if cur_batch + self.neuron_config.batch_size <= input_batch_size:
# we only process part of the input to run
logging.debug(f"running foward on batch {cur_batch}:{cur_batch + self.neuron_config.batch_size}")
outputs = self._forward(*[arg[cur_batch : cur_batch + self.neuron_config.batch_size] for arg in args])
else:
# we need to pad the input to run
logging.debug(
f"running forward on batch {cur_batch}:{input_batch_size}, padded up to {self.neuron_config.batch_size}"
)
outputs = self._forward_with_pad(*[arg[cur_batch:input_batch_size] for arg in args])
output_logits.append(outputs)
cur_batch += self.neuron_config.batch_size
return torch.cat(output_logits, dim=0)
class NxDDecoderWrapperForEmbedding(NxDModelWrapper):
"""A decoder wrapper for decoder models used in embedding extraction."""
def __init__(
self, config: PretrainedConfig, neuron_config: NxDNeuronConfig, model: torch.jit.ScriptModule
) -> None:
super().__init__()
self.config = config
self.neuron_config = neuron_config
self.model = model
if not self.neuron_config.torch_dtype:
self.neuron_config.torch_dtype = torch.float32
if config.pad_token_id is None:
config.pad_token_id = 0
def _forward_with_pad(self, input_ids, position_ids):
# pad the inputs up to the compiled batch size in the end
def pad_helper(tensor):
if tensor is None or tensor.shape[0] == self.neuron_config.batch_size:
return tensor
padded_shape = list(tensor.shape)
padded_shape[0] = self.neuron_config.batch_size
# pad with first batch line values instead of zeros, to reduce chances of NaN
padded_tensor = tensor[0].unsqueeze(0).repeat(padded_shape[0], 1).to(tensor.dtype)
padded_tensor[: tensor.shape[0]] = tensor
return padded_tensor
padded_args = []
for arg in (input_ids, position_ids):
padded_args.append(pad_helper(arg))
return self._forward(*padded_args)
def _forward(self, input_ids, position_ids):
return self.model(input_ids, position_ids)
def convert_int64_to_int32(self, *args):
"""
Convert int64 args to int32 to match compiled input types.
Neuron compiler handles int32 better than int64. Context: P165494809
"""
return [t.to(torch.int32) if t.dtype == torch.int64 else t for t in args]
def pad_to_max_compiled_seq(self, *args):
pad_lengths = [self.neuron_config.max_context_length - arg.shape[1] for arg in args]
tensor_pad_vals = [self.config.pad_token_id, 0, 1]
padded_args = [
F.pad(arg, (0, pad_len), "constant", pad_val)
for arg, pad_val, pad_len in zip(args, tensor_pad_vals, pad_lengths)
]
return padded_args
def forward(self, input_ids, position_ids):
input_ids, position_ids = self.convert_int64_to_int32(input_ids, position_ids)
input_ids, position_ids = self.pad_to_max_compiled_seq(input_ids, position_ids)
input_batch_size = input_ids.shape[0]
if input_batch_size > self.neuron_config.max_batch_size:
raise ValueError(
f"Input batch size {input_batch_size} exceeds the maximum batch size {self.neuron_config.max_batch_size}."
)
elif input_batch_size == self.neuron_config.batch_size:
return self._forward(input_ids, position_ids)
cur_batch = 0
output_logits = []
logging.debug(
f"get input_batch_size as {input_batch_size} but compiled batch_size as {self.neuron_config.batch_size}"
)
args = (input_ids, position_ids)
while cur_batch < input_batch_size:
if cur_batch + self.neuron_config.batch_size <= input_batch_size:
# we only process part of the input to run
logging.debug(f"running foward on batch {cur_batch}:{cur_batch + self.neuron_config.batch_size}")
outputs = self._forward(*[arg[cur_batch : cur_batch + self.neuron_config.batch_size] for arg in args])
else:
# we need to pad the input to run
logging.debug(
f"running forward on batch {cur_batch}:{input_batch_size}, padded up to {self.neuron_config.batch_size}"
)
outputs = self._forward_with_pad(*[arg[cur_batch:input_batch_size] for arg in args])
output_logits.append(outputs)
cur_batch += self.neuron_config.batch_size
return torch.cat(output_logits, dim=0)[:input_batch_size]