diff --git a/backends/candle/src/models/distilbert.rs b/backends/candle/src/models/distilbert.rs index b7b43893..7b8f0786 100644 --- a/backends/candle/src/models/distilbert.rs +++ b/backends/candle/src/models/distilbert.rs @@ -475,6 +475,11 @@ impl DistilBertModel { DistilBertEncoder::load(vb.pp("distilbert.transformer"), config), ) { (embeddings, encoder) + } else if let (Ok(embeddings), Ok(encoder)) = ( + DistilBertEmbeddings::load(vb.pp("embeddings"), config), + DistilBertEncoder::load(vb.pp("transformer"), config), + ) { + (embeddings, encoder) } else { return Err(err); } diff --git a/backends/candle/src/models/flash_distilbert.rs b/backends/candle/src/models/flash_distilbert.rs index b107e1e3..2664c660 100644 --- a/backends/candle/src/models/flash_distilbert.rs +++ b/backends/candle/src/models/flash_distilbert.rs @@ -213,6 +213,11 @@ impl FlashDistilBertModel { DistilBertEncoder::load(vb.pp("distilbert.transformer"), config), ) { (embeddings, encoder) + } else if let (Ok(embeddings), Ok(encoder)) = ( + DistilBertEmbeddings::load(vb.pp("embeddings"), config), + DistilBertEncoder::load(vb.pp("transformer"), config), + ) { + (embeddings, encoder) } else { return Err(err); }