Skip to content

Commit cf423d1

Browse files
authored
Patch DistilBERT variants with different weight keys (#614)
1 parent 5bcc41f commit cf423d1

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

backends/candle/src/models/distilbert.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,11 @@ impl DistilBertModel {
475475
DistilBertEncoder::load(vb.pp("distilbert.transformer"), config),
476476
) {
477477
(embeddings, encoder)
478+
} else if let (Ok(embeddings), Ok(encoder)) = (
479+
DistilBertEmbeddings::load(vb.pp("embeddings"), config),
480+
DistilBertEncoder::load(vb.pp("transformer"), config),
481+
) {
482+
(embeddings, encoder)
478483
} else {
479484
return Err(err);
480485
}

backends/candle/src/models/flash_distilbert.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,11 @@ impl FlashDistilBertModel {
213213
DistilBertEncoder::load(vb.pp("distilbert.transformer"), config),
214214
) {
215215
(embeddings, encoder)
216+
} else if let (Ok(embeddings), Ok(encoder)) = (
217+
DistilBertEmbeddings::load(vb.pp("embeddings"), config),
218+
DistilBertEncoder::load(vb.pp("transformer"), config),
219+
) {
220+
(embeddings, encoder)
216221
} else {
217222
return Err(err);
218223
}

0 commit comments

Comments
 (0)