@@ -655,39 +655,37 @@ extension BERT {
655
655
}
656
656
}
657
657
658
- /// Loads this pre-trained BERT model from the specified directory .
658
+ /// Loads this pre-trained BERT model from the specified URL .
659
659
///
660
660
/// - Note: This function will download the pre-trained model files to the specified
661
661
// directory, if they are not already there.
662
662
///
663
663
/// - Parameters:
664
- /// - directory: Directory to load the pretrained model from.
665
- public func load( from directory : URL ) throws -> BERT {
664
+ /// - url: URL to load the pretrained model from.
665
+ public func load( from url : URL ? = nil ) throws -> BERT {
666
666
print ( " Loading BERT pre-trained model ' \( name) '. " )
667
- let directory = directory. appendingPathComponent ( variant. description, isDirectory: true )
668
- try maybeDownload ( to: directory)
667
+
668
+ let reader = try CheckpointReader ( checkpointLocation: url ?? self . url, modelName: name)
669
+ // TODO(michellecasbon): expose this.
670
+ reader. isCRCVerificationEnabled = false
671
+
672
+ let storage = reader. localCheckpointLocation. deletingLastPathComponent ( )
669
673
670
674
// Load the appropriate vocabulary file.
671
675
let vocabulary : Vocabulary = {
672
676
switch self {
673
677
case . bertBase, . bertLarge:
674
- let vocabularyURL = directory
675
- . appendingPathComponent ( subDirectory)
676
- . appendingPathComponent ( " vocab.txt " )
678
+ let vocabularyURL = storage. appendingPathComponent ( " vocab.txt " )
677
679
return try ! Vocabulary ( fromFile: vocabularyURL)
678
680
case . robertaBase, . robertaLarge:
679
- let vocabularyURL = directory
680
- . appendingPathComponent ( subDirectory)
681
- . appendingPathComponent ( " vocab.json " )
682
- let dictionaryURL = directory
683
- . appendingPathComponent ( subDirectory)
684
- . appendingPathComponent ( " dict.txt " )
681
+ let vocabularyURL = storage. appendingPathComponent ( " vocab.json " )
682
+ let dictionaryURL = storage. appendingPathComponent ( " dict.txt " )
685
683
return try ! Vocabulary (
686
684
fromRoBERTaJSONFile: vocabularyURL,
687
685
dictionaryFile: dictionaryURL)
688
686
case . albertBase, . albertLarge, . albertXLarge, . albertXXLarge:
689
- let vocabularyURL = directory
690
- . appendingPathComponent ( subDirectory )
687
+ let vocabularyURL = storage
688
+ . deletingLastPathComponent ( )
691
689
. appendingPathComponent ( " assets " )
692
690
. appendingPathComponent ( " 30k-clean.model " )
693
691
return try ! Vocabulary ( fromSentencePieceModel: vocabularyURL)
@@ -704,8 +702,7 @@ extension BERT {
704
702
unknownToken: " [UNK] " ,
705
703
maxTokenLength: nil )
706
704
case . robertaBase, . robertaLarge:
707
- let mergePairsFileURL = directory
708
- . appendingPathComponent ( subDirectory)
705
+ let mergePairsFileURL = storage
709
706
. appendingPathComponent ( " merges.txt " )
710
707
let mergePairs = [ BytePairEncoder . Pair: Int] (
711
708
uniqueKeysWithValues:
@@ -749,150 +746,40 @@ extension BERT {
749
746
initializerStandardDeviation: 0.02 ,
750
747
useOneHotEmbeddings: false )
751
748
752
- // Load the pre-trained model checkpoint.
753
- switch self {
754
- case . bertBase, . bertLarge:
755
- model. load ( fromTensorFlowCheckpoint: directory
756
- . appendingPathComponent ( subDirectory)
757
- . appendingPathComponent ( " bert_model.ckpt " ) )
758
- case . robertaBase, . robertaLarge:
759
- model. load ( fromTensorFlowCheckpoint: directory
760
- . appendingPathComponent ( subDirectory)
761
- . appendingPathComponent ( " roberta_ \( subDirectory) .ckpt " ) )
762
- case . albertBase, . albertLarge, . albertXLarge, . albertXXLarge:
763
- model. load ( fromTensorFlowCheckpoint: directory
764
- . appendingPathComponent ( subDirectory)
765
- . appendingPathComponent ( " variables " )
766
- . appendingPathComponent ( " variables " ) )
767
- }
749
+ model. loadTensors ( reader)
768
750
return model
769
751
}
770
-
771
- /// Downloads this pre-trained model to the specified directory, if it's not already there.
772
- public func maybeDownload( to directory: URL ) throws {
773
- switch self {
774
- case . bertBase, . bertLarge, . robertaBase, . robertaLarge:
775
- // Download and extract the pretrained model, if necessary.
776
- DatasetUtilities . downloadResource ( filename: " \( subDirectory) " , fileExtension: " zip " ,
777
- remoteRoot: url. deletingLastPathComponent ( ) ,
778
- localStorageDirectory: directory)
779
- case . albertBase, . albertLarge, . albertXLarge, . albertXXLarge:
780
- // Download the model, if necessary.
781
- let compressedFileURL = directory. appendingPathComponent ( " \( subDirectory) .tar.gz " )
782
- try download ( from: url, to: compressedFileURL)
783
-
784
- // Extract the data, if necessary.
785
- let extractedDirectoryURL = directory. appendingPathComponent ( subDirectory)
786
- if !FileManager. default. fileExists ( atPath: extractedDirectoryURL. path) {
787
- try extract ( tarGZippedFileAt: compressedFileURL, to: extractedDirectoryURL)
788
- }
789
- }
790
- }
791
752
}
792
753
793
- /// Loads a BERT model from the provided TensorFlow checkpoint file into this BERT model.
754
+ /// Loads a BERT model from the provided CheckpointReader into this BERT model.
794
755
///
795
756
/// - Parameters:
796
- /// - fileURL: Path to the checkpoint file. Note that TensorFlow checkpoints typically
797
- /// consist of multiple files (e.g., `bert_model.ckpt.index`, `bert_model.ckpt.meta`, and
798
- /// `bert_model.ckpt.data-00000-of-00001`). In this case, the file URL should be specified
799
- /// as their common prefix (e.g., `bert_model.ckpt`).
800
- public mutating func load( fromTensorFlowCheckpoint fileURL: URL ) {
801
- let checkpointReader = TensorFlowCheckpointReader ( checkpointPath: fileURL. path)
802
- tokenEmbedding. embeddings =
803
- Tensor ( checkpointReader. loadTensor ( named: " bert/embeddings/word_embeddings " ) )
804
- positionEmbedding. embeddings =
805
- Tensor ( checkpointReader. loadTensor ( named: " bert/embeddings/position_embeddings " ) )
806
- embeddingLayerNorm. offset =
807
- Tensor ( checkpointReader. loadTensor ( named: " bert/embeddings/LayerNorm/beta " ) )
808
- embeddingLayerNorm. scale =
809
- Tensor ( checkpointReader. loadTensor ( named: " bert/embeddings/LayerNorm/gamma " ) )
757
+ /// - reader: CheckpointReader object to load tensors from.
758
+ public mutating func loadTensors( _ reader: CheckpointReader ) {
759
+ tokenEmbedding. embeddings = reader. readTensor ( name: " bert/embeddings/word_embeddings " )
760
+ positionEmbedding. embeddings = reader. readTensor ( name: " bert/embeddings/position_embeddings " )
761
+ embeddingLayerNorm. offset = reader. readTensor ( name: " bert/embeddings/LayerNorm/beta " )
762
+ embeddingLayerNorm. scale = reader. readTensor ( name: " bert/embeddings/LayerNorm/gamma " )
810
763
switch variant {
811
764
case . bert, . albert:
812
765
tokenTypeEmbedding. embeddings =
813
- Tensor ( checkpointReader . loadTensor ( named : " bert/embeddings/token_type_embeddings " ) )
766
+ reader . readTensor ( name : " bert/embeddings/token_type_embeddings " )
814
767
case . roberta: ( )
815
768
}
816
769
switch variant {
817
770
case . bert, . roberta:
818
771
for layerIndex in encoderLayers. indices {
819
- let prefix = " bert/encoder/layer_ \( layerIndex) "
820
- encoderLayers [ layerIndex] . multiHeadAttention. queryWeight =
821
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /attention/self/query/kernel " ) )
822
- encoderLayers [ layerIndex] . multiHeadAttention. queryBias =
823
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /attention/self/query/bias " ) )
824
- encoderLayers [ layerIndex] . multiHeadAttention. keyWeight =
825
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /attention/self/key/kernel " ) )
826
- encoderLayers [ layerIndex] . multiHeadAttention. keyBias =
827
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /attention/self/key/bias " ) )
828
- encoderLayers [ layerIndex] . multiHeadAttention. valueWeight =
829
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /attention/self/value/kernel " ) )
830
- encoderLayers [ layerIndex] . multiHeadAttention. valueBias =
831
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /attention/self/value/bias " ) )
832
- encoderLayers [ layerIndex] . attentionWeight =
833
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /attention/output/dense/kernel " ) )
834
- encoderLayers [ layerIndex] . attentionBias =
835
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /attention/output/dense/bias " ) )
836
- encoderLayers [ layerIndex] . attentionLayerNorm. offset =
837
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /attention/output/LayerNorm/beta " ) )
838
- encoderLayers [ layerIndex] . attentionLayerNorm. scale =
839
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /attention/output/LayerNorm/gamma " ) )
840
- encoderLayers [ layerIndex] . intermediateWeight =
841
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /intermediate/dense/kernel " ) )
842
- encoderLayers [ layerIndex] . intermediateBias =
843
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /intermediate/dense/bias " ) )
844
- encoderLayers [ layerIndex] . outputWeight =
845
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /output/dense/kernel " ) )
846
- encoderLayers [ layerIndex] . outputBias =
847
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /output/dense/bias " ) )
848
- encoderLayers [ layerIndex] . outputLayerNorm. offset =
849
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /output/LayerNorm/beta " ) )
850
- encoderLayers [ layerIndex] . outputLayerNorm. scale =
851
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /output/LayerNorm/gamma " ) )
772
+ encoderLayers [ layerIndex] . load ( bert: reader,
773
+ prefix: " bert/encoder/layer_ \( layerIndex) " )
852
774
}
853
775
case . albert:
854
776
embeddingProjection [ 0 ] . weight =
855
- Tensor ( checkpointReader. loadTensor (
856
- named: " bert/encoder/embedding_hidden_mapping_in/kernel " ) )
777
+ reader. readTensor ( name: " bert/encoder/embedding_hidden_mapping_in/kernel " )
857
778
embeddingProjection [ 0 ] . bias =
858
- Tensor ( checkpointReader. loadTensor (
859
- named: " bert/encoder/embedding_hidden_mapping_in/bias " ) )
779
+ reader. readTensor ( name: " bert/encoder/embedding_hidden_mapping_in/bias " )
860
780
for layerIndex in encoderLayers. indices {
861
781
let prefix = " bert/encoder/transformer/group_ \( layerIndex) /inner_group_0 "
862
- encoderLayers [ layerIndex] . multiHeadAttention. queryWeight =
863
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /attention_1/self/query/kernel " ) )
864
- encoderLayers [ layerIndex] . multiHeadAttention. queryBias =
865
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /attention_1/self/query/bias " ) )
866
- encoderLayers [ layerIndex] . multiHeadAttention. keyWeight =
867
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /attention_1/self/key/kernel " ) )
868
- encoderLayers [ layerIndex] . multiHeadAttention. keyBias =
869
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /attention_1/self/key/bias " ) )
870
- encoderLayers [ layerIndex] . multiHeadAttention. valueWeight =
871
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /attention_1/self/value/kernel " ) )
872
- encoderLayers [ layerIndex] . multiHeadAttention. valueBias =
873
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /attention_1/self/value/bias " ) )
874
- encoderLayers [ layerIndex] . attentionWeight =
875
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /attention_1/output/dense/kernel " ) )
876
- encoderLayers [ layerIndex] . attentionBias =
877
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /attention_1/output/dense/bias " ) )
878
- encoderLayers [ layerIndex] . attentionLayerNorm. offset =
879
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /LayerNorm/beta " ) )
880
- encoderLayers [ layerIndex] . attentionLayerNorm. scale =
881
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /LayerNorm/gamma " ) )
882
- encoderLayers [ layerIndex] . intermediateWeight =
883
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /ffn_1/intermediate/dense/kernel " ) )
884
- encoderLayers [ layerIndex] . intermediateBias =
885
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /ffn_1/intermediate/dense/bias " ) )
886
- encoderLayers [ layerIndex] . outputWeight =
887
- Tensor ( checkpointReader. loadTensor (
888
- named: " \( prefix) /ffn_1/intermediate/output/dense/kernel " ) )
889
- encoderLayers [ layerIndex] . outputBias =
890
- Tensor ( checkpointReader. loadTensor (
891
- named: " \( prefix) /ffn_1/intermediate/output/dense/bias " ) )
892
- encoderLayers [ layerIndex] . outputLayerNorm. offset =
893
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /LayerNorm_1/beta " ) )
894
- encoderLayers [ layerIndex] . outputLayerNorm. scale =
895
- Tensor ( checkpointReader. loadTensor ( named: " \( prefix) /LayerNorm_1/gamma " ) )
782
+ encoderLayers [ layerIndex] . load ( albert: reader, prefix: prefix)
896
783
}
897
784
}
898
785
}
0 commit comments