Skip to content

Commit 91ceb95

Browse files
committed
swap to new dataset
1 parent 5b31674 commit 91ceb95

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

Diff for: download_dataset.sh

100644100755
+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ fi
66

77
cd data
88

9-
URL="https://huggingface.co/datasets/commaai/commaSteeringControl/resolve/main/data/TOYOTA_COROLLA_TSS2_2019.zip"
9+
URL="https://huggingface.co/datasets/commaai/commaSteeringControl/resolve/main/data/SYNTHETIC_V0.zip"
1010

1111
echo "Downloading dataset from $URL"
1212
wget "$URL"

Diff for: tinyphysics.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
MAX_ACC_DELTA = 0.5
2727
DEL_T = 0.1
2828

29-
State = namedtuple('State', ['roll_lataccel', 'vEgo', 'aEgo'])
29+
State = namedtuple('State', ['roll_lataccel', 'v_ego', 'a_ego'])
3030

3131

3232
class LataccelTokenizer:
@@ -114,9 +114,10 @@ def get_data(self, data_path: str) -> pd.DataFrame:
114114
df = pd.read_csv(data_path)
115115
processed_df = pd.DataFrame({
116116
'roll_lataccel': np.sin(df['roll'].values) * ACC_G,
117-
'vEgo': df['vEgo'].values,
118-
'aEgo': df['aEgo'].values,
119-
'target_lataccel': df['latAccelSteeringAngle'].values,
117+
'v_ego': df['vEgo'].values,
118+
'a_ego': df['aEgo'].values,
119+
'target_lataccel': df['targetLateralAcceleration'].values,
120+
'steer_command': df['steerCommand'].values
120121
})
121122
return processed_df
122123

@@ -142,7 +143,7 @@ def control_step(self, step_idx: int) -> None:
142143

143144
def get_state_target(self, step_idx: int) -> Tuple[List, float]:
144145
state = self.data.iloc[step_idx]
145-
return State(roll_lataccel=state['roll_lataccel'], vEgo=state['vEgo'], aEgo=state['aEgo']), state['target_lataccel']
146+
return State(roll_lataccel=state['roll_lataccel'], v_ego=state['v_ego'], a_ego=state['a_ego']), state['target_lataccel']
146147

147148
def step(self) -> None:
148149
state, target = self.get_state_target(self.step_idx)
@@ -181,7 +182,7 @@ def rollout(self) -> None:
181182
self.plot_data(ax[0], [(self.target_lataccel_history, 'Target lataccel'), (self.current_lataccel_history, 'Current lataccel')], ['Step', 'Lateral Acceleration'], 'Lateral Acceleration')
182183
self.plot_data(ax[1], [(self.action_history, 'Action')], ['Step', 'Action'], 'Action')
183184
self.plot_data(ax[2], [(np.array(self.state_history)[:, 0], 'Roll Lateral Acceleration')], ['Step', 'Lateral Accel due to Road Roll'], 'Lateral Accel due to Road Roll')
184-
self.plot_data(ax[3], [(np.array(self.state_history)[:, 1], 'vEgo')], ['Step', 'vEgo'], 'vEgo')
185+
self.plot_data(ax[3], [(np.array(self.state_history)[:, 1], 'v_ego')], ['Step', 'v_ego'], 'v_ego')
185186
plt.pause(0.01)
186187

187188
if self.debug:

0 commit comments

Comments
 (0)