Skip to content

Commit 4f11d20

Browse files
committed
refactor: funtions naming, typing and rotation of arm change
1 parent 09206e5 commit 4f11d20

File tree

5 files changed

+19
-22
lines changed

5 files changed

+19
-22
lines changed

src/rai_bench/rai_bench/benchmark_model.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import logging
1616
import time
1717
from abc import ABC, abstractmethod
18-
from typing import List, TypeVar, Union
18+
from typing import Any, Dict, List, TypeVar, Union
1919

2020
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
2121
from rclpy.impl.rcutils_logger import RcutilsLogger
@@ -59,9 +59,8 @@ def get_prompt(self) -> str:
5959
pass
6060

6161
@abstractmethod
62-
def validate_scene(self, simulation_config: SimulationConfig) -> bool:
63-
"""Task should be able to verify if given scene is suitable for specific task
64-
for example: GrabCarrotTask should verify if there is any carrots in the scene
62+
def validate_config(self, simulation_config: SimulationConfig) -> bool:
63+
"""Task should be able to verify if given config is suitable for specific task
6564
6665
Args:
6766
simulation_config (SimulationConfig): initial scene setup
@@ -136,7 +135,7 @@ class Scenario:
136135
"""Single instances are run separatly by benchmark"""
137136

138137
def __init__(self, task: Task, simulation_config: SimulationConfig) -> None:
139-
if not task.validate_scene(simulation_config):
138+
if not task.validate_config(simulation_config):
140139
raise ValueError("This scene is invalid for this task.")
141140
self.task = task
142141
self.simulation_config = simulation_config
@@ -149,13 +148,13 @@ class Benchmark:
149148

150149
def __init__(
151150
self,
152-
simulation_bridge: SimulationBridge,
151+
simulation_bridge: SimulationBridge[SimulationConfig],
153152
scenarios: list[Scenario],
154153
logger: loggers_type | None = None,
155154
) -> None:
156155
self.simulation_bridge = simulation_bridge
157156
self.scenarios = enumerate(iter(scenarios))
158-
self.results = []
157+
self.results: List[Dict[str, Any]] = []
159158
if logger:
160159
self._logger = logger
161160
else:
@@ -164,7 +163,7 @@ def __init__(
164163
@classmethod
165164
def create_scenarios(
166165
cls, tasks: List[Task], simulation_configs: List[SimulationConfig]
167-
):
166+
) -> list[Any]:
168167
scenarios = []
169168
for task in tasks:
170169
for sim_conf in simulation_configs:
@@ -235,7 +234,7 @@ def run_next(self, agent):
235234
"initial_score": initial_result,
236235
"final_score": result,
237236
"total_time": f"{total_time:.3f}",
238-
"numer_of_tool_calls": tool_calls_num,
237+
"number_of_tool_calls": tool_calls_num,
239238
}
240239
)
241240

src/rai_bench/rai_bench/main.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
O3DExROS2SimulationConfig,
3636
PoseModel,
3737
)
38-
from rai_sim.simulation_bridge import Translation
38+
from rai_sim.simulation_bridge import Rotation, Translation
3939

4040
if __name__ == "__main__":
4141
rclpy.init()
@@ -147,7 +147,10 @@
147147
)
148148

149149
# custom request to arm
150-
base_arm_pose = PoseModel(translation=Translation(x=0.5, y=0.1, z=0.3))
150+
base_arm_pose = PoseModel(
151+
translation=Translation(x=0.5, y=0.1, z=0.3),
152+
rotation=Rotation(x=1.0, y=0.0, z=0.0, w=0.0),
153+
)
151154

152155
o3de = O3DEngineArmManipulationBridge(connector, logger=agent_logger)
153156
# define benchamrk

src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class GrabCarrotTask(Task):
2626
def get_prompt(self) -> str:
2727
return "Manipulate objects, so that all carrots to the left side of the table (positive y)"
2828

29-
def validate_scene(self, simulation_config: SimulationConfig) -> bool:
29+
def validate_config(self, simulation_config: SimulationConfig) -> bool:
3030
for ent in simulation_config.entities:
3131
if ent.prefab_name == "carrot":
3232
return True

src/rai_bench/rai_bench/o3de_test_bench/tasks/place_cubes_task.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class PlaceCubesTask(Task):
2323
def get_prompt(self) -> str:
2424
return "Manipulate objects, so that all cubes are adjacent to at least one cube"
2525

26-
def validate_scene(self, simulation_config: SimulationConfig) -> bool:
26+
def validate_config(self, simulation_config: SimulationConfig) -> bool:
2727
cube_types = ["red_cube", "blue_cube", "yellow_cube"]
2828
cubes_num = 0
2929
for ent in simulation_config.entities:

src/rai_sim/rai_sim/o3de/o3de_bridge.py

+4-9
Original file line numberDiff line numberDiff line change
@@ -214,11 +214,11 @@ def _is_robotic_stack_ready(
214214
services = self.connector.node.get_service_names_and_types()
215215
topics_names = [tp[0] for tp in topics]
216216
service_names = [srv[0] for srv in services]
217-
self.logger.debug(
217+
self.logger.info(
218218
f"required services: {simulation_config.required_services}"
219219
)
220-
self.logger.debug(f"required topics: {simulation_config.required_topics}")
221-
self.logger.debug(f"required actions: {simulation_config.required_actions}")
220+
self.logger.info(f"required topics: {simulation_config.required_topics}")
221+
self.logger.info(f"required actions: {simulation_config.required_actions}")
222222
# NOTE actions will be listed in services and topics
223223
if (
224224
all(srv in service_names for srv in simulation_config.required_services)
@@ -227,7 +227,7 @@ def _is_robotic_stack_ready(
227227
ac in service_names for ac in simulation_config.required_actions
228228
)
229229
):
230-
self.logger.debug("All required services are available.")
230+
self.logger.info("All required services are available.")
231231
return True
232232

233233
time.sleep(5)
@@ -380,11 +380,6 @@ def move_arm(
380380
request.target_pose.pose.orientation.y = pose.rotation.y
381381
request.target_pose.pose.orientation.z = pose.rotation.z
382382
request.target_pose.pose.orientation.w = pose.rotation.w
383-
else:
384-
request.target_pose.pose.orientation.x = 1.0
385-
request.target_pose.pose.orientation.y = 0.0
386-
request.target_pose.pose.orientation.z = 0.0
387-
request.target_pose.pose.orientation.w = 0.0
388383

389384
client = self.connector.node.create_client(
390385
ManipulatorMoveTo,

0 commit comments

Comments
 (0)