28
28
SpawnedEntity ,
29
29
)
30
30
31
- SimulationBridgeT = TypeVar ("SimulationBridgeT" , bound = SimulationBridge )
31
+ SimulationBridgeT = TypeVar (
32
+ "SimulationBridgeT" , bound = SimulationBridge [SimulationConfig ]
33
+ )
32
34
loggers_type = Union [RcutilsLogger , logging .Logger ]
33
35
34
36
@@ -70,7 +72,9 @@ def validate_config(self, simulation_config: SimulationConfig) -> bool:
70
72
pass
71
73
72
74
@abstractmethod
73
- def calculate_result (self , simulation_bridge : SimulationBridge ) -> float :
75
+ def calculate_result (
76
+ self , simulation_bridge : SimulationBridge [SimulationConfig ]
77
+ ) -> float :
74
78
"""
75
79
Calculate result of the task
76
80
"""
@@ -95,7 +99,7 @@ def is_adjacent(self, pos1: PoseModel, pos2: PoseModel, threshold_distance: floa
95
99
Check if positions are adjacent to each other, the threshold_distance is a distance
96
100
in simulation, refering to how close they have to be to classify them as adjacent
97
101
"""
98
- self .logger .debug (
102
+ self .logger .debug ( # type: ignore
99
103
f"Euclidean distance: { self .euclidean_distance (pos1 , pos2 )} , pos1: { pos1 } , pos2: { pos2 } "
100
104
)
101
105
return self .euclidean_distance (pos1 , pos2 ) < threshold_distance
@@ -164,7 +168,7 @@ def __init__(
164
168
def create_scenarios (
165
169
cls , tasks : List [Task ], simulation_configs : List [SimulationConfig ]
166
170
) -> list [Any ]:
167
- scenarios = []
171
+ scenarios : List [ Scenario ] = []
168
172
for task in tasks :
169
173
for sim_conf in simulation_configs :
170
174
try :
@@ -175,22 +179,22 @@ def create_scenarios(
175
179
)
176
180
return scenarios
177
181
178
- def run_next (self , agent ):
182
+ def run_next (self , agent ) -> None :
179
183
"""
180
184
Runs the next scenario
181
185
"""
182
186
try :
183
187
i , scenario = next (self .scenarios ) # Get the next scenario
184
188
185
189
self .simulation_bridge .setup_scene (scenario .simulation_config )
186
- self ._logger .info (
190
+ self ._logger .info ( # type: ignore
187
191
"======================================================================================"
188
192
)
189
- self ._logger .info (
193
+ self ._logger .info ( # type: ignore
190
194
f"RUNNING SCENARIO NUMBER { i + 1 } , TASK: { scenario .task .get_prompt ()} "
191
195
)
192
196
initial_result = scenario .task .calculate_result (self .simulation_bridge )
193
- self ._logger .info (f"RESULT OF THE INITIAL SETUP: { initial_result } " )
197
+ self ._logger .info (f"RESULT OF THE INITIAL SETUP: { initial_result } " ) # type: ignore
194
198
tool_calls_num = 0
195
199
196
200
ts = time .perf_counter ()
@@ -209,7 +213,7 @@ def run_next(self, agent):
209
213
last_msg = msg .content [0 ].get ("text" , "" )
210
214
else :
211
215
last_msg = msg .content
212
- self ._logger .debug (f"{ graph_node_name } : { last_msg } " )
216
+ self ._logger .debug (f"{ graph_node_name } : { last_msg } " ) # type: ignore
213
217
214
218
else :
215
219
raise ValueError (f"Unexpected type of message: { type (msg )} " )
@@ -218,13 +222,13 @@ def run_next(self, agent):
218
222
# TODO (jm) figure out more robust way of counting tool calls
219
223
tool_calls_num += len (msg .tool_calls )
220
224
221
- self ._logger .info (f"AI Message: { msg } " )
225
+ self ._logger .info (f"AI Message: { msg } " ) # type: ignore
222
226
223
227
te = time .perf_counter ()
224
228
225
229
result = scenario .task .calculate_result (self .simulation_bridge )
226
230
total_time = te - ts
227
- self ._logger .info (
231
+ self ._logger .info ( # type: ignore
228
232
f"TASK SCORE: { result } , TOTAL TIME: { total_time :.3f} , NUM_OF_TOOL_CALLS: { tool_calls_num } "
229
233
)
230
234
@@ -241,5 +245,5 @@ def run_next(self, agent):
241
245
except StopIteration :
242
246
print ("No more scenarios left to run." )
243
247
244
- def get_results (self ) -> list [ dict ]:
248
+ def get_results (self ) -> List [ Dict [ str , Any ] ]:
245
249
return self .results
0 commit comments