Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 23fa457

Browse files
authored
A native Swift checkpoint reader for TensorFlow v2 checkpoint files (#277)
* Initial addition of native-Swift TensorFlow v2 checkpoint reader. * Adding checkpoint tests, fixing remaining issues with uncompressed checkpoints. * Fixing some comment formatting. * Starting to build out Snappy decompression support. * Added a working implementation of Snappy decompression, MiniGo demo now works with new checkpoint loader. * Added missing Snappy allTests. * Cleaned up some formatting, added some documentation, added a couple of thrown error cases for Snappy decompression. * Adding more tests for reading and Snappy decompression, updating a failure comment.
1 parent 7395332 commit 23fa457

21 files changed

+1766
-44
lines changed

FastStyleTransfer/Demo/ColabDemo.ipynb

+2-2
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@
314314
" let remoteCheckpoint = URL(\n",
315315
" string: \"https://storage.googleapis.com/s4tf-hosted-binaries/checkpoints/FastStyleTransfer/\\(s)\")!\n",
316316
" let modelName = \"FastStyleTransfer_\\(s)\"\n",
317-
" let reader = CheckpointReader(checkpointLocation: remoteCheckpoint, modelName: modelName)\n",
317+
" let reader = try! CheckpointReader(checkpointLocation: remoteCheckpoint, modelName: modelName)\n",
318318
" // Load weights into model.\n",
319319
" style.unsafeImport(from: reader, map: map)\n",
320320
" // Apply model to image.\n",
@@ -357,4 +357,4 @@
357357
},
358358
"nbformat": 4,
359359
"nbformat_minor": 1
360-
}
360+
}

FastStyleTransfer/Demo/Helpers.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func importWeights(_ model: inout TransformerNet, for style: String) throws {
4949
exit(-1)
5050
}
5151

52-
let reader = CheckpointReader(checkpointLocation: remoteCheckpoint, modelName: modelName)
52+
let reader = try CheckpointReader(checkpointLocation: remoteCheckpoint, modelName: modelName)
5353

5454
// Names don't match exactly, and axes in filters need to be reversed.
5555
let map = [

MiniGo/main.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ var model = GoModel(configuration: modelConfig)
3232

3333
let remoteCheckpoint = URL(
3434
string: "https://storage.googleapis.com/s4tf-hosted-binaries/checkpoints/MiniGo/000939-heron")!
35-
let reader = MiniGoCheckpointReader(checkpointLocation: remoteCheckpoint, modelName: "MiniGo")
35+
let reader = try MiniGoCheckpointReader(checkpointLocation: remoteCheckpoint, modelName: "MiniGo")
3636
model.load(from: reader)
3737

3838
let predictor = MCTSModelBasedPredictor(boardSize: boardSize, model: model)

Package.resolved

+9
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@
1818
"revision": "f14ff47f45642aa5703900980b014c2e9394b6e5",
1919
"version": "0.9.0"
2020
}
21+
},
22+
{
23+
"package": "SwiftProtobuf",
24+
"repositoryURL": "https://github.com/apple/swift-protobuf.git",
25+
"state": {
26+
"branch": null,
27+
"revision": "da75a93ac017534e0028e83c0e4fc4610d2acf48",
28+
"version": "1.7.0"
29+
}
2130
}
2231
]
2332
},

Package.swift

+3-1
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@ let package = Package(
2525
.executable(name: "Benchmarks", targets: ["Benchmarks"]),
2626
],
2727
dependencies: [
28+
.package(url: "https://github.com/apple/swift-protobuf.git", from: "1.7.0"),
2829
.package(url: "https://github.com/kylef/Commander.git", from: "0.9.1"),
2930
],
3031
targets: [
3132
.target(name: "ImageClassificationModels", path: "Models/ImageClassification"),
3233
.target(name: "Datasets", dependencies: ["ModelSupport"], path: "Datasets"),
33-
.target(name: "ModelSupport", path: "Support"),
34+
.target(name: "ModelSupport", dependencies: ["SwiftProtobuf"], path: "Support"),
3435
.target(
3536
name: "Autoencoder", dependencies: ["Datasets", "ModelSupport"], path: "Autoencoder"),
3637
.target(name: "Catch", path: "Catch"),
@@ -70,5 +71,6 @@ let package = Package(
7071
name: "Benchmarks",
7172
dependencies: ["Datasets", "ModelSupport", "ImageClassificationModels", "Commander"],
7273
path: "Benchmarks"),
74+
.testTarget(name: "CheckpointTests", dependencies: ["ModelSupport"]),
7375
]
7476
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
// TensorFlow v2 checkpoints use an index file as a key-value store to map saved tensor names to
16+
// the metadata for each tensor. The format of this file is defined by tensorflow::table::Table
17+
//
18+
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/lib/io/table_format.txt
19+
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/lib/io/table.h
20+
//
21+
// and consists of a series of string keys and associated data values. It is based on the LevelDB
22+
// table format: https://github.com/google/leveldb
23+
//
24+
// The very first key is a null string and its value is a protobuf containing header information
25+
// about the entire checkpoint bundle (number of shards, etc.). The remaining keys are
26+
// prefix-compressed strings in ascending alphabetical order representing each named tensor in the
27+
// checkpoint, with their values being protobufs that contain metadata about each tensor.
28+
//
29+
// The binary data representing the tensors are stored in one or more shard files, with lookup
30+
// locations determined by this metadata.
31+
32+
import Foundation
33+
34+
// The block footer size is constant, and is obtained from the following:
35+
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/lib/io/format.h
36+
// `2 * BlockHandle::kMaxEncodedLength + 8` where `kMaxEncodedLength = 10 + 10`
37+
let footerSize = 48
38+
39+
class CheckpointIndexReader {
40+
let binaryData: Data
41+
var index: Int = 0
42+
var currentPrefix = Data()
43+
44+
var atEndOfFile: Bool { return index >= (binaryData.count - footerSize - 1) }
45+
46+
init(file: URL) throws {
47+
let fileData = try Data(contentsOf: file)
48+
if fileData[0] == 0 {
49+
binaryData = fileData
50+
} else {
51+
binaryData = try fileData.decompressFromSnappy()
52+
}
53+
}
54+
55+
func resetHead() {
56+
index = 0
57+
}
58+
}
59+
60+
// The main interface for iterating through all metadata contained in the index file.
61+
extension CheckpointIndexReader {
62+
func readHeader() throws -> Tensorflow_BundleHeaderProto {
63+
// The header has a string key of "", so there's nothing to read for the key.
64+
// If a non-zero initial value is encountered, the file is Snappy-compressed, so we bail out
65+
// until it can be uncompressed.
66+
let initialValue = binaryData.readVarint32(at: &index)
67+
guard initialValue == 0 else {
68+
fatalError("Snappy-compressed data should have been picked up earlier than this.")
69+
}
70+
let _ = binaryData.readVarint32(at: &index)
71+
let valueLength = binaryData.readVarint32(at: &index)
72+
let value = binaryData.readDataBlock(at: &index, size: valueLength)
73+
74+
let tempHeader = try Tensorflow_BundleHeaderProto(serializedData: value)
75+
return tempHeader
76+
}
77+
78+
func readAllKeysAndValues() throws -> [String: Tensorflow_BundleEntryProto] {
79+
var lookupTable: [String: Tensorflow_BundleEntryProto] = [:]
80+
while let (key, value) = try readKeyAndValue() {
81+
lookupTable[key] = value
82+
}
83+
84+
return lookupTable
85+
}
86+
}
87+
88+
// The internal file parsing methods for smaller datatypes that comprise the key-value groupings.
89+
extension CheckpointIndexReader {
90+
func readKey(sharedBytes: Int, unsharedBytes: Int) -> String {
91+
let newBytes = binaryData.readDataBlock(at: &index, size: unsharedBytes)
92+
guard sharedBytes <= currentPrefix.count else {
93+
fatalError(
94+
"Shared bytes of \(sharedBytes) exceeded stored prefix size of \(currentPrefix.count)."
95+
)
96+
}
97+
let keyData = currentPrefix[0..<sharedBytes] + newBytes
98+
currentPrefix = keyData
99+
return String(bytes: keyData, encoding: .utf8)!
100+
}
101+
102+
func readKeyAndValue() throws -> (String, Tensorflow_BundleEntryProto)? {
103+
guard !atEndOfFile else { return nil }
104+
105+
let sharedKeyBytes = binaryData.readVarint32(at: &index)
106+
let unsharedKeyBytes = binaryData.readVarint32(at: &index)
107+
let valueLength = binaryData.readVarint32(at: &index)
108+
let key = readKey(sharedBytes: sharedKeyBytes, unsharedBytes: unsharedKeyBytes)
109+
let value = binaryData.readDataBlock(at: &index, size: valueLength)
110+
111+
// TODO: Need to verify if these three being zero always indicates no more tensors to read.
112+
if (sharedKeyBytes + unsharedKeyBytes + valueLength) == 0 { return nil }
113+
114+
let bundleEntry = try Tensorflow_BundleEntryProto(serializedData: value)
115+
116+
return (key, bundleEntry)
117+
}
118+
}

0 commit comments

Comments
 (0)