Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 6ddca08

Browse files
authored
Switch off Batcher from CycleGAN (#589)
1 parent 521637a commit 6ddca08

File tree

2 files changed

+136
-75
lines changed

2 files changed

+136
-75
lines changed

CycleGAN/Data/Dataset.swift

Lines changed: 120 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,128 @@
1414

1515
import Foundation
1616
import ModelSupport
17+
import Datasets
1718
import TensorFlow
18-
import Batcher
19-
20-
public class Images {
21-
var batcher: Batcher<[Tensorf]>
22-
23-
public init(folderURL: URL) throws {
24-
let folderContents = try FileManager.default
25-
.contentsOfDirectory(at: folderURL,
26-
includingPropertiesForKeys: [.isDirectoryKey],
27-
options: [.skipsHiddenFiles])
28-
let imageFiles = folderContents.filter { $0.pathExtension == "jpg" }
29-
30-
let imageTensors = imageFiles.map {
31-
Image(jpeg: $0).tensor / 127.5 - 1.0
19+
20+
21+
public enum CycleGANDatasetVariant: String {
22+
case horse2zebra
23+
24+
public var url: URL {
25+
switch self {
26+
case .horse2zebra:
27+
return URL(string:
28+
"https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip")!
3229
}
30+
}
31+
}
32+
33+
public struct CycleGANDataset<Entropy: RandomNumberGenerator> {
34+
public typealias Samples = [(domainA: Tensor<Float>, domainB: Tensor<Float>)]
35+
public typealias Batches = Slices<Sampling<Samples, ArraySlice<Int>>>
36+
public typealias PairedImageBatch = (domainA: Tensor<Float>, domainB: Tensor<Float>)
37+
public typealias Training = LazyMapSequence<
38+
TrainingEpochs<Samples, Entropy>,
39+
LazyMapSequence<Batches, PairedImageBatch>
40+
>
41+
public typealias Testing = LazyMapSequence<
42+
Slices<Samples>,
43+
PairedImageBatch
44+
>
45+
46+
public let trainSamples: Samples
47+
public let testSamples: Samples
48+
public let training: Training
49+
public let testing: Testing
50+
51+
public init(
52+
from rootDirPath: String? = nil,
53+
variant: CycleGANDatasetVariant? = nil,
54+
trainBatchSize: Int = 1,
55+
testBatchSize: Int = 1,
56+
entropy: Entropy) throws {
57+
58+
let rootDirPath = rootDirPath ?? CycleGANDataset.downloadIfNotPresent(
59+
variant: variant ?? .horse2zebra,
60+
to: DatasetUtilities.defaultDirectory.appendingPathComponent("CycleGAN", isDirectory: true))
61+
let rootDirURL = URL(fileURLWithPath: rootDirPath, isDirectory: true)
3362

34-
self.batcher = Batcher(on: imageTensors,
35-
batchSize: 1,
36-
shuffle: true)
63+
trainSamples = Array(zip(
64+
try CycleGANDataset.loadSamples(from: rootDirURL.appendingPathComponent("trainA")),
65+
try CycleGANDataset.loadSamples(from: rootDirURL.appendingPathComponent("trainB"))))
66+
67+
testSamples = Array(zip(
68+
try CycleGANDataset.loadSamples(from: rootDirURL.appendingPathComponent("testA")),
69+
try CycleGANDataset.loadSamples(from: rootDirURL.appendingPathComponent("testB"))))
70+
71+
training = TrainingEpochs(
72+
samples: trainSamples,
73+
batchSize: trainBatchSize,
74+
entropy: entropy
75+
).lazy.map { (batches: Batches) -> LazyMapSequence<Batches, PairedImageBatch> in
76+
batches.lazy.map {
77+
(
78+
domainA: Tensor<Float>($0.map(\.domainA)),
79+
domainB: Tensor<Float>($0.map(\.domainB))
80+
)
81+
}
82+
}
83+
84+
testing = testSamples.inBatches(of: testBatchSize)
85+
.lazy.map {
86+
(
87+
domainA: Tensor<Float>($0.map(\.domainA)),
88+
domainB: Tensor<Float>($0.map(\.domainB))
89+
)
90+
}
91+
}
92+
93+
private static func downloadIfNotPresent(
94+
variant: CycleGANDatasetVariant,
95+
to directory: URL) -> String {
96+
let rootDirPath = directory.appendingPathComponent(variant.rawValue).path
97+
98+
let directoryExists = FileManager.default.fileExists(atPath: rootDirPath)
99+
let contentsOfDir = try? FileManager.default.contentsOfDirectory(atPath: rootDirPath)
100+
let directoryEmpty = (contentsOfDir == nil) || (contentsOfDir!.isEmpty)
101+
guard !directoryExists || directoryEmpty else { return rootDirPath }
102+
103+
let _ = DatasetUtilities.downloadResource(
104+
filename: variant.rawValue,
105+
fileExtension: "zip",
106+
remoteRoot: variant.url.deletingLastPathComponent(),
107+
localStorageDirectory: directory)
108+
109+
return rootDirPath
110+
}
111+
112+
private static func loadSamples(from directory: URL) throws -> [Tensor<Float>] {
113+
return try FileManager.default
114+
.contentsOfDirectory(
115+
at: directory,
116+
includingPropertiesForKeys: [.isDirectoryKey],
117+
options: [.skipsHiddenFiles])
118+
.filter { $0.pathExtension == "jpg" }
119+
.map {
120+
Image(jpeg: $0).tensor / 127.5 - 1.0
121+
}
37122
}
38123
}
124+
125+
extension CycleGANDataset where Entropy == SystemRandomNumberGenerator {
126+
public init(
127+
from rootDirPath: String? = nil,
128+
variant: CycleGANDatasetVariant? = nil,
129+
trainBatchSize: Int = 1,
130+
testBatchSize: Int = 1
131+
) throws {
132+
try self.init(
133+
from: rootDirPath,
134+
variant: variant,
135+
trainBatchSize: trainBatchSize,
136+
testBatchSize: testBatchSize,
137+
entropy: SystemRandomNumberGenerator()
138+
)
139+
}
140+
}
141+

CycleGAN/main.swift

Lines changed: 16 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -19,44 +19,10 @@ import Datasets
1919

2020
let options = Options.parseOrExit()
2121

22-
let datasetFolder: URL
23-
let trainFolderA: URL
24-
let trainFolderB: URL
25-
let testFolderA: URL
26-
let testFolderB: URL
27-
28-
if let datasetPath = options.datasetPath {
29-
datasetFolder = URL(fileURLWithPath: datasetPath, isDirectory: true)
30-
trainFolderA = datasetFolder.appendingPathComponent("trainA")
31-
trainFolderB = datasetFolder.appendingPathComponent("trainB")
32-
testFolderA = datasetFolder.appendingPathComponent("testA")
33-
testFolderB = datasetFolder.appendingPathComponent("testB")
34-
} else {
35-
func downloadZebraDataSetIfNotPresent(to directory: URL) {
36-
let downloadPath = directory.appendingPathComponent("horse2zebra").path
37-
let directoryExists = FileManager.default.fileExists(atPath: downloadPath)
38-
let contentsOfDir = try? FileManager.default.contentsOfDirectory(atPath: downloadPath)
39-
let directoryEmpty = (contentsOfDir == nil) || (contentsOfDir!.isEmpty)
40-
41-
guard !directoryExists || directoryEmpty else { return }
42-
43-
let location = URL(
44-
string: "https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip")!
45-
let _ = DatasetUtilities.downloadResource(
46-
filename: "horse2zebra", fileExtension: "zip",
47-
remoteRoot: location.deletingLastPathComponent(), localStorageDirectory: directory)
48-
}
49-
50-
datasetFolder = DatasetUtilities.defaultDirectory.appendingPathComponent("CycleGAN", isDirectory: true)
51-
downloadZebraDataSetIfNotPresent(to: datasetFolder)
52-
trainFolderA = datasetFolder.appendingPathComponent("horse2zebra/trainA")
53-
trainFolderB = datasetFolder.appendingPathComponent("horse2zebra/trainB")
54-
testFolderA = datasetFolder.appendingPathComponent("horse2zebra/testA")
55-
testFolderB = datasetFolder.appendingPathComponent("horse2zebra/testB")
56-
}
57-
58-
let trainDatasetA = try Images(folderURL: trainFolderA)
59-
let trainDatasetB = try Images(folderURL: trainFolderB)
22+
let dataset = try! CycleGANDataset(
23+
from: options.datasetPath,
24+
trainBatchSize: 1,
25+
testBatchSize: 1)
6026

6127
var generatorG = ResNetGenerator(inputChannels: 3, outputChannels: 3, blocks: 9, ngf: 64, normalization: InstanceNorm2D.self)
6228
var generatorF = ResNetGenerator(inputChannels: 3, outputChannels: 3, blocks: 9, ngf: 64, normalization: InstanceNorm2D.self)
@@ -68,30 +34,27 @@ let optimizerGG = Adam(for: generatorG, learningRate: 0.0002, beta1: 0.5)
6834
let optimizerDX = Adam(for: discriminatorX, learningRate: 0.0002, beta1: 0.5)
6935
let optimizerDY = Adam(for: discriminatorY, learningRate: 0.0002, beta1: 0.5)
7036

71-
let epochs = options.epochs
72-
let batchSize = 1
37+
let epochCount = options.epochs
7338
let lambdaL1 = Tensorf(10)
7439
let _zeros = Tensorf.zero
7540
let _ones = Tensorf.one
7641

7742
var step = 0
7843

79-
var sampleImage = trainDatasetA.batcher.dataset[0].expandingShape(at: 0)
80-
let sampleImageURL = URL(string: FileManager.default.currentDirectoryPath)!.appendingPathComponent("sample.jpg")
44+
var validationImage = dataset.trainSamples[0].domainA.expandingShape(at: 0)
45+
let validationImageURL = URL(string: FileManager.default.currentDirectoryPath)!.appendingPathComponent("sample.jpg")
8146

8247
// MARK: Train
8348

84-
for epoch in 0 ..< epochs {
49+
for (epoch, epochBatches) in dataset.training.prefix(epochCount).enumerated() {
8550
print("Epoch \(epoch) started at: \(Date())")
8651
Context.local.learningPhase = .training
87-
88-
let zippedAB = zip(trainDatasetA.batcher.sequenced(), trainDatasetB.batcher.sequenced())
8952

90-
for batch in zippedAB {
53+
for batch in epochBatches {
9154
Context.local.learningPhase = .training
9255

93-
let inputX = batch.0
94-
let inputY = batch.1
56+
let inputX = batch.domainA
57+
let inputY = batch.domainB
9558

9659
// we do it outside of GPU scope so that dataset shuffling happens on CPU side
9760
let concatanatedImages = inputX.concatenated(with: inputY)
@@ -187,10 +150,10 @@ for epoch in 0 ..< epochs {
187150
if step % options.sampleLogPeriod == 0 {
188151
Context.local.learningPhase = .inference
189152

190-
let fakeSample = generatorG(sampleImage) * 0.5 + 0.5
153+
let fakeSample = generatorG(validationImage) * 0.5 + 0.5
191154

192155
let fakeSampleImage = Image(tensor: fakeSample[0] * 255)
193-
fakeSampleImage.save(to: sampleImageURL, format: .rgb)
156+
fakeSampleImage.save(to: validationImageURL, format: .rgb)
194157

195158
print("GeneratorG loss: \(gLoss.scalars[0])")
196159
print("GeneratorF loss: \(fLoss.scalars[0])")
@@ -204,20 +167,15 @@ for epoch in 0 ..< epochs {
204167

205168
// MARK: Final test
206169

207-
let testDatasetA = try Images(folderURL: testFolderA).batcher.sequenced()
208-
let testDatasetB = try Images(folderURL: testFolderB).batcher.sequenced()
209-
210-
let zippedTest = zip(testDatasetA, testDatasetB)
211-
212170
let aResultsFolder = try createDirectoryIfNeeded(path: FileManager.default
213171
.currentDirectoryPath + "/testA_results")
214172
let bResultsFolder = try createDirectoryIfNeeded(path: FileManager.default
215173
.currentDirectoryPath + "/testB_results")
216174

217175
var testStep = 0
218-
for testBatch in zippedTest {
219-
let realX = testBatch.0
220-
let realY = testBatch.1
176+
for testBatch in dataset.testing {
177+
let realX = testBatch.domainA
178+
let realY = testBatch.domainB
221179

222180
let fakeY = generatorG(realX)
223181
let fakeX = generatorF(realY)

0 commit comments

Comments
 (0)