Skip to content

Commit fa889c9

Browse files
feat: spatial reasoning tasks
1 parent 28fe8f6 commit fa889c9

File tree

12 files changed

+320
-4
lines changed

12 files changed

+320
-4
lines changed
64.6 KB
Loading
66.2 KB
Loading
77.8 KB
Loading
75.1 KB
Loading
95.6 KB
Loading
83.3 KB
Loading
69.5 KB
Loading
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (C) 2025 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List, Sequence
16+
17+
from rai_bench.tool_calling_agent_bench.agent_tasks_interfaces import (
18+
SpatialReasoningAgentTask,
19+
)
20+
from rai_bench.tool_calling_agent_bench.spatial_reasoning_tasks import (
21+
BoolImageTask,
22+
BoolImageTaskInput,
23+
)
24+
25+
inputs: List[BoolImageTaskInput] = [
26+
BoolImageTaskInput(
27+
question="Is the door on the left from the desk?",
28+
images_paths=["src/rai_bench/rai_bench/examples/images/image_1.jpg"],
29+
expected_response=True,
30+
),
31+
BoolImageTaskInput(
32+
question="Is the door open?",
33+
images_paths=["src/rai_bench/rai_bench/examples/images/image_1.jpg"],
34+
expected_response=False,
35+
),
36+
BoolImageTaskInput(
37+
question="Is someone in the room?",
38+
images_paths=["src/rai_bench/rai_bench/examples/images/image_1.jpg"],
39+
expected_response=False,
40+
),
41+
BoolImageTaskInput(
42+
question="Is the light on in the room?",
43+
images_paths=["src/rai_bench/rai_bench/examples/images/image_2.jpg"],
44+
expected_response=True,
45+
),
46+
BoolImageTaskInput(
47+
question="Do you see the plant?",
48+
images_paths=["src/rai_bench/rai_bench/examples/images/image_2.jpg"],
49+
expected_response=True,
50+
),
51+
BoolImageTaskInput(
52+
question="Do you see the plant?",
53+
images_paths=["src/rai_bench/rai_bench/examples/images/image_3.jpg"],
54+
expected_response=False,
55+
),
56+
BoolImageTaskInput(
57+
question="Are there any pictures on the wall?",
58+
images_paths=["src/rai_bench/rai_bench/examples/images/image_3.jpg"],
59+
expected_response=True,
60+
),
61+
BoolImageTaskInput(
62+
question="Are there 3 pictures on the wall?",
63+
images_paths=["src/rai_bench/rai_bench/examples/images/image_4.jpg"],
64+
expected_response=True,
65+
),
66+
BoolImageTaskInput(
67+
question="Are there 4 pictures on the wall?",
68+
images_paths=["src/rai_bench/rai_bench/examples/images/image_4.jpg"],
69+
expected_response=False,
70+
),
71+
BoolImageTaskInput(
72+
question="Is there a rack on the left from the sofa?",
73+
images_paths=["src/rai_bench/rai_bench/examples/images/image_4.jpg"],
74+
expected_response=False,
75+
),
76+
BoolImageTaskInput(
77+
question="Is there a plant behind the rack?",
78+
images_paths=["src/rai_bench/rai_bench/examples/images/image_5.jpg"],
79+
expected_response=True,
80+
),
81+
BoolImageTaskInput(
82+
question="Is there a plant on the right from the window?",
83+
images_paths=["src/rai_bench/rai_bench/examples/images/image_6.jpg"],
84+
expected_response=False,
85+
),
86+
BoolImageTaskInput(
87+
question="Is there a pillow on the armchain?",
88+
images_paths=["src/rai_bench/rai_bench/examples/images/image_7.jpg"],
89+
expected_response=True,
90+
),
91+
BoolImageTaskInput(
92+
question="Is there a red pillow on the armchair?",
93+
images_paths=["src/rai_bench/rai_bench/examples/images/image_7.jpg"],
94+
expected_response=False,
95+
),
96+
]
97+
98+
tasks: Sequence[SpatialReasoningAgentTask] = [
99+
BoolImageTask(task_input=input_item) for input_item in inputs
100+
]
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (C) 2025 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
from datetime import datetime
17+
from pathlib import Path
18+
19+
from rai.agents.conversational_agent import create_conversational_agent
20+
from rai.utils.model_initialization import (
21+
get_llm_model,
22+
get_llm_model_config_and_vendor,
23+
)
24+
25+
from rai_bench.examples.spatial_reasoning_tasks import tasks
26+
from rai_bench.tool_calling_agent_bench.agent_bench import ToolCallingAgentBenchmark
27+
28+
if __name__ == "__main__":
29+
current_test_name = Path(__file__).stem
30+
31+
now = datetime.now()
32+
experiment_dir = (
33+
Path("src/rai_bench/rai_bench/experiments")
34+
/ current_test_name
35+
/ now.strftime("%Y-%m-%d_%H-%M-%S")
36+
)
37+
experiment_dir.mkdir(parents=True, exist_ok=True)
38+
log_filename = experiment_dir / "benchmark.log"
39+
results_filename = experiment_dir / "results.csv"
40+
41+
file_handler = logging.FileHandler(log_filename)
42+
file_handler.setLevel(logging.DEBUG)
43+
formatter = logging.Formatter(
44+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
45+
)
46+
file_handler.setFormatter(formatter)
47+
48+
bench_logger = logging.getLogger("Benchmark logger")
49+
bench_logger.setLevel(logging.INFO)
50+
bench_logger.addHandler(file_handler)
51+
52+
agent_logger = logging.getLogger("Agent logger")
53+
agent_logger.setLevel(logging.INFO)
54+
agent_logger.addHandler(file_handler)
55+
56+
for task in tasks:
57+
task.logger = bench_logger
58+
59+
benchmark = ToolCallingAgentBenchmark(
60+
tasks=tasks, logger=bench_logger, results_filename=results_filename
61+
)
62+
63+
model_type = "simple_model"
64+
model_config = get_llm_model_config_and_vendor(model_type=model_type)[0]
65+
model_name = getattr(model_config, model_type)
66+
67+
for task in tasks:
68+
agent = create_conversational_agent(
69+
llm=get_llm_model(model_type=model_type),
70+
tools=task.expected_tools,
71+
system_prompt=task.get_system_prompt(),
72+
logger=agent_logger,
73+
)
74+
benchmark.run_next(agent=agent, model_name=model_name)

src/rai_bench/rai_bench/tool_calling_agent_bench/agent_bench.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from rai.messages.multimodal import HumanMultimodalMessage
2929

3030
from rai_bench.tool_calling_agent_bench.agent_tasks_interfaces import (
31+
SpatialReasoningAgentTask,
3132
ToolCallingAgentTask,
3233
)
3334
from rai_bench.tool_calling_agent_bench.scores_tracing import ScoreTracingHandler
@@ -160,10 +161,26 @@ def run_next(self, agent: CompiledStateGraph, model_name: str) -> None:
160161

161162
ts = time.perf_counter()
162163
try:
163-
response = agent.invoke(
164-
{"messages": [HumanMultimodalMessage(content=task.get_prompt())]},
165-
config=config,
166-
)
164+
if isinstance(task, SpatialReasoningAgentTask):
165+
response = agent.invoke(
166+
{
167+
"messages": [
168+
HumanMultimodalMessage(
169+
content=task.get_prompt(), images=task.get_images()
170+
)
171+
]
172+
},
173+
config=config,
174+
)
175+
else:
176+
response = agent.invoke(
177+
{
178+
"messages": [
179+
HumanMultimodalMessage(content=task.get_prompt())
180+
]
181+
},
182+
config=config,
183+
)
167184
task.verify_tool_calls(response=response)
168185
except GraphRecursionError as e:
169186
task.log_error(msg=f"Graph Recursion Error: {e}")

0 commit comments

Comments
 (0)