13
13
// limitations under the License.
14
14
15
15
import Foundation
16
+ @_implementationOnly import STBImage
16
17
import TensorFlow
17
18
19
+ // Image loading and saving is inspired by t-ae's Swim library: https://github.com/t-ae/swim
20
+ // and uses the stb_image single-file C headers from https://github.com/nothings/stb .
21
+
18
22
public struct Image {
19
23
public enum ByteOrdering {
20
24
case bgr
21
25
case rgb
22
26
}
23
27
28
+ public enum Colorspace {
29
+ case rgb
30
+ case grayscale
31
+ }
32
+
24
33
enum ImageTensor {
25
34
case float( data: Tensor < Float > )
26
35
case uint8( data: Tensor < UInt8 > )
@@ -44,20 +53,41 @@ public struct Image {
44
53
}
45
54
46
55
public init ( jpeg url: URL , byteOrdering: ByteOrdering = . rgb) {
47
- let loadedFile = _Raw. readFile ( filename: StringTensor ( url. absoluteString) )
48
- let loadedJpeg = _Raw. decodeJpeg ( contents: loadedFile, channels: 3 , dctMethod: " " )
49
56
if byteOrdering == . bgr {
50
- self . imageData = . uint8 (
51
- data : _Raw . reverse ( loadedJpeg , dims : Tensor < Bool > ( [ false , false , false , true ] ) ) )
57
+ // TODO: Add BGR byte reordering.
58
+ fatalError ( " BGR byte ordering is currently unsupported. " )
52
59
} else {
53
- self . imageData = . uint8( data: loadedJpeg)
60
+ guard FileManager . default. fileExists ( atPath: url. path) else {
61
+ // TODO: Proper error propagation for this.
62
+ fatalError ( " File does not exist at: \( url. path) . " )
63
+ }
64
+
65
+ var width : Int32 = 0
66
+ var height : Int32 = 0
67
+ var bpp : Int32 = 0
68
+ guard let bytes = stbi_load ( url. path, & width, & height, & bpp, 0 ) else {
69
+ // TODO: Proper error propagation for this.
70
+ fatalError ( " Unable to read image at: \( url. path) . " )
71
+ }
72
+
73
+ let data = [ UInt8] ( UnsafeBufferPointer ( start: bytes, count: Int ( width * height * bpp) ) )
74
+ stbi_image_free ( bytes)
75
+ var loadedTensor = Tensor < UInt8 > (
76
+ shape: [ Int ( height) , Int ( width) , Int ( bpp) ] , scalars: data)
77
+ if bpp == 1 {
78
+ loadedTensor = loadedTensor. broadcasted ( to: [ Int ( height) , Int ( width) , 3 ] )
79
+ }
80
+ self . imageData = . uint8( data: loadedTensor)
54
81
}
55
82
}
56
83
57
- public func save( to url: URL , format: _Raw . Format = . rgb, quality: Int64 = 95 ) {
84
+ public func save( to url: URL , format: Colorspace = . rgb, quality: Int64 = 95 ) {
58
85
let outputImageData : Tensor < UInt8 >
86
+ let bpp : Int32
87
+
59
88
switch format {
60
89
case . grayscale:
90
+ bpp = 1
61
91
switch self . imageData {
62
92
case let . uint8( data) : outputImageData = data
63
93
case let . float( data) :
@@ -67,51 +97,51 @@ public struct Image {
67
97
outputImageData = Tensor < UInt8 > ( adjustedData)
68
98
}
69
99
case . rgb:
100
+ bpp = 3
70
101
switch self . imageData {
71
102
case let . uint8( data) : outputImageData = data
72
103
case let . float( data) :
73
- outputImageData = Tensor < UInt8 > (
74
- _Raw. clipByValue ( t: data, clipValueMin: Tensor ( 0 ) , clipValueMax: Tensor ( 255 ) ) )
104
+ outputImageData = Tensor < UInt8 > ( data. clipped ( min: 0 , max: 255 ) )
105
+ }
106
+ }
107
+
108
+ let height = Int32 ( outputImageData. shape [ 0 ] )
109
+ let width = Int32 ( outputImageData. shape [ 1 ] )
110
+ outputImageData. scalars. withUnsafeBufferPointer { bytes in
111
+ let status = stbi_write_jpg (
112
+ url. path, width, height, bpp, bytes. baseAddress!, Int32 ( quality) )
113
+ guard status != 0 else {
114
+ // TODO: Proper error propagation for this.
115
+ fatalError ( " Unable to save image to: \( url. path) . " )
75
116
}
76
- default :
77
- print ( " Image saving isn't supported for the format \( format) . " )
78
- exit ( - 1 )
79
117
}
80
-
81
- let encodedJpeg = _Raw. encodeJpeg (
82
- image: outputImageData, format: format, quality: quality, xmpMetadata: " " )
83
- _Raw. writeFile ( filename: StringTensor ( url. absoluteString) , contents: encodedJpeg)
84
118
}
85
119
86
120
public func resized( to size: ( Int , Int ) ) -> Image {
87
121
switch self . imageData {
88
122
case let . uint8( data) :
89
- return Image (
90
- tensor: _Raw. resizeBilinear (
91
- images: Tensor < UInt8 > ( [ data] ) ,
92
- size: Tensor < Int32 > ( [ Int32 ( size. 0 ) , Int32 ( size. 1 ) ] ) ) . squeezingShape ( at: 0 ) )
123
+ let resizedImage = resize ( images: Tensor < Float > ( data) , size: size, method: . bilinear)
124
+ return Image ( tensor: Tensor < UInt8 > ( resizedImage) )
93
125
case let . float( data) :
94
- return Image (
95
- tensor: _Raw. resizeBilinear (
96
- images: Tensor < Float > ( [ data] ) ,
97
- size: Tensor < Int32 > ( [ Int32 ( size. 0 ) , Int32 ( size. 1 ) ] ) ) . squeezingShape ( at: 0 ) )
126
+ let resizedImage = resize ( images: data, size: size, method: . bilinear)
127
+ return Image ( tensor: resizedImage)
98
128
}
99
-
100
129
}
101
130
}
102
131
103
- public func saveImage( _ tensor: Tensor < Float > , shape: ( Int , Int ) , size: ( Int , Int ) ? = nil ,
104
- format: _Raw . Format = . rgb, directory: String , name: String ,
105
- quality: Int64 = 95 ) throws {
132
+ public func saveImage(
133
+ _ tensor: Tensor < Float > , shape: ( Int , Int ) , size: ( Int , Int ) ? = nil ,
134
+ format: Image . Colorspace = . rgb, directory: String , name: String ,
135
+ quality: Int64 = 95
136
+ ) throws {
106
137
try createDirectoryIfMissing ( at: directory)
138
+
107
139
let channels : Int
108
140
switch format {
109
141
case . rgb: channels = 3
110
142
case . grayscale: channels = 1
111
- default :
112
- print ( " \( format) is not supported yet. " )
113
- exit ( - 1 )
114
143
}
144
+
115
145
let reshapedTensor = tensor. reshaped ( to: [ shape. 0 , shape. 1 , channels] )
116
146
let image = Image ( tensor: reshapedTensor)
117
147
let resizedImage = size != nil ? image. resized ( to: ( size!. 0 , size!. 1 ) ) : image
0 commit comments