Skip to content

Commit a0665a3

Browse files
committed
chore: pre-commit
1 parent 3998a7b commit a0665a3

File tree

3 files changed

+16
-17
lines changed

3 files changed

+16
-17
lines changed

src/rai_core/rai/aggregators/base.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,37 +14,36 @@
1414

1515
from abc import ABC, abstractmethod
1616
from collections import deque
17-
from typing import Deque, Generic, TypeVar, List
17+
from typing import Deque, Generic, List, TypeVar
1818

1919
from langchain_core.messages import HumanMessage
20-
from rai.messages.multimodal import HumanMultimodalMessage
2120

21+
from rai.messages.multimodal import HumanMultimodalMessage
2222

2323
T = TypeVar("T")
2424

2525

2626
class BaseAggregator(ABC, Generic[T]):
2727
"""
28-
Interface for aggregators.
28+
Interface for aggregators.
2929
3030
`__call__` method receives a message and appends it to the buffer.
3131
`get` method returns the aggregated message.
3232
"""
33-
def __init__(self, max_size: int | None=None) -> None:
33+
34+
def __init__(self, max_size: int | None = None) -> None:
3435
super().__init__()
3536
self._buffer: Deque[T] = deque()
3637
self.max_size = max_size
3738

38-
def __call__(
39-
self, msg: T
40-
) -> None:
39+
def __call__(self, msg: T) -> None:
4140
if self.max_size is not None and len(self._buffer) >= self.max_size:
4241
self._buffer.popleft()
4342
self._buffer.append(msg)
4443

4544
@abstractmethod
4645
def get(self) -> HumanMessage | HumanMultimodalMessage | None:
47-
""" Returns the aggregated message """
46+
"""Returns the aggregated message"""
4847
pass
4948

5049
def clear(self) -> None:
@@ -55,4 +54,3 @@ def get_buffer(self) -> List[T]:
5554

5655
def __str__(self) -> str:
5756
return f"{self.__class__.__name__}"
58-

src/rai_core/rai/aggregators/ros2/aggregators.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, List, Sequence, cast
15+
from typing import Any, List, cast
1616

1717
from langchain_core.messages import HumanMessage, SystemMessage
1818
from pydantic import BaseModel, Field
@@ -63,14 +63,12 @@ def get(self) -> HumanMultimodalMessage | None:
6363
return HumanMultimodalMessage(content="", images=[b64_image])
6464

6565

66-
class ROS2ImgVLMDescriptionAggregator(
67-
BaseAggregator[Image | CompressedImage]
68-
):
66+
class ROS2ImgVLMDescriptionAggregator(BaseAggregator[Image | CompressedImage]):
6967
"""
7068
Returns the VLM analysis of the last image in the aggregation buffer
7169
"""
7270

73-
def __init__(self, max_size: int | None=None) -> None:
71+
def __init__(self, max_size: int | None = None) -> None:
7472
super().__init__(max_size)
7573
self.llm = get_llm_model(model_type="simple_model", streaming=True)
7674

@@ -110,7 +108,7 @@ class ROS2ImgVLMDiffAggregator(BaseAggregator[Image | CompressedImage]):
110108
aggregation buffer: 1st, midden, last
111109
"""
112110

113-
def __init__(self, max_size: int | None=None) -> None:
111+
def __init__(self, max_size: int | None = None) -> None:
114112
super().__init__(max_size)
115113
self.llm = get_llm_model(model_type="simple_model", streaming=True)
116114

@@ -130,7 +128,7 @@ def get(self) -> HumanMessage | None:
130128
return None
131129

132130
b64_images = [encode_ros2_img_to_base64(msg) for msg in msgs]
133-
131+
134132
self.clear()
135133

136134
b64_images = self.get_key_elements(b64_images)

src/rai_core/rai/tools/ros2/generic/topics.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
from pydantic import BaseModel, Field
2323

2424
from rai.communication.ros2 import ROS2Connector, ROS2Message
25-
from rai.communication.ros2.api.conversion import encode_ros2_img_to_base64, ros2_message_to_dict
25+
from rai.communication.ros2.api.conversion import (
26+
encode_ros2_img_to_base64,
27+
ros2_message_to_dict,
28+
)
2629
from rai.messages import MultimodalArtifact
2730
from rai.tools.ros2.base import BaseROS2Tool, BaseROS2Toolkit
2831

0 commit comments

Comments
 (0)