1313// limitations under the License.
1414
1515import Foundation
16+ @_implementationOnly import STBImage
1617import TensorFlow
1718
19+ // Image loading and saving is inspired by t-ae's Swim library: https://github.com/t-ae/swim
20+ // and uses the stb_image single-file C headers from https://github.com/nothings/stb .
21+
1822public struct Image {
1923 public enum ByteOrdering {
2024 case bgr
2125 case rgb
2226 }
2327
28+ public enum Colorspace {
29+ case rgb
30+ case grayscale
31+ }
32+
2433 enum ImageTensor {
2534 case float( data: Tensor < Float > )
2635 case uint8( data: Tensor < UInt8 > )
@@ -44,20 +53,41 @@ public struct Image {
4453 }
4554
4655 public init ( jpeg url: URL , byteOrdering: ByteOrdering = . rgb) {
47- let loadedFile = _Raw. readFile ( filename: StringTensor ( url. absoluteString) )
48- let loadedJpeg = _Raw. decodeJpeg ( contents: loadedFile, channels: 3 , dctMethod: " " )
4956 if byteOrdering == . bgr {
50- self . imageData = . uint8 (
51- data : _Raw . reverse ( loadedJpeg , dims : Tensor < Bool > ( [ false , false , false , true ] ) ) )
57+ // TODO: Add BGR byte reordering.
58+ fatalError ( " BGR byte ordering is currently unsupported. " )
5259 } else {
53- self . imageData = . uint8( data: loadedJpeg)
60+ guard FileManager . default. fileExists ( atPath: url. path) else {
61+ // TODO: Proper error propagation for this.
62+ fatalError ( " File does not exist at: \( url. path) . " )
63+ }
64+
65+ var width : Int32 = 0
66+ var height : Int32 = 0
67+ var bpp : Int32 = 0
68+ guard let bytes = stbi_load ( url. path, & width, & height, & bpp, 0 ) else {
69+ // TODO: Proper error propagation for this.
70+ fatalError ( " Unable to read image at: \( url. path) . " )
71+ }
72+
73+ let data = [ UInt8] ( UnsafeBufferPointer ( start: bytes, count: Int ( width * height * bpp) ) )
74+ stbi_image_free ( bytes)
75+ var loadedTensor = Tensor < UInt8 > (
76+ shape: [ Int ( height) , Int ( width) , Int ( bpp) ] , scalars: data)
77+ if bpp == 1 {
78+ loadedTensor = loadedTensor. broadcasted ( to: [ Int ( height) , Int ( width) , 3 ] )
79+ }
80+ self . imageData = . uint8( data: loadedTensor)
5481 }
5582 }
5683
57- public func save( to url: URL , format: _Raw . Format = . rgb, quality: Int64 = 95 ) {
84+ public func save( to url: URL , format: Colorspace = . rgb, quality: Int64 = 95 ) {
5885 let outputImageData : Tensor < UInt8 >
86+ let bpp : Int32
87+
5988 switch format {
6089 case . grayscale:
90+ bpp = 1
6191 switch self . imageData {
6292 case let . uint8( data) : outputImageData = data
6393 case let . float( data) :
@@ -67,51 +97,51 @@ public struct Image {
6797 outputImageData = Tensor < UInt8 > ( adjustedData)
6898 }
6999 case . rgb:
100+ bpp = 3
70101 switch self . imageData {
71102 case let . uint8( data) : outputImageData = data
72103 case let . float( data) :
73- outputImageData = Tensor < UInt8 > (
74- _Raw. clipByValue ( t: data, clipValueMin: Tensor ( 0 ) , clipValueMax: Tensor ( 255 ) ) )
104+ outputImageData = Tensor < UInt8 > ( data. clipped ( min: 0 , max: 255 ) )
105+ }
106+ }
107+
108+ let height = Int32 ( outputImageData. shape [ 0 ] )
109+ let width = Int32 ( outputImageData. shape [ 1 ] )
110+ outputImageData. scalars. withUnsafeBufferPointer { bytes in
111+ let status = stbi_write_jpg (
112+ url. path, width, height, bpp, bytes. baseAddress!, Int32 ( quality) )
113+ guard status != 0 else {
114+ // TODO: Proper error propagation for this.
115+ fatalError ( " Unable to save image to: \( url. path) . " )
75116 }
76- default :
77- print ( " Image saving isn't supported for the format \( format) . " )
78- exit ( - 1 )
79117 }
80-
81- let encodedJpeg = _Raw. encodeJpeg (
82- image: outputImageData, format: format, quality: quality, xmpMetadata: " " )
83- _Raw. writeFile ( filename: StringTensor ( url. absoluteString) , contents: encodedJpeg)
84118 }
85119
86120 public func resized( to size: ( Int , Int ) ) -> Image {
87121 switch self . imageData {
88122 case let . uint8( data) :
89- return Image (
90- tensor: _Raw. resizeBilinear (
91- images: Tensor < UInt8 > ( [ data] ) ,
92- size: Tensor < Int32 > ( [ Int32 ( size. 0 ) , Int32 ( size. 1 ) ] ) ) . squeezingShape ( at: 0 ) )
123+ let resizedImage = resize ( images: Tensor < Float > ( data) , size: size, method: . bilinear)
124+ return Image ( tensor: Tensor < UInt8 > ( resizedImage) )
93125 case let . float( data) :
94- return Image (
95- tensor: _Raw. resizeBilinear (
96- images: Tensor < Float > ( [ data] ) ,
97- size: Tensor < Int32 > ( [ Int32 ( size. 0 ) , Int32 ( size. 1 ) ] ) ) . squeezingShape ( at: 0 ) )
126+ let resizedImage = resize ( images: data, size: size, method: . bilinear)
127+ return Image ( tensor: resizedImage)
98128 }
99-
100129 }
101130}
102131
103- public func saveImage( _ tensor: Tensor < Float > , shape: ( Int , Int ) , size: ( Int , Int ) ? = nil ,
104- format: _Raw . Format = . rgb, directory: String , name: String ,
105- quality: Int64 = 95 ) throws {
132+ public func saveImage(
133+ _ tensor: Tensor < Float > , shape: ( Int , Int ) , size: ( Int , Int ) ? = nil ,
134+ format: Image . Colorspace = . rgb, directory: String , name: String ,
135+ quality: Int64 = 95
136+ ) throws {
106137 try createDirectoryIfMissing ( at: directory)
138+
107139 let channels : Int
108140 switch format {
109141 case . rgb: channels = 3
110142 case . grayscale: channels = 1
111- default :
112- print ( " \( format) is not supported yet. " )
113- exit ( - 1 )
114143 }
144+
115145 let reshapedTensor = tensor. reshaped ( to: [ shape. 0 , shape. 1 , channels] )
116146 let image = Image ( tensor: reshapedTensor)
117147 let resizedImage = size != nil ? image. resized ( to: ( size!. 0 , size!. 1 ) ) : image
0 commit comments