6
6
7
7
from collections import namedtuple
8
8
from hashlib import md5
9
+ from pathlib import Path
9
10
from typing import List , Union , Tuple
11
+ from tqdm import tqdm
10
12
11
13
from controllers import BaseController , SimpleController
12
14
15
+
13
16
ACC_G = 9.81
14
17
SIM_START_IDX = 100
15
18
CONTEXT_LENGTH = 20
16
19
VOCAB_SIZE = 1024
17
20
LATACCEL_RANGE = [- 4 , 4 ]
21
+ STEER_RANGE = [- 1 , 1 ]
18
22
MAX_ACC_DELTA = 0.5
23
+ DEL_T = 0.1
19
24
20
25
State = namedtuple ('State' , ['roll_lataccel' , 'vEgo' , 'aEgo' ])
21
26
@@ -37,17 +42,19 @@ def clip(self, value: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
37
42
38
43
39
44
class TinyPhysicsModel :
40
- def __init__ (self , model_path : str ) -> None :
45
+ def __init__ (self , model_path : str , debug : bool ) -> None :
41
46
self .tokenizer = LataccelTokenizer ()
42
47
options = ort .SessionOptions ()
43
48
options .intra_op_num_threads = 1
44
49
options .inter_op_num_threads = 1
45
50
options .log_severity_level = 3
46
51
if 'CUDAExecutionProvider' in ort .get_available_providers ():
47
- print ("ONNX Runtime is using GPU" )
52
+ if debug :
53
+ print ("ONNX Runtime is using GPU" )
48
54
provider = ('CUDAExecutionProvider' , {'cudnn_conv_algo_search' : 'DEFAULT' })
49
55
else :
50
- print ("ONNX Runtime is using CPU" )
56
+ if debug :
57
+ print ("ONNX Runtime is using CPU" )
51
58
provider = 'CPUExecutionProvider'
52
59
53
60
with open (model_path , "rb" ) as f :
@@ -78,13 +85,15 @@ def get_current_lataccel(self, sim_states: List[State], actions: List[float], pa
78
85
79
86
80
87
class TinyPhysicsSimulator :
81
- def __init__ (self , model_path : str , data_path : str , do_sim_step : bool , do_control_step : bool , controller : BaseController ) -> None :
88
+ def __init__ (self , model : TinyPhysicsModel , data_path : str , do_sim_step : bool , do_control_step : bool , controller : BaseController , debug : bool = False ) -> None :
82
89
self .data_path = data_path
83
- self .sim_model = TinyPhysicsModel ( model_path )
90
+ self .sim_model = model
84
91
self .data = self .get_data (data_path )
85
92
self .do_sim_step = do_sim_step
86
93
self .do_control_step = do_control_step
87
94
self .controller = controller
95
+ self .debug = debug
96
+ self .times = []
88
97
self .reset ()
89
98
90
99
def reset (self ) -> None :
@@ -124,6 +133,7 @@ def control_step(self, step_idx: int) -> None:
124
133
action = self .controller .update (self .target_lataccel_history [step_idx ], self .current_lataccel , self .state_history [step_idx ])
125
134
else :
126
135
action = 0.
136
+ action = np .clip (action , STEER_RANGE [0 ], STEER_RANGE [1 ])
127
137
self .action_history .append (action )
128
138
129
139
def get_state_target (self , step_idx : int ) -> Tuple [List , float ]:
@@ -147,42 +157,64 @@ def plot_data(self, ax, lines, axis_labels, title) -> None:
147
157
ax .set_xlabel (axis_labels [0 ])
148
158
ax .set_ylabel (axis_labels [1 ])
149
159
150
- def compute_score (self ) -> float :
160
+ def compute_cost (self ) -> float :
151
161
target = np .array (self .target_lataccel_history )[SIM_START_IDX :]
152
162
pred = np .array (self .current_lataccel_history )[SIM_START_IDX :]
153
- return - np .mean ((target - pred )** 2 )
154
163
155
- def rollout (self , debug = True ) -> None :
156
- if debug :
164
+ lat_accel_cost = np .mean (((target - pred ) / DEL_T )** 2 )
165
+ jerk_cost = np .mean (np .diff (pred )** 2 )
166
+ return lat_accel_cost , jerk_cost
167
+
168
+ def rollout (self ) -> None :
169
+ if self .debug :
157
170
plt .ion ()
158
171
fig , ax = plt .subplots (4 , figsize = (12 , 14 ))
159
172
160
173
for _ in range (len (self .data )):
161
174
self .step ()
162
- if debug and self .step_idx % 10 == 0 :
175
+ if self . debug and self .step_idx % 10 == 0 :
163
176
print (f"Step { self .step_idx :<5} : Current lataccel: { self .current_lataccel :>6.2f} , Target lataccel: { self .target_lataccel_history [- 1 ]:>6.2f} " )
164
177
self .plot_data (ax [0 ], [(self .target_lataccel_history , 'Target lataccel' ), (self .current_lataccel_history , 'Current lataccel' )], ['Step' , 'Lateral Acceleration' ], 'Lateral Acceleration' )
165
178
self .plot_data (ax [1 ], [(self .action_history , 'Action' )], ['Step' , 'Action' ], 'Action' )
166
179
self .plot_data (ax [2 ], [(np .array (self .state_history )[:, 0 ], 'Roll Lateral Acceleration' )], ['Step' , 'Lateral Accel due to Road Roll' ], 'Lateral Accel due to Road Roll' )
167
180
self .plot_data (ax [3 ], [(np .array (self .state_history )[:, 1 ], 'vEgo' )], ['Step' , 'vEgo' ], 'vEgo' )
168
181
plt .pause (0.01 )
169
182
170
- if debug :
183
+ if self . debug :
171
184
plt .ioff ()
172
185
plt .show ()
173
-
174
- return self .compute_score ()
186
+ return self .compute_cost ()
175
187
176
188
177
189
if __name__ == "__main__" :
178
190
parser = argparse .ArgumentParser ()
179
191
parser .add_argument ("--model_path" , type = str , required = True )
180
192
parser .add_argument ("--data_path" , type = str , required = True )
193
+ parser .add_argument ("--num_segs" , type = int , default = 1000 )
181
194
parser .add_argument ("--do_sim_step" , action = 'store_true' )
182
195
parser .add_argument ("--do_control_step" , action = 'store_true' )
183
- parser .add_argument ("--vis " , action = 'store_true' )
196
+ parser .add_argument ("--debug " , action = 'store_true' )
184
197
args = parser .parse_args ()
185
198
186
- sim = TinyPhysicsSimulator (args .model_path , args .data_path , args .do_sim_step , args .do_control_step , controller = SimpleController ())
187
- score = sim .rollout (args .vis )
188
- print (f"Final score: { score :>6.4} " )
199
+ tinyphysicsmodel = TinyPhysicsModel (args .model_path , debug = args .debug )
200
+
201
+ data_path = Path (args .data_path )
202
+ if data_path .is_file ():
203
+ sim = TinyPhysicsSimulator (tinyphysicsmodel , args .data_path , args .do_sim_step , args .do_control_step , controller = SimpleController (), debug = args .debug )
204
+ lat_accel_cost , jerk_cost = sim .rollout ()
205
+ print (f"\n Average lat_accel_cost: { lat_accel_cost :>6.4} , average jerk_cost: { jerk_cost :>6.4} " )
206
+ elif data_path .is_dir ():
207
+ costs = []
208
+ files = sorted (data_path .iterdir ())[:args .num_segs ]
209
+ for data_file in tqdm (files , total = len (files )):
210
+ sim = TinyPhysicsSimulator (tinyphysicsmodel , str (data_file ), args .do_sim_step , args .do_control_step , controller = SimpleController (), debug = args .debug )
211
+ cost = sim .rollout ()
212
+ costs .append (cost )
213
+ costs = np .array (costs )
214
+ print (f"\n Average lat_accel_cost: { np .mean (costs [:, 0 ]):>6.4} , average jerk_cost: { np .mean (costs [:, 1 ]):>6.4} " )
215
+ plt .hist (costs [:, 0 ], bins = np .arange (0 , 2 , 0.1 ), label = 'lat_accel_cost' , alpha = 0.5 )
216
+ plt .hist (costs [:, 1 ], bins = np .arange (0 , 2 , 0.1 ), label = 'jerk_cost' , alpha = 0.5 )
217
+ plt .xlabel ('costs' )
218
+ plt .ylabel ('Frequency' )
219
+ plt .title ('costs Distribution' )
220
+ plt .show ()
0 commit comments