Skip to content

Commit 4de55d2

Browse files
committed
added way to automatically create scenarios
1 parent ab162ff commit 4de55d2

File tree

2 files changed

+71
-35
lines changed

2 files changed

+71
-35
lines changed

src/rai_bench/rai_bench/benchmark_model.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,22 @@ def __init__(
160160
else:
161161
self._logger = logging.getLogger(__name__)
162162

163+
@classmethod
164+
def create_scenarios(
165+
cls, tasks: List[Task], simulation_configs: List[SimulationConfig]
166+
):
167+
168+
scenarios = []
169+
for task in tasks:
170+
for sim_conf in simulation_configs:
171+
try:
172+
scenarios.append(Scenario(task=task, simulation_config=sim_conf))
173+
except ValueError as e:
174+
print(
175+
f"Could not create Scenario from task: {task.get_prompt()} and simulation_config: {sim_conf}, {e}"
176+
)
177+
return scenarios
178+
163179
def run_next(self, agent):
164180
"""
165181
Runs the next scenario

src/rai_bench/rai_bench/main.py

Lines changed: 55 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -92,44 +92,64 @@
9292
agent_logger.setLevel(logging.INFO)
9393
agent_logger.addHandler(file_handler)
9494

95-
# load different scenes
9695
configs_dir = "src/rai_bench/rai_bench/o3de_test_bench/configs/"
9796
connector_path = configs_dir + "o3de_config.yaml"
98-
one_carrot_simulation_config = O3DExROS2SimulationConfig.load_config(
99-
base_config_path=Path(configs_dir + "scene1.yaml"),
100-
connector_config_path=Path(connector_path),
101-
)
102-
multiple_carrot_simulation_config = O3DExROS2SimulationConfig.load_config(
103-
base_config_path=Path(configs_dir + "scene2.yaml"),
104-
connector_config_path=Path(connector_path),
105-
)
106-
red_cubes_simulation_config = O3DExROS2SimulationConfig.load_config(
107-
base_config_path=Path(configs_dir + "scene3.yaml"),
108-
connector_config_path=Path(connector_path),
109-
)
110-
multiple_cubes_simulation_config = O3DExROS2SimulationConfig.load_config(
111-
base_config_path=Path(configs_dir + "scene4.yaml"),
112-
connector_config_path=Path(connector_path),
113-
)
114-
# combine different scene configs with the tasks to create various scenarios
115-
scenarios = [
116-
Scenario(
117-
task=GrabCarrotTask(logger=bench_logger),
118-
simulation_config=one_carrot_simulation_config,
119-
),
120-
Scenario(
121-
task=GrabCarrotTask(logger=bench_logger),
122-
simulation_config=multiple_carrot_simulation_config,
123-
),
124-
Scenario(
125-
task=PlaceCubesTask(logger=bench_logger),
126-
simulation_config=red_cubes_simulation_config,
127-
),
128-
Scenario(
129-
task=PlaceCubesTask(logger=bench_logger),
130-
simulation_config=multiple_cubes_simulation_config,
131-
),
97+
#### Create scenarios manually
98+
# load different scenes
99+
# one_carrot_simulation_config = O3DExROS2SimulationConfig.load_config(
100+
# base_config_path=Path(configs_dir + "scene1.yaml"),
101+
# connector_config_path=Path(connector_path),
102+
# )
103+
# multiple_carrot_simulation_config = O3DExROS2SimulationConfig.load_config(
104+
# base_config_path=Path(configs_dir + "scene2.yaml"),
105+
# connector_config_path=Path(connector_path),
106+
# )
107+
# red_cubes_simulation_config = O3DExROS2SimulationConfig.load_config(
108+
# base_config_path=Path(configs_dir + "scene3.yaml"),
109+
# connector_config_path=Path(connector_path),
110+
# )
111+
# multiple_cubes_simulation_config = O3DExROS2SimulationConfig.load_config(
112+
# base_config_path=Path(configs_dir + "scene4.yaml"),
113+
# connector_config_path=Path(connector_path),
114+
# )
115+
# # combine different scene configs with the tasks to create various scenarios
116+
# scenarios = [
117+
# Scenario(
118+
# task=GrabCarrotTask(logger=bench_logger),
119+
# simulation_config=one_carrot_simulation_config,
120+
# ),
121+
# Scenario(
122+
# task=GrabCarrotTask(logger=bench_logger),
123+
# simulation_config=multiple_carrot_simulation_config,
124+
# ),
125+
# Scenario(
126+
# task=PlaceCubesTask(logger=bench_logger),
127+
# simulation_config=red_cubes_simulation_config,
128+
# ),
129+
# Scenario(
130+
# task=PlaceCubesTask(logger=bench_logger),
131+
# simulation_config=multiple_cubes_simulation_config,
132+
# ),
133+
# ]
134+
135+
### Create scenarios automatically
136+
scene_paths = [
137+
configs_dir + "scene1.yaml",
138+
configs_dir + "scene2.yaml",
139+
configs_dir + "scene3.yaml",
140+
configs_dir + "scene4.yaml",
132141
]
142+
simulations_configs = [
143+
O3DExROS2SimulationConfig.load_config(Path(path), Path(connector_path))
144+
for path in scene_paths
145+
]
146+
tasks = [
147+
GrabCarrotTask(logger=bench_logger),
148+
PlaceCubesTask(logger=bench_logger),
149+
]
150+
scenarios = Benchmark.create_scenarios(
151+
tasks=tasks, simulation_configs=simulations_configs
152+
)
133153

134154
# custom request to arm
135155
base_arm_pose = PoseModel(translation=Translation(x=0.3, y=0.0, z=0.4))

0 commit comments

Comments
 (0)