Skip to content

Commit e8649f3

Browse files
committed
add requirements, fix metric
1 parent 2511f77 commit e8649f3

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

Diff for: requirements.txt

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
numpy==1.25.2
2+
onnxruntime-gpu==1.16.3
3+
pandas==2.1.2
4+
matplotlib==3.8.1
5+
seaborn==0.13.2
6+
tqdm

Diff for: tinyphysics.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import onnxruntime as ort
44
import pandas as pd
55
import matplotlib.pyplot as plt
6+
import seaborn as sns
67

78
from collections import namedtuple
89
from hashlib import md5
@@ -12,6 +13,7 @@
1213

1314
from controllers import BaseController, SimpleController
1415

16+
sns.set_theme()
1517

1618
ACC_G = 9.81
1719
SIM_START_IDX = 100
@@ -161,14 +163,14 @@ def compute_cost(self) -> float:
161163
target = np.array(self.target_lataccel_history)[SIM_START_IDX:]
162164
pred = np.array(self.current_lataccel_history)[SIM_START_IDX:]
163165

164-
lat_accel_cost = np.mean(((target - pred) / DEL_T)**2)
165-
jerk_cost = np.mean(np.diff(pred)**2)
166+
lat_accel_cost = np.mean((target - pred)**2)
167+
jerk_cost = np.mean((np.diff(pred) / DEL_T)**2)
166168
return lat_accel_cost, jerk_cost
167169

168170
def rollout(self) -> None:
169171
if self.debug:
170172
plt.ion()
171-
fig, ax = plt.subplots(4, figsize=(12, 14))
173+
fig, ax = plt.subplots(4, figsize=(12, 14), constrained_layout=True)
172174

173175
for _ in range(len(self.data)):
174176
self.step()

0 commit comments

Comments
 (0)