Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit cdf7f12

Browse files
authored
Replace TensorFlow image loading with stb_image (#429)
* Replace TensorFlow image loading with Swim. * Formatting of Image.swift. * Adding Swim to CMake dependencies. * Adding tests for image loading and saving. * Replacing Swim with stb_image. * Removing Swim references in main CMakeLists. * Converting grayscale color channels to match RGB expectations. * Add directory prefix to CMake files. * Adding C language support to CMake. * Adding an StbImage modulemap. * Move StbImage CMakeLists into its own directory, make import @_implementationOnly, update with Saleem's suggestions. * Fix paths. * Renamed to STBImage, moved the modulemap.
1 parent a3bcac8 commit cdf7f12

17 files changed

+9545
-33
lines changed

.gitignore

-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
.build
77
*.xcodeproj
8-
*.png
98
.DS_Store
109
.swiftpm
1110
cifar-10-batches-py/

CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
cmake_minimum_required(VERSION 3.16)
22
project(Models
3-
LANGUAGES Swift)
3+
LANGUAGES C Swift)
44

55
if(CMAKE_VERSION VERSION_LESS 3.17)
66
if(NOT CMAKE_SYSTEM_NAME STREQUAL Windows)

Package.swift

+4-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ let package = Package(
4242
targets: [
4343
.target(name: "Batcher", path: "Batcher"),
4444
.target(name: "Datasets", dependencies: ["ModelSupport", "Batcher"], path: "Datasets"),
45-
.target(name: "ModelSupport", dependencies: ["SwiftProtobuf"], path: "Support"),
45+
.target(name: "STBImage", path: "Support/STBImage"),
46+
.target(
47+
name: "ModelSupport", dependencies: ["SwiftProtobuf", "STBImage"], path: "Support",
48+
exclude: ["STBImage"]),
4649
.target(name: "ImageClassificationModels", path: "Models/ImageClassification"),
4750
.target(name: "VideoClassificationModels", path: "Models/Spatiotemporal"),
4851
.target(name: "TextModels", dependencies: ["Datasets"], path: "Models/Text"),

Support/CMakeLists.txt

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
add_subdirectory(STBImage)
2+
13
add_library(ModelSupport
24
BijectiveDictionary.swift
35
Checkpoints/CheckpointIndexReader.swift
@@ -22,6 +24,7 @@ set_target_properties(ModelSupport PROPERTIES
2224
target_compile_options(ModelSupport PRIVATE
2325
$<$<BOOL:${BUILD_TESTING}>:-enable-testing>)
2426
target_link_libraries(ModelSupport PUBLIC
27+
STBImage
2528
SwiftProtobuf)
2629

2730

Support/Image.swift

+60-30
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,23 @@
1313
// limitations under the License.
1414

1515
import Foundation
16+
@_implementationOnly import STBImage
1617
import TensorFlow
1718

19+
// Image loading and saving is inspired by t-ae's Swim library: https://github.com/t-ae/swim
20+
// and uses the stb_image single-file C headers from https://github.com/nothings/stb .
21+
1822
public struct Image {
1923
public enum ByteOrdering {
2024
case bgr
2125
case rgb
2226
}
2327

28+
public enum Colorspace {
29+
case rgb
30+
case grayscale
31+
}
32+
2433
enum ImageTensor {
2534
case float(data: Tensor<Float>)
2635
case uint8(data: Tensor<UInt8>)
@@ -44,20 +53,41 @@ public struct Image {
4453
}
4554

4655
public init(jpeg url: URL, byteOrdering: ByteOrdering = .rgb) {
47-
let loadedFile = _Raw.readFile(filename: StringTensor(url.absoluteString))
48-
let loadedJpeg = _Raw.decodeJpeg(contents: loadedFile, channels: 3, dctMethod: "")
4956
if byteOrdering == .bgr {
50-
self.imageData = .uint8(
51-
data: _Raw.reverse(loadedJpeg, dims: Tensor<Bool>([false, false, false, true])))
57+
// TODO: Add BGR byte reordering.
58+
fatalError("BGR byte ordering is currently unsupported.")
5259
} else {
53-
self.imageData = .uint8(data: loadedJpeg)
60+
guard FileManager.default.fileExists(atPath: url.path) else {
61+
// TODO: Proper error propagation for this.
62+
fatalError("File does not exist at: \(url.path).")
63+
}
64+
65+
var width: Int32 = 0
66+
var height: Int32 = 0
67+
var bpp: Int32 = 0
68+
guard let bytes = stbi_load(url.path, &width, &height, &bpp, 0) else {
69+
// TODO: Proper error propagation for this.
70+
fatalError("Unable to read image at: \(url.path).")
71+
}
72+
73+
let data = [UInt8](UnsafeBufferPointer(start: bytes, count: Int(width * height * bpp)))
74+
stbi_image_free(bytes)
75+
var loadedTensor = Tensor<UInt8>(
76+
shape: [Int(height), Int(width), Int(bpp)], scalars: data)
77+
if bpp == 1 {
78+
loadedTensor = loadedTensor.broadcasted(to: [Int(height), Int(width), 3])
79+
}
80+
self.imageData = .uint8(data: loadedTensor)
5481
}
5582
}
5683

57-
public func save(to url: URL, format: _Raw.Format = .rgb, quality: Int64 = 95) {
84+
public func save(to url: URL, format: Colorspace = .rgb, quality: Int64 = 95) {
5885
let outputImageData: Tensor<UInt8>
86+
let bpp: Int32
87+
5988
switch format {
6089
case .grayscale:
90+
bpp = 1
6191
switch self.imageData {
6292
case let .uint8(data): outputImageData = data
6393
case let .float(data):
@@ -67,51 +97,51 @@ public struct Image {
6797
outputImageData = Tensor<UInt8>(adjustedData)
6898
}
6999
case .rgb:
100+
bpp = 3
70101
switch self.imageData {
71102
case let .uint8(data): outputImageData = data
72103
case let .float(data):
73-
outputImageData = Tensor<UInt8>(
74-
_Raw.clipByValue(t: data, clipValueMin: Tensor(0), clipValueMax: Tensor(255)))
104+
outputImageData = Tensor<UInt8>(data.clipped(min: 0, max: 255))
105+
}
106+
}
107+
108+
let height = Int32(outputImageData.shape[0])
109+
let width = Int32(outputImageData.shape[1])
110+
outputImageData.scalars.withUnsafeBufferPointer { bytes in
111+
let status = stbi_write_jpg(
112+
url.path, width, height, bpp, bytes.baseAddress!, Int32(quality))
113+
guard status != 0 else {
114+
// TODO: Proper error propagation for this.
115+
fatalError("Unable to save image to: \(url.path).")
75116
}
76-
default:
77-
print("Image saving isn't supported for the format \(format).")
78-
exit(-1)
79117
}
80-
81-
let encodedJpeg = _Raw.encodeJpeg(
82-
image: outputImageData, format: format, quality: quality, xmpMetadata: "")
83-
_Raw.writeFile(filename: StringTensor(url.absoluteString), contents: encodedJpeg)
84118
}
85119

86120
public func resized(to size: (Int, Int)) -> Image {
87121
switch self.imageData {
88122
case let .uint8(data):
89-
return Image(
90-
tensor: _Raw.resizeBilinear(
91-
images: Tensor<UInt8>([data]),
92-
size: Tensor<Int32>([Int32(size.0), Int32(size.1)])).squeezingShape(at: 0))
123+
let resizedImage = resize(images: Tensor<Float>(data), size: size, method: .bilinear)
124+
return Image(tensor: Tensor<UInt8>(resizedImage))
93125
case let .float(data):
94-
return Image(
95-
tensor: _Raw.resizeBilinear(
96-
images: Tensor<Float>([data]),
97-
size: Tensor<Int32>([Int32(size.0), Int32(size.1)])).squeezingShape(at: 0))
126+
let resizedImage = resize(images: data, size: size, method: .bilinear)
127+
return Image(tensor: resizedImage)
98128
}
99-
100129
}
101130
}
102131

103-
public func saveImage(_ tensor: Tensor<Float>, shape: (Int, Int), size: (Int, Int)? = nil,
104-
format: _Raw.Format = .rgb, directory: String, name: String,
105-
quality: Int64 = 95) throws {
132+
public func saveImage(
133+
_ tensor: Tensor<Float>, shape: (Int, Int), size: (Int, Int)? = nil,
134+
format: Image.Colorspace = .rgb, directory: String, name: String,
135+
quality: Int64 = 95
136+
) throws {
106137
try createDirectoryIfMissing(at: directory)
138+
107139
let channels: Int
108140
switch format {
109141
case .rgb: channels = 3
110142
case .grayscale: channels = 1
111-
default:
112-
print("\(format) is not supported yet.")
113-
exit(-1)
114143
}
144+
115145
let reshapedTensor = tensor.reshaped(to: [shape.0, shape.1, channels])
116146
let image = Image(tensor: reshapedTensor)
117147
let resizedImage = size != nil ? image.resized(to: (size!.0, size!.1)) : image

Support/STBImage/CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
add_library(STBImage STATIC
2+
stb_image_write.c
3+
stb_image.c)
4+
target_include_directories(STBImage PUBLIC
5+
include)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
module STBImage {
2+
header "stb_image.h"
3+
header "stb_image_write.h"
4+
export *
5+
}

0 commit comments

Comments
 (0)