-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathmain.py
113 lines (93 loc) · 3.14 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Author: Jimmy Wu
# Date: October 2024
import argparse
import time
from itertools import count
from constants import POLICY_CONTROL_PERIOD
from episode_storage import EpisodeWriter
from policies import TeleopPolicy, RemotePolicy
def should_save_episode(writer):
if len(writer) == 0:
print('Discarding empty episode')
return False
# Prompt user whether to save episode
while True:
user_input = input('Save episode (y/n)? ').strip().lower()
if user_input == 'y':
return True
if user_input == 'n':
print('Discarding episode')
return False
print('Invalid response')
def run_episode(env, policy, writer=None):
# Reset the env
print('Resetting env...')
env.reset()
print('Env has been reset')
# Wait for user to press "Start episode"
print('Press "Start episode" in the web app when ready to start new episode')
policy.reset()
print('Starting new episode')
episode_ended = False
start_time = time.time()
for step_idx in count():
# Enforce desired control freq
step_end_time = start_time + step_idx * POLICY_CONTROL_PERIOD
while time.time() < step_end_time:
time.sleep(0.0001)
# Get latest observation
obs = env.get_obs()
# Get action
action = policy.step(obs)
# No action if teleop not enabled
if action is None:
continue
# Execute valid action on robot
if isinstance(action, dict):
env.step(action)
if writer is not None and not episode_ended:
# Record executed action
writer.step(obs, action)
# Episode ended
elif not episode_ended and action == 'end_episode':
episode_ended = True
print('Episode ended')
if writer is not None and should_save_episode(writer):
# Save to disk in background thread
writer.flush_async()
print('Teleop is now active. Press "Reset env" in the web app when ready to proceed.')
# Ready for env reset
elif action == 'reset_env':
break
if writer is not None:
# Wait for writer to finish saving to disk
writer.wait_for_flush()
def main(args):
# Create env
if args.sim:
from mujoco_env import MujocoEnv
if args.teleop:
env = MujocoEnv(show_images=True)
else:
env = MujocoEnv()
else:
from real_env import RealEnv
env = RealEnv()
# Create policy
if args.teleop:
policy = TeleopPolicy()
else:
policy = RemotePolicy()
try:
while True:
writer = EpisodeWriter(args.output_dir) if args.save else None
run_episode(env, policy, writer)
finally:
env.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--sim', action='store_true')
parser.add_argument('--teleop', action='store_true')
parser.add_argument('--save', action='store_true')
parser.add_argument('--output-dir', default='data/demos')
main(parser.parse_args())