Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 2ae53ba

Browse files
authored
Switch off Batcher from Pix2Pix (#602)
1 parent 9c73365 commit 2ae53ba

File tree

2 files changed

+150
-81
lines changed

2 files changed

+150
-81
lines changed

pix2pix/Dataset.swift

+137-33
Original file line numberDiff line numberDiff line change
@@ -14,44 +14,148 @@
1414

1515
import Foundation
1616
import ModelSupport
17+
import Datasets
1718
import TensorFlow
18-
import Batcher
1919

20-
public class PairedImages {
21-
public struct ImagePair: _Collatable {
22-
public init(oldCollating: [PairedImages.ImagePair]) {
23-
self.source = .init(stacking: oldCollating.map(\.source))
24-
self.target = .init(stacking: oldCollating.map(\.target))
20+
public enum Pix2PixDatasetVariant: String {
21+
case facades
22+
23+
public var url: URL {
24+
switch self {
25+
case .facades:
26+
return URL(string:
27+
"https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/facades.zip")!
2528
}
29+
}
30+
}
31+
32+
public struct Pix2PixDataset<Entropy: RandomNumberGenerator> {
33+
public typealias Samples = [(source: Tensor<Float>, target: Tensor<Float>)]
34+
public typealias Batches = Slices<Sampling<Samples, ArraySlice<Int>>>
35+
public typealias PairedImageBatch = (source: Tensor<Float>, target: Tensor<Float>)
36+
public typealias Training = LazyMapSequence<
37+
TrainingEpochs<Samples, Entropy>,
38+
LazyMapSequence<Batches, PairedImageBatch>
39+
>
40+
public typealias Testing = LazyMapSequence<
41+
Slices<Samples>,
42+
PairedImageBatch
43+
>
44+
45+
public let trainSamples: Samples
46+
public let testSamples: Samples
47+
public let training: Training
48+
public let testing: Testing
49+
50+
public init(
51+
from rootDirPath: String? = nil,
52+
variant: Pix2PixDatasetVariant? = nil,
53+
trainBatchSize: Int = 1,
54+
testBatchSize: Int = 1,
55+
entropy: Entropy) throws {
2656

27-
public init(source: Tensor<Float>, target: Tensor<Float>) {
28-
self.source = source
29-
self.target = target
30-
}
57+
let rootDirPath = rootDirPath ?? Pix2PixDataset.downloadIfNotPresent(
58+
variant: variant ?? .facades,
59+
to: DatasetUtilities.defaultDirectory.appendingPathComponent("pix2pix", isDirectory: true))
60+
let rootDirURL = URL(fileURLWithPath: rootDirPath, isDirectory: true)
3161

32-
var source: Tensor<Float>
33-
var target: Tensor<Float>
34-
}
35-
var batcher: Batcher<[ImagePair]>
36-
37-
public init(folderAURL: URL, folderBURL: URL) throws {
38-
let folderAContents = try FileManager.default
39-
.contentsOfDirectory(at: folderAURL,
40-
includingPropertiesForKeys: [.isDirectoryKey],
41-
options: [.skipsHiddenFiles])
42-
.filter { $0.pathExtension == "jpg" }
43-
44-
let imageTensors = folderAContents.map { (url: URL) -> ImagePair in
45-
let tensorA = Image(jpeg: url).tensor / 127.5 - 1.0
46-
47-
let tensorBImageURL = folderBURL.appendingPathComponent(url.lastPathComponent.replacingOccurrences(of: "_A.jpg", with: "_B.jpg"))
48-
let tensorB = Image(jpeg: tensorBImageURL).tensor / 127.5 - 1.0
49-
50-
return ImagePair(source: tensorA, target: tensorB)
51-
}
62+
trainSamples = Array(zip(
63+
try Pix2PixDataset.loadSortedSamples(
64+
from: rootDirURL.appendingPathComponent("trainA"),
65+
fileIndexRetriever: "_"
66+
),
67+
try Pix2PixDataset.loadSortedSamples(
68+
from: rootDirURL.appendingPathComponent("trainB"),
69+
fileIndexRetriever: "_"
70+
)
71+
))
5272

53-
self.batcher = Batcher(on: imageTensors,
54-
batchSize: 1,
55-
shuffle: true)
73+
testSamples = Array(zip(
74+
try Pix2PixDataset.loadSortedSamples(
75+
from: rootDirURL.appendingPathComponent("testA"),
76+
fileIndexRetriever: "."
77+
),
78+
try Pix2PixDataset.loadSortedSamples(
79+
from: rootDirURL.appendingPathComponent("testB"),
80+
fileIndexRetriever: "."
81+
)
82+
))
83+
84+
training = TrainingEpochs(
85+
samples: trainSamples,
86+
batchSize: trainBatchSize,
87+
entropy: entropy
88+
).lazy.map { (batches: Batches) -> LazyMapSequence<Batches, PairedImageBatch> in
89+
batches.lazy.map {
90+
(
91+
source: Tensor<Float>($0.map(\.source)),
92+
target: Tensor<Float>($0.map(\.target))
93+
)
94+
}
95+
}
96+
97+
testing = testSamples.inBatches(of: testBatchSize)
98+
.lazy.map {
99+
(
100+
source: Tensor<Float>($0.map(\.source)),
101+
target: Tensor<Float>($0.map(\.target))
102+
)
103+
}
104+
}
105+
106+
private static func downloadIfNotPresent(
107+
variant: Pix2PixDatasetVariant,
108+
to directory: URL) -> String {
109+
let rootDirPath = directory.appendingPathComponent(variant.rawValue).path
110+
111+
let directoryExists = FileManager.default.fileExists(atPath: rootDirPath)
112+
let contentsOfDir = try? FileManager.default.contentsOfDirectory(atPath: rootDirPath)
113+
let directoryEmpty = (contentsOfDir == nil) || (contentsOfDir!.isEmpty)
114+
guard !directoryExists || directoryEmpty else { return rootDirPath }
115+
116+
let _ = DatasetUtilities.downloadResource(
117+
filename: variant.rawValue,
118+
fileExtension: "zip",
119+
remoteRoot: variant.url.deletingLastPathComponent(),
120+
localStorageDirectory: directory)
121+
print("\(rootDirPath) downloaded.")
122+
123+
return rootDirPath
124+
}
125+
126+
private static func loadSortedSamples(
127+
from directory: URL,
128+
fileIndexRetriever: String
129+
) throws -> [Tensor<Float>] {
130+
return try FileManager.default
131+
.contentsOfDirectory(
132+
at: directory,
133+
includingPropertiesForKeys: [.isDirectoryKey],
134+
options: [.skipsHiddenFiles])
135+
.filter { $0.pathExtension == "jpg" }
136+
.sorted {
137+
Int($0.lastPathComponent.components(separatedBy: fileIndexRetriever)[0])! <
138+
Int($1.lastPathComponent.components(separatedBy: fileIndexRetriever)[0])!
139+
}
140+
.map {
141+
Image(jpeg: $0).tensor / 127.5 - 1.0
142+
}
56143
}
57144
}
145+
146+
extension Pix2PixDataset where Entropy == SystemRandomNumberGenerator {
147+
public init(
148+
from rootDirPath: String? = nil,
149+
variant: Pix2PixDatasetVariant? = nil,
150+
trainBatchSize: Int = 1,
151+
testBatchSize: Int = 1
152+
) throws {
153+
try self.init(
154+
from: rootDirPath,
155+
variant: variant,
156+
trainBatchSize: trainBatchSize,
157+
testBatchSize: testBatchSize,
158+
entropy: SystemRandomNumberGenerator()
159+
)
160+
}
161+
}

pix2pix/main.swift

+13-48
Original file line numberDiff line numberDiff line change
@@ -19,67 +19,32 @@ import ModelSupport
1919

2020
let options = Options.parseOrExit()
2121

22-
let datasetFolder: URL
23-
let trainFolderA: URL
24-
let trainFolderB: URL
25-
let testFolderA: URL
26-
let testFolderB: URL
27-
28-
if let datasetPath = options.datasetPath {
29-
datasetFolder = URL(fileURLWithPath: datasetPath, isDirectory: true)
30-
trainFolderA = datasetFolder.appendingPathComponent("trainA")
31-
testFolderA = datasetFolder.appendingPathComponent("testA")
32-
trainFolderB = datasetFolder.appendingPathComponent("trainB")
33-
testFolderB = datasetFolder.appendingPathComponent("testB")
34-
} else {
35-
func downloadFacadesDataSetIfNotPresent(to directory: URL) {
36-
let downloadPath = directory.appendingPathComponent("facades").path
37-
let directoryExists = FileManager.default.fileExists(atPath: downloadPath)
38-
let contentsOfDir = try? FileManager.default.contentsOfDirectory(atPath: downloadPath)
39-
let directoryEmpty = (contentsOfDir == nil) || (contentsOfDir!.isEmpty)
40-
41-
guard !directoryExists || directoryEmpty else { return }
42-
43-
let location = URL(
44-
string: "https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/facades.zip")!
45-
let _ = DatasetUtilities.downloadResource(
46-
filename: "facades", fileExtension: "zip",
47-
remoteRoot: location.deletingLastPathComponent(), localStorageDirectory: directory)
48-
}
22+
let dataset = try! Pix2PixDataset(
23+
from: options.datasetPath,
24+
trainBatchSize: 1,
25+
testBatchSize: 1)
4926

50-
datasetFolder = DatasetUtilities.defaultDirectory.appendingPathComponent("pix2pix", isDirectory: true)
51-
downloadFacadesDataSetIfNotPresent(to: datasetFolder)
52-
trainFolderA = datasetFolder.appendingPathComponent("facades/trainA", isDirectory: true)
53-
testFolderA = datasetFolder.appendingPathComponent("facades/testA", isDirectory: true)
54-
trainFolderB = datasetFolder.appendingPathComponent("facades/trainB", isDirectory: true)
55-
testFolderB = datasetFolder.appendingPathComponent("facades/testB", isDirectory: true)
56-
}
27+
var validationImage = dataset.testSamples[0].source.expandingShape(at: 0)
28+
let validationImageURL = URL(string: FileManager.default.currentDirectoryPath)!.appendingPathComponent("sample.jpg")
5729

5830
var generator = NetG(inputChannels: 3, outputChannels: 3, ngf: 64, useDropout: false)
5931
var discriminator = NetD(inChannels: 6, lastConvFilters: 64)
6032

6133
let optimizerG = Adam(for: generator, learningRate: 0.0002, beta1: 0.5)
6234
let optimizerD = Adam(for: discriminator, learningRate: 0.0002, beta1: 0.5)
6335

64-
let batchSize = 1
65-
let lambdaL1 = Tensor<Float>(100)
66-
67-
let trainDataset = try PairedImages(folderAURL: trainFolderA, folderBURL: trainFolderB)
68-
let testDataset = try PairedImages(folderAURL: testFolderA, folderBURL: testFolderB)
69-
70-
var sampleImage = testDataset.batcher.dataset[0].source.expandingShape(at: 0)
71-
let sampleImageURL = URL(string: FileManager.default.currentDirectoryPath)!.appendingPathComponent("sample.jpg")
72-
36+
let epochCount = options.epochs
7337
var step = 0
38+
let lambdaL1 = Tensor<Float>(100)
7439

75-
for epoch in 0..<options.epochs {
40+
for (epoch, epochBatches) in dataset.training.prefix(epochCount).enumerated() {
7641
print("Epoch \(epoch) started at: \(Date())")
7742

7843
var discriminatorTotalLoss = Tensor<Float>(0)
7944
var generatorTotalLoss = Tensor<Float>(0)
8045
var discriminatorCount = 0
8146

82-
for batch in trainDataset.batcher.sequenced() {
47+
for batch in epochBatches {
8348
defer { step += 1 }
8449

8550
Context.local.learningPhase = .training
@@ -137,10 +102,10 @@ for epoch in 0..<options.epochs {
137102
if step % options.sampleLogPeriod == 0 {
138103
Context.local.learningPhase = .inference
139104

140-
let fakeSample = generator(sampleImage) * 0.5 + 0.5
105+
let fakeSample = generator(validationImage) * 0.5 + 0.5
141106

142107
let fakeSampleImage = Image(tensor: fakeSample[0] * 255)
143-
fakeSampleImage.save(to: sampleImageURL, format: .rgb)
108+
fakeSampleImage.save(to: validationImageURL, format: .rgb)
144109
}
145110

146111
discriminatorCount += 1
@@ -158,7 +123,7 @@ var totalLoss = Tensor<Float>(0)
158123
var count = 0
159124

160125
let resultsFolder = try createDirectoryIfNeeded(path: FileManager.default.currentDirectoryPath + "/results")
161-
for batch in testDataset.batcher.sequenced() {
126+
for batch in dataset.testing {
162127
let fakeImages = generator(batch.source)
163128

164129
let tensorImage = batch.source

0 commit comments

Comments
 (0)