Skip to content

Commit c1176b6

Browse files
authored
(shortfin-sd) Adds program isolation optionality and fibers_per_device. (#360)
1 parent 9209a36 commit c1176b6

File tree

4 files changed

+232
-93
lines changed

4 files changed

+232
-93
lines changed

.github/workflows/ci-sdxl.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,4 @@ jobs:
9999
working-directory: ${{ env.LIBSHORTFIN_DIR }}
100100
run: |
101101
ctest --timeout 30 --output-on-failure --test-dir build
102-
pytest tests/apps/sd/e2e_test.py -v -s --system=amdgpu
102+
HIP_VISIBLE_DEVICES=0 pytest tests/apps/sd/e2e_test.py -v -s --system=amdgpu

shortfin/python/shortfin_apps/sd/components/service.py

Lines changed: 79 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@
2424

2525
logger = logging.getLogger(__name__)
2626

27+
prog_isolations = {
28+
"none": sf.ProgramIsolation.NONE,
29+
"per_fiber": sf.ProgramIsolation.PER_FIBER,
30+
"per_call": sf.ProgramIsolation.PER_CALL,
31+
}
32+
2733

2834
class GenerateService:
2935
"""Top level service interface for image generation."""
@@ -39,6 +45,9 @@ def __init__(
3945
sysman: SystemManager,
4046
tokenizers: list[Tokenizer],
4147
model_params: ModelParams,
48+
fibers_per_device: int,
49+
prog_isolation: str = "per_fiber",
50+
show_progress: bool = False,
4251
):
4352
self.name = name
4453

@@ -50,17 +59,20 @@ def __init__(
5059
self.inference_modules: dict[str, sf.ProgramModule] = {}
5160
self.inference_functions: dict[str, dict[str, sf.ProgramFunction]] = {}
5261
self.inference_programs: dict[str, sf.Program] = {}
53-
self.procs_per_device = 1
62+
self.trace_execution = False
63+
self.show_progress = show_progress
64+
self.fibers_per_device = fibers_per_device
65+
self.prog_isolation = prog_isolations[prog_isolation]
5466
self.workers = []
5567
self.fibers = []
56-
self.locks = []
68+
self.fiber_status = []
5769
for idx, device in enumerate(self.sysman.ls.devices):
58-
for i in range(self.procs_per_device):
70+
for i in range(self.fibers_per_device):
5971
worker = sysman.ls.create_worker(f"{name}-inference-{device.name}-{i}")
6072
fiber = sysman.ls.create_fiber(worker, devices=[device])
6173
self.workers.append(worker)
6274
self.fibers.append(fiber)
63-
self.locks.append(asyncio.Lock())
75+
self.fiber_status.append(0)
6476

6577
# Scope dependent objects.
6678
self.batcher = BatcherProcess(self)
@@ -99,7 +111,8 @@ def start(self):
99111
self.inference_programs[component] = sf.Program(
100112
modules=component_modules,
101113
devices=fiber.raw_devices,
102-
trace_execution=False,
114+
isolation=self.prog_isolation,
115+
trace_execution=self.trace_execution,
103116
)
104117

105118
# TODO: export vmfbs with multiple batch size entrypoints
@@ -169,6 +182,7 @@ def __init__(self, service: GenerateService):
169182
self.strobe_enabled = True
170183
self.strobes: int = 0
171184
self.ideal_batch_size: int = max(service.model_params.max_batch_size)
185+
self.num_fibers = len(service.fibers)
172186

173187
def shutdown(self):
174188
self.batcher_infeed.close()
@@ -199,6 +213,7 @@ async def run(self):
199213
logger.error("Illegal message received by batcher: %r", item)
200214

201215
self.board_flights()
216+
202217
self.strobe_enabled = True
203218
await strober_task
204219

@@ -210,28 +225,40 @@ def board_flights(self):
210225
logger.info("Waiting a bit longer to fill flight")
211226
return
212227
self.strobes = 0
228+
batches = self.sort_batches()
229+
for idx, batch in batches.items():
230+
for fidx, status in enumerate(self.service.fiber_status):
231+
if (
232+
status == 0
233+
or self.service.prog_isolation == sf.ProgramIsolation.PER_CALL
234+
):
235+
self.board(batch["reqs"], index=fidx)
236+
break
213237

214-
batches = self.sort_pending()
215-
for idx in batches.keys():
216-
self.board(batches[idx]["reqs"], index=idx)
217-
218-
def sort_pending(self):
219-
"""Returns pending requests as sorted batches suitable for program invocations."""
238+
def sort_batches(self):
239+
"""Files pending requests into sorted batches suitable for program invocations."""
240+
reqs = self.pending_requests
241+
next_key = 0
220242
batches = {}
221-
for req in self.pending_requests:
243+
for req in reqs:
222244
is_sorted = False
223245
req_metas = [req.phases[phase]["metadata"] for phase in req.phases.keys()]
224-
next_key = 0
246+
225247
for idx_key, data in batches.items():
226248
if not isinstance(data, dict):
227249
logger.error(
228250
"Expected to find a dictionary containing a list of requests and their shared metadatas."
229251
)
230-
if data["meta"] == req_metas:
231-
batches[idx_key]["reqs"].append(req)
252+
if len(batches[idx_key]["reqs"]) >= self.ideal_batch_size:
253+
# Batch is full
254+
next_key = idx_key + 1
255+
continue
256+
elif data["meta"] == req_metas:
257+
batches[idx_key]["reqs"].extend([req])
232258
is_sorted = True
233259
break
234-
next_key = idx_key + 1
260+
else:
261+
next_key = idx_key + 1
235262
if not is_sorted:
236263
batches[next_key] = {
237264
"reqs": [req],
@@ -251,7 +278,8 @@ def board(self, request_bundle, index):
251278
if exec_process.exec_requests:
252279
for flighted_request in exec_process.exec_requests:
253280
self.pending_requests.remove(flighted_request)
254-
print(f"launching exec process for {exec_process.exec_requests}")
281+
if self.service.prog_isolation != sf.ProgramIsolation.PER_CALL:
282+
self.service.fiber_status[index] = 1
255283
exec_process.launch()
256284

257285

@@ -284,22 +312,22 @@ async def run(self):
284312
phases = self.exec_requests[0].phases
285313

286314
req_count = len(self.exec_requests)
287-
async with self.service.locks[self.worker_index]:
288-
device0 = self.fiber.device(0)
289-
if phases[InferencePhase.PREPARE]["required"]:
290-
await self._prepare(device=device0, requests=self.exec_requests)
291-
if phases[InferencePhase.ENCODE]["required"]:
292-
await self._encode(device=device0, requests=self.exec_requests)
293-
if phases[InferencePhase.DENOISE]["required"]:
294-
await self._denoise(device=device0, requests=self.exec_requests)
295-
if phases[InferencePhase.DECODE]["required"]:
296-
await self._decode(device=device0, requests=self.exec_requests)
297-
if phases[InferencePhase.POSTPROCESS]["required"]:
298-
await self._postprocess(device=device0, requests=self.exec_requests)
315+
device0 = self.service.fibers[self.worker_index].device(0)
316+
if phases[InferencePhase.PREPARE]["required"]:
317+
await self._prepare(device=device0, requests=self.exec_requests)
318+
if phases[InferencePhase.ENCODE]["required"]:
319+
await self._encode(device=device0, requests=self.exec_requests)
320+
if phases[InferencePhase.DENOISE]["required"]:
321+
await self._denoise(device=device0, requests=self.exec_requests)
322+
if phases[InferencePhase.DECODE]["required"]:
323+
await self._decode(device=device0, requests=self.exec_requests)
324+
if phases[InferencePhase.POSTPROCESS]["required"]:
325+
await self._postprocess(device=device0, requests=self.exec_requests)
299326

300327
for i in range(req_count):
301328
req = self.exec_requests[i]
302329
req.done.set_success()
330+
self.service.fiber_status[self.worker_index] = 0
303331

304332
except Exception:
305333
logger.exception("Fatal error in image generation")
@@ -345,7 +373,6 @@ async def _prepare(self, device, requests):
345373
sfnp.fill_randn(sample_host, generator=generator)
346374

347375
request.sample.copy_from(sample_host)
348-
await device
349376
return
350377

351378
async def _encode(self, device, requests):
@@ -385,15 +412,13 @@ async def _encode(self, device, requests):
385412
clip_inputs[idx].copy_from(host_arrs[idx])
386413

387414
# Encode tokenized inputs.
388-
logger.info(
415+
logger.debug(
389416
"INVOKE %r: %s",
390417
fn,
391418
"".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(clip_inputs)]),
392419
)
393-
await device
394420
pe, te = await fn(*clip_inputs, fiber=self.fiber)
395421

396-
await device
397422
for i in range(req_bs):
398423
cfg_mult = 2
399424
requests[i].prompt_embeds = pe.view(slice(i * cfg_mult, (i + 1) * cfg_mult))
@@ -477,35 +502,34 @@ async def _denoise(self, device, requests):
477502
ns_host.items = [step_count]
478503
num_steps.copy_from(ns_host)
479504

480-
await device
505+
init_inputs = [
506+
denoise_inputs["sample"],
507+
num_steps,
508+
]
509+
481510
# Initialize scheduler.
482-
logger.info(
483-
"INVOKE %r: %s",
511+
logger.debug(
512+
"INVOKE %r",
484513
fns["init"],
485-
"".join([f"\n 0: {latents_shape}"]),
486514
)
487515
(latents, time_ids, timesteps, sigmas) = await fns["init"](
488-
denoise_inputs["sample"], num_steps, fiber=self.fiber
516+
*init_inputs, fiber=self.fiber
489517
)
490-
491-
await device
492518
for i, t in tqdm(
493519
enumerate(range(step_count)),
520+
disable=(not self.service.show_progress),
521+
desc=f"Worker #{self.worker_index} DENOISE (bs{req_bs})",
494522
):
495523
step = sfnp.device_array.for_device(device, [1], sfnp.sint64)
496524
s_host = step.for_transfer()
497525
with s_host.map(write=True) as m:
498526
s_host.items = [i]
499527
step.copy_from(s_host)
500528
scale_inputs = [latents, step, timesteps, sigmas]
501-
logger.info(
502-
"INVOKE %r: %s",
529+
logger.debug(
530+
"INVOKE %r",
503531
fns["scale"],
504-
"".join(
505-
[f"\n {i}: {ary.shape}" for i, ary in enumerate(scale_inputs)]
506-
),
507532
)
508-
await device
509533
latent_model_input, t, sigma, next_sigma = await fns["scale"](
510534
*scale_inputs, fiber=self.fiber
511535
)
@@ -519,32 +543,25 @@ async def _denoise(self, device, requests):
519543
time_ids,
520544
denoise_inputs["guidance_scale"],
521545
]
522-
logger.info(
523-
"INVOKE %r: %s",
546+
logger.debug(
547+
"INVOKE %r",
524548
fns["unet"],
525-
"".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(unet_inputs)]),
526549
)
527-
await device
528550
(noise_pred,) = await fns["unet"](*unet_inputs, fiber=self.fiber)
529-
await device
530551

531552
step_inputs = [noise_pred, latents, sigma, next_sigma]
532-
logger.info(
533-
"INVOKE %r: %s",
553+
logger.debug(
554+
"INVOKE %r",
534555
fns["step"],
535-
"".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(step_inputs)]),
536556
)
537-
await device
538557
(latent_model_output,) = await fns["step"](*step_inputs, fiber=self.fiber)
539558
latents.copy_from(latent_model_output)
540-
await device
541559

542560
for idx, req in enumerate(requests):
543561
req.denoised_latents = sfnp.device_array.for_device(
544562
device, latents_shape, self.service.model_params.vae_dtype
545563
)
546564
req.denoised_latents.copy_from(latents.view(idx))
547-
await device
548565
return
549566

550567
async def _decode(self, device, requests):
@@ -569,6 +586,11 @@ async def _decode(self, device, requests):
569586

570587
await device
571588
# Decode the denoised latents.
589+
logger.debug(
590+
"INVOKE %r: %s",
591+
fn,
592+
"".join([f"\n 0: {latents.shape}"]),
593+
)
572594
(image,) = await fn(latents, fiber=self.fiber)
573595

574596
await device

shortfin/python/shortfin_apps/sd/server.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131
from .components.tokenizer import Tokenizer
3232

3333

34-
logger = logging.getLogger(__name__)
34+
from shortfin.support.logging_setup import configure_main_logger
35+
36+
logger = configure_main_logger("server")
3537

3638

3739
@asynccontextmanager
@@ -87,7 +89,13 @@ def configure(args) -> SystemManager:
8789

8890
model_params = ModelParams.load_json(args.model_config)
8991
sm = GenerateService(
90-
name="sd", sysman=sysman, tokenizers=tokenizers, model_params=model_params
92+
name="sd",
93+
sysman=sysman,
94+
tokenizers=tokenizers,
95+
model_params=model_params,
96+
fibers_per_device=args.fibers_per_device,
97+
prog_isolation=args.isolation,
98+
show_progress=args.show_progress,
9199
)
92100
sm.load_inference_module(args.clip_vmfb, component="clip")
93101
sm.load_inference_module(args.unet_vmfb, component="unet")
@@ -188,10 +196,40 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
188196
nargs="*",
189197
help="Parameter archives to load",
190198
)
199+
parser.add_argument(
200+
"--fibers_per_device",
201+
type=int,
202+
default=1,
203+
help="Concurrency control -- how many fibers are created per device to run inference.",
204+
)
205+
parser.add_argument(
206+
"--isolation",
207+
type=str,
208+
default="per_fiber",
209+
choices=["per_fiber", "per_call", "none"],
210+
help="Concurrency control -- How to isolate programs.",
211+
)
212+
parser.add_argument(
213+
"--log_level", type=str, default="error", choices=["info", "debug", "error"]
214+
)
215+
parser.add_argument(
216+
"--show_progress",
217+
action="store_true",
218+
help="enable tqdm progress for unet iterations.",
219+
)
220+
log_levels = {
221+
"info": logging.INFO,
222+
"debug": logging.DEBUG,
223+
"error": logging.ERROR,
224+
}
225+
191226
args = parser.parse_args(argv)
227+
228+
log_level = log_levels[args.log_level]
229+
logger.setLevel(log_level)
230+
192231
global sysman
193232
sysman = configure(args)
194-
195233
uvicorn.run(
196234
app,
197235
host=args.host,
@@ -202,9 +240,6 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
202240

203241

204242
if __name__ == "__main__":
205-
from shortfin.support.logging_setup import configure_main_logger
206-
207-
logger = configure_main_logger("server")
208243
main(
209244
sys.argv[1:],
210245
# Make logging defer to the default shortfin logging config.

0 commit comments

Comments
 (0)