1919from rcl_interfaces .msg import Log
2020from sensor_msgs .msg import CompressedImage , Image
2121
22- from rai .agents . postprocessors import BaseStatePostprocessor
22+ from rai .aggregators import BaseAggregator
2323from rai .communication .ros2 .api .conversion import encode_ros2_img_to_base64
2424from rai .initialization .model_initialization import get_llm_model
2525from rai .messages import HumanMultimodalMessage
2626
2727
28- class ROS2LogsPostprocessor (BaseStatePostprocessor [Log ]):
28+ class ROS2LogsPostprocessor (BaseAggregator [Log ]):
2929 """Returns only unique messages while keeping their order"""
3030
3131 levels = {10 : "DEBUG" , 20 : "INFO" , 30 : "WARNING" , 40 : "ERROR" , 50 : "FATAL" }
3232
33- def __call__ (self , msgs : Sequence [Log ]) -> HumanMessage :
33+ def get (self ) -> HumanMessage :
34+ msgs = self .get_buffer ()
3435 buffer = []
3536 prev_parsed = None
3637 counter = 0
@@ -50,12 +51,11 @@ def __call__(self, msgs: Sequence[Log]) -> HumanMessage:
5051 return HumanMessage (content = result )
5152
5253
53- class ROS2GetLastImagePostprocessor (BaseStatePostprocessor [Image | CompressedImage ]):
54+ class ROS2GetLastImagePostprocessor (BaseAggregator [Image | CompressedImage ]):
5455 """Returns the last image from the buffer as base64 encoded string"""
5556
56- def __call__ (
57- self , msgs : Sequence [Image | CompressedImage ]
58- ) -> HumanMultimodalMessage | None :
57+ def get (self ) -> HumanMultimodalMessage | None :
58+ msgs = self .get_buffer ()
5959 if len (msgs ) == 0 :
6060 return None
6161 ros2_img = msgs [- 1 ]
@@ -64,21 +64,23 @@ def __call__(
6464
6565
6666class ROS2ImgVLMDescriptionPostprocessor (
67- BaseStatePostprocessor [Image | CompressedImage ]
67+ BaseAggregator [Image | CompressedImage ]
6868):
6969 """
7070 Returns the VLM analysis of the last image in the aggregation buffer
7171 """
7272
73- def __init__ (self ) -> None :
74- super ().__init__ ()
73+ def __init__ (self , max_size : int | None = None ) -> None :
74+ super ().__init__ (max_size )
7575 self .llm = get_llm_model (model_type = "simple_model" , streaming = True )
7676
77- def __call__ (self , msgs : Sequence [Image | CompressedImage ]) -> HumanMessage | None :
77+ def get (self ) -> HumanMessage | None :
78+ msgs : List [Image | CompressedImage ] = self .get_buffer ()
7879 if len (msgs ) == 0 :
7980 return None
8081
8182 b64_images : List [str ] = [encode_ros2_img_to_base64 (msg ) for msg in msgs ]
83+ self .clear ()
8284
8385 system_prompt = "You are an expert in image analysis and your speciality is the"
8486 "description of images"
@@ -102,18 +104,18 @@ class ROS2ImgDescription(BaseModel):
102104 )
103105
104106
105- class ROS2ImgVLMDiffPostprocessor (BaseStatePostprocessor [Image | CompressedImage ]):
107+ class ROS2ImgVLMDiffPostprocessor (BaseAggregator [Image | CompressedImage ]):
106108 """
107109 Returns the LLM analysis of the differences between 3 images in the
108110 aggregation buffer: 1st, midden, last
109111 """
110112
111- def __init__ (self ) -> None :
112- super ().__init__ ()
113+ def __init__ (self , max_size : int | None = None ) -> None :
114+ super ().__init__ (max_size )
113115 self .llm = get_llm_model (model_type = "simple_model" , streaming = True )
114116
115117 @staticmethod
116- def get_key_elements (elements : Sequence [Any ]) -> List [Any ]:
118+ def get_key_elements (elements : List [Any ]) -> List [Any ]:
117119 """
118120 Returns 1st, last and middle elements of the list
119121 """
@@ -122,11 +124,14 @@ def get_key_elements(elements: Sequence[Any]) -> List[Any]:
122124 middle_index = len (elements ) // 2
123125 return [elements [0 ], elements [middle_index ], elements [- 1 ]]
124126
125- def __call__ (self , msgs : Sequence [ Any ] ) -> HumanMessage | None :
126- if len (msgs ) == 0 :
127+ def get (self ) -> HumanMessage | None :
128+ if len (self . get_buffer () ) == 0 :
127129 return None
128130
129- b64_images = [encode_ros2_img_to_base64 (msg ) for msg in msgs ]
131+ b64_images = [encode_ros2_img_to_base64 (msg ) for msg in self ._buffer ]
132+
133+ self .clear ()
134+
130135 b64_images = self .get_key_elements (b64_images )
131136
132137 system_prompt = "You are an expert in image analysis and your speciality is the comparison of 2 images"
0 commit comments