Skip to content

Commit 5b82ddc

Browse files
committed
added storing results, bug with get_object_pose
1 parent 946c5fc commit 5b82ddc

File tree

4 files changed

+22
-4
lines changed

4 files changed

+22
-4
lines changed

Diff for: src/rai_bench/rai_bench/benchmark_model.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,8 @@ def run_next(self, agent):
187187
self._logger.info(
188188
f"RUNNING SCENARIO NUMBER {i+1}, TASK: {scenario.task.get_prompt()}"
189189
)
190+
initial_result = scenario.task.calculate_result(self.simulation_bridge)
191+
self._logger.info(f"RESULT OF THE INITIAL SETUP: {initial_result}")
190192
ts = time.perf_counter()
191193
for state in agent.stream(
192194
{"messages": [HumanMessage(content=scenario.task.get_prompt())]}
@@ -208,15 +210,26 @@ def run_next(self, agent):
208210
raise ValueError(f"Unexpected type of message: {type(msg)}")
209211

210212
self._logger.info(f"AI Message: {msg}")
211-
# TODO (jm) figure out how to get number of tool calls
212213
te = time.perf_counter()
213214

214215
result = scenario.task.calculate_result(self.simulation_bridge)
215216

216217
total_time = te - ts
217218
self._logger.info(f"TASK SCORE: {result}, TOTAL TIME: {total_time:.3f}")
218219

219-
# self.results.append{""})
220+
self.results.append(
221+
{
222+
"task": scenario.task.get_prompt(),
223+
"initial_score": initial_result,
224+
"final_score": result,
225+
"total_time": f"{total_time:.3f}",
226+
# TODO (jm) figure out how to get number of tool calls
227+
"tool_calls": None,
228+
}
229+
)
220230

221231
except StopIteration:
222232
print("No more scenarios left to run.")
233+
234+
def get_results(self) -> list[dict]:
235+
return self.results

Diff for: src/rai_bench/rai_bench/main.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@
4646
node = connector.node
4747
node.declare_parameter("conversion_ratio", 1.0)
4848

49-
o3de = O3DEngineArmManipulationBridge(connector)
50-
5149
# define model
5250
llm = get_llm_model(model_type="complex_model", streaming=True)
5351

@@ -154,6 +152,7 @@
154152
# custom request to arm
155153
base_arm_pose = PoseModel(translation=Translation(x=0.3, y=0.0, z=0.5))
156154

155+
o3de = O3DEngineArmManipulationBridge(connector, logger=agent_logger)
157156
# define benchamrk
158157
benchmark = Benchmark(
159158
simulation_bridge=o3de,

Diff for: src/rai_bench/rai_bench/o3de_test_bench/tasks/grab_carrot_task.py

+4
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def calculate_result(self, simulation_bridge: SimulationBridge) -> float:
3636
)
3737
num_initial_carrots = len(initial_carrots)
3838

39+
self.logger.info(initial_carrots)
40+
self.logger.info(final_carrots)
3941
if num_initial_carrots != len(final_carrots):
4042
raise EntitiesMismatchException(
4143
"Number of initially spawned entities does not match number of entities present at the end."
@@ -51,6 +53,8 @@ def calculate_result(self, simulation_bridge: SimulationBridge) -> float:
5153
# NOTE the specific coords that refer to for example
5254
# middle of the table can differ across simulations,
5355
# take that into consideration
56+
self.logger.info(initial_y)
57+
self.logger.info(final_y)
5458
if (
5559
initial_y <= 0.0
5660
): # Carrot started in the incorrect place (right side)

Diff for: src/rai_sim/rai_sim/o3de/o3de_bridge.py

+2
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def _despawn_entity(self, entity: SpawnedEntity):
175175
)
176176

177177
def get_object_pose(self, entity: SpawnedEntity) -> PoseModel:
178+
self.logger.info(f"GET OBJECT POSE: {entity}")
178179
object_frame = entity.name + "/"
179180
ros2_pose = do_transform_pose(
180181
Pose(), self.connector.get_transform(object_frame + "odom", object_frame)
@@ -193,6 +194,7 @@ def get_scene_state(self) -> SceneState:
193194
entities: list[SpawnedEntity] = []
194195
for entity in self.spawned_entities:
195196
current_pose = self.get_object_pose(entity)
197+
self.logger.info(f"AFTER GET OBJECT POSE: {current_pose}")
196198
entities.append(
197199
SpawnedEntity(
198200
id=entity.id,

0 commit comments

Comments
 (0)