|
14 | 14 |
|
15 | 15 | import Foundation
|
16 | 16 | import ModelSupport
|
| 17 | +import Datasets |
17 | 18 | import TensorFlow
|
18 |
| -import Batcher |
19 | 19 |
|
20 |
| -public class PairedImages { |
21 |
| - public struct ImagePair: _Collatable { |
22 |
| - public init(oldCollating: [PairedImages.ImagePair]) { |
23 |
| - self.source = .init(stacking: oldCollating.map(\.source)) |
24 |
| - self.target = .init(stacking: oldCollating.map(\.target)) |
| 20 | +public enum Pix2PixDatasetVariant: String { |
| 21 | + case facades |
| 22 | + |
| 23 | + public var url: URL { |
| 24 | + switch self { |
| 25 | + case .facades: |
| 26 | + return URL(string: |
| 27 | + "https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/facades.zip")! |
25 | 28 | }
|
| 29 | + } |
| 30 | +} |
| 31 | + |
| 32 | +public struct Pix2PixDataset<Entropy: RandomNumberGenerator> { |
| 33 | + public typealias Samples = [(source: Tensor<Float>, target: Tensor<Float>)] |
| 34 | + public typealias Batches = Slices<Sampling<Samples, ArraySlice<Int>>> |
| 35 | + public typealias PairedImageBatch = (source: Tensor<Float>, target: Tensor<Float>) |
| 36 | + public typealias Training = LazyMapSequence< |
| 37 | + TrainingEpochs<Samples, Entropy>, |
| 38 | + LazyMapSequence<Batches, PairedImageBatch> |
| 39 | + > |
| 40 | + public typealias Testing = LazyMapSequence< |
| 41 | + Slices<Samples>, |
| 42 | + PairedImageBatch |
| 43 | + > |
| 44 | + |
| 45 | + public let trainSamples: Samples |
| 46 | + public let testSamples: Samples |
| 47 | + public let training: Training |
| 48 | + public let testing: Testing |
| 49 | + |
| 50 | + public init( |
| 51 | + from rootDirPath: String? = nil, |
| 52 | + variant: Pix2PixDatasetVariant? = nil, |
| 53 | + trainBatchSize: Int = 1, |
| 54 | + testBatchSize: Int = 1, |
| 55 | + entropy: Entropy) throws { |
26 | 56 |
|
27 |
| - public init(source: Tensor<Float>, target: Tensor<Float>) { |
28 |
| - self.source = source |
29 |
| - self.target = target |
30 |
| - } |
| 57 | + let rootDirPath = rootDirPath ?? Pix2PixDataset.downloadIfNotPresent( |
| 58 | + variant: variant ?? .facades, |
| 59 | + to: DatasetUtilities.defaultDirectory.appendingPathComponent("pix2pix", isDirectory: true)) |
| 60 | + let rootDirURL = URL(fileURLWithPath: rootDirPath, isDirectory: true) |
31 | 61 |
|
32 |
| - var source: Tensor<Float> |
33 |
| - var target: Tensor<Float> |
34 |
| - } |
35 |
| - var batcher: Batcher<[ImagePair]> |
36 |
| - |
37 |
| - public init(folderAURL: URL, folderBURL: URL) throws { |
38 |
| - let folderAContents = try FileManager.default |
39 |
| - .contentsOfDirectory(at: folderAURL, |
40 |
| - includingPropertiesForKeys: [.isDirectoryKey], |
41 |
| - options: [.skipsHiddenFiles]) |
42 |
| - .filter { $0.pathExtension == "jpg" } |
43 |
| - |
44 |
| - let imageTensors = folderAContents.map { (url: URL) -> ImagePair in |
45 |
| - let tensorA = Image(jpeg: url).tensor / 127.5 - 1.0 |
46 |
| - |
47 |
| - let tensorBImageURL = folderBURL.appendingPathComponent(url.lastPathComponent.replacingOccurrences(of: "_A.jpg", with: "_B.jpg")) |
48 |
| - let tensorB = Image(jpeg: tensorBImageURL).tensor / 127.5 - 1.0 |
49 |
| - |
50 |
| - return ImagePair(source: tensorA, target: tensorB) |
51 |
| - } |
| 62 | + trainSamples = Array(zip( |
| 63 | + try Pix2PixDataset.loadSortedSamples( |
| 64 | + from: rootDirURL.appendingPathComponent("trainA"), |
| 65 | + fileIndexRetriever: "_" |
| 66 | + ), |
| 67 | + try Pix2PixDataset.loadSortedSamples( |
| 68 | + from: rootDirURL.appendingPathComponent("trainB"), |
| 69 | + fileIndexRetriever: "_" |
| 70 | + ) |
| 71 | + )) |
52 | 72 |
|
53 |
| - self.batcher = Batcher(on: imageTensors, |
54 |
| - batchSize: 1, |
55 |
| - shuffle: true) |
| 73 | + testSamples = Array(zip( |
| 74 | + try Pix2PixDataset.loadSortedSamples( |
| 75 | + from: rootDirURL.appendingPathComponent("testA"), |
| 76 | + fileIndexRetriever: "." |
| 77 | + ), |
| 78 | + try Pix2PixDataset.loadSortedSamples( |
| 79 | + from: rootDirURL.appendingPathComponent("testB"), |
| 80 | + fileIndexRetriever: "." |
| 81 | + ) |
| 82 | + )) |
| 83 | + |
| 84 | + training = TrainingEpochs( |
| 85 | + samples: trainSamples, |
| 86 | + batchSize: trainBatchSize, |
| 87 | + entropy: entropy |
| 88 | + ).lazy.map { (batches: Batches) -> LazyMapSequence<Batches, PairedImageBatch> in |
| 89 | + batches.lazy.map { |
| 90 | + ( |
| 91 | + source: Tensor<Float>($0.map(\.source)), |
| 92 | + target: Tensor<Float>($0.map(\.target)) |
| 93 | + ) |
| 94 | + } |
| 95 | + } |
| 96 | + |
| 97 | + testing = testSamples.inBatches(of: testBatchSize) |
| 98 | + .lazy.map { |
| 99 | + ( |
| 100 | + source: Tensor<Float>($0.map(\.source)), |
| 101 | + target: Tensor<Float>($0.map(\.target)) |
| 102 | + ) |
| 103 | + } |
| 104 | + } |
| 105 | + |
| 106 | + private static func downloadIfNotPresent( |
| 107 | + variant: Pix2PixDatasetVariant, |
| 108 | + to directory: URL) -> String { |
| 109 | + let rootDirPath = directory.appendingPathComponent(variant.rawValue).path |
| 110 | + |
| 111 | + let directoryExists = FileManager.default.fileExists(atPath: rootDirPath) |
| 112 | + let contentsOfDir = try? FileManager.default.contentsOfDirectory(atPath: rootDirPath) |
| 113 | + let directoryEmpty = (contentsOfDir == nil) || (contentsOfDir!.isEmpty) |
| 114 | + guard !directoryExists || directoryEmpty else { return rootDirPath } |
| 115 | + |
| 116 | + let _ = DatasetUtilities.downloadResource( |
| 117 | + filename: variant.rawValue, |
| 118 | + fileExtension: "zip", |
| 119 | + remoteRoot: variant.url.deletingLastPathComponent(), |
| 120 | + localStorageDirectory: directory) |
| 121 | + print("\(rootDirPath) downloaded.") |
| 122 | + |
| 123 | + return rootDirPath |
| 124 | + } |
| 125 | + |
| 126 | + private static func loadSortedSamples( |
| 127 | + from directory: URL, |
| 128 | + fileIndexRetriever: String |
| 129 | + ) throws -> [Tensor<Float>] { |
| 130 | + return try FileManager.default |
| 131 | + .contentsOfDirectory( |
| 132 | + at: directory, |
| 133 | + includingPropertiesForKeys: [.isDirectoryKey], |
| 134 | + options: [.skipsHiddenFiles]) |
| 135 | + .filter { $0.pathExtension == "jpg" } |
| 136 | + .sorted { |
| 137 | + Int($0.lastPathComponent.components(separatedBy: fileIndexRetriever)[0])! < |
| 138 | + Int($1.lastPathComponent.components(separatedBy: fileIndexRetriever)[0])! |
| 139 | + } |
| 140 | + .map { |
| 141 | + Image(jpeg: $0).tensor / 127.5 - 1.0 |
| 142 | + } |
56 | 143 | }
|
57 | 144 | }
|
| 145 | + |
| 146 | +extension Pix2PixDataset where Entropy == SystemRandomNumberGenerator { |
| 147 | + public init( |
| 148 | + from rootDirPath: String? = nil, |
| 149 | + variant: Pix2PixDatasetVariant? = nil, |
| 150 | + trainBatchSize: Int = 1, |
| 151 | + testBatchSize: Int = 1 |
| 152 | + ) throws { |
| 153 | + try self.init( |
| 154 | + from: rootDirPath, |
| 155 | + variant: variant, |
| 156 | + trainBatchSize: trainBatchSize, |
| 157 | + testBatchSize: testBatchSize, |
| 158 | + entropy: SystemRandomNumberGenerator() |
| 159 | + ) |
| 160 | + } |
| 161 | +} |
0 commit comments