@@ -19,44 +19,10 @@ import Datasets
19
19
20
20
let options = Options . parseOrExit ( )
21
21
22
- let datasetFolder : URL
23
- let trainFolderA : URL
24
- let trainFolderB : URL
25
- let testFolderA : URL
26
- let testFolderB : URL
27
-
28
- if let datasetPath = options. datasetPath {
29
- datasetFolder = URL ( fileURLWithPath: datasetPath, isDirectory: true )
30
- trainFolderA = datasetFolder. appendingPathComponent ( " trainA " )
31
- trainFolderB = datasetFolder. appendingPathComponent ( " trainB " )
32
- testFolderA = datasetFolder. appendingPathComponent ( " testA " )
33
- testFolderB = datasetFolder. appendingPathComponent ( " testB " )
34
- } else {
35
- func downloadZebraDataSetIfNotPresent( to directory: URL ) {
36
- let downloadPath = directory. appendingPathComponent ( " horse2zebra " ) . path
37
- let directoryExists = FileManager . default. fileExists ( atPath: downloadPath)
38
- let contentsOfDir = try ? FileManager . default. contentsOfDirectory ( atPath: downloadPath)
39
- let directoryEmpty = ( contentsOfDir == nil ) || ( contentsOfDir!. isEmpty)
40
-
41
- guard !directoryExists || directoryEmpty else { return }
42
-
43
- let location = URL (
44
- string: " https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip " ) !
45
- let _ = DatasetUtilities . downloadResource (
46
- filename: " horse2zebra " , fileExtension: " zip " ,
47
- remoteRoot: location. deletingLastPathComponent ( ) , localStorageDirectory: directory)
48
- }
49
-
50
- datasetFolder = DatasetUtilities . defaultDirectory. appendingPathComponent ( " CycleGAN " , isDirectory: true )
51
- downloadZebraDataSetIfNotPresent ( to: datasetFolder)
52
- trainFolderA = datasetFolder. appendingPathComponent ( " horse2zebra/trainA " )
53
- trainFolderB = datasetFolder. appendingPathComponent ( " horse2zebra/trainB " )
54
- testFolderA = datasetFolder. appendingPathComponent ( " horse2zebra/testA " )
55
- testFolderB = datasetFolder. appendingPathComponent ( " horse2zebra/testB " )
56
- }
57
-
58
- let trainDatasetA = try Images ( folderURL: trainFolderA)
59
- let trainDatasetB = try Images ( folderURL: trainFolderB)
22
+ let dataset = try ! CycleGANDataset (
23
+ from: options. datasetPath,
24
+ trainBatchSize: 1 ,
25
+ testBatchSize: 1 )
60
26
61
27
var generatorG = ResNetGenerator ( inputChannels: 3 , outputChannels: 3 , blocks: 9 , ngf: 64 , normalization: InstanceNorm2D . self)
62
28
var generatorF = ResNetGenerator ( inputChannels: 3 , outputChannels: 3 , blocks: 9 , ngf: 64 , normalization: InstanceNorm2D . self)
@@ -68,30 +34,27 @@ let optimizerGG = Adam(for: generatorG, learningRate: 0.0002, beta1: 0.5)
68
34
let optimizerDX = Adam ( for: discriminatorX, learningRate: 0.0002 , beta1: 0.5 )
69
35
let optimizerDY = Adam ( for: discriminatorY, learningRate: 0.0002 , beta1: 0.5 )
70
36
71
- let epochs = options. epochs
72
- let batchSize = 1
37
+ let epochCount = options. epochs
73
38
let lambdaL1 = Tensorf ( 10 )
74
39
let _zeros = Tensorf . zero
75
40
let _ones = Tensorf . one
76
41
77
42
var step = 0
78
43
79
- var sampleImage = trainDatasetA . batcher . dataset [ 0 ] . expandingShape ( at: 0 )
80
- let sampleImageURL = URL ( string: FileManager . default. currentDirectoryPath) !. appendingPathComponent ( " sample.jpg " )
44
+ var validationImage = dataset. trainSamples [ 0 ] . domainA . expandingShape ( at: 0 )
45
+ let validationImageURL = URL ( string: FileManager . default. currentDirectoryPath) !. appendingPathComponent ( " sample.jpg " )
81
46
82
47
// MARK: Train
83
48
84
- for epoch in 0 ..< epochs {
49
+ for ( epoch, epochBatches ) in dataset . training . prefix ( epochCount ) . enumerated ( ) {
85
50
print ( " Epoch \( epoch) started at: \( Date ( ) ) " )
86
51
Context . local. learningPhase = . training
87
-
88
- let zippedAB = zip ( trainDatasetA. batcher. sequenced ( ) , trainDatasetB. batcher. sequenced ( ) )
89
52
90
- for batch in zippedAB {
53
+ for batch in epochBatches {
91
54
Context . local. learningPhase = . training
92
55
93
- let inputX = batch. 0
94
- let inputY = batch. 1
56
+ let inputX = batch. domainA
57
+ let inputY = batch. domainB
95
58
96
59
// we do it outside of GPU scope so that dataset shuffling happens on CPU side
97
60
let concatanatedImages = inputX. concatenated ( with: inputY)
@@ -187,10 +150,10 @@ for epoch in 0 ..< epochs {
187
150
if step % options. sampleLogPeriod == 0 {
188
151
Context . local. learningPhase = . inference
189
152
190
- let fakeSample = generatorG ( sampleImage ) * 0.5 + 0.5
153
+ let fakeSample = generatorG ( validationImage ) * 0.5 + 0.5
191
154
192
155
let fakeSampleImage = Image ( tensor: fakeSample [ 0 ] * 255 )
193
- fakeSampleImage. save ( to: sampleImageURL , format: . rgb)
156
+ fakeSampleImage. save ( to: validationImageURL , format: . rgb)
194
157
195
158
print ( " GeneratorG loss: \( gLoss. scalars [ 0 ] ) " )
196
159
print ( " GeneratorF loss: \( fLoss. scalars [ 0 ] ) " )
@@ -204,20 +167,15 @@ for epoch in 0 ..< epochs {
204
167
205
168
// MARK: Final test
206
169
207
- let testDatasetA = try Images ( folderURL: testFolderA) . batcher. sequenced ( )
208
- let testDatasetB = try Images ( folderURL: testFolderB) . batcher. sequenced ( )
209
-
210
- let zippedTest = zip ( testDatasetA, testDatasetB)
211
-
212
170
let aResultsFolder = try createDirectoryIfNeeded ( path: FileManager . default
213
171
. currentDirectoryPath + " /testA_results " )
214
172
let bResultsFolder = try createDirectoryIfNeeded ( path: FileManager . default
215
173
. currentDirectoryPath + " /testB_results " )
216
174
217
175
var testStep = 0
218
- for testBatch in zippedTest {
219
- let realX = testBatch. 0
220
- let realY = testBatch. 1
176
+ for testBatch in dataset . testing {
177
+ let realX = testBatch. domainA
178
+ let realY = testBatch. domainB
221
179
222
180
let fakeY = generatorG ( realX)
223
181
let fakeX = generatorF ( realY)
0 commit comments