Skip to content

Commit 285817b

Browse files
committed
style: change of formatting
1 parent 4f11d20 commit 285817b

File tree

3 files changed

+30
-25
lines changed

3 files changed

+30
-25
lines changed

src/rai_bench/rai_bench/benchmark_model.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
SpawnedEntity,
2929
)
3030

31-
SimulationBridgeT = TypeVar("SimulationBridgeT", bound=SimulationBridge)
31+
SimulationBridgeT = TypeVar(
32+
"SimulationBridgeT", bound=SimulationBridge[SimulationConfig]
33+
)
3234
loggers_type = Union[RcutilsLogger, logging.Logger]
3335

3436

@@ -70,7 +72,9 @@ def validate_config(self, simulation_config: SimulationConfig) -> bool:
7072
pass
7173

7274
@abstractmethod
73-
def calculate_result(self, simulation_bridge: SimulationBridge) -> float:
75+
def calculate_result(
76+
self, simulation_bridge: SimulationBridge[SimulationConfig]
77+
) -> float:
7478
"""
7579
Calculate result of the task
7680
"""
@@ -95,7 +99,7 @@ def is_adjacent(self, pos1: PoseModel, pos2: PoseModel, threshold_distance: floa
9599
Check if positions are adjacent to each other, the threshold_distance is a distance
96100
in simulation, refering to how close they have to be to classify them as adjacent
97101
"""
98-
self.logger.debug(
102+
self.logger.debug( # type: ignore
99103
f"Euclidean distance: {self.euclidean_distance(pos1, pos2)}, pos1: {pos1}, pos2: {pos2}"
100104
)
101105
return self.euclidean_distance(pos1, pos2) < threshold_distance
@@ -164,7 +168,7 @@ def __init__(
164168
def create_scenarios(
165169
cls, tasks: List[Task], simulation_configs: List[SimulationConfig]
166170
) -> list[Any]:
167-
scenarios = []
171+
scenarios: List[Scenario] = []
168172
for task in tasks:
169173
for sim_conf in simulation_configs:
170174
try:
@@ -175,22 +179,22 @@ def create_scenarios(
175179
)
176180
return scenarios
177181

178-
def run_next(self, agent):
182+
def run_next(self, agent) -> None:
179183
"""
180184
Runs the next scenario
181185
"""
182186
try:
183187
i, scenario = next(self.scenarios) # Get the next scenario
184188

185189
self.simulation_bridge.setup_scene(scenario.simulation_config)
186-
self._logger.info(
190+
self._logger.info( # type: ignore
187191
"======================================================================================"
188192
)
189-
self._logger.info(
193+
self._logger.info( # type: ignore
190194
f"RUNNING SCENARIO NUMBER {i + 1}, TASK: {scenario.task.get_prompt()}"
191195
)
192196
initial_result = scenario.task.calculate_result(self.simulation_bridge)
193-
self._logger.info(f"RESULT OF THE INITIAL SETUP: {initial_result}")
197+
self._logger.info(f"RESULT OF THE INITIAL SETUP: {initial_result}") # type: ignore
194198
tool_calls_num = 0
195199

196200
ts = time.perf_counter()
@@ -209,7 +213,7 @@ def run_next(self, agent):
209213
last_msg = msg.content[0].get("text", "")
210214
else:
211215
last_msg = msg.content
212-
self._logger.debug(f"{graph_node_name}: {last_msg}")
216+
self._logger.debug(f"{graph_node_name}: {last_msg}") # type: ignore
213217

214218
else:
215219
raise ValueError(f"Unexpected type of message: {type(msg)}")
@@ -218,13 +222,13 @@ def run_next(self, agent):
218222
# TODO (jm) figure out more robust way of counting tool calls
219223
tool_calls_num += len(msg.tool_calls)
220224

221-
self._logger.info(f"AI Message: {msg}")
225+
self._logger.info(f"AI Message: {msg}") # type: ignore
222226

223227
te = time.perf_counter()
224228

225229
result = scenario.task.calculate_result(self.simulation_bridge)
226230
total_time = te - ts
227-
self._logger.info(
231+
self._logger.info( # type: ignore
228232
f"TASK SCORE: {result}, TOTAL TIME: {total_time:.3f}, NUM_OF_TOOL_CALLS: {tool_calls_num}"
229233
)
230234

@@ -241,5 +245,5 @@ def run_next(self, agent):
241245
except StopIteration:
242246
print("No more scenarios left to run.")
243247

244-
def get_results(self) -> list[dict]:
248+
def get_results(self) -> List[Dict[str, Any]]:
245249
return self.results

src/rai_bench/rai_bench/main.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,25 @@
1616
import logging
1717
import time
1818
from pathlib import Path
19+
from typing import List
1920

2021
import rclpy
21-
import rclpy.qos
22+
from langchain.tools import BaseTool
2223
from rai_open_set_vision.tools import GetGrabbingPointTool
2324

2425
from rai.agents.conversational_agent import create_conversational_agent
2526
from rai.communication.ros2.connectors import ROS2ARIConnector
2627
from rai.tools.ros.manipulation import GetObjectPositionsTool, MoveToPointTool
2728
from rai.tools.ros2.topics import GetROS2ImageTool, GetROS2TopicsNamesAndTypesTool
2829
from rai.utils.model_initialization import get_llm_model
29-
from rai_bench.benchmark_model import (
30-
Benchmark,
31-
)
30+
from rai_bench.benchmark_model import Benchmark, Task
3231
from rai_bench.o3de_test_bench.tasks import GrabCarrotTask, PlaceCubesTask
3332
from rai_sim.o3de.o3de_bridge import (
3433
O3DEngineArmManipulationBridge,
3534
O3DExROS2SimulationConfig,
3635
PoseModel,
3736
)
38-
from rai_sim.simulation_bridge import Rotation, Translation
37+
from rai_sim.simulation_bridge import Rotation, SimulationConfig, Translation
3938

4039
if __name__ == "__main__":
4140
rclpy.init()
@@ -55,7 +54,7 @@
5554
Before starting the task, make sure to grab the camera image to understand the environment.
5655
"""
5756
# define tools
58-
tools = [
57+
tools: List[BaseTool] = [
5958
GetObjectPositionsTool(
6059
connector=connector,
6160
target_frame="panda_link0",
@@ -80,7 +79,7 @@
8079
file_handler.setFormatter(formatter)
8180

8281
bench_logger = logging.getLogger("Benchmark logger")
83-
bench_logger.setLevel(logging.DEBUG)
82+
bench_logger.setLevel(logging.INFO)
8483
bench_logger.addHandler(file_handler)
8584

8685
agent_logger = logging.getLogger("Agent logger")
@@ -134,11 +133,11 @@
134133
configs_dir + "scene3.yaml",
135134
configs_dir + "scene4.yaml",
136135
]
137-
simulations_configs = [
136+
simulations_configs: List[SimulationConfig] = [
138137
O3DExROS2SimulationConfig.load_config(Path(path), Path(connector_path))
139138
for path in scene_paths
140139
]
141-
tasks = [
140+
tasks: List[Task] = [
142141
GrabCarrotTask(logger=bench_logger),
143142
PlaceCubesTask(logger=bench_logger),
144143
]

src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ def validate_config(self, simulation_config: SimulationConfig) -> bool:
3333

3434
return False
3535

36-
def calculate_result(self, simulation_bridge: SimulationBridge) -> float:
36+
def calculate_result(
37+
self, simulation_bridge: SimulationBridge[SimulationConfig]
38+
) -> float:
3739
# TODO (jm) extract common logic to some parent manipulation task?
3840
initially_misplaced_now_correct = 0 # when the object which was in the incorrect place at the start, is in a correct place at the end
3941
initially_misplaced_still_incorrect = 0 # when the object which was in the incorrect place at the start, is in a incorrect place at the end
@@ -55,8 +57,8 @@ def calculate_result(self, simulation_bridge: SimulationBridge) -> float:
5557
)
5658

5759
else:
58-
self.logger.debug(f"initial positions: {initial_carrots}")
59-
self.logger.debug(f"current positions: {final_carrots}")
60+
self.logger.debug(f"initial positions: {initial_carrots}") # type: ignore
61+
self.logger.debug(f"current positions: {final_carrots}") # type: ignore
6062
for ini_carrot in initial_carrots:
6163
for final_carrot in final_carrots:
6264
if ini_carrot.name == final_carrot.name:
@@ -91,7 +93,7 @@ def calculate_result(self, simulation_bridge: SimulationBridge) -> float:
9193
f"Entity with name: {ini_carrot.name} which was present in initial scene, not found in final scene."
9294
)
9395

94-
self.logger.info(
96+
self.logger.info( # type: ignore
9597
f"initially_misplaced_now_correct: {initially_misplaced_now_correct}, initially_misplaced_still_incorrect: {initially_misplaced_still_incorrect}, initially_correct_still_correct: {initially_correct_still_correct}, initially_correct_now_incorrect: {initially_correct_now_incorrect}"
9698
)
9799
return (

0 commit comments

Comments
 (0)