5
5
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
6
7
7
from datetime import datetime as dt
8
- import os
9
8
import time
10
9
import json
11
10
import argparse
12
- import base64
13
11
import asyncio
14
12
import aiohttp
15
13
import requests
16
14
17
- from PIL import Image
15
+ from shortfin_apps .types .Base64CharacterEncodedByteSequence import (
16
+ Base64CharacterEncodedByteSequence ,
17
+ )
18
+
19
+ from shortfin_apps .utilities .image import (
20
+ save_to_file ,
21
+ image_from ,
22
+ )
18
23
19
24
sample_request = {
20
25
"prompt" : [
31
36
}
32
37
33
38
34
- def bytes_to_img (in_bytes , outputdir , idx = 0 , width = 1024 , height = 1024 ):
35
- timestamp = dt .now ().strftime ("%Y-%m-%d_%H-%M-%S" )
36
- image = Image .frombytes (
37
- mode = "RGB" , size = (width , height ), data = base64 .b64decode (in_bytes )
38
- )
39
- if not os .path .isdir (outputdir ):
40
- os .mkdir (outputdir )
41
- im_path = os .path .join (outputdir , f"shortfin_sd_output_{ timestamp } _{ idx } .png" )
42
- image .save (im_path )
43
- print (f"Saved to { im_path } " )
44
-
45
-
46
- def get_batched (request , arg , idx ):
47
- if isinstance (request [arg ], list ):
48
- # some args are broadcasted to each prompt, hence overriding idx for single-item entries
49
- if len (request [arg ]) == 1 :
50
- indexed = request [arg ][0 ]
51
- else :
52
- indexed = request [arg ][idx ]
53
- else :
54
- indexed = request [arg ]
55
- return indexed
56
-
57
-
58
39
async def send_request (session : aiohttp .ClientSession , rep , args , data ):
59
40
print ("Sending request batch #" , rep )
60
41
url = f"{ args .host } :{ args .port } /generate"
@@ -66,13 +47,22 @@ async def send_request(session: aiohttp.ClientSession, rep, args, data):
66
47
response .raise_for_status () # Raise an error for bad responses
67
48
res_json = await response .json (content_type = None )
68
49
if args .save :
69
- for idx , item in enumerate (res_json ["images" ]):
70
- width = get_batched (data , "width" , idx )
71
- height = get_batched (data , "height" , idx )
72
- print ("Saving response as image..." )
73
- bytes_to_img (
74
- item .encode ("utf-8" ), args .outputdir , idx , width , height
50
+ for idx , each_png in enumerate (res_json ["images" ]):
51
+ if not isinstance (each_png , str ):
52
+ raise ValueError (f"png was not string at index { idx } " )
53
+
54
+ each_image = image_from (
55
+ Base64CharacterEncodedByteSequence (each_png )
56
+ )
57
+
58
+ timestamp = dt .now ().strftime ("%Y-%m-%d_%H-%M-%S" )
59
+ each_file_name = f"shortfin_sd_output_{ timestamp } _{ idx } .png"
60
+
61
+ each_file_path = save_to_file (
62
+ each_image , args .outputdir , each_file_name
75
63
)
64
+
65
+ print (f"Saved to { each_file_path } " )
76
66
latency = end - start
77
67
print ("Responses processed." )
78
68
return latency , len (data ["prompt" ])
0 commit comments