44# This source code is licensed under the terms described in the LICENSE file in
55# the root directory of this source tree.
66
7- from typing import Optional
7+ from typing import Any , Iterator , Optional , Tuple
88
99from termcolor import cprint
1010
1111from llama_stack_client .types import InterleavedContent , ToolResponseMessage
1212
1313
1414def interleaved_content_as_str (content : InterleavedContent , sep : str = " " ) -> str :
15- def _process (c ) -> str :
15+ def _process (c : Any ) -> str :
1616 if isinstance (c , str ):
1717 return c
1818 elif hasattr (c , "type" ):
@@ -36,36 +36,38 @@ def __init__(
3636 self ,
3737 role : Optional [str ] = None ,
3838 content : str = "" ,
39- end : str = "\n " ,
40- color = "white" ,
41- ):
39+ end : Optional [ str ] = "\n " ,
40+ color : str = "white" ,
41+ ) -> None :
4242 self .role = role
4343 self .content = content
4444 self .color = color
4545 self .end = "\n " if end is None else end
4646
47- def __str__ (self ):
47+ def __str__ (self ) -> str :
4848 if self .role is not None :
4949 return f"{ self .role } > { self .content } "
5050 else :
5151 return f"{ self .content } "
5252
53- def print (self , flush = True ):
53+ def print (self , flush : bool = True ) -> None :
5454 cprint (f"{ str (self )} " , color = self .color , end = self .end , flush = flush )
5555
5656
5757class TurnStreamEventPrinter :
58- def __init__ (self ):
59- self .previous_event_type = None
60- self .previous_step_type = None
58+ def __init__ (self ) -> None :
59+ self .previous_event_type : Optional [ str ] = None
60+ self .previous_step_type : Optional [ str ] = None
6161
62- def yield_printable_events (self , chunk ) :
62+ def yield_printable_events (self , chunk : Any ) -> Iterator [ TurnStreamPrintableEvent ] :
6363 for printable_event in self ._yield_printable_events (chunk , self .previous_event_type , self .previous_step_type ):
6464 yield printable_event
6565
6666 self .previous_event_type , self .previous_step_type = self ._get_event_type_step_type (chunk )
6767
68- def _yield_printable_events (self , chunk , previous_event_type = None , previous_step_type = None ):
68+ def _yield_printable_events (
69+ self , chunk : Any , previous_event_type : Optional [str ] = None , previous_step_type : Optional [str ] = None
70+ ) -> Iterator [TurnStreamPrintableEvent ]:
6971 if hasattr (chunk , "error" ):
7072 yield TurnStreamPrintableEvent (role = None , content = chunk .error ["message" ], color = "red" )
7173 return
@@ -151,7 +153,7 @@ def _yield_printable_events(self, chunk, previous_event_type=None, previous_step
151153 color = "green" ,
152154 )
153155
154- def _get_event_type_step_type (self , chunk ) :
156+ def _get_event_type_step_type (self , chunk : Any ) -> Tuple [ Optional [ str ], Optional [ str ]] :
155157 if hasattr (chunk , "event" ):
156158 previous_event_type = chunk .event .payload .event_type if hasattr (chunk , "event" ) else None
157159 previous_step_type = (
@@ -162,7 +164,7 @@ def _get_event_type_step_type(self, chunk):
162164
163165
164166class EventLogger :
165- def log (self , event_generator ) :
167+ def log (self , event_generator : Iterator [ Any ]) -> Iterator [ TurnStreamPrintableEvent ] :
166168 printer = TurnStreamEventPrinter ()
167169 for chunk in event_generator :
168170 yield from printer .yield_printable_events (chunk )
0 commit comments