20
20
import Foundation
21
21
import ModelSupport
22
22
import TensorFlow
23
- import Batcher
24
-
25
- public typealias LazyDataSet = LazyMapSequence < [ URL ] , TensorPair < Float , Int32 > >
26
-
27
- public struct Imagenette : ImageClassificationDataset {
28
- public typealias SourceDataSet = LazyDataSet
29
- public let training : Batcher < SourceDataSet >
30
- public let test : Batcher < SourceDataSet >
31
-
32
- public enum ImageSize {
33
- case full
34
- case resized160
35
- case resized320
36
-
37
- var suffix : String {
38
- switch self {
39
- case . full: return " "
40
- case . resized160: return " -160 "
41
- case . resized320: return " -320 "
42
- }
43
- }
44
- }
45
23
46
- public init ( batchSize: Int ) {
47
- self . init ( batchSize: batchSize, inputSize: . resized320, outputSize: 224 )
24
+ /// The three variants of Imagenette, determined by their source image size.
25
+ public enum ImagenetteSize {
26
+ case full
27
+ case resized160
28
+ case resized320
29
+
30
+ var suffix : String {
31
+ switch self {
32
+ case . full: return " "
33
+ case . resized160: return " -160 "
34
+ case . resized320: return " -320 "
48
35
}
36
+ }
37
+ }
49
38
50
- public init (
51
- batchSize: Int ,
52
- inputSize: ImageSize , outputSize: Int ,
53
- localStorageDirectory: URL = DatasetUtilities . defaultDirectory
54
- . appendingPathComponent ( " Imagenette " , isDirectory: true )
55
- ) {
56
- do {
57
- training = Batcher < SourceDataSet > (
58
- on: try loadImagenetteTrainingDirectory (
59
- inputSize: inputSize, outputSize: outputSize,
60
- localStorageDirectory: localStorageDirectory) ,
61
- batchSize: batchSize,
62
- shuffle: true )
63
- test = Batcher < SourceDataSet > (
64
- on: try loadImagenetteValidationDirectory (
65
- inputSize: inputSize, outputSize: outputSize,
66
- localStorageDirectory: localStorageDirectory) ,
67
- batchSize: batchSize)
68
- } catch {
69
- fatalError ( " Could not load Imagenette dataset: \( error) " )
39
+ public struct Imagenette < Entropy: RandomNumberGenerator > {
40
+ /// Type of the collection of non-collated batches.
41
+ public typealias Batches = Slices < Sampling < [ ( file: URL , label: Int32 ) ] , ArraySlice < Int > > >
42
+ /// The type of the training data, represented as a sequence of epochs, which
43
+ /// are collection of batches.
44
+ public typealias Training = LazyMapSequence <
45
+ TrainingEpochs < [ ( file: URL , label: Int32 ) ] , Entropy > ,
46
+ LazyMapSequence < Batches , LabeledImage >
47
+ >
48
+ /// The type of the validation data, represented as a collection of batches.
49
+ public typealias Validation = LazyMapSequence < Slices < [ ( file: URL , label: Int32 ) ] > , LabeledImage >
50
+ /// The training epochs.
51
+ public let training : Training
52
+ /// The validation batches.
53
+ public let validation : Validation
54
+
55
+ /// Creates an instance with `batchSize`.
56
+ ///
57
+ /// - Parameters:
58
+ /// - batchSize: Number of images provided per batch.
59
+ /// - entropy: A source of randomness used to shuffle sample
60
+ /// ordering. It will be stored in `self`, so if it is only pseudorandom
61
+ /// and has value semantics, the sequence of epochs is deterministic and not
62
+ /// dependent on other operations.
63
+ /// - device: The Device on which resulting Tensors from this dataset will be placed, as well
64
+ /// as where the latter stages of any conversion calculations will be performed.
65
+ public init ( batchSize: Int , entropy: Entropy , device: Device ) {
66
+ self . init (
67
+ batchSize: batchSize, entropy: entropy, device: device, inputSize: ImagenetteSize . resized320,
68
+ outputSize: 224 )
69
+ }
70
+
71
+ /// Creates an instance with `batchSize` on `device` using `remoteBinaryArchiveLocation`.
72
+ ///
73
+ /// - Parameters:
74
+ /// - batchSize: Number of images provided per batch.
75
+ /// - entropy: A source of randomness used to shuffle sample ordering. It
76
+ /// will be stored in `self`, so if it is only pseudorandom and has value
77
+ /// semantics, the sequence of epochs is deterministic and not dependent
78
+ /// on other operations.
79
+ /// - device: The Device on which resulting Tensors from this dataset will be placed, as well
80
+ /// as where the latter stages of any conversion calculations will be performed.
81
+ /// - inputSize: Which Imagenette image size variant to use.
82
+ /// - outputSize: The square width and height of the images returned from this dataset.
83
+ /// - localStorageDirectory: Where to place the downloaded and unarchived dataset.
84
+ public init (
85
+ batchSize: Int , entropy: Entropy , device: Device , inputSize: ImagenetteSize ,
86
+ outputSize: Int ,
87
+ localStorageDirectory: URL = DatasetUtilities . defaultDirectory
88
+ . appendingPathComponent ( " Imagenette " , isDirectory: true )
89
+ ) {
90
+ do {
91
+ let trainingSamples = try loadImagenetteTrainingDirectory (
92
+ inputSize: inputSize, localStorageDirectory: localStorageDirectory, base: " imagenette " )
93
+
94
+ let mean = Tensor < Float > ( [ 0.485 , 0.456 , 0.406 ] , on: device)
95
+ let standardDeviation = Tensor < Float > ( [ 0.229 , 0.224 , 0.225 ] , on: device)
96
+
97
+ training = TrainingEpochs ( samples: trainingSamples, batchSize: batchSize, entropy: entropy)
98
+ . lazy. map { ( batches: Batches ) -> LazyMapSequence < Batches , LabeledImage > in
99
+ return batches. lazy. map {
100
+ makeImagenetteBatch (
101
+ samples: $0, outputSize: outputSize, mean: mean, standardDeviation: standardDeviation,
102
+ device: device)
103
+ }
70
104
}
105
+
106
+ let validationSamples = try loadImagenetteValidationDirectory (
107
+ inputSize: inputSize, localStorageDirectory: localStorageDirectory, base: " imagenette " )
108
+
109
+ validation = validationSamples. inBatches ( of: batchSize) . lazy. map {
110
+ makeImagenetteBatch (
111
+ samples: $0, outputSize: outputSize, mean: mean, standardDeviation: standardDeviation,
112
+ device: device)
113
+ }
114
+ } catch {
115
+ fatalError ( " Could not load Imagenette dataset: \( error) " )
71
116
}
117
+ }
118
+ }
119
+
120
+ extension Imagenette : ImageClassificationData where Entropy == SystemRandomNumberGenerator {
121
+ /// Creates an instance with `batchSize`, using the SystemRandomNumberGenerator.
122
+ public init ( batchSize: Int , on device: Device = Device . default) {
123
+ self . init ( batchSize: batchSize, entropy: SystemRandomNumberGenerator ( ) , device: device)
124
+ }
125
+
126
+ /// Creates an instance with `batchSize`, `inputSize`, and `outputSize`, using the
127
+ /// SystemRandomNumberGenerator.
128
+ public init (
129
+ batchSize: Int , inputSize: ImagenetteSize , outputSize: Int , on device: Device = Device . default
130
+ ) {
131
+ self . init (
132
+ batchSize: batchSize, entropy: SystemRandomNumberGenerator ( ) , device: device,
133
+ inputSize: inputSize, outputSize: outputSize)
134
+ }
72
135
}
73
136
74
- func downloadImagenetteIfNotPresent( to directory: URL , size: Imagenette . ImageSize ) {
75
- let downloadPath = directory. appendingPathComponent ( " imagenette \( size. suffix) " ) . path
76
- let directoryExists = FileManager . default. fileExists ( atPath: downloadPath)
77
- let contentsOfDir = try ? FileManager . default. contentsOfDirectory ( atPath: downloadPath)
78
- let directoryEmpty = ( contentsOfDir == nil ) || ( contentsOfDir!. isEmpty)
137
+ func downloadImagenetteIfNotPresent( to directory: URL , size: ImagenetteSize , base : String ) {
138
+ let downloadPath = directory. appendingPathComponent ( " \( base ) \( size. suffix) " ) . path
139
+ let directoryExists = FileManager . default. fileExists ( atPath: downloadPath)
140
+ let contentsOfDir = try ? FileManager . default. contentsOfDirectory ( atPath: downloadPath)
141
+ let directoryEmpty = ( contentsOfDir == nil ) || ( contentsOfDir!. isEmpty)
79
142
80
- guard !directoryExists || directoryEmpty else { return }
143
+ guard !directoryExists || directoryEmpty else { return }
81
144
82
- let location = URL (
83
- string: " https://s3.amazonaws.com/fast-ai-imageclas/imagenette \( size. suffix) .tgz " ) !
84
- let _ = DatasetUtilities . downloadResource (
85
- filename: " imagenette \( size. suffix) " , fileExtension: " tgz " ,
86
- remoteRoot: location. deletingLastPathComponent ( ) , localStorageDirectory: directory)
145
+ let location = URL (
146
+ string: " https://s3.amazonaws.com/fast-ai-imageclas/ \( base ) \( size. suffix) .tgz " ) !
147
+ let _ = DatasetUtilities . downloadResource (
148
+ filename: " \( base ) \( size. suffix) " , fileExtension: " tgz " ,
149
+ remoteRoot: location. deletingLastPathComponent ( ) , localStorageDirectory: directory)
87
150
}
88
151
89
- func exploreImagenetteDirectory( named name: String , in directory: URL , inputSize: Imagenette . ImageSize ) throws -> [ URL ] {
90
- downloadImagenetteIfNotPresent ( to: directory, size: inputSize)
91
- let path = directory. appendingPathComponent ( " imagenette \( inputSize. suffix) / \( name) " )
92
- let dirContents = try FileManager . default. contentsOfDirectory (
93
- at: path, includingPropertiesForKeys: [ . isDirectoryKey] , options: [ . skipsHiddenFiles] )
94
-
95
- var urls : [ URL ] = [ ]
96
- for directoryURL in dirContents {
97
- let subdirContents = try FileManager . default. contentsOfDirectory (
98
- at: directoryURL, includingPropertiesForKeys: [ . isDirectoryKey] ,
99
- options: [ . skipsHiddenFiles] )
100
- urls += subdirContents
101
- }
102
- return urls
152
+ func exploreImagenetteDirectory(
153
+ named name: String , in directory: URL , inputSize: ImagenetteSize , base: String
154
+ ) throws -> [ URL ] {
155
+ downloadImagenetteIfNotPresent ( to: directory, size: inputSize, base: base)
156
+ let path = directory. appendingPathComponent ( " \( base) \( inputSize. suffix) / \( name) " )
157
+ let dirContents = try FileManager . default. contentsOfDirectory (
158
+ at: path, includingPropertiesForKeys: [ . isDirectoryKey] , options: [ . skipsHiddenFiles] )
159
+
160
+ var urls : [ URL ] = [ ]
161
+ for directoryURL in dirContents {
162
+ let subdirContents = try FileManager . default. contentsOfDirectory (
163
+ at: directoryURL, includingPropertiesForKeys: [ . isDirectoryKey] ,
164
+ options: [ . skipsHiddenFiles] )
165
+ urls += subdirContents
166
+ }
167
+ return urls
103
168
}
104
169
105
170
func parentLabel( url: URL ) -> String {
106
- return url. deletingLastPathComponent ( ) . lastPathComponent
171
+ return url. deletingLastPathComponent ( ) . lastPathComponent
107
172
}
108
173
109
174
func createLabelDict( urls: [ URL ] ) -> [ String : Int ] {
110
- let allLabels = urls. map ( parentLabel)
111
- let labels = Array ( Set ( allLabels) ) . sorted ( )
112
- return Dictionary ( uniqueKeysWithValues: labels. enumerated ( ) . map { ( $0. element, $0. offset) } )
175
+ let allLabels = urls. map ( parentLabel)
176
+ let labels = Array ( Set ( allLabels) ) . sorted ( )
177
+ return Dictionary ( uniqueKeysWithValues: labels. enumerated ( ) . map { ( $0. element, $0. offset) } )
113
178
}
114
179
115
180
func loadImagenetteDirectory(
116
- named name: String , in directory: URL , inputSize: Imagenette . ImageSize , outputSize: Int ,
117
- labelDict: [ String : Int ] ? = nil
118
- ) throws -> LazyDataSet {
119
- let urls = try exploreImagenetteDirectory ( named: name, in: directory, inputSize: inputSize)
120
- let unwrappedLabelDict = labelDict ?? createLabelDict ( urls: urls)
121
- return urls. lazy. map { ( url: URL ) -> TensorPair < Float , Int32 > in
122
- TensorPair < Float , Int32 > (
123
- first: Image ( jpeg: url) . resized ( to: ( outputSize, outputSize) ) . tensor / 255.0 ,
124
- second: Tensor < Int32 > ( Int32 ( unwrappedLabelDict [ parentLabel ( url: url) ] !) )
125
- )
126
- }
181
+ named name: String , in directory: URL , inputSize: ImagenetteSize , base: String ,
182
+ labelDict: [ String : Int ] ? = nil
183
+ ) throws -> [ ( file: URL , label: Int32 ) ] {
184
+ let urls = try exploreImagenetteDirectory (
185
+ named: name, in: directory, inputSize: inputSize, base: base)
186
+ let unwrappedLabelDict = labelDict ?? createLabelDict ( urls: urls)
187
+ return urls. lazy. map { ( url: URL ) -> ( file: URL , label: Int32 ) in
188
+ ( file: url, label: Int32 ( unwrappedLabelDict [ parentLabel ( url: url) ] !) )
189
+ }
127
190
}
128
191
129
192
func loadImagenetteTrainingDirectory(
130
- inputSize: Imagenette . ImageSize , outputSize: Int , localStorageDirectory: URL , labelDict: [ String : Int ] ? = nil
193
+ inputSize: ImagenetteSize , localStorageDirectory: URL , base: String ,
194
+ labelDict: [ String : Int ] ? = nil
131
195
) throws
132
- -> LazyDataSet
196
+ -> [ ( file : URL , label : Int32 ) ]
133
197
{
134
- return try loadImagenetteDirectory (
135
- named: " train " , in: localStorageDirectory, inputSize: inputSize, outputSize: outputSize, labelDict: labelDict)
198
+ return try loadImagenetteDirectory (
199
+ named: " train " , in: localStorageDirectory, inputSize: inputSize, base: base,
200
+ labelDict: labelDict)
136
201
}
137
202
138
203
func loadImagenetteValidationDirectory(
139
- inputSize: Imagenette . ImageSize , outputSize: Int , localStorageDirectory: URL , labelDict: [ String : Int ] ? = nil
204
+ inputSize: ImagenetteSize , localStorageDirectory: URL , base: String ,
205
+ labelDict: [ String : Int ] ? = nil
140
206
) throws
141
- -> LazyDataSet
207
+ -> [ ( file : URL , label : Int32 ) ]
142
208
{
143
- return try loadImagenetteDirectory (
144
- named: " val " , in: localStorageDirectory, inputSize: inputSize, outputSize: outputSize, labelDict: labelDict)
145
- }
209
+ return try loadImagenetteDirectory (
210
+ named: " val " , in: localStorageDirectory, inputSize: inputSize, base: base, labelDict: labelDict)
211
+ }
212
+
213
+ func makeImagenetteBatch< BatchSamples: Collection > (
214
+ samples: BatchSamples , outputSize: Int , mean: Tensor < Float > ? , standardDeviation: Tensor < Float > ? ,
215
+ device: Device
216
+ ) -> LabeledImage where BatchSamples. Element == ( file: URL , label: Int32 ) {
217
+ let images = samples. map ( \. file) . map { url -> Tensor < Float > in
218
+ Image ( jpeg: url) . resized ( to: ( outputSize, outputSize) ) . tensor
219
+ }
220
+
221
+ var imageTensor = Tensor ( stacking: images)
222
+ imageTensor = Tensor ( copying: imageTensor, to: device)
223
+ imageTensor /= 255.0
224
+
225
+ if let mean = mean, let standardDeviation = standardDeviation {
226
+ imageTensor = ( imageTensor - mean) / standardDeviation
227
+ }
228
+
229
+ let labels = Tensor < Int32 > ( samples. map ( \. label) , on: device)
230
+ return LabeledImage ( data: imageTensor, label: labels)
231
+ }
0 commit comments