-
Notifications
You must be signed in to change notification settings - Fork 250
/
Copy pathdistilbert.rs
664 lines (556 loc) · 23.3 KB
/
distilbert.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
use crate::layers::{get_cublas_lt_wrapper, HiddenAct, LayerNorm, Linear};
use crate::models::Model;
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{Embedding, VarBuilder};
use serde::Deserialize;
use text_embeddings_backend_core::{Batch, ModelType, Pool};
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct DistilBertConfig {
pub vocab_size: usize,
pub dim: usize,
pub n_layers: usize,
pub n_heads: usize,
pub hidden_dim: usize,
pub activation: HiddenAct,
pub max_position_embeddings: usize,
pub pad_token_id: usize,
pub model_type: Option<String>,
}
#[derive(Debug)]
pub struct DistilBertEmbeddings {
word_embeddings: Embedding,
position_embeddings: Embedding,
layer_norm: LayerNorm,
span: tracing::Span,
}
impl DistilBertEmbeddings {
pub fn load(vb: VarBuilder, config: &DistilBertConfig) -> Result<Self> {
Ok(Self {
word_embeddings: Embedding::new(
vb.pp("word_embeddings")
.get((config.vocab_size, config.dim), "weight")?,
config.dim,
),
position_embeddings: Embedding::new(
vb.pp("position_embeddings")
.get((config.max_position_embeddings, config.dim), "weight")?,
config.dim,
),
layer_norm: LayerNorm::load(vb.pp("LayerNorm"), config.dim, 1e-12f32)?,
span: tracing::span!(tracing::Level::TRACE, "embeddings"),
})
}
pub fn forward(&self, input_ids: &Tensor, position_ids: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let input_embeddings = self.word_embeddings.forward(input_ids)?;
let position_embeddings = self.position_embeddings.forward(position_ids)?;
let embeddings = self
.layer_norm
.forward(&input_embeddings, Some(&position_embeddings))?;
Ok(embeddings)
}
}
#[derive(Debug)]
struct DistilBertAttention {
qkv_linear: Linear,
dense: Linear,
num_attention_heads: usize,
attention_head_size: usize,
softmax_scale: f64,
span: tracing::Span,
}
impl DistilBertAttention {
pub fn load(vb: VarBuilder, config: &DistilBertConfig) -> Result<Self> {
let attention_head_size = config.dim / config.n_heads;
let all_head_size = config.n_heads * attention_head_size;
let hidden_size = config.dim;
let query_weight = vb.pp("q_lin").get((all_head_size, hidden_size), "weight")?;
let query_bias = vb.pp("q_lin").get(all_head_size, "bias")?;
let key_weight = vb.pp("k_lin").get((all_head_size, hidden_size), "weight")?;
let key_bias = vb.pp("k_lin").get(all_head_size, "bias")?;
let value_weight = vb.pp("v_lin").get((all_head_size, hidden_size), "weight")?;
let value_bias = vb.pp("v_lin").get(all_head_size, "bias")?;
let qkv_weight = Tensor::cat(&[&query_weight, &key_weight, &value_weight], 0)?;
let qkv_bias = Tensor::cat(&[&query_bias, &key_bias, &value_bias], 0)?;
let qkv_linear = Linear::new(qkv_weight, Some(qkv_bias), None);
let dense_weight = vb.pp("out_lin").get((hidden_size, hidden_size), "weight")?;
let dense_bias = vb.pp("out_lin").get(hidden_size, "bias")?;
let dense = Linear::new(dense_weight, Some(dense_bias), None);
let softmax_scale = 1. / (attention_head_size as f64).sqrt();
Ok(Self {
qkv_linear,
dense,
num_attention_heads: config.n_heads,
attention_head_size,
softmax_scale,
span: tracing::span!(tracing::Level::TRACE, "attention"),
})
}
fn forward(&self, hidden_states: &Tensor, attention_bias: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let device = hidden_states.device();
let qkv = self.qkv_linear.forward(hidden_states)?;
let mut new_qkv_shape = qkv.dims().to_vec();
new_qkv_shape.pop();
new_qkv_shape.push(self.num_attention_heads * 3);
new_qkv_shape.push(self.attention_head_size);
let qkv = qkv.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?;
let qkv = qkv.chunk(3, 1)?;
let query_layer = &qkv[0].contiguous()?;
let key_layer = &qkv[1].contiguous()?;
let value_layer = &qkv[2];
#[allow(unused_variables)]
let context_layer = if let (Device::Cuda(_), Some(cublaslt)) =
(device, get_cublas_lt_wrapper())
{
#[cfg(feature = "cuda")]
{
// cuBLASLt batch matmul implementation requires inputs to be dims3
let (batch_size, _, seq_len, _) = key_layer.shape().dims4()?;
let key_layer = key_layer.flatten(0, 1)?;
let query_layer = query_layer.flatten(0, 1)?;
let value_layer = value_layer.flatten(0, 1)?;
let attention_bias = attention_bias.map(|mask| mask.flatten(0, 1)).transpose()?;
// If attention_bias is set, we fuse the add by giving it as the output matrix
// and setting beta to 1.0
let beta = match attention_bias.is_some() {
true => Some(1.0),
false => None,
};
// Batch matrix multiplication
// Fuse softmax scale and attention_bias add
let attention_scores = cublaslt.batch_matmul(
&key_layer,
&query_layer,
attention_bias.as_ref(),
Some(self.softmax_scale as f32),
beta,
None,
None,
)?;
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
let context_layer = cublaslt.batch_matmul(
&value_layer.t()?.contiguous()?,
&attention_probs,
// We save one allocation
Some(&query_layer),
None,
None,
None,
None,
)?;
// Reshape to dims4
context_layer.reshape((
batch_size,
self.num_attention_heads,
seq_len,
self.attention_head_size,
))
}
#[cfg(not(feature = "cuda"))]
{
candle::bail!("`cuda` feature is not enabled")
}
} else {
let attention_scores = query_layer.matmul(&key_layer.t()?)?;
let mut attention_scores = (attention_scores * self.softmax_scale)?;
if let Some(attention_bias) = attention_bias {
attention_scores = attention_scores.add(attention_bias)?;
}
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
attention_probs.matmul(&value_layer.contiguous()?)
}?;
let context_layer = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?;
let hidden_states = self.dense.forward(&context_layer)?;
Ok(hidden_states)
}
}
#[derive(Debug)]
pub struct DistilBertMLP {
lin1: Linear,
lin2: Linear,
span: tracing::Span,
}
impl DistilBertMLP {
pub fn load(vb: VarBuilder, config: &DistilBertConfig) -> Result<Self> {
let lin1_weight = vb
.pp("lin1")
.get((config.hidden_dim, config.dim), "weight")?;
let lin1_bias = vb.pp("lin1").get(config.hidden_dim, "bias")?;
let lin1 = Linear::new(
lin1_weight,
Some(lin1_bias),
Some(config.activation.clone()),
);
let lin2_weight = vb
.pp("lin2")
.get((config.dim, config.hidden_dim), "weight")?;
let lin2_bias = vb.pp("lin2").get(config.dim, "bias")?;
let lin2 = Linear::new(lin2_weight, Some(lin2_bias), None);
Ok(Self {
lin1,
lin2,
span: tracing::span!(tracing::Level::TRACE, "mlp"),
})
}
pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hidden_states = self.lin1.forward(hidden_states)?;
self.lin2.forward(&hidden_states)
}
}
#[derive(Debug)]
struct DistilBertBlock {
attention: DistilBertAttention,
mlp: DistilBertMLP,
post_attention_layer_norm: LayerNorm,
output_layer_norm: LayerNorm,
span: tracing::Span,
}
impl DistilBertBlock {
pub fn load(vb: VarBuilder, config: &DistilBertConfig) -> Result<Self> {
let attention = DistilBertAttention::load(vb.pp("attention"), config)?;
let mlp = DistilBertMLP::load(vb.pp("ffn"), config)?;
let post_attention_layer_norm =
LayerNorm::load(vb.pp("sa_layer_norm"), config.dim, 1e-12f32)?;
let output_layer_norm = LayerNorm::load(vb.pp("output_layer_norm"), config.dim, 1e-12f32)?;
Ok(Self {
attention,
mlp,
post_attention_layer_norm,
output_layer_norm,
span: tracing::span!(tracing::Level::TRACE, "layer"),
})
}
pub fn forward(
&self,
hidden_states: &Tensor,
attention_bias: Option<&Tensor>,
) -> Result<Tensor> {
let _enter = self.span.enter();
let attn_output = self.attention.forward(hidden_states, attention_bias)?;
let hidden_states = self
.post_attention_layer_norm
.forward(hidden_states, Some(&attn_output))?;
let mlp_out = self.mlp.forward(&hidden_states)?;
self.output_layer_norm
.forward(&hidden_states, Some(&mlp_out))
}
}
#[derive(Debug)]
struct DistilBertEncoder {
layers: Vec<DistilBertBlock>,
span: tracing::Span,
}
impl DistilBertEncoder {
pub fn load(vb: VarBuilder, config: &DistilBertConfig) -> Result<Self> {
let layers = (0..config.n_layers)
.map(|index| DistilBertBlock::load(vb.pp(format!("layer.{index}")), config))
.collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "encoder");
Ok(DistilBertEncoder { layers, span })
}
fn forward(&self, hidden_states: &Tensor, attention_bias: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let mut hidden_states = hidden_states.clone();
// Use a loop rather than a fold as it's easier to modify when adding debug/...
for layer in self.layers.iter() {
hidden_states = layer.forward(&hidden_states, attention_bias)?;
}
Ok(hidden_states)
}
}
#[derive(Debug)]
pub struct DistilBertSpladeHead {
vocab_transform: Linear,
vocab_projector: Linear,
vocab_layer_norm: LayerNorm,
span: tracing::Span,
}
impl DistilBertSpladeHead {
pub(crate) fn load(vb: VarBuilder, config: &DistilBertConfig) -> Result<Self> {
let vocab_transform_weight = vb
.pp("vocab_transform")
.get((config.dim, config.dim), "weight")?;
let vocab_transform_bias = vb.pp("vocab_transform").get(config.dim, "bias")?;
let vocab_transform = Linear::new(
vocab_transform_weight,
Some(vocab_transform_bias),
Some(config.activation.clone()),
);
let vocab_projector_weight = vb
.pp("vocab_projector")
.get((config.vocab_size, config.dim), "weight")?;
let vocab_projector_bias = vb.pp("vocab_projector").get(config.vocab_size, "bias")?;
let vocab_projector = Linear::new(
vocab_projector_weight,
Some(vocab_projector_bias),
Some(HiddenAct::Relu),
);
let vocab_layer_norm = LayerNorm::load(vb.pp("vocab_layer_norm"), config.dim, 1e-12f32)?;
Ok(Self {
vocab_transform,
vocab_projector,
vocab_layer_norm,
span: tracing::span!(tracing::Level::TRACE, "splade"),
})
}
pub(crate) fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hidden_states = self.vocab_transform.forward(hidden_states)?;
let hidden_states = self.vocab_layer_norm.forward(&hidden_states, None)?;
let hidden_states = self.vocab_projector.forward(&hidden_states)?;
(1.0 + hidden_states)?.log()
}
}
#[derive(Debug)]
pub struct DistilBertModel {
embeddings: DistilBertEmbeddings,
encoder: DistilBertEncoder,
pool: Pool,
splade: Option<DistilBertSpladeHead>,
num_attention_heads: usize,
device: Device,
dtype: DType,
span: tracing::Span,
}
impl DistilBertModel {
pub fn load(vb: VarBuilder, config: &DistilBertConfig, model_type: ModelType) -> Result<Self> {
let pool = match model_type {
ModelType::Classifier => {
candle::bail!("`classifier` model type is not supported for DistilBert")
}
ModelType::Embedding(pool) => {
if pool == Pool::LastToken {
candle::bail!("`last_token` is not supported for DistilBert");
}
pool
}
};
let (embeddings, encoder) = match (
DistilBertEmbeddings::load(vb.pp("embeddings"), config),
DistilBertEncoder::load(vb.pp("encoder"), config),
) {
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
(Err(err), _) | (_, Err(err)) => {
if let (Ok(embeddings), Ok(encoder)) = (
DistilBertEmbeddings::load(vb.pp("distilbert.embeddings"), config),
DistilBertEncoder::load(vb.pp("distilbert.transformer"), config),
) {
(embeddings, encoder)
} else {
return Err(err);
}
}
};
let splade = if pool == Pool::Splade {
Some(DistilBertSpladeHead::load(vb.clone(), config)?)
} else {
None
};
Ok(Self {
embeddings,
encoder,
pool,
splade,
num_attention_heads: config.n_heads,
device: vb.device().clone(),
dtype: vb.dtype(),
span: tracing::span!(tracing::Level::TRACE, "model"),
})
}
pub fn forward(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
let _enter = self.span.enter();
let batch_size = batch.len();
let max_length = batch.max_length as usize;
let shape = (batch_size, max_length);
let (input_ids, position_ids, input_lengths, attention_bias, attention_mask) =
if batch_size > 1 {
// Prepare padded batch
let elems = batch_size * max_length;
let mut input_ids = Vec::with_capacity(elems);
let mut position_ids = Vec::with_capacity(elems);
let mut attention_mask = Vec::with_capacity(elems);
let mut attention_bias = Vec::with_capacity(elems);
let mut input_lengths = Vec::with_capacity(batch_size);
// Bool to know if we need to use the attention mask
let mut masking = false;
for i in 0..batch_size {
let start = batch.cumulative_seq_lengths[i] as usize;
let end = batch.cumulative_seq_lengths[i + 1] as usize;
let seq_length = (end - start) as u32;
input_lengths.push(seq_length as f32);
// Copy values
for j in start..end {
input_ids.push(batch.input_ids[j]);
position_ids.push(batch.position_ids[j]);
attention_mask.push(1.0_f32);
attention_bias.push(0.0);
}
// Add padding if needed
let padding = batch.max_length - seq_length;
if padding > 0 {
// Set bool to use attention mask
masking = true;
for _ in 0..padding {
input_ids.push(0);
position_ids.push(0);
attention_mask.push(0.0_f32);
attention_bias.push(f32::NEG_INFINITY);
}
}
}
let (attention_bias, attention_mask) = match masking {
true => {
// We only need the mask if we use mean pooling
// For CLS pooling, the bias is enough
let attention_mask = if self.pool == Pool::Mean {
let attention_mask = Tensor::from_vec(
attention_mask,
(batch_size, max_length, 1),
&self.device,
)?
.to_dtype(self.dtype)?;
Some(attention_mask)
} else {
None
};
let attention_bias = Tensor::from_vec(
attention_bias,
(batch_size, 1, 1, max_length),
&self.device,
)?
.to_dtype(self.dtype)?;
// Broadcast once instead of at every layer
let attention_bias = attention_bias
.broadcast_as((
batch_size,
self.num_attention_heads,
max_length,
max_length,
))?
.contiguous()?;
(Some(attention_bias), attention_mask)
}
false => (None, None),
};
(
input_ids,
position_ids,
input_lengths,
attention_bias,
attention_mask,
)
} else {
(
batch.input_ids,
batch.position_ids,
vec![batch.max_length as f32],
None,
None,
)
};
// Create CPU tensors
let input_ids = Tensor::from_vec(input_ids, shape, &self.device)?;
let position_ids = Tensor::from_vec(position_ids, shape, &self.device)?;
let input_lengths =
Tensor::from_vec(input_lengths, (batch_size, 1), &self.device)?.to_dtype(self.dtype)?;
let embedding_output = self.embeddings.forward(&input_ids, &position_ids)?;
let outputs = self
.encoder
.forward(&embedding_output, attention_bias.as_ref())?;
let has_pooling_requests = !batch.pooled_indices.is_empty();
let has_raw_requests = !batch.raw_indices.is_empty();
let pooled_embeddings = if has_pooling_requests {
let pooled_indices_length = batch.pooled_indices.len();
let mut outputs = outputs.clone();
// Only use pooled_indices if at least one member of the batch ask for raw embeddings
let pooled_indices = if has_raw_requests {
let pooled_indices =
Tensor::from_vec(batch.pooled_indices, pooled_indices_length, &self.device)?;
// Select values in the batch
outputs = outputs.index_select(&pooled_indices, 0)?;
Some(pooled_indices)
} else {
None
};
let pooled_embeddings = match self.pool {
// CLS pooling
Pool::Cls => outputs.i((.., 0))?,
// Last token pooling is not supported for this model
Pool::LastToken => unreachable!(),
// Mean pooling
Pool::Mean => {
if let Some(ref attention_mask) = attention_mask {
let mut attention_mask = attention_mask.clone();
if let Some(pooled_indices) = pooled_indices {
// Select values in the batch
attention_mask = attention_mask.index_select(&pooled_indices, 0)?;
};
// Mask padded values
outputs = outputs.broadcast_mul(&attention_mask)?;
}
(outputs.sum(1)?.broadcast_div(&input_lengths))?
}
Pool::BM42 => unreachable!(),
Pool::Splade => {
// Unwrap is safe here
let splade_head = self.splade.as_ref().unwrap();
let mut relu_log = splade_head.forward(&outputs)?;
if let Some(ref attention_mask) = attention_mask {
let mut attention_mask = attention_mask.clone();
if let Some(pooled_indices) = pooled_indices {
// Select values in the batch
attention_mask = attention_mask.index_select(&pooled_indices, 0)?;
};
// Mask padded values
relu_log = relu_log.broadcast_mul(&attention_mask)?;
}
relu_log.max(1)?
}
};
Some(pooled_embeddings)
} else {
None
};
let raw_embeddings = if has_raw_requests {
// Reshape outputs
let (b, l, h) = outputs.shape().dims3()?;
let outputs = outputs.reshape((b * l, h))?;
// We need to remove the padding tokens only if batch_size > 1 and there are some
// member of the batch that require pooling
// or if batch_size > 1 and the members of the batch have different lengths
if (attention_mask.is_some() || has_pooling_requests) && batch_size > 1 {
let mut final_indices: Vec<u32> = Vec::with_capacity(batch_size * max_length);
for i in batch.raw_indices.into_iter() {
let start = i * batch.max_length;
let i = i as usize;
let length =
batch.cumulative_seq_lengths[i + 1] - batch.cumulative_seq_lengths[i];
for j in start..start + length {
// Add indices for the tokens of this specific member of the batch
final_indices.push(j);
}
}
let final_indices_length = final_indices.len();
let final_indices =
Tensor::from_vec(final_indices, final_indices_length, &self.device)?;
// Select the tokens with final indices
Some(outputs.index_select(&final_indices, 0)?)
} else {
Some(outputs)
}
} else {
None
};
Ok((pooled_embeddings, raw_embeddings))
}
}
impl Model for DistilBertModel {
fn is_padded(&self) -> bool {
true
}
fn embed(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
self.forward(batch)
}
}