Skip to content

Commit fa889c9

Browse files
feat: spatial reasoning tasks
1 parent 28fe8f6 commit fa889c9

12 files changed

+320
-4
lines changed
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (C) 2025 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List, Sequence
16+
17+
from rai_bench.tool_calling_agent_bench.agent_tasks_interfaces import (
18+
SpatialReasoningAgentTask,
19+
)
20+
from rai_bench.tool_calling_agent_bench.spatial_reasoning_tasks import (
21+
BoolImageTask,
22+
BoolImageTaskInput,
23+
)
24+
25+
inputs: List[BoolImageTaskInput] = [
26+
BoolImageTaskInput(
27+
question="Is the door on the left from the desk?",
28+
images_paths=["src/rai_bench/rai_bench/examples/images/image_1.jpg"],
29+
expected_response=True,
30+
),
31+
BoolImageTaskInput(
32+
question="Is the door open?",
33+
images_paths=["src/rai_bench/rai_bench/examples/images/image_1.jpg"],
34+
expected_response=False,
35+
),
36+
BoolImageTaskInput(
37+
question="Is someone in the room?",
38+
images_paths=["src/rai_bench/rai_bench/examples/images/image_1.jpg"],
39+
expected_response=False,
40+
),
41+
BoolImageTaskInput(
42+
question="Is the light on in the room?",
43+
images_paths=["src/rai_bench/rai_bench/examples/images/image_2.jpg"],
44+
expected_response=True,
45+
),
46+
BoolImageTaskInput(
47+
question="Do you see the plant?",
48+
images_paths=["src/rai_bench/rai_bench/examples/images/image_2.jpg"],
49+
expected_response=True,
50+
),
51+
BoolImageTaskInput(
52+
question="Do you see the plant?",
53+
images_paths=["src/rai_bench/rai_bench/examples/images/image_3.jpg"],
54+
expected_response=False,
55+
),
56+
BoolImageTaskInput(
57+
question="Are there any pictures on the wall?",
58+
images_paths=["src/rai_bench/rai_bench/examples/images/image_3.jpg"],
59+
expected_response=True,
60+
),
61+
BoolImageTaskInput(
62+
question="Are there 3 pictures on the wall?",
63+
images_paths=["src/rai_bench/rai_bench/examples/images/image_4.jpg"],
64+
expected_response=True,
65+
),
66+
BoolImageTaskInput(
67+
question="Are there 4 pictures on the wall?",
68+
images_paths=["src/rai_bench/rai_bench/examples/images/image_4.jpg"],
69+
expected_response=False,
70+
),
71+
BoolImageTaskInput(
72+
question="Is there a rack on the left from the sofa?",
73+
images_paths=["src/rai_bench/rai_bench/examples/images/image_4.jpg"],
74+
expected_response=False,
75+
),
76+
BoolImageTaskInput(
77+
question="Is there a plant behind the rack?",
78+
images_paths=["src/rai_bench/rai_bench/examples/images/image_5.jpg"],
79+
expected_response=True,
80+
),
81+
BoolImageTaskInput(
82+
question="Is there a plant on the right from the window?",
83+
images_paths=["src/rai_bench/rai_bench/examples/images/image_6.jpg"],
84+
expected_response=False,
85+
),
86+
BoolImageTaskInput(
87+
question="Is there a pillow on the armchain?",
88+
images_paths=["src/rai_bench/rai_bench/examples/images/image_7.jpg"],
89+
expected_response=True,
90+
),
91+
BoolImageTaskInput(
92+
question="Is there a red pillow on the armchair?",
93+
images_paths=["src/rai_bench/rai_bench/examples/images/image_7.jpg"],
94+
expected_response=False,
95+
),
96+
]
97+
98+
tasks: Sequence[SpatialReasoningAgentTask] = [
99+
BoolImageTask(task_input=input_item) for input_item in inputs
100+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (C) 2025 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
from datetime import datetime
17+
from pathlib import Path
18+
19+
from rai.agents.conversational_agent import create_conversational_agent
20+
from rai.utils.model_initialization import (
21+
get_llm_model,
22+
get_llm_model_config_and_vendor,
23+
)
24+
25+
from rai_bench.examples.spatial_reasoning_tasks import tasks
26+
from rai_bench.tool_calling_agent_bench.agent_bench import ToolCallingAgentBenchmark
27+
28+
if __name__ == "__main__":
29+
current_test_name = Path(__file__).stem
30+
31+
now = datetime.now()
32+
experiment_dir = (
33+
Path("src/rai_bench/rai_bench/experiments")
34+
/ current_test_name
35+
/ now.strftime("%Y-%m-%d_%H-%M-%S")
36+
)
37+
experiment_dir.mkdir(parents=True, exist_ok=True)
38+
log_filename = experiment_dir / "benchmark.log"
39+
results_filename = experiment_dir / "results.csv"
40+
41+
file_handler = logging.FileHandler(log_filename)
42+
file_handler.setLevel(logging.DEBUG)
43+
formatter = logging.Formatter(
44+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
45+
)
46+
file_handler.setFormatter(formatter)
47+
48+
bench_logger = logging.getLogger("Benchmark logger")
49+
bench_logger.setLevel(logging.INFO)
50+
bench_logger.addHandler(file_handler)
51+
52+
agent_logger = logging.getLogger("Agent logger")
53+
agent_logger.setLevel(logging.INFO)
54+
agent_logger.addHandler(file_handler)
55+
56+
for task in tasks:
57+
task.logger = bench_logger
58+
59+
benchmark = ToolCallingAgentBenchmark(
60+
tasks=tasks, logger=bench_logger, results_filename=results_filename
61+
)
62+
63+
model_type = "simple_model"
64+
model_config = get_llm_model_config_and_vendor(model_type=model_type)[0]
65+
model_name = getattr(model_config, model_type)
66+
67+
for task in tasks:
68+
agent = create_conversational_agent(
69+
llm=get_llm_model(model_type=model_type),
70+
tools=task.expected_tools,
71+
system_prompt=task.get_system_prompt(),
72+
logger=agent_logger,
73+
)
74+
benchmark.run_next(agent=agent, model_name=model_name)

src/rai_bench/rai_bench/tool_calling_agent_bench/agent_bench.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from rai.messages.multimodal import HumanMultimodalMessage
2929

3030
from rai_bench.tool_calling_agent_bench.agent_tasks_interfaces import (
31+
SpatialReasoningAgentTask,
3132
ToolCallingAgentTask,
3233
)
3334
from rai_bench.tool_calling_agent_bench.scores_tracing import ScoreTracingHandler
@@ -160,10 +161,26 @@ def run_next(self, agent: CompiledStateGraph, model_name: str) -> None:
160161

161162
ts = time.perf_counter()
162163
try:
163-
response = agent.invoke(
164-
{"messages": [HumanMultimodalMessage(content=task.get_prompt())]},
165-
config=config,
166-
)
164+
if isinstance(task, SpatialReasoningAgentTask):
165+
response = agent.invoke(
166+
{
167+
"messages": [
168+
HumanMultimodalMessage(
169+
content=task.get_prompt(), images=task.get_images()
170+
)
171+
]
172+
},
173+
config=config,
174+
)
175+
else:
176+
response = agent.invoke(
177+
{
178+
"messages": [
179+
HumanMultimodalMessage(content=task.get_prompt())
180+
]
181+
},
182+
config=config,
183+
)
167184
task.verify_tool_calls(response=response)
168185
except GraphRecursionError as e:
169186
task.log_error(msg=f"Graph Recursion Error: {e}")

src/rai_bench/rai_bench/tool_calling_agent_bench/agent_tasks_interfaces.py

+21
Original file line numberDiff line numberDiff line change
@@ -274,3 +274,24 @@ def _is_ai_message_requesting_get_ros2_topics_and_types(
274274
):
275275
return False
276276
return True
277+
278+
279+
class SpatialReasoningAgentTask(ToolCallingAgentTask):
280+
"""Abstract class for spatial reasoning tasks for tool calling agent."""
281+
282+
def __init__(self, logger: loggers_type | None = None) -> None:
283+
super().__init__(logger)
284+
self.expected_tools: List[BaseTool]
285+
self.question: str
286+
self.images_paths: List[str]
287+
288+
@abstractmethod
289+
def get_images(self) -> List[str]:
290+
"""Get the images related to the task.
291+
292+
Returns
293+
-------
294+
List[str]
295+
List of image paths
296+
"""
297+
pass
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright (C) 2025 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import logging
17+
from typing import Any, List, Optional, Sequence
18+
19+
from langchain_core.messages import AIMessage
20+
from langchain_core.tools import BaseTool
21+
from pydantic import BaseModel, Field
22+
from rai.messages import preprocess_image
23+
24+
from rai_bench.tool_calling_agent_bench.agent_tasks_interfaces import (
25+
SpatialReasoningAgentTask,
26+
)
27+
28+
29+
class TaskParametrizationError(Exception):
30+
"""Exception raised when the task parameters are not valid."""
31+
32+
pass
33+
34+
35+
SPATIAL_REASONING_SYSTEM_PROMPT = "You are a helpful and knowledgeable AI assistant that specializes in interpreting and analyzing visual content. Your task is to answer questions based on the images provided to you. Please response with the use of the provided tools."
36+
37+
38+
class ReturnBoolResponseToolInput(BaseModel):
39+
response: bool = Field(..., description="The response to the question.")
40+
41+
42+
class ReturnBoolResponseTool(BaseTool):
43+
"""Tool that returns a boolean response."""
44+
45+
name: str = "return_bool_response"
46+
description: str = "Return a bool response to the question."
47+
args_schema = ReturnBoolResponseToolInput
48+
49+
def _run(self, response: bool) -> bool:
50+
if type(response) is bool:
51+
return response
52+
raise ValueError("Invalid response type. Response must be a boolean.")
53+
54+
55+
class BoolImageTaskInput(BaseModel):
56+
question: str = Field(..., description="The question to be answered.")
57+
images_paths: List[str] = Field(
58+
...,
59+
description="List of image file paths to be used for answering the question.",
60+
)
61+
expected_response: bool = Field(
62+
..., description="The expected answer to the question."
63+
)
64+
65+
66+
class BoolImageTask(SpatialReasoningAgentTask):
67+
complexity = "easy"
68+
69+
def __init__(
70+
self,
71+
task_input: BoolImageTaskInput,
72+
logger: Optional[logging.Logger] = None,
73+
) -> None:
74+
super().__init__(logger)
75+
self.expected_tools = [ReturnBoolResponseTool()]
76+
self.question = task_input.question
77+
self.images_paths = task_input.images_paths
78+
self.expected_response = task_input.expected_response
79+
80+
def get_system_prompt(self) -> str:
81+
return SPATIAL_REASONING_SYSTEM_PROMPT
82+
83+
def get_prompt(self):
84+
return self.question
85+
86+
def get_images(self):
87+
images = [preprocess_image(image_path) for image_path in self.images_paths]
88+
return images
89+
90+
def verify_tool_calls(self, response: dict[str, Any]):
91+
messages = response["messages"]
92+
ai_messages: Sequence[AIMessage] = [
93+
message for message in messages if isinstance(message, AIMessage)
94+
]
95+
96+
if ai_messages:
97+
if self._check_tool_calls_num_in_ai_message(ai_messages[0], expected_num=1):
98+
self._check_tool_call(
99+
tool_call=ai_messages[0].tool_calls[0],
100+
expected_name="return_bool_response",
101+
expected_args={"response": self.expected_response},
102+
)
103+
if not self.result.errors:
104+
self.result.success = True

0 commit comments

Comments
 (0)