19
19
from rcl_interfaces .msg import Log
20
20
from sensor_msgs .msg import CompressedImage , Image
21
21
22
- from rai .agents . postprocessors import BaseStatePostprocessor
22
+ from rai .aggregators import BaseAggregator
23
23
from rai .communication .ros2 .api .conversion import encode_ros2_img_to_base64
24
24
from rai .initialization .model_initialization import get_llm_model
25
25
from rai .messages import HumanMultimodalMessage
26
26
27
27
28
- class ROS2LogsPostprocessor (BaseStatePostprocessor [Log ]):
28
+ class ROS2LogsPostprocessor (BaseAggregator [Log ]):
29
29
"""Returns only unique messages while keeping their order"""
30
30
31
31
levels = {10 : "DEBUG" , 20 : "INFO" , 30 : "WARNING" , 40 : "ERROR" , 50 : "FATAL" }
32
32
33
- def __call__ (self , msgs : Sequence [Log ]) -> HumanMessage :
33
+ def get (self ) -> HumanMessage :
34
+ msgs = self .get_buffer ()
34
35
buffer = []
35
36
prev_parsed = None
36
37
counter = 0
@@ -50,12 +51,11 @@ def __call__(self, msgs: Sequence[Log]) -> HumanMessage:
50
51
return HumanMessage (content = result )
51
52
52
53
53
- class ROS2GetLastImagePostprocessor (BaseStatePostprocessor [Image | CompressedImage ]):
54
+ class ROS2GetLastImagePostprocessor (BaseAggregator [Image | CompressedImage ]):
54
55
"""Returns the last image from the buffer as base64 encoded string"""
55
56
56
- def __call__ (
57
- self , msgs : Sequence [Image | CompressedImage ]
58
- ) -> HumanMultimodalMessage | None :
57
+ def get (self ) -> HumanMultimodalMessage | None :
58
+ msgs = self .get_buffer ()
59
59
if len (msgs ) == 0 :
60
60
return None
61
61
ros2_img = msgs [- 1 ]
@@ -64,21 +64,23 @@ def __call__(
64
64
65
65
66
66
class ROS2ImgVLMDescriptionPostprocessor (
67
- BaseStatePostprocessor [Image | CompressedImage ]
67
+ BaseAggregator [Image | CompressedImage ]
68
68
):
69
69
"""
70
70
Returns the VLM analysis of the last image in the aggregation buffer
71
71
"""
72
72
73
- def __init__ (self ) -> None :
74
- super ().__init__ ()
73
+ def __init__ (self , max_size : int | None = None ) -> None :
74
+ super ().__init__ (max_size )
75
75
self .llm = get_llm_model (model_type = "simple_model" , streaming = True )
76
76
77
- def __call__ (self , msgs : Sequence [Image | CompressedImage ]) -> HumanMessage | None :
77
+ def get (self ) -> HumanMessage | None :
78
+ msgs : List [Image | CompressedImage ] = self .get_buffer ()
78
79
if len (msgs ) == 0 :
79
80
return None
80
81
81
82
b64_images : List [str ] = [encode_ros2_img_to_base64 (msg ) for msg in msgs ]
83
+ self .clear ()
82
84
83
85
system_prompt = "You are an expert in image analysis and your speciality is the"
84
86
"description of images"
@@ -102,18 +104,18 @@ class ROS2ImgDescription(BaseModel):
102
104
)
103
105
104
106
105
- class ROS2ImgVLMDiffPostprocessor (BaseStatePostprocessor [Image | CompressedImage ]):
107
+ class ROS2ImgVLMDiffPostprocessor (BaseAggregator [Image | CompressedImage ]):
106
108
"""
107
109
Returns the LLM analysis of the differences between 3 images in the
108
110
aggregation buffer: 1st, midden, last
109
111
"""
110
112
111
- def __init__ (self ) -> None :
112
- super ().__init__ ()
113
+ def __init__ (self , max_size : int | None = None ) -> None :
114
+ super ().__init__ (max_size )
113
115
self .llm = get_llm_model (model_type = "simple_model" , streaming = True )
114
116
115
117
@staticmethod
116
- def get_key_elements (elements : Sequence [Any ]) -> List [Any ]:
118
+ def get_key_elements (elements : List [Any ]) -> List [Any ]:
117
119
"""
118
120
Returns 1st, last and middle elements of the list
119
121
"""
@@ -122,11 +124,14 @@ def get_key_elements(elements: Sequence[Any]) -> List[Any]:
122
124
middle_index = len (elements ) // 2
123
125
return [elements [0 ], elements [middle_index ], elements [- 1 ]]
124
126
125
- def __call__ (self , msgs : Sequence [ Any ] ) -> HumanMessage | None :
126
- if len (msgs ) == 0 :
127
+ def get (self ) -> HumanMessage | None :
128
+ if len (self . get_buffer () ) == 0 :
127
129
return None
128
130
129
- b64_images = [encode_ros2_img_to_base64 (msg ) for msg in msgs ]
131
+ b64_images = [encode_ros2_img_to_base64 (msg ) for msg in self ._buffer ]
132
+
133
+ self .clear ()
134
+
130
135
b64_images = self .get_key_elements (b64_images )
131
136
132
137
system_prompt = "You are an expert in image analysis and your speciality is the comparison of 2 images"
0 commit comments