diff --git a/shortfin/python/shortfin_apps/sd/components/generate.py b/shortfin/python/shortfin_apps/sd/components/generate.py index fee914a4e..508ecd690 100644 --- a/shortfin/python/shortfin_apps/sd/components/generate.py +++ b/shortfin/python/shortfin_apps/sd/components/generate.py @@ -5,13 +5,23 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import asyncio +import base64 import logging import json from typing import ( + TypeVar, Union, ) +from PIL import Image + +from shortfin_apps.types.Base64CharacterEncodedByteSequence import ( + Base64CharacterEncodedByteSequence, +) + +from shortfin_apps.utilities.image import png_from + import shortfin as sf # TODO: Have a generic "Responder" interface vs just the concrete impl. @@ -54,6 +64,26 @@ async def run(self): self.result_image = exec.result_image +Item = TypeVar("Item") + + +def from_batch( + given_subject: list[Item] | Item | None, + given_batch_index, +) -> Item: + if given_subject is None: + raise Exception("Expected an item or batch of items but got `None`") + + if not isinstance(given_subject, list): + return given_subject + + # some args are broadcasted to each prompt, hence overriding index for single-item entries + if len(given_subject) == 1: + return given_subject[0] + + return given_subject[given_batch_index] + + class ClientGenerateBatchProcess(sf.Process): """Process instantiated for handling a batch from a client. @@ -99,8 +129,34 @@ async def run(self): # TODO: stream image outputs logging.debug("Responding to one shot batch") - response_data = {"images": [p.result_image for p in gen_processes]} - json_str = json.dumps(response_data) - self.responder.send_response(json_str) + + png_images: list[Base64CharacterEncodedByteSequence] = [] + + for index_of_each_process, each_process in enumerate(gen_processes): + if each_process.result_image is None: + raise Exception( + f"Expected image result for batch {index_of_each_process} but got `None`" + ) + + size_of_each_image = ( + from_batch(self.gen_req.width, index_of_each_process), + from_batch(self.gen_req.height, index_of_each_process), + ) + + rgb_sequence_of_each_image = Base64CharacterEncodedByteSequence( + each_process.result_image + ) + + each_image = Image.frombytes( + mode="RGB", + size=size_of_each_image, + data=rgb_sequence_of_each_image.as_bytes, + ) + + png_images.append(png_from(each_image)) + + response_body = {"images": png_images} + response_body_in_json = json.dumps(response_body) + self.responder.send_response(response_body_in_json) finally: self.responder.ensure_response() diff --git a/shortfin/python/shortfin_apps/sd/simple_client.py b/shortfin/python/shortfin_apps/sd/simple_client.py index 0cf42f953..3e225c269 100644 --- a/shortfin/python/shortfin_apps/sd/simple_client.py +++ b/shortfin/python/shortfin_apps/sd/simple_client.py @@ -5,19 +5,20 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from datetime import datetime as dt -import os import time import json import argparse -import base64 import asyncio import aiohttp import requests -from PIL import Image +from shortfin_apps.types.Base64CharacterEncodedByteSequence import ( + Base64CharacterEncodedByteSequence, +) from shortfin_apps.utilities.image import ( save_to_file, + image_from, ) sample_request = { @@ -35,18 +36,6 @@ } -def get_batched(request, arg, idx): - if isinstance(request[arg], list): - # some args are broadcasted to each prompt, hence overriding idx for single-item entries - if len(request[arg]) == 1: - indexed = request[arg][0] - else: - indexed = request[arg][idx] - else: - indexed = request[arg] - return indexed - - async def send_request(session: aiohttp.ClientSession, rep, args, data): print("Sending request batch #", rep) url = f"{args.host}:{args.port}/generate" @@ -58,15 +47,12 @@ async def send_request(session: aiohttp.ClientSession, rep, args, data): response.raise_for_status() # Raise an error for bad responses res_json = await response.json(content_type=None) if args.save: - for idx, item in enumerate(res_json["images"]): - width = get_batched(data, "width", idx) - height = get_batched(data, "height", idx) - print("Saving response as image...") - - each_image = Image.frombytes( - mode="RGB", - size=(width, height), - data=base64.b64decode(item.encode("utf-8")), + for idx, each_png in enumerate(res_json["images"]): + if not isinstance(each_png, str): + raise ValueError(f"png was not string at index {idx}") + + each_image = image_from( + Base64CharacterEncodedByteSequence(each_png) ) timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") diff --git a/shortfin/python/shortfin_apps/types/Base64CharacterEncodedByteSequence.py b/shortfin/python/shortfin_apps/types/Base64CharacterEncodedByteSequence.py new file mode 100644 index 000000000..5d82d90b2 --- /dev/null +++ b/shortfin/python/shortfin_apps/types/Base64CharacterEncodedByteSequence.py @@ -0,0 +1,40 @@ +import re +import binascii + + +class Base64CharacterEncodedByteSequence(str): + """ + A sequence of 8-bit integers encoded using the [Base64 alphabet](https://www.rfc-editor.org/rfc/rfc4648.html#section-4). + """ + + pattern = ( + r"^(?:[A-Za-z0-9+\/]{4})*(?:[A-Za-z0-9+\/]{3}={1}|[A-Za-z0-9+\/]{2}={2})?$" + ) + + def __new__(Self, given_subject: str): + if re.match(Self.pattern, given_subject) == None: + raise ValueError("String cannot be interpreted as a byte sequence") + + return super().__new__(Self, given_subject) + + @property + def as_bytes(self): + base64_integer_encoded_byte_sequence = self.encode() + + derived_raw_byte_sequence = binascii.a2b_base64( + base64_integer_encoded_byte_sequence + ) + + return derived_raw_byte_sequence + + @classmethod + def decoded_from(Self, given_raw_byte_sequence: bytes): + base64_integer_encoded_byte_sequence = binascii.b2a_base64( + given_raw_byte_sequence, newline=False + ) + + base64_character_encoded_byte_sequence = ( + base64_integer_encoded_byte_sequence.decode() + ) + + return Self(base64_character_encoded_byte_sequence) diff --git a/shortfin/python/shortfin_apps/utilities/image.py b/shortfin/python/shortfin_apps/utilities/image.py index 316ecab7d..4a2374764 100644 --- a/shortfin/python/shortfin_apps/utilities/image.py +++ b/shortfin/python/shortfin_apps/utilities/image.py @@ -1,7 +1,15 @@ import os +from io import ( + BytesIO, +) + from PIL import Image +from shortfin_apps.types.Base64CharacterEncodedByteSequence import ( + Base64CharacterEncodedByteSequence, +) + def save_to_file( given_image: Image.Image, @@ -13,3 +21,15 @@ def save_to_file( derived_file_path = os.path.join(given_directory, given_file_name) given_image.save(derived_file_path) return derived_file_path + + +def png_from(given_image: Image.Image) -> Base64CharacterEncodedByteSequence: + memory_for_png = BytesIO() + given_image.save(memory_for_png, format="PNG") + png_from_memory = memory_for_png.getvalue() + return Base64CharacterEncodedByteSequence.decoded_from(png_from_memory) + + +def image_from(given_png: Base64CharacterEncodedByteSequence) -> Image.Image: + memory_for_png = BytesIO(given_png.as_bytes) + return Image.open(memory_for_png, formats=["PNG"])