-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathmain.py
61 lines (59 loc) · 2.07 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import ray
from utils import (
parse_args,
setup,
load_config,
setup_train_vae,
seed_all
)
from common_utils import tqdm_remote_get
from tqdm import tqdm
from environment.meshUtils import (
create_collision_mesh,
create_grasp_object_urdf
)
from multiprocessing import Pool
from train import cotrain, train_vae, pretrain, pretrain_gn
from generate_data import (
create_grasp_objects,
generate_pretrain_data,
generate_pretrain_imprint_data,
generate_imprints
)
from pathlib import Path
from torch.multiprocessing import spawn
if __name__ == "__main__":
args = parse_args()
seed_all(args.seed)
ray.init()
if args.mode == 'collision_mesh':
async_create_collision_mesh = ray.remote(create_collision_mesh)
object_paths = list(Path(args.objects).rglob('*normalized.obj'))
tqdm_remote_get(task_handles=[async_create_collision_mesh.remote(
object_path) for object_path in object_paths],
desc='creating collision meshs')
exit()
elif args.mode == 'urdf':
async_create_grasp_object_urdf = ray.remote(create_grasp_object_urdf)
object_paths = list(Path(args.objects).rglob('*normalized.obj'))
tqdm_remote_get(task_handles=[async_create_grasp_object_urdf.remote(
object_path) for object_path in object_paths],
desc="creating grasp object urdf")
exit()
config = load_config(args.config)
if args.mode == 'cotrain':
cotrain(*setup(args, config))
elif args.mode == 'vae':
train_vae(*setup_train_vae(args, config))
elif args.mode == 'pretrain':
pretrain(args, config)
elif args.mode == 'pretrain_gn':
pretrain_gn(args, config)
elif args.mode == 'pretrain_dataset':
generate_pretrain_data(*setup(args, config))
elif args.mode == 'pretrain_imprint_dataset':
generate_pretrain_imprint_data(*setup(args, config))
elif args.mode == 'grasp_objects':
create_grasp_objects(*setup(args, config))
elif args.mode == 'imprint_baseline':
generate_imprints(*setup(args, config))