Skip to content

Commit

Permalink
feature(shortfin_apps): moves rgb to png conversion to server-side
Browse files Browse the repository at this point in the history
  • Loading branch information
bjacobgordon committed Mar 3, 2025
1 parent c6c8439 commit f398448
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 27 deletions.
62 changes: 59 additions & 3 deletions shortfin/python/shortfin_apps/sd/components/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,23 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import asyncio
import base64
import logging
import json

from typing import (
TypeVar,
Union,
)

from PIL import Image

from shortfin_apps.types.Base64CharacterEncodedByteSequence import (
Base64CharacterEncodedByteSequence,
)

from shortfin_apps.utilities.image import png_from

import shortfin as sf

# TODO: Have a generic "Responder" interface vs just the concrete impl.
Expand Down Expand Up @@ -54,6 +64,26 @@ async def run(self):
self.result_image = exec.result_image


Item = TypeVar("Item")


def from_batch(
given_subject: list[Item] | Item | None,
given_batch_index,
) -> Item:
if given_subject is None:
raise Exception("Expected an item or batch of items but got `None`")

if not isinstance(given_subject, list):
return given_subject

# some args are broadcasted to each prompt, hence overriding index for single-item entries
if len(given_subject) == 1:
return given_subject[0]

return given_subject[given_batch_index]


class ClientGenerateBatchProcess(sf.Process):
"""Process instantiated for handling a batch from a client.
Expand Down Expand Up @@ -99,8 +129,34 @@ async def run(self):

# TODO: stream image outputs
logging.debug("Responding to one shot batch")
response_data = {"images": [p.result_image for p in gen_processes]}
json_str = json.dumps(response_data)
self.responder.send_response(json_str)

png_images: list[Base64CharacterEncodedByteSequence] = []

for index_of_each_process, each_process in enumerate(gen_processes):
if each_process.result_image is None:
raise Exception(
f"Expected image result for batch {index_of_each_process} but got `None`"
)

size_of_each_image = (
from_batch(self.gen_req.width, index_of_each_process),
from_batch(self.gen_req.height, index_of_each_process),
)

rgb_sequence_of_each_image = Base64CharacterEncodedByteSequence(
each_process.result_image
)

each_image = Image.frombytes(
mode="RGB",
size=size_of_each_image,
data=rgb_sequence_of_each_image.as_bytes,
)

png_images.append(png_from(each_image))

response_body = {"images": png_images}
response_body_in_json = json.dumps(response_body)
self.responder.send_response(response_body_in_json)
finally:
self.responder.ensure_response()
34 changes: 10 additions & 24 deletions shortfin/python/shortfin_apps/sd/simple_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from datetime import datetime as dt
import os
import time
import json
import argparse
import base64
import asyncio
import aiohttp
import requests

from PIL import Image
from shortfin_apps.types.Base64CharacterEncodedByteSequence import (
Base64CharacterEncodedByteSequence,
)

from shortfin_apps.utilities.image import (
save_to_file,
image_from,
)

sample_request = {
Expand All @@ -35,18 +36,6 @@
}


def get_batched(request, arg, idx):
if isinstance(request[arg], list):
# some args are broadcasted to each prompt, hence overriding idx for single-item entries
if len(request[arg]) == 1:
indexed = request[arg][0]
else:
indexed = request[arg][idx]
else:
indexed = request[arg]
return indexed


async def send_request(session: aiohttp.ClientSession, rep, args, data):
print("Sending request batch #", rep)
url = f"{args.host}:{args.port}/generate"
Expand All @@ -58,15 +47,12 @@ async def send_request(session: aiohttp.ClientSession, rep, args, data):
response.raise_for_status() # Raise an error for bad responses
res_json = await response.json(content_type=None)
if args.save:
for idx, item in enumerate(res_json["images"]):
width = get_batched(data, "width", idx)
height = get_batched(data, "height", idx)
print("Saving response as image...")

each_image = Image.frombytes(
mode="RGB",
size=(width, height),
data=base64.b64decode(item.encode("utf-8")),
for idx, each_png in enumerate(res_json["images"]):
if not isinstance(each_png, str):
raise ValueError(f"png was not string at index {idx}")

each_image = image_from(
Base64CharacterEncodedByteSequence(each_png)
)

timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import re
import binascii


class Base64CharacterEncodedByteSequence(str):
"""
A sequence of 8-bit integers encoded using the [Base64 alphabet](https://www.rfc-editor.org/rfc/rfc4648.html#section-4).
"""

pattern = (
r"^(?:[A-Za-z0-9+\/]{4})*(?:[A-Za-z0-9+\/]{3}={1}|[A-Za-z0-9+\/]{2}={2})?$"
)

def __new__(Self, given_subject: str):
if re.match(Self.pattern, given_subject) == None:
raise ValueError("String cannot be interpreted as a byte sequence")

return super().__new__(Self, given_subject)

@property
def as_bytes(self):
base64_integer_encoded_byte_sequence = self.encode()

derived_raw_byte_sequence = binascii.a2b_base64(
base64_integer_encoded_byte_sequence
)

return derived_raw_byte_sequence

@classmethod
def decoded_from(Self, given_raw_byte_sequence: bytes):
base64_integer_encoded_byte_sequence = binascii.b2a_base64(
given_raw_byte_sequence, newline=False
)

base64_character_encoded_byte_sequence = (
base64_integer_encoded_byte_sequence.decode()
)

return Self(base64_character_encoded_byte_sequence)
20 changes: 20 additions & 0 deletions shortfin/python/shortfin_apps/utilities/image.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import os

from io import (
BytesIO,
)

from PIL import Image

from shortfin_apps.types.Base64CharacterEncodedByteSequence import (
Base64CharacterEncodedByteSequence,
)


def save_to_file(
given_image: Image.Image,
Expand All @@ -13,3 +21,15 @@ def save_to_file(
derived_file_path = os.path.join(given_directory, given_file_name)
given_image.save(derived_file_path)
return derived_file_path


def png_from(given_image: Image.Image) -> Base64CharacterEncodedByteSequence:
memory_for_png = BytesIO()
given_image.save(memory_for_png, format="PNG")
png_from_memory = memory_for_png.getvalue()
return Base64CharacterEncodedByteSequence.decoded_from(png_from_memory)


def image_from(given_png: Base64CharacterEncodedByteSequence) -> Image.Image:
memory_for_png = BytesIO(given_png.as_bytes)
return Image.open(memory_for_png, formats=["PNG"])

0 comments on commit f398448

Please sign in to comment.