Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 47a4eaa

Browse files
authored
Migrating OxfordIIITPets dataset to Epochs (#604)
* Initial conversion of the OxfordIIITPets dataset. * Formatting. * Adding convenience initializers. * Documentation fix.
1 parent c7d0968 commit 47a4eaa

File tree

3 files changed

+145
-66
lines changed

3 files changed

+145
-66
lines changed

Datasets/ImageSegmentationDataset.swift

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,32 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
import Batcher
15+
import ModelSupport
1616
import TensorFlow
1717

18-
public protocol ImageSegmentationDataset {
19-
associatedtype SourceDataSet: Collection
20-
where SourceDataSet.Element == TensorPair<Float, Int32>, SourceDataSet.Index == Int
21-
init(batchSize: Int)
22-
var training: Batcher<SourceDataSet> { get }
23-
var test: Batcher<SourceDataSet> { get }
18+
/// An image with a label.
19+
public typealias SegmentedImage = LabeledData<Tensor<Float>, Tensor<Int32>>
20+
21+
/// Types whose elements represent an image segmentation dataset (with both
22+
/// training and validation data).
23+
public protocol ImageSegmentationData {
24+
/// The type of the training data, represented as a sequence of epochs, which
25+
/// are collection of batches.
26+
associatedtype Training: Sequence
27+
where Training.Element: Collection, Training.Element.Element == SegmentedImage
28+
/// The type of the validation data, represented as a collection of batches.
29+
associatedtype Validation: Collection where Validation.Element == SegmentedImage
30+
/// Creates an instance from a given `batchSize`.
31+
init(batchSize: Int, on device: Device)
32+
/// The `training` epochs.
33+
var training: Training { get }
34+
/// The `validation` batches.
35+
var validation: Validation { get }
36+
37+
// The following is probably going to be necessary since we can't extract that
38+
// information from `Epochs` or `Batches`.
39+
/// The number of samples in the `training` set.
40+
//var trainingSampleCount: Int {get}
41+
/// The number of samples in the `validation` set.
42+
//var validationSampleCount: Int {get}
2443
}

Datasets/OxfordIIITPets/OxfordIIITPets.swift

Lines changed: 109 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -17,48 +17,98 @@
1717
// Omkar M Parkhi, Andrea Vedaldi, Andrew Zisserman and C. V. Jawahar
1818
// https://www.robots.ox.ac.uk/~vgg/data/pets/
1919

20-
import Batcher
2120
import Foundation
2221
import ModelSupport
2322
import TensorFlow
2423

25-
public typealias LazyDataSet = LazyMapSequence<[URL], TensorPair<Float, Int32>>
26-
27-
public struct OxfordIIITPets: ImageSegmentationDataset {
28-
public typealias SourceDataSet = LazyDataSet
29-
public let training: Batcher<SourceDataSet>
30-
public let test: Batcher<SourceDataSet>
31-
32-
public init(batchSize: Int) {
33-
self.init(batchSize: batchSize, imageSize: 224)
24+
public struct OxfordIIITPets<Entropy: RandomNumberGenerator> {
25+
/// Type of the collection of non-collated batches.
26+
public typealias Batches = Slices<Sampling<[(file: URL, annotation: URL)], ArraySlice<Int>>>
27+
/// The type of the training data, represented as a sequence of epochs, which
28+
/// are collection of batches.
29+
public typealias Training = LazyMapSequence<
30+
TrainingEpochs<[(file: URL, annotation: URL)], Entropy>,
31+
LazyMapSequence<Batches, SegmentedImage>
32+
>
33+
/// The type of the validation data, represented as a collection of batches.
34+
public typealias Validation = LazyMapSequence<
35+
Slices<[(file: URL, annotation: URL)]>, LabeledImage
36+
>
37+
/// The training epochs.
38+
public let training: Training
39+
/// The validation batches.
40+
public let validation: Validation
41+
42+
/// Creates an instance with `batchSize`.
43+
///
44+
/// - Parameters:
45+
/// - batchSize: Number of images provided per batch.
46+
/// - entropy: A source of randomness used to shuffle sample
47+
/// ordering. It will be stored in `self`, so if it is only pseudorandom
48+
/// and has value semantics, the sequence of epochs is deterministic and not
49+
/// dependent on other operations.
50+
/// - device: The Device on which resulting Tensors from this dataset will be placed, as well
51+
/// as where the latter stages of any conversion calculations will be performed.
52+
public init(batchSize: Int, entropy: Entropy, device: Device) {
53+
self.init(
54+
batchSize: batchSize, entropy: entropy, device: device, imageSize: 224)
3455
}
3556

57+
/// Creates an instance with `batchSize` on `device` using `remoteBinaryArchiveLocation`.
58+
///
59+
/// - Parameters:
60+
/// - batchSize: Number of images provided per batch.
61+
/// - entropy: A source of randomness used to shuffle sample ordering. It
62+
/// will be stored in `self`, so if it is only pseudorandom and has value
63+
/// semantics, the sequence of epochs is deterministic and not dependent
64+
/// on other operations.
65+
/// - device: The Device on which resulting Tensors from this dataset will be placed, as well
66+
/// as where the latter stages of any conversion calculations will be performed.
67+
/// - imageSize: The square width and height of the images returned from this dataset.
68+
/// - localStorageDirectory: Where to place the downloaded and unarchived dataset.
3669
public init(
37-
batchSize: Int,
70+
batchSize: Int, entropy: Entropy, device: Device, imageSize: Int,
3871
localStorageDirectory: URL = DatasetUtilities.defaultDirectory
39-
.appendingPathComponent("OxfordIIITPets", isDirectory: true),
40-
imageSize: Int
72+
.appendingPathComponent("OxfordIIITPets", isDirectory: true)
4173
) {
4274
do {
43-
training = Batcher<SourceDataSet>(
44-
on: try loadOxfordIITPetsTraining(
45-
imageSize: imageSize,
46-
localStorageDirectory: localStorageDirectory
47-
),
48-
batchSize: batchSize,
49-
shuffle: true)
50-
test = Batcher<SourceDataSet>(
51-
on: try loadOxfordIIITPetsValidation(
52-
imageSize: imageSize,
53-
localStorageDirectory: localStorageDirectory
54-
),
55-
batchSize: batchSize)
75+
let trainingSamples = try loadOxfordIITPetsTraining(
76+
localStorageDirectory: localStorageDirectory)
77+
78+
training = TrainingEpochs(samples: trainingSamples, batchSize: batchSize, entropy: entropy)
79+
.lazy.map { (batches: Batches) -> LazyMapSequence<Batches, LabeledImage> in
80+
return batches.lazy.map {
81+
makeBatch(samples: $0, imageSize: imageSize, device: device)
82+
}
83+
}
84+
85+
let validationSamples = try loadOxfordIITPetsTraining(
86+
localStorageDirectory: localStorageDirectory)
87+
88+
validation = validationSamples.inBatches(of: batchSize).lazy.map {
89+
makeBatch(samples: $0, imageSize: imageSize, device: device)
90+
}
5691
} catch {
57-
fatalError("Could not load Oxford IIIT Pets dataset: \(error)")
92+
fatalError("Could not load the Oxford IIIT Pets dataset: \(error)")
5893
}
5994
}
6095
}
6196

97+
extension OxfordIIITPets: ImageSegmentationData where Entropy == SystemRandomNumberGenerator {
98+
/// Creates an instance with `batchSize`, using the SystemRandomNumberGenerator.
99+
public init(batchSize: Int, on device: Device = Device.default) {
100+
self.init(batchSize: batchSize, entropy: SystemRandomNumberGenerator(), device: device)
101+
}
102+
103+
/// Creates an instance with `batchSize`, `inputSize`, and `outputSize`, using the
104+
/// SystemRandomNumberGenerator.
105+
public init(batchSize: Int, imageSize: Int, on device: Device = Device.default) {
106+
self.init(
107+
batchSize: batchSize, entropy: SystemRandomNumberGenerator(), device: device,
108+
imageSize: imageSize)
109+
}
110+
}
111+
62112
func downloadOxfordIIITPetsIfNotPresent(to directory: URL) {
63113
let downloadPath = directory.appendingPathComponent("images", isDirectory: true).path
64114
let directoryExists = FileManager.default.fileExists(atPath: downloadPath)
@@ -80,21 +130,13 @@ func downloadOxfordIIITPetsIfNotPresent(to directory: URL) {
80130
)
81131
}
82132

83-
func loadOxfordIIITPets(filename: String, in directory: URL, imageSize: Int) throws -> LazyDataSet {
133+
func loadOxfordIIITPets(filename: String, in directory: URL) throws -> [(
134+
file: URL, annotation: URL
135+
)] {
84136
downloadOxfordIIITPetsIfNotPresent(to: directory)
85137
let imageURLs = getImageURLs(filename: filename, directory: directory)
86-
return imageURLs.lazy.map { (imageURL: URL) -> TensorPair<Float, Int32> in
87-
TensorPair<Float, Int32>(
88-
first:
89-
Image(jpeg: imageURL).resized(to: (imageSize, imageSize)).tensor[0..., 0..., 0..<3]
90-
/ 255.0,
91-
second: Tensor<Int32>(
92-
Image(jpeg: makeAnnotationURL(imageURL: imageURL, directory: directory)).resized(
93-
to: (imageSize, imageSize)
94-
).tensor[0..., 0..., 0...0] - 1
95-
)
96-
)
97-
138+
return imageURLs.lazy.map { (imageURL: URL) -> (file: URL, annotation: URL) in
139+
(file: imageURL, annotation: makeAnnotationURL(imageURL: imageURL, directory: directory))
98140
}
99141
}
100142

@@ -114,20 +156,36 @@ func getImageURLs(filename: String, directory: URL) -> [URL] {
114156
}
115157
}
116158

117-
func loadOxfordIITPetsTraining(
118-
imageSize: Int, localStorageDirectory: URL
119-
) throws
120-
-> LazyDataSet
159+
func loadOxfordIITPetsTraining(localStorageDirectory: URL) throws -> [(file: URL, annotation: URL)]
121160
{
122161
return try loadOxfordIIITPets(
123-
filename: "trainval.txt", in: localStorageDirectory, imageSize: imageSize)
162+
filename: "trainval.txt", in: localStorageDirectory)
124163
}
125164

126-
func loadOxfordIIITPetsValidation(
127-
imageSize: Int, localStorageDirectory: URL
128-
) throws
129-
-> LazyDataSet
130-
{
165+
func loadOxfordIIITPetsValidation(localStorageDirectory: URL) throws -> [(
166+
file: URL, annotation: URL
167+
)] {
131168
return try loadOxfordIIITPets(
132-
filename: "test.txt", in: localStorageDirectory, imageSize: imageSize)
169+
filename: "test.txt", in: localStorageDirectory)
170+
}
171+
172+
fileprivate func makeBatch<BatchSamples: Collection>(
173+
samples: BatchSamples, imageSize: Int, device: Device
174+
) -> SegmentedImage where BatchSamples.Element == (file: URL, annotation: URL) {
175+
let images = samples.map(\.file).map { url -> Tensor<Float> in
176+
Image(jpeg: url).resized(to: (imageSize, imageSize)).tensor[0..., 0..., 0..<3]
177+
}
178+
179+
var imageTensor = Tensor(stacking: images)
180+
imageTensor = Tensor(copying: imageTensor, to: device)
181+
imageTensor /= 255.0
182+
183+
let annotations = samples.map(\.annotation).map { url -> Tensor<Int32> in
184+
Tensor<Int32>(
185+
Image(jpeg: url).resized(to: (imageSize, imageSize)).tensor[0..., 0..., 0...0] - 1)
186+
}
187+
var annotationTensor = Tensor(stacking: annotations)
188+
annotationTensor = Tensor(copying: annotationTensor, to: device)
189+
190+
return SegmentedImage(data: imageTensor, label: annotationTensor)
133191
}

Tests/DatasetsTests/OxfordIIITPets/OxfordIIITPetsTests.swift

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,22 @@ final class OxfordIIITPetsTests: XCTestCase {
77
let dataset = OxfordIIITPets(batchSize: 64)
88

99
var batchCount = 0
10-
for batch in dataset.training.sequenced() {
11-
batchCount += 1
12-
/// 3680 samples make 57 batches of size 64 and one batch of size 32
13-
let expectedBS = batchCount <= 57 ? 64 : 32
10+
for epochBatches in dataset.training.prefix(1) {
11+
for batch in epochBatches {
12+
batchCount += 1
13+
/// 3680 samples make 57 batches of size 64 and one batch of size 32
14+
let expectedBS = batchCount <= 57 ? 64 : 32
1415

15-
XCTAssertEqual(batch.first.shape, [expectedBS, 224, 224, 3])
16-
XCTAssertEqual(batch.second.shape, [expectedBS, 224, 224, 1])
16+
XCTAssertEqual(batch.data.shape, [expectedBS, 224, 224, 3])
17+
XCTAssertEqual(batch.label.shape, [expectedBS, 224, 224, 1])
18+
}
1719
}
18-
XCTAssertEqual(batchCount, dataset.training.count)
20+
XCTAssertEqual(batchCount, 57)
1921
}
2022
}
2123

2224
extension OxfordIIITPetsTests {
2325
static var allTests = [
24-
("testCreateOxfordIIITPets", testCreateOxfordIIITPets),
26+
("testCreateOxfordIIITPets", testCreateOxfordIIITPets)
2527
]
2628
}

0 commit comments

Comments
 (0)