Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 521637a

Browse files
leoxzhaoBradLarsonxihui-wu
authored
[BERT] Switch to use CheckpointReader to load tensors from checkpoint. (#554)
* [BERT] Switch to use CheckpointReader to load tensors from checkpoint (it was using TensorFlow API directly). Also add a simple command line parameter for the example app, BERT-CoLA, which takes "BERT", "ALBERT", "RoBERTa" to switch between models. * Add `BERTCheckpointReader.swift` into `CMakeLists.txt`. Revert a change to load checkpoints of a different format from TensorFlow Hub. Fix indentations. * Make BERT-CoLA display correct model name. * Pass URL to a checkpoint into BERT to create CheckpointReader and then load tensors from it. * Clean BERT-CoLA/main.swift Co-authored-by: Brad Larson <[email protected]> Co-authored-by: Xihui Wu <[email protected]>
1 parent 81390ad commit 521637a

File tree

5 files changed

+119
-163
lines changed

5 files changed

+119
-163
lines changed

Examples/BERT-CoLA/main.swift

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,20 @@ import x10_optimizers_optimizer
2121

2222
let device = Device.defaultXLA
2323

24-
let bertPretrained = BERT.PreTrainedModel.bertBase(cased: false, multilingual: false)
25-
let workspaceURL = URL(
26-
fileURLWithPath: "bert_models", isDirectory: true,
27-
relativeTo: URL(
28-
fileURLWithPath: NSTemporaryDirectory(),
29-
isDirectory: true))
30-
let bert = try BERT.PreTrainedModel.load(bertPretrained)(from: workspaceURL)
24+
var bertPretrained: BERT.PreTrainedModel
25+
if CommandLine.arguments.count >= 2 {
26+
if CommandLine.arguments[1].lowercased() == "albert" {
27+
bertPretrained = BERT.PreTrainedModel.albertBase
28+
} else if CommandLine.arguments[1].lowercased() == "roberta" {
29+
bertPretrained = BERT.PreTrainedModel.robertaBase
30+
} else {
31+
bertPretrained = BERT.PreTrainedModel.bertBase(cased: false, multilingual: false)
32+
}
33+
} else {
34+
bertPretrained = BERT.PreTrainedModel.bertBase(cased: false, multilingual: false)
35+
}
36+
37+
let bert = try bertPretrained.load()
3138
var bertClassifier = BERTClassifier(bert: bert, classCount: 1)
3239
bertClassifier.move(to: device)
3340

@@ -48,6 +55,9 @@ let epochCount = 3
4855
let stepsPerEpoch = 1068 // function of training set size and batching configuration
4956
let peakLearningRate: Float = 2e-5
5057

58+
let workspaceURL = URL(fileURLWithPath: "bert_models", isDirectory: true,
59+
relativeTo: URL(fileURLWithPath: NSTemporaryDirectory(),isDirectory: true))
60+
5161
var cola = try CoLA(
5262
taskDirectoryURL: workspaceURL,
5363
maxSequenceLength: maxSequenceLength,
@@ -85,7 +95,7 @@ var scheduledLearningRate = LinearlyDecayedParameter(
8595
startStep: 10
8696
)
8797

88-
print("Training BERT for the CoLA task!")
98+
print("Training \(bertPretrained.name) for the CoLA task!")
8999
for (epoch, epochBatches) in cola.trainingEpochs.prefix(3).enumerated() {
90100
print("[Epoch \(epoch + 1)]")
91101
Context.local.learningPhase = .training

Models/Text/BERT.swift

Lines changed: 29 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -655,39 +655,37 @@ extension BERT {
655655
}
656656
}
657657

658-
/// Loads this pre-trained BERT model from the specified directory.
658+
/// Loads this pre-trained BERT model from the specified URL.
659659
///
660660
/// - Note: This function will download the pre-trained model files to the specified
661661
// directory, if they are not already there.
662662
///
663663
/// - Parameters:
664-
/// - directory: Directory to load the pretrained model from.
665-
public func load(from directory: URL) throws -> BERT {
664+
/// - url: URL to load the pretrained model from.
665+
public func load(from url: URL? = nil) throws -> BERT {
666666
print("Loading BERT pre-trained model '\(name)'.")
667-
let directory = directory.appendingPathComponent(variant.description, isDirectory: true)
668-
try maybeDownload(to: directory)
667+
668+
let reader = try CheckpointReader(checkpointLocation: url ?? self.url, modelName: name)
669+
// TODO(michellecasbon): expose this.
670+
reader.isCRCVerificationEnabled = false
671+
672+
let storage = reader.localCheckpointLocation.deletingLastPathComponent()
669673

670674
// Load the appropriate vocabulary file.
671675
let vocabulary: Vocabulary = {
672676
switch self {
673677
case .bertBase, .bertLarge:
674-
let vocabularyURL = directory
675-
.appendingPathComponent(subDirectory)
676-
.appendingPathComponent("vocab.txt")
678+
let vocabularyURL = storage.appendingPathComponent("vocab.txt")
677679
return try! Vocabulary(fromFile: vocabularyURL)
678680
case .robertaBase, .robertaLarge:
679-
let vocabularyURL = directory
680-
.appendingPathComponent(subDirectory)
681-
.appendingPathComponent("vocab.json")
682-
let dictionaryURL = directory
683-
.appendingPathComponent(subDirectory)
684-
.appendingPathComponent("dict.txt")
681+
let vocabularyURL = storage.appendingPathComponent("vocab.json")
682+
let dictionaryURL = storage.appendingPathComponent("dict.txt")
685683
return try! Vocabulary(
686684
fromRoBERTaJSONFile: vocabularyURL,
687685
dictionaryFile: dictionaryURL)
688686
case .albertBase, .albertLarge, .albertXLarge, .albertXXLarge:
689-
let vocabularyURL = directory
690-
.appendingPathComponent(subDirectory)
687+
let vocabularyURL = storage
688+
.deletingLastPathComponent()
691689
.appendingPathComponent("assets")
692690
.appendingPathComponent("30k-clean.model")
693691
return try! Vocabulary(fromSentencePieceModel: vocabularyURL)
@@ -704,8 +702,7 @@ extension BERT {
704702
unknownToken: "[UNK]",
705703
maxTokenLength: nil)
706704
case .robertaBase, .robertaLarge:
707-
let mergePairsFileURL = directory
708-
.appendingPathComponent(subDirectory)
705+
let mergePairsFileURL = storage
709706
.appendingPathComponent("merges.txt")
710707
let mergePairs = [BytePairEncoder.Pair: Int](
711708
uniqueKeysWithValues:
@@ -749,150 +746,40 @@ extension BERT {
749746
initializerStandardDeviation: 0.02,
750747
useOneHotEmbeddings: false)
751748

752-
// Load the pre-trained model checkpoint.
753-
switch self {
754-
case .bertBase, .bertLarge:
755-
model.load(fromTensorFlowCheckpoint: directory
756-
.appendingPathComponent(subDirectory)
757-
.appendingPathComponent("bert_model.ckpt"))
758-
case .robertaBase, .robertaLarge:
759-
model.load(fromTensorFlowCheckpoint: directory
760-
.appendingPathComponent(subDirectory)
761-
.appendingPathComponent("roberta_\(subDirectory).ckpt"))
762-
case .albertBase, .albertLarge, .albertXLarge, .albertXXLarge:
763-
model.load(fromTensorFlowCheckpoint: directory
764-
.appendingPathComponent(subDirectory)
765-
.appendingPathComponent("variables")
766-
.appendingPathComponent("variables"))
767-
}
749+
model.loadTensors(reader)
768750
return model
769751
}
770-
771-
/// Downloads this pre-trained model to the specified directory, if it's not already there.
772-
public func maybeDownload(to directory: URL) throws {
773-
switch self {
774-
case .bertBase, .bertLarge, .robertaBase, .robertaLarge:
775-
// Download and extract the pretrained model, if necessary.
776-
DatasetUtilities.downloadResource(filename: "\(subDirectory)", fileExtension: "zip",
777-
remoteRoot: url.deletingLastPathComponent(),
778-
localStorageDirectory: directory)
779-
case .albertBase, .albertLarge, .albertXLarge, .albertXXLarge:
780-
// Download the model, if necessary.
781-
let compressedFileURL = directory.appendingPathComponent("\(subDirectory).tar.gz")
782-
try download(from: url, to: compressedFileURL)
783-
784-
// Extract the data, if necessary.
785-
let extractedDirectoryURL = directory.appendingPathComponent(subDirectory)
786-
if !FileManager.default.fileExists(atPath: extractedDirectoryURL.path) {
787-
try extract(tarGZippedFileAt: compressedFileURL, to: extractedDirectoryURL)
788-
}
789-
}
790-
}
791752
}
792753

793-
/// Loads a BERT model from the provided TensorFlow checkpoint file into this BERT model.
754+
/// Loads a BERT model from the provided CheckpointReader into this BERT model.
794755
///
795756
/// - Parameters:
796-
/// - fileURL: Path to the checkpoint file. Note that TensorFlow checkpoints typically
797-
/// consist of multiple files (e.g., `bert_model.ckpt.index`, `bert_model.ckpt.meta`, and
798-
/// `bert_model.ckpt.data-00000-of-00001`). In this case, the file URL should be specified
799-
/// as their common prefix (e.g., `bert_model.ckpt`).
800-
public mutating func load(fromTensorFlowCheckpoint fileURL: URL) {
801-
let checkpointReader = TensorFlowCheckpointReader(checkpointPath: fileURL.path)
802-
tokenEmbedding.embeddings =
803-
Tensor(checkpointReader.loadTensor(named: "bert/embeddings/word_embeddings"))
804-
positionEmbedding.embeddings =
805-
Tensor(checkpointReader.loadTensor(named: "bert/embeddings/position_embeddings"))
806-
embeddingLayerNorm.offset =
807-
Tensor(checkpointReader.loadTensor(named: "bert/embeddings/LayerNorm/beta"))
808-
embeddingLayerNorm.scale =
809-
Tensor(checkpointReader.loadTensor(named: "bert/embeddings/LayerNorm/gamma"))
757+
/// - reader: CheckpointReader object to load tensors from.
758+
public mutating func loadTensors(_ reader: CheckpointReader) {
759+
tokenEmbedding.embeddings = reader.readTensor(name: "bert/embeddings/word_embeddings")
760+
positionEmbedding.embeddings = reader.readTensor(name: "bert/embeddings/position_embeddings")
761+
embeddingLayerNorm.offset = reader.readTensor(name: "bert/embeddings/LayerNorm/beta")
762+
embeddingLayerNorm.scale = reader.readTensor(name: "bert/embeddings/LayerNorm/gamma")
810763
switch variant {
811764
case .bert, .albert:
812765
tokenTypeEmbedding.embeddings =
813-
Tensor(checkpointReader.loadTensor(named: "bert/embeddings/token_type_embeddings"))
766+
reader.readTensor(name: "bert/embeddings/token_type_embeddings")
814767
case .roberta: ()
815768
}
816769
switch variant {
817770
case .bert, .roberta:
818771
for layerIndex in encoderLayers.indices {
819-
let prefix = "bert/encoder/layer_\(layerIndex)"
820-
encoderLayers[layerIndex].multiHeadAttention.queryWeight =
821-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/attention/self/query/kernel"))
822-
encoderLayers[layerIndex].multiHeadAttention.queryBias =
823-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/attention/self/query/bias"))
824-
encoderLayers[layerIndex].multiHeadAttention.keyWeight =
825-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/attention/self/key/kernel"))
826-
encoderLayers[layerIndex].multiHeadAttention.keyBias =
827-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/attention/self/key/bias"))
828-
encoderLayers[layerIndex].multiHeadAttention.valueWeight =
829-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/attention/self/value/kernel"))
830-
encoderLayers[layerIndex].multiHeadAttention.valueBias =
831-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/attention/self/value/bias"))
832-
encoderLayers[layerIndex].attentionWeight =
833-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/attention/output/dense/kernel"))
834-
encoderLayers[layerIndex].attentionBias =
835-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/attention/output/dense/bias"))
836-
encoderLayers[layerIndex].attentionLayerNorm.offset =
837-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/attention/output/LayerNorm/beta"))
838-
encoderLayers[layerIndex].attentionLayerNorm.scale =
839-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/attention/output/LayerNorm/gamma"))
840-
encoderLayers[layerIndex].intermediateWeight =
841-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/intermediate/dense/kernel"))
842-
encoderLayers[layerIndex].intermediateBias =
843-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/intermediate/dense/bias"))
844-
encoderLayers[layerIndex].outputWeight =
845-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/output/dense/kernel"))
846-
encoderLayers[layerIndex].outputBias =
847-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/output/dense/bias"))
848-
encoderLayers[layerIndex].outputLayerNorm.offset =
849-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/output/LayerNorm/beta"))
850-
encoderLayers[layerIndex].outputLayerNorm.scale =
851-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/output/LayerNorm/gamma"))
772+
encoderLayers[layerIndex].load(bert: reader,
773+
prefix: "bert/encoder/layer_\(layerIndex)")
852774
}
853775
case .albert:
854776
embeddingProjection[0].weight =
855-
Tensor(checkpointReader.loadTensor(
856-
named: "bert/encoder/embedding_hidden_mapping_in/kernel"))
777+
reader.readTensor(name: "bert/encoder/embedding_hidden_mapping_in/kernel")
857778
embeddingProjection[0].bias =
858-
Tensor(checkpointReader.loadTensor(
859-
named: "bert/encoder/embedding_hidden_mapping_in/bias"))
779+
reader.readTensor(name: "bert/encoder/embedding_hidden_mapping_in/bias")
860780
for layerIndex in encoderLayers.indices {
861781
let prefix = "bert/encoder/transformer/group_\(layerIndex)/inner_group_0"
862-
encoderLayers[layerIndex].multiHeadAttention.queryWeight =
863-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/attention_1/self/query/kernel"))
864-
encoderLayers[layerIndex].multiHeadAttention.queryBias =
865-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/attention_1/self/query/bias"))
866-
encoderLayers[layerIndex].multiHeadAttention.keyWeight =
867-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/attention_1/self/key/kernel"))
868-
encoderLayers[layerIndex].multiHeadAttention.keyBias =
869-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/attention_1/self/key/bias"))
870-
encoderLayers[layerIndex].multiHeadAttention.valueWeight =
871-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/attention_1/self/value/kernel"))
872-
encoderLayers[layerIndex].multiHeadAttention.valueBias =
873-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/attention_1/self/value/bias"))
874-
encoderLayers[layerIndex].attentionWeight =
875-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/attention_1/output/dense/kernel"))
876-
encoderLayers[layerIndex].attentionBias =
877-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/attention_1/output/dense/bias"))
878-
encoderLayers[layerIndex].attentionLayerNorm.offset =
879-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/LayerNorm/beta"))
880-
encoderLayers[layerIndex].attentionLayerNorm.scale =
881-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/LayerNorm/gamma"))
882-
encoderLayers[layerIndex].intermediateWeight =
883-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/ffn_1/intermediate/dense/kernel"))
884-
encoderLayers[layerIndex].intermediateBias =
885-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/ffn_1/intermediate/dense/bias"))
886-
encoderLayers[layerIndex].outputWeight =
887-
Tensor(checkpointReader.loadTensor(
888-
named: "\(prefix)/ffn_1/intermediate/output/dense/kernel"))
889-
encoderLayers[layerIndex].outputBias =
890-
Tensor(checkpointReader.loadTensor(
891-
named: "\(prefix)/ffn_1/intermediate/output/dense/bias"))
892-
encoderLayers[layerIndex].outputLayerNorm.offset =
893-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/LayerNorm_1/beta"))
894-
encoderLayers[layerIndex].outputLayerNorm.scale =
895-
Tensor(checkpointReader.loadTensor(named: "\(prefix)/LayerNorm_1/gamma"))
782+
encoderLayers[layerIndex].load(albert: reader, prefix: prefix)
896783
}
897784
}
898785
}

0 commit comments

Comments
 (0)