Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 6fecd19

Browse files
authored
CycleGan (#400)
* introduce data loading * introduce initial model implementation * utils glue code * finalize training code * add model into main Package.swift file * import model support * comply with google swiftformat * fix license * cleanup pix2pix leftovers * bug-fix results saving * bug fix result dumping * comply with google swiftformat * capitilize Net * don't use GPU-index from CLI, CUDA_VISIBLE_DEVICES instead * drop _Raw operations and explicit device specification * remove type inferred init calls * cleanup layer bank * remove type inferred init calls, more clear norm type name * remove type inferred init calls * comply with google's swift format * drop Files dependency * use zeropad from swift-api * remove runid * remove tensorboard arguments from CLI * migrate to Batcher * migrate to batcher and remove tensorboard * update package files * remove intermediate relu calculations * bug fix sample image generation
1 parent 2ac021c commit 6fecd19

File tree

9 files changed

+664
-7
lines changed

9 files changed

+664
-7
lines changed

CycleGAN/CLI.swift

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import ArgumentParser
16+
17+
struct Options: ParsableArguments {
18+
@Option(default: "./dataset", help: ArgumentHelp("Path to the dataset folder", valueName: "dataset-path"))
19+
var datasetPath: String
20+
21+
@Option(default: 50, help: ArgumentHelp("Number of epochs", valueName: "epochs"))
22+
var epochs: Int
23+
24+
@Option(default: 20, help: ArgumentHelp("Number of steps to log a sample image into tensorboard", valueName: "sampleLogPeriod"))
25+
var sampleLogPeriod: Int
26+
}

CycleGAN/Data/Dataset.swift

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
import ModelSupport
17+
import TensorFlow
18+
import Batcher
19+
20+
public class Images {
21+
var batcher: Batcher<[Tensorf]>
22+
23+
public init(folderURL: URL) throws {
24+
let folderContents = try FileManager.default
25+
.contentsOfDirectory(at: folderURL,
26+
includingPropertiesForKeys: [.isDirectoryKey],
27+
options: [.skipsHiddenFiles])
28+
let imageFiles = folderContents.filter { $0.pathExtension == "jpg" }
29+
30+
let imageTensors = imageFiles.map {
31+
Image(jpeg: $0).tensor / 127.5 - 1.0
32+
}
33+
34+
self.batcher = Batcher(on: imageTensors,
35+
batchSize: 1,
36+
shuffle: true)
37+
}
38+
}

CycleGAN/Models/Discriminator.swift

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import TensorFlow
16+
17+
public struct NetD: Layer {
18+
var module: Sequential<Sequential<Conv2D<Float>, Sequential<Function<Tensorf, Tensorf>, Sequential<Conv2D<Float>, Sequential<BatchNorm<Float>, Sequential<Function<Tensorf, Tensorf>, Sequential<Conv2D<Float>, Sequential<BatchNorm<Float>, Function<Tensorf, Tensorf>>>>>>>>, Sequential<ConvLayer, Sequential<BatchNorm<Float>, Sequential<Function<Tensorf, Tensorf>, ConvLayer>>>>
19+
20+
public init(inChannels: Int, lastConvFilters: Int) {
21+
let kw = 4
22+
23+
let module = Sequential {
24+
Conv2D<Float>(filterShape: (kw, kw, inChannels, lastConvFilters),
25+
strides: (2, 2),
26+
padding: .same,
27+
filterInitializer: { Tensorf(randomNormal: $0, standardDeviation: Tensorf(0.02)) })
28+
Function<Tensorf, Tensorf> { leakyRelu($0) }
29+
30+
Conv2D<Float>(filterShape: (kw, kw, lastConvFilters, 2 * lastConvFilters),
31+
strides: (2, 2),
32+
padding: .same,
33+
filterInitializer: { Tensorf(randomNormal: $0, standardDeviation: Tensorf(0.02)) })
34+
BatchNorm<Float>(featureCount: 2 * lastConvFilters)
35+
Function<Tensorf, Tensorf> { leakyRelu($0) }
36+
37+
Conv2D<Float>(filterShape: (kw, kw, 2 * lastConvFilters, 4 * lastConvFilters),
38+
strides: (2, 2),
39+
padding: .same,
40+
filterInitializer: { Tensorf(randomNormal: $0, standardDeviation: Tensorf(0.02)) })
41+
BatchNorm<Float>(featureCount: 4 * lastConvFilters)
42+
Function<Tensorf, Tensorf> { leakyRelu($0) }
43+
}
44+
45+
let module2 = Sequential {
46+
module
47+
ConvLayer(inChannels: 4 * lastConvFilters, outChannels: 8 * lastConvFilters,
48+
kernelSize: 4, stride: 1, padding: 1)
49+
50+
BatchNorm<Float>(featureCount: 8 * lastConvFilters)
51+
Function<Tensorf, Tensorf> { leakyRelu($0) }
52+
53+
ConvLayer(inChannels: 8 * lastConvFilters, outChannels: 1,
54+
kernelSize: 4, stride: 1, padding: 1)
55+
}
56+
57+
self.module = module2
58+
}
59+
60+
@differentiable
61+
public func callAsFunction(_ input: Tensorf) -> Tensorf {
62+
return module(input)
63+
}
64+
}

CycleGAN/Models/Generator.swift

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import TensorFlow
16+
17+
public struct ResNetGenerator<NormalizationType: FeatureChannelInitializable>: Layer where NormalizationType.TangentVector.VectorSpaceScalar == Float, NormalizationType.Input == Tensorf, NormalizationType.Output == Tensorf {
18+
var conv1: Conv2D<Float>
19+
var norm1: NormalizationType
20+
21+
var conv2: Conv2D<Float>
22+
var norm2: NormalizationType
23+
24+
var conv3: Conv2D<Float>
25+
var norm3: NormalizationType
26+
27+
var resblocks: [ResNetBlock<NormalizationType>]
28+
29+
var upConv1: TransposedConv2D<Float>
30+
var upNorm1: NormalizationType
31+
32+
var upConv2: TransposedConv2D<Float>
33+
var upNorm2: NormalizationType
34+
35+
var lastConv: Conv2D<Float>
36+
37+
public init(inputChannels: Int,
38+
outputChannels: Int,
39+
blocks: Int,
40+
ngf: Int,
41+
normalization: NormalizationType.Type,
42+
useDropout: Bool = false) {
43+
norm1 = NormalizationType(featureCount: ngf)
44+
let useBias = norm1 is InstanceNorm2D<Float>
45+
46+
let filterInit: (TensorShape) -> Tensorf = { Tensorf(randomNormal: $0, standardDeviation: Tensorf(0.02)) }
47+
let biasInit: (TensorShape) -> Tensorf = useBias ? filterInit : zeros()
48+
49+
conv1 = Conv2D(filterShape: (7, 7, inputChannels, ngf),
50+
strides: (1, 1),
51+
filterInitializer: filterInit,
52+
biasInitializer: biasInit)
53+
54+
var mult = 1
55+
56+
conv2 = Conv2D(filterShape: (3, 3, ngf * mult, ngf * mult * 2),
57+
strides: (2, 2),
58+
padding: .same,
59+
filterInitializer: filterInit,
60+
biasInitializer: biasInit)
61+
norm2 = NormalizationType(featureCount: ngf * mult * 2)
62+
63+
mult = 2
64+
65+
conv3 = Conv2D(filterShape: (3, 3, ngf * mult, ngf * mult * 2),
66+
strides: (2, 2),
67+
padding: .same,
68+
filterInitializer: filterInit,
69+
biasInitializer: biasInit)
70+
norm3 = NormalizationType(featureCount: ngf * mult * 2)
71+
72+
mult = 4
73+
74+
resblocks = (0 ..< blocks).map { _ in
75+
ResNetBlock(channels: ngf * mult,
76+
paddingMode: .reflect,
77+
normalization: normalization,
78+
useDropOut: useDropout,
79+
filterInit: filterInit,
80+
biasInit: biasInit)
81+
}
82+
83+
mult = 4
84+
85+
upConv1 = TransposedConv2D(filterShape: (3, 3, ngf * mult / 2, ngf * mult),
86+
strides: (2, 2),
87+
padding: .same,
88+
filterInitializer: filterInit,
89+
biasInitializer: biasInit)
90+
upNorm1 = NormalizationType(featureCount: ngf * mult / 2)
91+
92+
mult = 2
93+
94+
upConv2 = TransposedConv2D(filterShape: (3, 3, ngf * mult / 2, ngf * mult),
95+
strides: (2, 2),
96+
padding: .same,
97+
filterInitializer: filterInit,
98+
biasInitializer: biasInit)
99+
upNorm2 = NormalizationType(featureCount: ngf * mult / 2)
100+
101+
lastConv = Conv2D(filterShape: (7, 7, ngf, outputChannels),
102+
padding: .same,
103+
filterInitializer: filterInit,
104+
biasInitializer: biasInit)
105+
}
106+
107+
@differentiable
108+
public func callAsFunction(_ input: Tensorf) -> Tensorf {
109+
var x = input.padded(forSizes: [(0, 0), (3, 3), (3, 3), (0, 0)], mode: .reflect)
110+
x = relu(x.sequenced(through: conv1, norm1))
111+
x = relu(x.sequenced(through: conv2, norm2))
112+
x = relu(x.sequenced(through: conv3, norm3))
113+
114+
x = resblocks(x)
115+
116+
x = relu(x.sequenced(through: upConv1, upNorm1))
117+
x = relu(x.sequenced(through: upConv2, upNorm2))
118+
119+
x = lastConv(x)
120+
x = tanh(x)
121+
122+
return x
123+
}
124+
}

0 commit comments

Comments
 (0)