Skip to content

Commit c6478b4

Browse files
authored
feature(shortfin_apps): move png construction to server side (#1025)
1 parent 2705fde commit c6478b4

File tree

8 files changed

+186
-71
lines changed

8 files changed

+186
-71
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2025 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
from PIL.Image import Image
8+
9+
from dataclasses import dataclass
10+
11+
12+
@dataclass
13+
class TextToImageInferenceOutput:
14+
image: Image

shortfin/python/shortfin_apps/sd/components/generate.py

+50-5
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,16 @@
99
import json
1010

1111
from typing import (
12+
TypeVar,
1213
Union,
1314
)
1415

16+
from shortfin_apps.types.Base64CharacterEncodedByteSequence import (
17+
Base64CharacterEncodedByteSequence,
18+
)
19+
20+
from shortfin_apps.utilities.image import png_from
21+
1522
import shortfin as sf
1623

1724
# TODO: Have a generic "Responder" interface vs just the concrete impl.
@@ -21,6 +28,7 @@
2128
from .messages import SDXLInferenceExecRequest
2229
from .service import SDXLGenerateService
2330
from .metrics import measure
31+
from .TextToImageInferenceOutput import TextToImageInferenceOutput
2432

2533
logger = logging.getLogger("shortfin-sd.generate")
2634

@@ -45,13 +53,38 @@ def __init__(
4553
self.client = client
4654
self.gen_req = gen_req
4755
self.index = index
48-
self.result_image: Union[str, None] = None
56+
self.output: Union[TextToImageInferenceOutput, None] = None
4957

5058
async def run(self):
5159
exec = SDXLInferenceExecRequest.from_batch(self.gen_req, self.index)
5260
self.client.batcher.submit(exec)
5361
await exec.done
54-
self.result_image = exec.result_image
62+
63+
self.output = (
64+
TextToImageInferenceOutput(exec.response_image)
65+
if exec.response_image
66+
else None
67+
)
68+
69+
70+
Item = TypeVar("Item")
71+
72+
73+
def from_batch(
74+
given_subject: list[Item] | Item | None,
75+
given_batch_index,
76+
) -> Item:
77+
if given_subject is None:
78+
raise Exception("Expected an item or batch of items but got `None`")
79+
80+
if not isinstance(given_subject, list):
81+
return given_subject
82+
83+
# some args are broadcasted to each prompt, hence overriding index for single-item entries
84+
if len(given_subject) == 1:
85+
return given_subject[0]
86+
87+
return given_subject[given_batch_index]
5588

5689

5790
class ClientGenerateBatchProcess(sf.Process):
@@ -99,8 +132,20 @@ async def run(self):
99132

100133
# TODO: stream image outputs
101134
logging.debug("Responding to one shot batch")
102-
response_data = {"images": [p.result_image for p in gen_processes]}
103-
json_str = json.dumps(response_data)
104-
self.responder.send_response(json_str)
135+
136+
png_images: list[Base64CharacterEncodedByteSequence] = []
137+
138+
for index_of_each_process, each_process in enumerate(gen_processes):
139+
if each_process.output is None:
140+
raise Exception(
141+
f"Expected output for process {index_of_each_process} but got `None`"
142+
)
143+
144+
each_png_image = png_from(each_process.output.image)
145+
png_images.append(each_png_image)
146+
147+
response_body = {"images": png_images}
148+
response_body_in_json = json.dumps(response_body)
149+
self.responder.send_response(response_body_in_json)
105150
finally:
106151
self.responder.ensure_response()

shortfin/python/shortfin_apps/sd/components/messages.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
Union,
1111
)
1212

13+
from PIL.Image import Image
14+
1315
import logging
1416

1517
import shortfin as sf
@@ -90,7 +92,7 @@ def __init__(
9092
# Decode phase.
9193
self.image_array = image_array
9294

93-
self.result_image: Union[str, None] = None
95+
self.response_image: Union[Image, None] = None
9496

9597
self.done = sf.VoidFuture()
9698

shortfin/python/shortfin_apps/sd/components/service.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from pathlib import Path
1313
from PIL import Image
1414
from collections import namedtuple
15-
import base64
1615

1716
import shortfin as sf
1817
import shortfin.array as sfnp
@@ -435,10 +434,8 @@ async def _postprocess(self, device):
435434
# TODO: reimpl with sfnp
436435
permuted = np.transpose(self.exec_request.image_array, (0, 2, 3, 1))[0]
437436
cast_image = (permuted * 255).round().astype("uint8")
438-
image_bytes = Image.fromarray(cast_image).tobytes()
439-
440-
image = base64.b64encode(image_bytes).decode("utf-8")
441-
self.exec_request.result_image = image
437+
processed_image = Image.fromarray(cast_image)
438+
self.exec_request.response_image = processed_image
442439
return
443440

444441

shortfin/python/shortfin_apps/sd/simple_client.py

+23-33
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,21 @@
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
from datetime import datetime as dt
8-
import os
98
import time
109
import json
1110
import argparse
12-
import base64
1311
import asyncio
1412
import aiohttp
1513
import requests
1614

17-
from PIL import Image
15+
from shortfin_apps.types.Base64CharacterEncodedByteSequence import (
16+
Base64CharacterEncodedByteSequence,
17+
)
18+
19+
from shortfin_apps.utilities.image import (
20+
save_to_file,
21+
image_from,
22+
)
1823

1924
sample_request = {
2025
"prompt": [
@@ -31,30 +36,6 @@
3136
}
3237

3338

34-
def bytes_to_img(in_bytes, outputdir, idx=0, width=1024, height=1024):
35-
timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S")
36-
image = Image.frombytes(
37-
mode="RGB", size=(width, height), data=base64.b64decode(in_bytes)
38-
)
39-
if not os.path.isdir(outputdir):
40-
os.mkdir(outputdir)
41-
im_path = os.path.join(outputdir, f"shortfin_sd_output_{timestamp}_{idx}.png")
42-
image.save(im_path)
43-
print(f"Saved to {im_path}")
44-
45-
46-
def get_batched(request, arg, idx):
47-
if isinstance(request[arg], list):
48-
# some args are broadcasted to each prompt, hence overriding idx for single-item entries
49-
if len(request[arg]) == 1:
50-
indexed = request[arg][0]
51-
else:
52-
indexed = request[arg][idx]
53-
else:
54-
indexed = request[arg]
55-
return indexed
56-
57-
5839
async def send_request(session: aiohttp.ClientSession, rep, args, data):
5940
print("Sending request batch #", rep)
6041
url = f"{args.host}:{args.port}/generate"
@@ -66,13 +47,22 @@ async def send_request(session: aiohttp.ClientSession, rep, args, data):
6647
response.raise_for_status() # Raise an error for bad responses
6748
res_json = await response.json(content_type=None)
6849
if args.save:
69-
for idx, item in enumerate(res_json["images"]):
70-
width = get_batched(data, "width", idx)
71-
height = get_batched(data, "height", idx)
72-
print("Saving response as image...")
73-
bytes_to_img(
74-
item.encode("utf-8"), args.outputdir, idx, width, height
50+
for idx, each_png in enumerate(res_json["images"]):
51+
if not isinstance(each_png, str):
52+
raise ValueError(f"png was not string at index {idx}")
53+
54+
each_image = image_from(
55+
Base64CharacterEncodedByteSequence(each_png)
56+
)
57+
58+
timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S")
59+
each_file_name = f"shortfin_sd_output_{timestamp}_{idx}.png"
60+
61+
each_file_path = save_to_file(
62+
each_image, args.outputdir, each_file_name
7563
)
64+
65+
print(f"Saved to {each_file_path}")
7666
latency = end - start
7767
print("Responses processed.")
7868
return latency, len(data["prompt"])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import re
2+
import binascii
3+
4+
5+
class Base64CharacterEncodedByteSequence(str):
6+
"""
7+
A sequence of 8-bit integers encoded using the [Base64 alphabet](https://www.rfc-editor.org/rfc/rfc4648.html#section-4).
8+
"""
9+
10+
pattern = (
11+
r"^(?:[A-Za-z0-9+\/]{4})*(?:[A-Za-z0-9+\/]{3}={1}|[A-Za-z0-9+\/]{2}={2})?$"
12+
)
13+
14+
def __new__(Self, given_subject: str):
15+
if re.match(Self.pattern, given_subject) == None:
16+
raise ValueError("String cannot be interpreted as a byte sequence")
17+
18+
return super().__new__(Self, given_subject)
19+
20+
@property
21+
def as_bytes(self):
22+
base64_integer_encoded_byte_sequence = self.encode()
23+
24+
derived_raw_byte_sequence = binascii.a2b_base64(
25+
base64_integer_encoded_byte_sequence
26+
)
27+
28+
return derived_raw_byte_sequence
29+
30+
@classmethod
31+
def decoded_from(Self, given_raw_byte_sequence: bytes):
32+
base64_integer_encoded_byte_sequence = binascii.b2a_base64(
33+
given_raw_byte_sequence, newline=False
34+
)
35+
36+
base64_character_encoded_byte_sequence = (
37+
base64_integer_encoded_byte_sequence.decode()
38+
)
39+
40+
return Self(base64_character_encoded_byte_sequence)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import os
2+
3+
from io import (
4+
BytesIO,
5+
)
6+
7+
from PIL import Image
8+
9+
from shortfin_apps.types.Base64CharacterEncodedByteSequence import (
10+
Base64CharacterEncodedByteSequence,
11+
)
12+
13+
14+
def save_to_file(
15+
given_image: Image.Image,
16+
given_directory: str,
17+
given_file_name: str,
18+
) -> str:
19+
if not os.path.isdir(given_directory):
20+
os.mkdir(given_directory)
21+
derived_file_path = os.path.join(given_directory, given_file_name)
22+
given_image.save(derived_file_path)
23+
return derived_file_path
24+
25+
26+
def png_from(given_image: Image.Image) -> Base64CharacterEncodedByteSequence:
27+
memory_for_png = BytesIO()
28+
given_image.save(memory_for_png, format="PNG")
29+
png_from_memory = memory_for_png.getvalue()
30+
return Base64CharacterEncodedByteSequence.decoded_from(png_from_memory)
31+
32+
33+
def image_from(given_png: Base64CharacterEncodedByteSequence) -> Image.Image:
34+
memory_for_png = BytesIO(given_png.as_bytes)
35+
return Image.open(memory_for_png, formats=["PNG"])

shortfin/tests/apps/sd/e2e_test.py

+19-27
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7-
import json
87
import requests
98
import time
10-
import base64
119
import pytest
1210
import subprocess
1311
import os
@@ -16,7 +14,15 @@
1614
import copy
1715
from contextlib import closing
1816

19-
from PIL import Image
17+
from PIL.Image import Image
18+
19+
from shortfin_apps.types.Base64CharacterEncodedByteSequence import (
20+
Base64CharacterEncodedByteSequence,
21+
)
22+
23+
from shortfin_apps.utilities.image import (
24+
image_from,
25+
)
2026

2127
BATCH_SIZES = [1]
2228

@@ -197,17 +203,10 @@ def __del__(self):
197203
process.wait()
198204

199205

200-
def bytes_to_img(bytes, idx=0, width=1024, height=1024):
201-
image = Image.frombytes(
202-
mode="RGB", size=(width, height), data=base64.b64decode(bytes)
203-
)
204-
return image
205-
206-
207206
def send_json_file(url="http://0.0.0.0:8000", num_copies=1):
208207
# Read the JSON file
209208
data = copy.deepcopy(sample_request)
210-
imgs = []
209+
imgs: list[Image] = []
211210
# Send the data to the /generate endpoint
212211
data["prompt"] = (
213212
[data["prompt"]]
@@ -217,30 +216,23 @@ def send_json_file(url="http://0.0.0.0:8000", num_copies=1):
217216
try:
218217
response = requests.post(url + "/generate", json=data)
219218
response.raise_for_status() # Raise an error for bad responses
220-
request = json.loads(response.request.body.decode("utf-8"))
219+
response_body = response.json()
220+
221+
for idx, each_png in enumerate(response_body["images"]):
222+
if not isinstance(each_png, str):
223+
raise ValueError(
224+
f"Expected string-encoded png at index {idx}, found {each_png}"
225+
)
221226

222-
for idx, item in enumerate(response.json()["images"]):
223-
width = getbatched(request, idx, "width")
224-
height = getbatched(request, idx, "height")
225-
img = bytes_to_img(item.encode("utf-8"), idx, width, height)
226-
imgs.append(img)
227+
each_image = image_from(Base64CharacterEncodedByteSequence(each_png))
228+
imgs.append(each_image)
227229

228230
except requests.exceptions.RequestException as e:
229231
print(f"Error sending the request: {e}")
230232

231233
return imgs, response.status_code
232234

233235

234-
def getbatched(req, idx, key):
235-
if isinstance(req[key], list):
236-
if len(req[key]) == 1:
237-
return req[key][0]
238-
elif len(req[key]) > idx:
239-
return req[key][idx]
240-
else:
241-
return req[key]
242-
243-
244236
def find_free_port():
245237
"""This tries to find a free port to run a server on for the test.
246238

0 commit comments

Comments
 (0)