Skip to content

Commit 2511f77

Browse files
committed
add jerk cost
1 parent 3c65bde commit 2511f77

File tree

2 files changed

+62
-18
lines changed

2 files changed

+62
-18
lines changed

README.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@ We'll be using driving segments from the [comma-steering-control](https://github
1212
bash ./download_dataset.sh
1313
1414
# Test this works
15-
python tinyphysics.py --model_path ./models/tinyphysics.onnx --data_path ./data/00000.csv --do_sim_step --do_control_step --vis
15+
python tinyphysics.py --model_path ./models/tinyphysics.onnx --data_path ./data/00000.csv --do_sim_step --do_control_step --debug
1616
1717
1818
# Batch Metrics on lots of routes
19+
python tinyphysics.py --model_path ./models/tinyphysics.onnx --data_path ./data --num_segs 1000 --do_sim_step --do_control_step
1920
2021
```
2122

@@ -27,3 +28,14 @@ This is a "simulated car" that has been trained to mimic a very simple physics m
2728
## Controllers
2829
Your controller should implement an [update function](https://github.com/commaai/controls_challenge/blob/1a25ee200f5466cb7dc1ab0bf6b7d0c67a2481db/controllers.py#L2) that returns the `steer_action [-1, 1]`. This controller is then run in-loop, in the simulator to autoregressively predict the car's response.
2930

31+
*Note: The `steerFiltered` column in the dataset is not relevant here. That was the steer command for a particular platform. We're using the dataset here only to get realistic driving scenarios wrt road roll, desired acceleration and car states (velocity, forward acceleration).*
32+
33+
34+
## Evaluation
35+
Each rollout will result in 2 costs:
36+
- `lat_accel_cost`: $\dfrac{\Sigma(actual\_lat\_accel - target\_lat\_accel)^2}{steps}$
37+
38+
- `jerk_cost`: $\dfrac{\Sigma((actual\_lat\_accel_{t} - actual\_lat\_accel_{t-1}) / \Delta t)^2}{steps - 1}$
39+
40+
41+
Minimizing both costs are very important.

tinyphysics.py

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,21 @@
66

77
from collections import namedtuple
88
from hashlib import md5
9+
from pathlib import Path
910
from typing import List, Union, Tuple
11+
from tqdm import tqdm
1012

1113
from controllers import BaseController, SimpleController
1214

15+
1316
ACC_G = 9.81
1417
SIM_START_IDX = 100
1518
CONTEXT_LENGTH = 20
1619
VOCAB_SIZE = 1024
1720
LATACCEL_RANGE = [-4, 4]
21+
STEER_RANGE = [-1, 1]
1822
MAX_ACC_DELTA = 0.5
23+
DEL_T = 0.1
1924

2025
State = namedtuple('State', ['roll_lataccel', 'vEgo', 'aEgo'])
2126

@@ -37,17 +42,19 @@ def clip(self, value: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
3742

3843

3944
class TinyPhysicsModel:
40-
def __init__(self, model_path: str) -> None:
45+
def __init__(self, model_path: str, debug: bool) -> None:
4146
self.tokenizer = LataccelTokenizer()
4247
options = ort.SessionOptions()
4348
options.intra_op_num_threads = 1
4449
options.inter_op_num_threads = 1
4550
options.log_severity_level = 3
4651
if 'CUDAExecutionProvider' in ort.get_available_providers():
47-
print("ONNX Runtime is using GPU")
52+
if debug:
53+
print("ONNX Runtime is using GPU")
4854
provider = ('CUDAExecutionProvider', {'cudnn_conv_algo_search': 'DEFAULT'})
4955
else:
50-
print("ONNX Runtime is using CPU")
56+
if debug:
57+
print("ONNX Runtime is using CPU")
5158
provider = 'CPUExecutionProvider'
5259

5360
with open(model_path, "rb") as f:
@@ -78,13 +85,15 @@ def get_current_lataccel(self, sim_states: List[State], actions: List[float], pa
7885

7986

8087
class TinyPhysicsSimulator:
81-
def __init__(self, model_path: str, data_path: str, do_sim_step: bool, do_control_step: bool, controller: BaseController) -> None:
88+
def __init__(self, model: TinyPhysicsModel, data_path: str, do_sim_step: bool, do_control_step: bool, controller: BaseController, debug: bool = False) -> None:
8289
self.data_path = data_path
83-
self.sim_model = TinyPhysicsModel(model_path)
90+
self.sim_model = model
8491
self.data = self.get_data(data_path)
8592
self.do_sim_step = do_sim_step
8693
self.do_control_step = do_control_step
8794
self.controller = controller
95+
self.debug = debug
96+
self.times = []
8897
self.reset()
8998

9099
def reset(self) -> None:
@@ -124,6 +133,7 @@ def control_step(self, step_idx: int) -> None:
124133
action = self.controller.update(self.target_lataccel_history[step_idx], self.current_lataccel, self.state_history[step_idx])
125134
else:
126135
action = 0.
136+
action = np.clip(action, STEER_RANGE[0], STEER_RANGE[1])
127137
self.action_history.append(action)
128138

129139
def get_state_target(self, step_idx: int) -> Tuple[List, float]:
@@ -147,42 +157,64 @@ def plot_data(self, ax, lines, axis_labels, title) -> None:
147157
ax.set_xlabel(axis_labels[0])
148158
ax.set_ylabel(axis_labels[1])
149159

150-
def compute_score(self) -> float:
160+
def compute_cost(self) -> float:
151161
target = np.array(self.target_lataccel_history)[SIM_START_IDX:]
152162
pred = np.array(self.current_lataccel_history)[SIM_START_IDX:]
153-
return -np.mean((target - pred)**2)
154163

155-
def rollout(self, debug=True) -> None:
156-
if debug:
164+
lat_accel_cost = np.mean(((target - pred) / DEL_T)**2)
165+
jerk_cost = np.mean(np.diff(pred)**2)
166+
return lat_accel_cost, jerk_cost
167+
168+
def rollout(self) -> None:
169+
if self.debug:
157170
plt.ion()
158171
fig, ax = plt.subplots(4, figsize=(12, 14))
159172

160173
for _ in range(len(self.data)):
161174
self.step()
162-
if debug and self.step_idx % 10 == 0:
175+
if self.debug and self.step_idx % 10 == 0:
163176
print(f"Step {self.step_idx:<5}: Current lataccel: {self.current_lataccel:>6.2f}, Target lataccel: {self.target_lataccel_history[-1]:>6.2f}")
164177
self.plot_data(ax[0], [(self.target_lataccel_history, 'Target lataccel'), (self.current_lataccel_history, 'Current lataccel')], ['Step', 'Lateral Acceleration'], 'Lateral Acceleration')
165178
self.plot_data(ax[1], [(self.action_history, 'Action')], ['Step', 'Action'], 'Action')
166179
self.plot_data(ax[2], [(np.array(self.state_history)[:, 0], 'Roll Lateral Acceleration')], ['Step', 'Lateral Accel due to Road Roll'], 'Lateral Accel due to Road Roll')
167180
self.plot_data(ax[3], [(np.array(self.state_history)[:, 1], 'vEgo')], ['Step', 'vEgo'], 'vEgo')
168181
plt.pause(0.01)
169182

170-
if debug:
183+
if self.debug:
171184
plt.ioff()
172185
plt.show()
173-
174-
return self.compute_score()
186+
return self.compute_cost()
175187

176188

177189
if __name__ == "__main__":
178190
parser = argparse.ArgumentParser()
179191
parser.add_argument("--model_path", type=str, required=True)
180192
parser.add_argument("--data_path", type=str, required=True)
193+
parser.add_argument("--num_segs", type=int, default=1000)
181194
parser.add_argument("--do_sim_step", action='store_true')
182195
parser.add_argument("--do_control_step", action='store_true')
183-
parser.add_argument("--vis", action='store_true')
196+
parser.add_argument("--debug", action='store_true')
184197
args = parser.parse_args()
185198

186-
sim = TinyPhysicsSimulator(args.model_path, args.data_path, args.do_sim_step, args.do_control_step, controller=SimpleController())
187-
score = sim.rollout(args.vis)
188-
print(f"Final score: {score:>6.4}")
199+
tinyphysicsmodel = TinyPhysicsModel(args.model_path, debug=args.debug)
200+
201+
data_path = Path(args.data_path)
202+
if data_path.is_file():
203+
sim = TinyPhysicsSimulator(tinyphysicsmodel, args.data_path, args.do_sim_step, args.do_control_step, controller=SimpleController(), debug=args.debug)
204+
lat_accel_cost, jerk_cost = sim.rollout()
205+
print(f"\nAverage lat_accel_cost: {lat_accel_cost:>6.4}, average jerk_cost: {jerk_cost:>6.4}")
206+
elif data_path.is_dir():
207+
costs = []
208+
files = sorted(data_path.iterdir())[:args.num_segs]
209+
for data_file in tqdm(files, total=len(files)):
210+
sim = TinyPhysicsSimulator(tinyphysicsmodel, str(data_file), args.do_sim_step, args.do_control_step, controller=SimpleController(), debug=args.debug)
211+
cost = sim.rollout()
212+
costs.append(cost)
213+
costs = np.array(costs)
214+
print(f"\nAverage lat_accel_cost: {np.mean(costs[:, 0]):>6.4}, average jerk_cost: {np.mean(costs[:, 1]):>6.4}")
215+
plt.hist(costs[:, 0], bins=np.arange(0, 2, 0.1), label='lat_accel_cost', alpha=0.5)
216+
plt.hist(costs[:, 1], bins=np.arange(0, 2, 0.1), label='jerk_cost', alpha=0.5)
217+
plt.xlabel('costs')
218+
plt.ylabel('Frequency')
219+
plt.title('costs Distribution')
220+
plt.show()

0 commit comments

Comments
 (0)