Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit d1c0921

Browse files
authored
Fractal image generation (#565)
* Addition of an example for generating the Mandelbrot and Julia set fractals. * Formatting Swift code. * Added notebook example for the Mandelbrot set. * Have the Y-axis coordinate space go from max at top to min at bottom to match final image orientation. * Add ArgumentParser to the root project CMakeLists, update display of default parameters. * README formatting update. * Disabling example and test builds in CMakeLists for ArgumentParser. * Remove logging code for the first iteration timing from the sample notebook. * Combined shared parameters into a reused option group, added notes about the need for equals in region arguments, and changed the visibility of the prismColor() function. * Better formatting of the notes.
1 parent 6a9b642 commit d1c0921

13 files changed

+1007
-0
lines changed

CMakeLists.txt

+35
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,41 @@ set_target_properties(SwiftProtobuf PROPERTIES
8686
INTERFACE_INCLUDE_DIRECTORIES ${CMAKE_Swift_MODULE_DIRECTORY})
8787
add_dependencies(SwiftProtobuf swift-protobuf-install)
8888

89+
if(CMAKE_SYSTEM_NAME STREQUAL Windows)
90+
set(_copy_swift_argument_parser_import_library
91+
${CMAKE_COMMAND} -E copy_if_different <BINARY_DIR>/Sources/ArgumentParser/${CMAKE_IMPORT_LIBRARY_PREFIX}ArgumentParser${CMAKE_IMPORT_LIBRARY_SUFFIX} ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/)
92+
endif()
93+
ExternalProject_Add(swift-argument-parser
94+
GIT_REPOSITORY git://github.com/apple/swift-argument-parser.git
95+
GIT_TAG master
96+
CMAKE_ARGS
97+
-DBUILD_SHARED_LIBS=YES
98+
-DCMAKE_MAKE_PROGRAM=${CMAKE_MAKE_PROGRAM}
99+
-DCMAKE_Swift_COMPILER=${CMAKE_Swift_COMPILER}
100+
-DCMAKE_Swift_FLAGS=${CMAKE_Swift_FLAGS}
101+
-DBUILD_EXAMPLES=NO
102+
-DBUILD_TESTING=NO
103+
INSTALL_COMMAND
104+
COMMAND
105+
${CMAKE_COMMAND} -E copy_if_different <BINARY_DIR>/lib/${CMAKE_SHARED_LIBRARY_PREFIX}ArgumentParser${CMAKE_SHARED_LIBRARY_SUFFIX} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/
106+
COMMAND
107+
${CMAKE_COMMAND} -E copy_if_different <BINARY_DIR>/swift/ArgumentParser.swiftmodule ${CMAKE_Swift_MODULE_DIRECTORY}/
108+
COMMAND
109+
${_copy_swift_argument_parser_import_library}
110+
BUILD_BYPRODUCTS
111+
<BINARY_DIR>/Sources/ArgumentParser/${CMAKE_SHARED_LIBRARY_PREFIX}ArgumentParser${CMAKE_SHARED_LIBRARY_SUFFIX}
112+
<BINARY_DIR>/Sources/ArgumentParser/${CMAKE_IMPORT_LIBRARY_PREFIX}ArgumentParser${CMAKE_IMPORT_LIBRARY_SUFFIX}
113+
${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${CMAKE_SHARED_LIBRARY_PREFIX}ArgumentParser${CMAKE_SHARED_LIBRARY_SUFFIX}
114+
${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/${CMAKE_IMPORT_LIBRARY_PREFIX}ArgumentParser${CMAKE_IMPORT_LIBRARY_SUFFIX}
115+
STEP_TARGETS install)
116+
117+
add_library(ArgumentParser SHARED IMPORTED)
118+
set_target_properties(ArgumentParser PROPERTIES
119+
IMPORTED_LOCATION ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${CMAKE_SHARED_LIBRARY_PREFIX}ArgumentParser${CMAKE_SHARED_LIBRARY_SUFFIX}
120+
IMPORTED_IMPLIB ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/${CMAKE_IMPORT_LIBRARY_PREFIX}ArgumentParser${CMAKE_IMPORT_LIBRARY_SUFFIX}
121+
INTERFACE_INCLUDE_DIRECTORIES ${CMAKE_Swift_MODULE_DIRECTORY})
122+
add_dependencies(ArgumentParser swift-argument-parser-install)
123+
89124
add_subdirectory(Autoencoder)
90125
add_subdirectory(Support)
91126
add_subdirectory(Batcher)

Examples/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ add_subdirectory(GPT2-WikiText2)
99
add_subdirectory(NeuMF-MovieLens)
1010
add_subdirectory(GPT2-Inference)
1111
add_subdirectory(WordSeg)
12+
add_subdirectory(Fractals)

Examples/Fractals/CMakeLists.txt

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
add_executable(Fractals
2+
ComplexTensor.swift
3+
ImageUtilities.swift
4+
JuliaSet.swift
5+
MandelbrotSet.swift
6+
main.swift)
7+
target_link_libraries(Fractals PRIVATE
8+
ArgumentParser
9+
ModelSupport)
10+
11+
12+
install(TARGETS Fractals
13+
DESTINATION bin)

Examples/Fractals/ComplexTensor.swift

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import ArgumentParser
16+
import TensorFlow
17+
18+
struct ComplexTensor {
19+
let real: Tensor<Float>
20+
let imaginary: Tensor<Float>
21+
}
22+
23+
func +(lhs: ComplexTensor, rhs: ComplexTensor) -> ComplexTensor {
24+
let real = lhs.real + rhs.real
25+
let imaginary = lhs.imaginary + rhs.imaginary
26+
return ComplexTensor(real: real, imaginary: imaginary)
27+
}
28+
29+
func *(lhs: ComplexTensor, rhs: ComplexTensor) -> ComplexTensor {
30+
let real = lhs.real .* rhs.real - lhs.imaginary .* rhs.imaginary
31+
let imaginary = lhs.real .* rhs.imaginary + lhs.imaginary .* rhs.real
32+
return ComplexTensor(real: real, imaginary: imaginary)
33+
}
34+
35+
func abs(_ value: ComplexTensor) -> Tensor<Float> {
36+
return value.real .* value.real + value.imaginary .* value.imaginary
37+
}
38+
39+
struct ComplexRegion {
40+
let realMinimum: Float
41+
let realMaximum: Float
42+
let imaginaryMinimum: Float
43+
let imaginaryMaximum: Float
44+
}
45+
46+
extension ComplexRegion: ExpressibleByArgument {
47+
init?(argument: String) {
48+
let subArguments = argument.split(separator: ",").compactMap { Float(String($0)) }
49+
guard subArguments.count >= 4 else { return nil }
50+
51+
self.realMinimum = subArguments[0]
52+
self.realMaximum = subArguments[1]
53+
self.imaginaryMinimum = subArguments[2]
54+
self.imaginaryMaximum = subArguments[3]
55+
}
56+
57+
var defaultValueDescription: String {
58+
"\(self.realMinimum),\(self.realMaximum),\(self.imaginaryMinimum),\(self.imaginaryMaximum)"
59+
}
60+
}
+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import ArgumentParser
16+
import ModelSupport
17+
import TensorFlow
18+
19+
struct ImageSize {
20+
let width: Int
21+
let height: Int
22+
}
23+
24+
extension ImageSize: ExpressibleByArgument {
25+
init?(argument: String) {
26+
let subArguments = argument.split(separator: ",").compactMap { Int(String($0)) }
27+
guard subArguments.count >= 2 else { return nil }
28+
29+
self.width = subArguments[0]
30+
self.height = subArguments[1]
31+
}
32+
33+
var defaultValueDescription: String {
34+
"\(self.width) \(self.height)"
35+
}
36+
}
37+
38+
fileprivate func prismColor(_ value: Float, iterations: Int) -> [Float] {
39+
guard value < Float(iterations) else { return [0.0, 0.0, 0.0] }
40+
41+
let normalizedValue = value / Float(iterations)
42+
43+
// Values drawn from Matplotlib: https://github.com/matplotlib/matplotlib/blob/master/lib/matplotlib/_cm.py
44+
let red = (0.75 * sinf((normalizedValue * 20.9 + 0.25) * Float.pi) + 0.67) * 255
45+
let green = (0.75 * sinf((normalizedValue * 20.9 - 0.25) * Float.pi) + 0.33) * 255
46+
let blue = (-1.1 * sinf((normalizedValue * 20.9) * Float.pi)) * 255
47+
return [red, green, blue]
48+
}
49+
50+
func saveFractalImage(_ divergenceGrid: Tensor<Float>, iterations: Int, fileName: String) throws {
51+
let gridShape = divergenceGrid.shape
52+
53+
let colorValues: [Float] = divergenceGrid.scalars.reduce(into: []) {
54+
$0 += prismColor($1, iterations: iterations)
55+
}
56+
let colorImage = Tensor<Float>(
57+
shape: [gridShape[0], gridShape[1], 3], scalars: colorValues, on: divergenceGrid.device)
58+
59+
try saveImage(
60+
colorImage, shape: (gridShape[0], gridShape[1]),
61+
format: .rgb, directory: "./", name: fileName,
62+
quality: 95)
63+
}

Examples/Fractals/JuliaSet.swift

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import ArgumentParser
16+
import Foundation
17+
import TensorFlow
18+
19+
struct ComplexConstant {
20+
let real: Float
21+
let imaginary: Float
22+
}
23+
24+
func juliaSet(
25+
iterations: Int, constant: ComplexConstant, tolerance: Float, region: ComplexRegion,
26+
imageSize: ImageSize, device: Device
27+
) -> Tensor<Float> {
28+
let xs = Tensor<Float>(
29+
linearSpaceFrom: region.realMinimum, to: region.realMaximum, count: imageSize.width, on: device
30+
).broadcasted(to: [imageSize.width, imageSize.height])
31+
let ys = Tensor<Float>(
32+
linearSpaceFrom: region.imaginaryMaximum, to: region.imaginaryMinimum, count: imageSize.height,
33+
on: device
34+
).expandingShape(at: 1).broadcasted(to: [imageSize.width, imageSize.height])
35+
var Z = ComplexTensor(real: xs, imaginary: ys)
36+
let C = ComplexTensor(
37+
real: Tensor<Float>(repeating: constant.real, shape: xs.shape, on: device),
38+
imaginary: Tensor<Float>(repeating: constant.imaginary, shape: xs.shape, on: device))
39+
var divergence = Tensor<Float>(repeating: Float(iterations), shape: xs.shape, on: device)
40+
41+
// We'll make sure the initialization of these tensors doesn't carry
42+
// into the trace for the first iteration.
43+
LazyTensorBarrier()
44+
45+
let start = Date()
46+
var firstIteration = Date()
47+
48+
for iteration in 0..<iterations {
49+
Z = Z * Z + C
50+
51+
let aboveThreshold = abs(Z) .> tolerance
52+
divergence = divergence.replacing(
53+
with: min(divergence, Float(iteration)), where: aboveThreshold)
54+
55+
// We're cutting the trace to be a single iteration.
56+
LazyTensorBarrier()
57+
if iteration == 1 {
58+
firstIteration = Date()
59+
}
60+
}
61+
62+
print(
63+
"Total calculation time: \(String(format: "%.3f", Date().timeIntervalSince(start))) seconds")
64+
print(
65+
"Time after first iteration: \(String(format: "%.3f", Date().timeIntervalSince(firstIteration))) seconds"
66+
)
67+
68+
return divergence
69+
}
70+
71+
extension ComplexConstant: ExpressibleByArgument {
72+
init?(argument: String) {
73+
let subArguments = argument.split(separator: ",").compactMap { Float(String($0)) }
74+
guard subArguments.count >= 2 else { return nil }
75+
76+
self.real = subArguments[0]
77+
self.imaginary = subArguments[1]
78+
}
79+
80+
var defaultValueDescription: String {
81+
"\(self.real),\(self.imaginary)"
82+
}
83+
}

Examples/Fractals/MandelbrotSet.swift

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
import TensorFlow
17+
18+
func mandelbrotSet(
19+
iterations: Int, tolerance: Float, region: ComplexRegion, imageSize: ImageSize, device: Device
20+
) -> Tensor<Float> {
21+
let xs = Tensor<Float>(
22+
linearSpaceFrom: region.realMinimum, to: region.realMaximum, count: imageSize.width, on: device
23+
).broadcasted(to: [imageSize.width, imageSize.height])
24+
let ys = Tensor<Float>(
25+
linearSpaceFrom: region.imaginaryMaximum, to: region.imaginaryMinimum, count: imageSize.height,
26+
on: device
27+
).expandingShape(at: 1).broadcasted(to: [imageSize.width, imageSize.height])
28+
let X = ComplexTensor(real: xs, imaginary: ys)
29+
var Z = ComplexTensor(real: Tensor(zerosLike: xs), imaginary: Tensor(zerosLike: ys))
30+
var divergence = Tensor<Float>(repeating: Float(iterations), shape: xs.shape, on: device)
31+
32+
// We'll make sure the initialization of these tensors doesn't carry
33+
// into the trace for the first iteration.
34+
LazyTensorBarrier()
35+
36+
let start = Date()
37+
var firstIteration = Date()
38+
39+
for iteration in 0..<iterations {
40+
Z = Z * Z + X
41+
42+
let aboveThreshold = abs(Z) .> tolerance
43+
divergence = divergence.replacing(
44+
with: min(divergence, Float(iteration)), where: aboveThreshold)
45+
46+
// We're cutting the trace to be a single iteration.
47+
LazyTensorBarrier()
48+
if iteration == 1 {
49+
firstIteration = Date()
50+
}
51+
}
52+
53+
print(
54+
"Total calculation time: \(String(format: "%.3f", Date().timeIntervalSince(start))) seconds")
55+
print(
56+
"Time after first iteration: \(String(format: "%.3f", Date().timeIntervalSince(firstIteration))) seconds"
57+
)
58+
59+
return divergence
60+
}

0 commit comments

Comments
 (0)