Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 0f574e0

Browse files
authored
Cyclegan tweaks (#441)
* make cyclegan demo standalone * add readme * formatting, use let
1 parent cdf7f12 commit 0f574e0

File tree

4 files changed

+56
-7
lines changed

4 files changed

+56
-7
lines changed

CycleGAN/CLI.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import ArgumentParser
1616

1717
struct Options: ParsableArguments {
18-
@Option(default: "./dataset", help: ArgumentHelp("Path to the dataset folder", valueName: "dataset-path"))
18+
@Option(default: "", help: ArgumentHelp("Path to the dataset folder", valueName: "dataset-path"))
1919
var datasetPath: String
2020

2121
@Option(default: 50, help: ArgumentHelp("Number of epochs", valueName: "epochs"))

CycleGAN/README.md

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# CycleGAN
2+
3+
**Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks**
4+
website: https://junyanz.github.io/CycleGAN/
5+
6+
arXiv: https://arxiv.org/abs/1703.10593
7+
8+
## Setup
9+
10+
To begin, you'll need the [latest version of Swift for
11+
TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md)
12+
installed. Make sure you've added the correct version of `swift` to your path.
13+
14+
To train the model, run:
15+
16+
```sh
17+
swift run CycleGAN
18+
```

CycleGAN/main.swift

+36-5
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,45 @@
1515
import Foundation
1616
import ModelSupport
1717
import TensorFlow
18+
import Datasets
1819

1920
let options = Options.parseOrExit()
2021

21-
let datasetFolder = URL(fileURLWithPath: options.datasetPath, isDirectory: true)
22-
let trainFolderA = datasetFolder.appendingPathComponent("trainA")
23-
let trainFolderB = datasetFolder.appendingPathComponent("trainB")
24-
let testFolderA = datasetFolder.appendingPathComponent("testA")
25-
let testFolderB = datasetFolder.appendingPathComponent("testB")
22+
let datasetFolder: URL
23+
let trainFolderA: URL
24+
let trainFolderB: URL
25+
let testFolderA: URL
26+
let testFolderB: URL
27+
28+
if options.datasetPath.length != 0 {
29+
datasetFolder = URL(fileURLWithPath: options.datasetPath, isDirectory: true)
30+
trainFolderA = datasetFolder.appendingPathComponent("trainA")
31+
trainFolderB = datasetFolder.appendingPathComponent("trainB")
32+
testFolderA = datasetFolder.appendingPathComponent("testA")
33+
testFolderB = datasetFolder.appendingPathComponent("testB")
34+
} else {
35+
func downloadZebraDataSetIfNotPresent(to directory: URL) {
36+
let downloadPath = directory.appendingPathComponent("horse2zebra").path
37+
let directoryExists = FileManager.default.fileExists(atPath: downloadPath)
38+
let contentsOfDir = try? FileManager.default.contentsOfDirectory(atPath: downloadPath)
39+
let directoryEmpty = (contentsOfDir == nil) || (contentsOfDir!.isEmpty)
40+
41+
guard !directoryExists || directoryEmpty else { return }
42+
43+
let location = URL(
44+
string: "https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip")!
45+
let _ = DatasetUtilities.downloadResource(
46+
filename: "horse2zebra", fileExtension: "zip",
47+
remoteRoot: location.deletingLastPathComponent(), localStorageDirectory: directory)
48+
}
49+
50+
datasetFolder = DatasetUtilities.defaultDirectory.appendingPathComponent("CycleGAN", isDirectory: true)
51+
downloadZebraDataSetIfNotPresent(to: datasetFolder)
52+
trainFolderA = datasetFolder.appendingPathComponent("horse2zebra/trainA")
53+
trainFolderB = datasetFolder.appendingPathComponent("horse2zebra/trainB")
54+
testFolderA = datasetFolder.appendingPathComponent("horse2zebra/testA")
55+
testFolderB = datasetFolder.appendingPathComponent("horse2zebra/testB")
56+
}
2657

2758
let trainDatasetA = try Images(folderURL: trainFolderA)
2859
let trainDatasetB = try Images(folderURL: trainFolderB)

Package.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ let package = Package(
121121
.testTarget(name: "SupportTests", dependencies: ["ModelSupport"]),
122122
.target(
123123
name: "CycleGAN",
124-
dependencies: ["Batcher", .product(name: "ArgumentParser", package: "swift-argument-parser"), "ModelSupport"],
124+
dependencies: ["Batcher", .product(name: "ArgumentParser", package: "swift-argument-parser"), "ModelSupport", "Datasets"],
125125
path: "CycleGAN"
126126
)
127127
]

0 commit comments

Comments
 (0)