2424
2525logger = logging .getLogger (__name__ )
2626
27+ prog_isolations = {
28+ "none" : sf .ProgramIsolation .NONE ,
29+ "per_fiber" : sf .ProgramIsolation .PER_FIBER ,
30+ "per_call" : sf .ProgramIsolation .PER_CALL ,
31+ }
32+
2733
2834class GenerateService :
2935 """Top level service interface for image generation."""
@@ -39,6 +45,9 @@ def __init__(
3945 sysman : SystemManager ,
4046 tokenizers : list [Tokenizer ],
4147 model_params : ModelParams ,
48+ fibers_per_device : int ,
49+ prog_isolation : str = "per_fiber" ,
50+ show_progress : bool = False ,
4251 ):
4352 self .name = name
4453
@@ -50,17 +59,20 @@ def __init__(
5059 self .inference_modules : dict [str , sf .ProgramModule ] = {}
5160 self .inference_functions : dict [str , dict [str , sf .ProgramFunction ]] = {}
5261 self .inference_programs : dict [str , sf .Program ] = {}
53- self .procs_per_device = 1
62+ self .trace_execution = False
63+ self .show_progress = show_progress
64+ self .fibers_per_device = fibers_per_device
65+ self .prog_isolation = prog_isolations [prog_isolation ]
5466 self .workers = []
5567 self .fibers = []
56- self .locks = []
68+ self .fiber_status = []
5769 for idx , device in enumerate (self .sysman .ls .devices ):
58- for i in range (self .procs_per_device ):
70+ for i in range (self .fibers_per_device ):
5971 worker = sysman .ls .create_worker (f"{ name } -inference-{ device .name } -{ i } " )
6072 fiber = sysman .ls .create_fiber (worker , devices = [device ])
6173 self .workers .append (worker )
6274 self .fibers .append (fiber )
63- self .locks .append (asyncio . Lock () )
75+ self .fiber_status .append (0 )
6476
6577 # Scope dependent objects.
6678 self .batcher = BatcherProcess (self )
@@ -99,7 +111,8 @@ def start(self):
99111 self .inference_programs [component ] = sf .Program (
100112 modules = component_modules ,
101113 devices = fiber .raw_devices ,
102- trace_execution = False ,
114+ isolation = self .prog_isolation ,
115+ trace_execution = self .trace_execution ,
103116 )
104117
105118 # TODO: export vmfbs with multiple batch size entrypoints
@@ -169,6 +182,7 @@ def __init__(self, service: GenerateService):
169182 self .strobe_enabled = True
170183 self .strobes : int = 0
171184 self .ideal_batch_size : int = max (service .model_params .max_batch_size )
185+ self .num_fibers = len (service .fibers )
172186
173187 def shutdown (self ):
174188 self .batcher_infeed .close ()
@@ -199,6 +213,7 @@ async def run(self):
199213 logger .error ("Illegal message received by batcher: %r" , item )
200214
201215 self .board_flights ()
216+
202217 self .strobe_enabled = True
203218 await strober_task
204219
@@ -210,28 +225,40 @@ def board_flights(self):
210225 logger .info ("Waiting a bit longer to fill flight" )
211226 return
212227 self .strobes = 0
228+ batches = self .sort_batches ()
229+ for idx , batch in batches .items ():
230+ for fidx , status in enumerate (self .service .fiber_status ):
231+ if (
232+ status == 0
233+ or self .service .prog_isolation == sf .ProgramIsolation .PER_CALL
234+ ):
235+ self .board (batch ["reqs" ], index = fidx )
236+ break
213237
214- batches = self .sort_pending ()
215- for idx in batches .keys ():
216- self .board (batches [idx ]["reqs" ], index = idx )
217-
218- def sort_pending (self ):
219- """Returns pending requests as sorted batches suitable for program invocations."""
238+ def sort_batches (self ):
239+ """Files pending requests into sorted batches suitable for program invocations."""
240+ reqs = self .pending_requests
241+ next_key = 0
220242 batches = {}
221- for req in self . pending_requests :
243+ for req in reqs :
222244 is_sorted = False
223245 req_metas = [req .phases [phase ]["metadata" ] for phase in req .phases .keys ()]
224- next_key = 0
246+
225247 for idx_key , data in batches .items ():
226248 if not isinstance (data , dict ):
227249 logger .error (
228250 "Expected to find a dictionary containing a list of requests and their shared metadatas."
229251 )
230- if data ["meta" ] == req_metas :
231- batches [idx_key ]["reqs" ].append (req )
252+ if len (batches [idx_key ]["reqs" ]) >= self .ideal_batch_size :
253+ # Batch is full
254+ next_key = idx_key + 1
255+ continue
256+ elif data ["meta" ] == req_metas :
257+ batches [idx_key ]["reqs" ].extend ([req ])
232258 is_sorted = True
233259 break
234- next_key = idx_key + 1
260+ else :
261+ next_key = idx_key + 1
235262 if not is_sorted :
236263 batches [next_key ] = {
237264 "reqs" : [req ],
@@ -251,7 +278,8 @@ def board(self, request_bundle, index):
251278 if exec_process .exec_requests :
252279 for flighted_request in exec_process .exec_requests :
253280 self .pending_requests .remove (flighted_request )
254- print (f"launching exec process for { exec_process .exec_requests } " )
281+ if self .service .prog_isolation != sf .ProgramIsolation .PER_CALL :
282+ self .service .fiber_status [index ] = 1
255283 exec_process .launch ()
256284
257285
@@ -284,22 +312,22 @@ async def run(self):
284312 phases = self .exec_requests [0 ].phases
285313
286314 req_count = len (self .exec_requests )
287- async with self .service .locks [self .worker_index ]:
288- device0 = self .fiber .device (0 )
289- if phases [InferencePhase .PREPARE ]["required" ]:
290- await self ._prepare (device = device0 , requests = self .exec_requests )
291- if phases [InferencePhase .ENCODE ]["required" ]:
292- await self ._encode (device = device0 , requests = self .exec_requests )
293- if phases [InferencePhase .DENOISE ]["required" ]:
294- await self ._denoise (device = device0 , requests = self .exec_requests )
295- if phases [InferencePhase .DECODE ]["required" ]:
296- await self ._decode (device = device0 , requests = self .exec_requests )
297- if phases [InferencePhase .POSTPROCESS ]["required" ]:
298- await self ._postprocess (device = device0 , requests = self .exec_requests )
315+ device0 = self .service .fibers [self .worker_index ].device (0 )
316+ if phases [InferencePhase .PREPARE ]["required" ]:
317+ await self ._prepare (device = device0 , requests = self .exec_requests )
318+ if phases [InferencePhase .ENCODE ]["required" ]:
319+ await self ._encode (device = device0 , requests = self .exec_requests )
320+ if phases [InferencePhase .DENOISE ]["required" ]:
321+ await self ._denoise (device = device0 , requests = self .exec_requests )
322+ if phases [InferencePhase .DECODE ]["required" ]:
323+ await self ._decode (device = device0 , requests = self .exec_requests )
324+ if phases [InferencePhase .POSTPROCESS ]["required" ]:
325+ await self ._postprocess (device = device0 , requests = self .exec_requests )
299326
300327 for i in range (req_count ):
301328 req = self .exec_requests [i ]
302329 req .done .set_success ()
330+ self .service .fiber_status [self .worker_index ] = 0
303331
304332 except Exception :
305333 logger .exception ("Fatal error in image generation" )
@@ -345,7 +373,6 @@ async def _prepare(self, device, requests):
345373 sfnp .fill_randn (sample_host , generator = generator )
346374
347375 request .sample .copy_from (sample_host )
348- await device
349376 return
350377
351378 async def _encode (self , device , requests ):
@@ -385,15 +412,13 @@ async def _encode(self, device, requests):
385412 clip_inputs [idx ].copy_from (host_arrs [idx ])
386413
387414 # Encode tokenized inputs.
388- logger .info (
415+ logger .debug (
389416 "INVOKE %r: %s" ,
390417 fn ,
391418 "" .join ([f"\n { i } : { ary .shape } " for i , ary in enumerate (clip_inputs )]),
392419 )
393- await device
394420 pe , te = await fn (* clip_inputs , fiber = self .fiber )
395421
396- await device
397422 for i in range (req_bs ):
398423 cfg_mult = 2
399424 requests [i ].prompt_embeds = pe .view (slice (i * cfg_mult , (i + 1 ) * cfg_mult ))
@@ -477,35 +502,34 @@ async def _denoise(self, device, requests):
477502 ns_host .items = [step_count ]
478503 num_steps .copy_from (ns_host )
479504
480- await device
505+ init_inputs = [
506+ denoise_inputs ["sample" ],
507+ num_steps ,
508+ ]
509+
481510 # Initialize scheduler.
482- logger .info (
483- "INVOKE %r: %s " ,
511+ logger .debug (
512+ "INVOKE %r" ,
484513 fns ["init" ],
485- "" .join ([f"\n 0: { latents_shape } " ]),
486514 )
487515 (latents , time_ids , timesteps , sigmas ) = await fns ["init" ](
488- denoise_inputs [ "sample" ], num_steps , fiber = self .fiber
516+ * init_inputs , fiber = self .fiber
489517 )
490-
491- await device
492518 for i , t in tqdm (
493519 enumerate (range (step_count )),
520+ disable = (not self .service .show_progress ),
521+ desc = f"Worker #{ self .worker_index } DENOISE (bs{ req_bs } )" ,
494522 ):
495523 step = sfnp .device_array .for_device (device , [1 ], sfnp .sint64 )
496524 s_host = step .for_transfer ()
497525 with s_host .map (write = True ) as m :
498526 s_host .items = [i ]
499527 step .copy_from (s_host )
500528 scale_inputs = [latents , step , timesteps , sigmas ]
501- logger .info (
502- "INVOKE %r: %s " ,
529+ logger .debug (
530+ "INVOKE %r" ,
503531 fns ["scale" ],
504- "" .join (
505- [f"\n { i } : { ary .shape } " for i , ary in enumerate (scale_inputs )]
506- ),
507532 )
508- await device
509533 latent_model_input , t , sigma , next_sigma = await fns ["scale" ](
510534 * scale_inputs , fiber = self .fiber
511535 )
@@ -519,32 +543,25 @@ async def _denoise(self, device, requests):
519543 time_ids ,
520544 denoise_inputs ["guidance_scale" ],
521545 ]
522- logger .info (
523- "INVOKE %r: %s " ,
546+ logger .debug (
547+ "INVOKE %r" ,
524548 fns ["unet" ],
525- "" .join ([f"\n { i } : { ary .shape } " for i , ary in enumerate (unet_inputs )]),
526549 )
527- await device
528550 (noise_pred ,) = await fns ["unet" ](* unet_inputs , fiber = self .fiber )
529- await device
530551
531552 step_inputs = [noise_pred , latents , sigma , next_sigma ]
532- logger .info (
533- "INVOKE %r: %s " ,
553+ logger .debug (
554+ "INVOKE %r" ,
534555 fns ["step" ],
535- "" .join ([f"\n { i } : { ary .shape } " for i , ary in enumerate (step_inputs )]),
536556 )
537- await device
538557 (latent_model_output ,) = await fns ["step" ](* step_inputs , fiber = self .fiber )
539558 latents .copy_from (latent_model_output )
540- await device
541559
542560 for idx , req in enumerate (requests ):
543561 req .denoised_latents = sfnp .device_array .for_device (
544562 device , latents_shape , self .service .model_params .vae_dtype
545563 )
546564 req .denoised_latents .copy_from (latents .view (idx ))
547- await device
548565 return
549566
550567 async def _decode (self , device , requests ):
@@ -569,6 +586,11 @@ async def _decode(self, device, requests):
569586
570587 await device
571588 # Decode the denoised latents.
589+ logger .debug (
590+ "INVOKE %r: %s" ,
591+ fn ,
592+ "" .join ([f"\n 0: { latents .shape } " ]),
593+ )
572594 (image ,) = await fn (latents , fiber = self .fiber )
573595
574596 await device
0 commit comments