Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit c7d0968

Browse files
authored
Converting Imagenette and Imagewoof datasets to Epochs (#603)
* Initial update of Imagenette and Imagewoof for Epochs. * Formatting. * Update testing batch counts. * Updating comments, removing old Batcher notebook.
1 parent 2ae53ba commit c7d0968

File tree

9 files changed

+315
-1183
lines changed

9 files changed

+315
-1183
lines changed

Datasets/ImageClassificationDataset.swift

-9
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,8 @@
1313
// limitations under the License.
1414

1515
import TensorFlow
16-
import Batcher
1716
import ModelSupport
1817

19-
public protocol ImageClassificationDataset {
20-
associatedtype SourceDataSet: Collection
21-
where SourceDataSet.Element == TensorPair<Float, Int32>, SourceDataSet.Index == Int
22-
init(batchSize: Int)
23-
var training: Batcher<SourceDataSet> { get }
24-
var test: Batcher<SourceDataSet> { get }
25-
}
26-
2718
/// An image with a label.
2819
public typealias LabeledImage = LabeledData<Tensor<Float>, Tensor<Int32>>
2920

Datasets/Imagenette/Imagenette.swift

+179-93
Original file line numberDiff line numberDiff line change
@@ -20,126 +20,212 @@
2020
import Foundation
2121
import ModelSupport
2222
import TensorFlow
23-
import Batcher
24-
25-
public typealias LazyDataSet = LazyMapSequence<[URL], TensorPair<Float, Int32>>
26-
27-
public struct Imagenette: ImageClassificationDataset {
28-
public typealias SourceDataSet = LazyDataSet
29-
public let training: Batcher<SourceDataSet>
30-
public let test: Batcher<SourceDataSet>
31-
32-
public enum ImageSize {
33-
case full
34-
case resized160
35-
case resized320
36-
37-
var suffix: String {
38-
switch self {
39-
case .full: return ""
40-
case .resized160: return "-160"
41-
case .resized320: return "-320"
42-
}
43-
}
44-
}
4523

46-
public init(batchSize: Int) {
47-
self.init(batchSize: batchSize, inputSize: .resized320, outputSize: 224)
24+
/// The three variants of Imagenette, determined by their source image size.
25+
public enum ImagenetteSize {
26+
case full
27+
case resized160
28+
case resized320
29+
30+
var suffix: String {
31+
switch self {
32+
case .full: return ""
33+
case .resized160: return "-160"
34+
case .resized320: return "-320"
4835
}
36+
}
37+
}
4938

50-
public init(
51-
batchSize: Int,
52-
inputSize: ImageSize, outputSize: Int,
53-
localStorageDirectory: URL = DatasetUtilities.defaultDirectory
54-
.appendingPathComponent("Imagenette", isDirectory: true)
55-
) {
56-
do {
57-
training = Batcher<SourceDataSet>(
58-
on: try loadImagenetteTrainingDirectory(
59-
inputSize: inputSize, outputSize: outputSize,
60-
localStorageDirectory: localStorageDirectory),
61-
batchSize: batchSize,
62-
shuffle: true)
63-
test = Batcher<SourceDataSet>(
64-
on: try loadImagenetteValidationDirectory(
65-
inputSize: inputSize, outputSize: outputSize,
66-
localStorageDirectory: localStorageDirectory),
67-
batchSize: batchSize)
68-
} catch {
69-
fatalError("Could not load Imagenette dataset: \(error)")
39+
public struct Imagenette<Entropy: RandomNumberGenerator> {
40+
/// Type of the collection of non-collated batches.
41+
public typealias Batches = Slices<Sampling<[(file: URL, label: Int32)], ArraySlice<Int>>>
42+
/// The type of the training data, represented as a sequence of epochs, which
43+
/// are collection of batches.
44+
public typealias Training = LazyMapSequence<
45+
TrainingEpochs<[(file: URL, label: Int32)], Entropy>,
46+
LazyMapSequence<Batches, LabeledImage>
47+
>
48+
/// The type of the validation data, represented as a collection of batches.
49+
public typealias Validation = LazyMapSequence<Slices<[(file: URL, label: Int32)]>, LabeledImage>
50+
/// The training epochs.
51+
public let training: Training
52+
/// The validation batches.
53+
public let validation: Validation
54+
55+
/// Creates an instance with `batchSize`.
56+
///
57+
/// - Parameters:
58+
/// - batchSize: Number of images provided per batch.
59+
/// - entropy: A source of randomness used to shuffle sample
60+
/// ordering. It will be stored in `self`, so if it is only pseudorandom
61+
/// and has value semantics, the sequence of epochs is deterministic and not
62+
/// dependent on other operations.
63+
/// - device: The Device on which resulting Tensors from this dataset will be placed, as well
64+
/// as where the latter stages of any conversion calculations will be performed.
65+
public init(batchSize: Int, entropy: Entropy, device: Device) {
66+
self.init(
67+
batchSize: batchSize, entropy: entropy, device: device, inputSize: ImagenetteSize.resized320,
68+
outputSize: 224)
69+
}
70+
71+
/// Creates an instance with `batchSize` on `device` using `remoteBinaryArchiveLocation`.
72+
///
73+
/// - Parameters:
74+
/// - batchSize: Number of images provided per batch.
75+
/// - entropy: A source of randomness used to shuffle sample ordering. It
76+
/// will be stored in `self`, so if it is only pseudorandom and has value
77+
/// semantics, the sequence of epochs is deterministic and not dependent
78+
/// on other operations.
79+
/// - device: The Device on which resulting Tensors from this dataset will be placed, as well
80+
/// as where the latter stages of any conversion calculations will be performed.
81+
/// - inputSize: Which Imagenette image size variant to use.
82+
/// - outputSize: The square width and height of the images returned from this dataset.
83+
/// - localStorageDirectory: Where to place the downloaded and unarchived dataset.
84+
public init(
85+
batchSize: Int, entropy: Entropy, device: Device, inputSize: ImagenetteSize,
86+
outputSize: Int,
87+
localStorageDirectory: URL = DatasetUtilities.defaultDirectory
88+
.appendingPathComponent("Imagenette", isDirectory: true)
89+
) {
90+
do {
91+
let trainingSamples = try loadImagenetteTrainingDirectory(
92+
inputSize: inputSize, localStorageDirectory: localStorageDirectory, base: "imagenette")
93+
94+
let mean = Tensor<Float>([0.485, 0.456, 0.406], on: device)
95+
let standardDeviation = Tensor<Float>([0.229, 0.224, 0.225], on: device)
96+
97+
training = TrainingEpochs(samples: trainingSamples, batchSize: batchSize, entropy: entropy)
98+
.lazy.map { (batches: Batches) -> LazyMapSequence<Batches, LabeledImage> in
99+
return batches.lazy.map {
100+
makeImagenetteBatch(
101+
samples: $0, outputSize: outputSize, mean: mean, standardDeviation: standardDeviation,
102+
device: device)
103+
}
70104
}
105+
106+
let validationSamples = try loadImagenetteValidationDirectory(
107+
inputSize: inputSize, localStorageDirectory: localStorageDirectory, base: "imagenette")
108+
109+
validation = validationSamples.inBatches(of: batchSize).lazy.map {
110+
makeImagenetteBatch(
111+
samples: $0, outputSize: outputSize, mean: mean, standardDeviation: standardDeviation,
112+
device: device)
113+
}
114+
} catch {
115+
fatalError("Could not load Imagenette dataset: \(error)")
71116
}
117+
}
118+
}
119+
120+
extension Imagenette: ImageClassificationData where Entropy == SystemRandomNumberGenerator {
121+
/// Creates an instance with `batchSize`, using the SystemRandomNumberGenerator.
122+
public init(batchSize: Int, on device: Device = Device.default) {
123+
self.init(batchSize: batchSize, entropy: SystemRandomNumberGenerator(), device: device)
124+
}
125+
126+
/// Creates an instance with `batchSize`, `inputSize`, and `outputSize`, using the
127+
/// SystemRandomNumberGenerator.
128+
public init(
129+
batchSize: Int, inputSize: ImagenetteSize, outputSize: Int, on device: Device = Device.default
130+
) {
131+
self.init(
132+
batchSize: batchSize, entropy: SystemRandomNumberGenerator(), device: device,
133+
inputSize: inputSize, outputSize: outputSize)
134+
}
72135
}
73136

74-
func downloadImagenetteIfNotPresent(to directory: URL, size: Imagenette.ImageSize) {
75-
let downloadPath = directory.appendingPathComponent("imagenette\(size.suffix)").path
76-
let directoryExists = FileManager.default.fileExists(atPath: downloadPath)
77-
let contentsOfDir = try? FileManager.default.contentsOfDirectory(atPath: downloadPath)
78-
let directoryEmpty = (contentsOfDir == nil) || (contentsOfDir!.isEmpty)
137+
func downloadImagenetteIfNotPresent(to directory: URL, size: ImagenetteSize, base: String) {
138+
let downloadPath = directory.appendingPathComponent("\(base)\(size.suffix)").path
139+
let directoryExists = FileManager.default.fileExists(atPath: downloadPath)
140+
let contentsOfDir = try? FileManager.default.contentsOfDirectory(atPath: downloadPath)
141+
let directoryEmpty = (contentsOfDir == nil) || (contentsOfDir!.isEmpty)
79142

80-
guard !directoryExists || directoryEmpty else { return }
143+
guard !directoryExists || directoryEmpty else { return }
81144

82-
let location = URL(
83-
string: "https://s3.amazonaws.com/fast-ai-imageclas/imagenette\(size.suffix).tgz")!
84-
let _ = DatasetUtilities.downloadResource(
85-
filename: "imagenette\(size.suffix)", fileExtension: "tgz",
86-
remoteRoot: location.deletingLastPathComponent(), localStorageDirectory: directory)
145+
let location = URL(
146+
string: "https://s3.amazonaws.com/fast-ai-imageclas/\(base)\(size.suffix).tgz")!
147+
let _ = DatasetUtilities.downloadResource(
148+
filename: "\(base)\(size.suffix)", fileExtension: "tgz",
149+
remoteRoot: location.deletingLastPathComponent(), localStorageDirectory: directory)
87150
}
88151

89-
func exploreImagenetteDirectory(named name: String, in directory: URL, inputSize: Imagenette.ImageSize) throws -> [URL] {
90-
downloadImagenetteIfNotPresent(to: directory, size: inputSize)
91-
let path = directory.appendingPathComponent("imagenette\(inputSize.suffix)/\(name)")
92-
let dirContents = try FileManager.default.contentsOfDirectory(
93-
at: path, includingPropertiesForKeys: [.isDirectoryKey], options: [.skipsHiddenFiles])
94-
95-
var urls: [URL] = []
96-
for directoryURL in dirContents {
97-
let subdirContents = try FileManager.default.contentsOfDirectory(
98-
at: directoryURL, includingPropertiesForKeys: [.isDirectoryKey],
99-
options: [.skipsHiddenFiles])
100-
urls += subdirContents
101-
}
102-
return urls
152+
func exploreImagenetteDirectory(
153+
named name: String, in directory: URL, inputSize: ImagenetteSize, base: String
154+
) throws -> [URL] {
155+
downloadImagenetteIfNotPresent(to: directory, size: inputSize, base: base)
156+
let path = directory.appendingPathComponent("\(base)\(inputSize.suffix)/\(name)")
157+
let dirContents = try FileManager.default.contentsOfDirectory(
158+
at: path, includingPropertiesForKeys: [.isDirectoryKey], options: [.skipsHiddenFiles])
159+
160+
var urls: [URL] = []
161+
for directoryURL in dirContents {
162+
let subdirContents = try FileManager.default.contentsOfDirectory(
163+
at: directoryURL, includingPropertiesForKeys: [.isDirectoryKey],
164+
options: [.skipsHiddenFiles])
165+
urls += subdirContents
166+
}
167+
return urls
103168
}
104169

105170
func parentLabel(url: URL) -> String {
106-
return url.deletingLastPathComponent().lastPathComponent
171+
return url.deletingLastPathComponent().lastPathComponent
107172
}
108173

109174
func createLabelDict(urls: [URL]) -> [String: Int] {
110-
let allLabels = urls.map(parentLabel)
111-
let labels = Array(Set(allLabels)).sorted()
112-
return Dictionary(uniqueKeysWithValues: labels.enumerated().map{ ($0.element, $0.offset) })
175+
let allLabels = urls.map(parentLabel)
176+
let labels = Array(Set(allLabels)).sorted()
177+
return Dictionary(uniqueKeysWithValues: labels.enumerated().map { ($0.element, $0.offset) })
113178
}
114179

115180
func loadImagenetteDirectory(
116-
named name: String, in directory: URL, inputSize: Imagenette.ImageSize, outputSize: Int,
117-
labelDict: [String:Int]? = nil
118-
) throws -> LazyDataSet {
119-
let urls = try exploreImagenetteDirectory(named: name, in: directory, inputSize: inputSize)
120-
let unwrappedLabelDict = labelDict ?? createLabelDict(urls: urls)
121-
return urls.lazy.map { (url: URL) -> TensorPair<Float, Int32> in
122-
TensorPair<Float, Int32>(
123-
first: Image(jpeg: url).resized(to: (outputSize, outputSize)).tensor / 255.0,
124-
second: Tensor<Int32>(Int32(unwrappedLabelDict[parentLabel(url: url)]!))
125-
)
126-
}
181+
named name: String, in directory: URL, inputSize: ImagenetteSize, base: String,
182+
labelDict: [String: Int]? = nil
183+
) throws -> [(file: URL, label: Int32)] {
184+
let urls = try exploreImagenetteDirectory(
185+
named: name, in: directory, inputSize: inputSize, base: base)
186+
let unwrappedLabelDict = labelDict ?? createLabelDict(urls: urls)
187+
return urls.lazy.map { (url: URL) -> (file: URL, label: Int32) in
188+
(file: url, label: Int32(unwrappedLabelDict[parentLabel(url: url)]!))
189+
}
127190
}
128191

129192
func loadImagenetteTrainingDirectory(
130-
inputSize: Imagenette.ImageSize, outputSize: Int, localStorageDirectory: URL, labelDict: [String:Int]? = nil
193+
inputSize: ImagenetteSize, localStorageDirectory: URL, base: String,
194+
labelDict: [String: Int]? = nil
131195
) throws
132-
-> LazyDataSet
196+
-> [(file: URL, label: Int32)]
133197
{
134-
return try loadImagenetteDirectory(
135-
named: "train", in: localStorageDirectory, inputSize: inputSize, outputSize: outputSize, labelDict: labelDict)
198+
return try loadImagenetteDirectory(
199+
named: "train", in: localStorageDirectory, inputSize: inputSize, base: base,
200+
labelDict: labelDict)
136201
}
137202

138203
func loadImagenetteValidationDirectory(
139-
inputSize: Imagenette.ImageSize, outputSize: Int, localStorageDirectory: URL, labelDict: [String:Int]? = nil
204+
inputSize: ImagenetteSize, localStorageDirectory: URL, base: String,
205+
labelDict: [String: Int]? = nil
140206
) throws
141-
-> LazyDataSet
207+
-> [(file: URL, label: Int32)]
142208
{
143-
return try loadImagenetteDirectory(
144-
named: "val", in: localStorageDirectory, inputSize: inputSize, outputSize: outputSize, labelDict: labelDict)
145-
}
209+
return try loadImagenetteDirectory(
210+
named: "val", in: localStorageDirectory, inputSize: inputSize, base: base, labelDict: labelDict)
211+
}
212+
213+
func makeImagenetteBatch<BatchSamples: Collection>(
214+
samples: BatchSamples, outputSize: Int, mean: Tensor<Float>?, standardDeviation: Tensor<Float>?,
215+
device: Device
216+
) -> LabeledImage where BatchSamples.Element == (file: URL, label: Int32) {
217+
let images = samples.map(\.file).map { url -> Tensor<Float> in
218+
Image(jpeg: url).resized(to: (outputSize, outputSize)).tensor
219+
}
220+
221+
var imageTensor = Tensor(stacking: images)
222+
imageTensor = Tensor(copying: imageTensor, to: device)
223+
imageTensor /= 255.0
224+
225+
if let mean = mean, let standardDeviation = standardDeviation {
226+
imageTensor = (imageTensor - mean) / standardDeviation
227+
}
228+
229+
let labels = Tensor<Int32>(samples.map(\.label), on: device)
230+
return LabeledImage(data: imageTensor, label: labels)
231+
}

0 commit comments

Comments
 (0)