-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_pointnet.py
99 lines (79 loc) · 3.14 KB
/
train_pointnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from abc import ABC
from pathlib import Path
from numcodecs import blosc
import pandas as pd, numpy as np
import os
import bisect
import itertools as it
from tqdm import tqdm
import logzero
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.autograd import Variable
from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.loggers.neptune import NeptuneLogger
import pickle, copy, re, time, datetime, random, warnings, gc
import zarr
from poinet_model import *
with open('parameters.json') as json_file:
JSON_PARAMETERS = json.load(json_file)
DATA_ROOT = Path("/data/lyft-motion-prediction-autonomous-vehicles")
TRAIN_ZARR = JSON_PARAMETERS["TRAIN_ZARR"]
VALID_ZARR = JSON_PARAMETERS["VALID_ZARR"]
HBACKWARD = JSON_PARAMETERS["HBACKWARD"]
HFORWARD = JSON_PARAMETERS["HFORWARD"]
NFRAMES = JSON_PARAMETERS["NFRAMES"]
FRAME_STRIDE = JSON_PARAMETERS["FRAME_STRIDE"]
AGENT_FEATURE_DIM = JSON_PARAMETERS["AGENT_FEATURE_DIM"]
MAX_AGENTS = JSON_PARAMETERS["MAX_AGENTS"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_WORKERS = JSON_PARAMETERS["NUM_WORKERS"]
BATCH_SIZE = JSON_PARAMETERS["BATCH_SIZE"]
EPOCHS = JSON_PARAMETERS["EPOCHS"]
LEARNING_RATE = JSON_PARAMETERS["LEARNING_RATE"]
WEIGHT_DECAY = JSON_PARAMETERS["WEIGHT_DECAY"]
GRADIENT_CLIP_VAL = JSON_PARAMETERS["GRADIENT_CLIP_VAL"]
LIMIT_VAL_BATCHES = JSON_PARAMETERS["LIMIT_VAL_BATCHES"]
torch.backends.cudnn.benchmark = True
# last_checkpoint = get_last_checkpoint(ROOT)
last_checkpoint = None
if last_checkpoint is not None:
print(f'\n***** RESUMING FROM CHECKPOINT `{last_checkpoint.as_posix()}`***********\n')
model = LyftNet.load_from_checkpoint(Path(last_checkpoint).as_posix(),
map_location=device, num_workers = NUM_WORKERS, batch_size = BATCH_SIZE)
else:
print('\n***** NEW MODEL ***********\n')
model = LyftNet(batch_size=BATCH_SIZE,
lr= LEARNING_RATE, weight_decay=WEIGHT_DECAY, num_workers=NUM_WORKERS)
checkpoint_callback = ModelCheckpoint(
filepath=ROOT,
save_top_k=5,
verbose=0,
monitor='val_loss',
mode='min',
prefix='lyfnet_',
)
API_KEY = os.environ.get('NEPTUNE_API_KEY')
neptune_logger = NeptuneLogger(
api_key=API_KEY,
project_name='hvergnes/KagglePointNet',
params={'epoch_nr': f'{EPOCHS}', 'bs': f'{BATCH_SIZE}', 'LEARNING_RATE': f'{LEARNING_RATE}', 'WEIGHT_DECAY': f'{WEIGHT_DECAY}', 'HBACKWARD': f'{HBACKWARD}',
'HFORWARD': f'{HFORWARD}', 'NFRAMES': f'{NFRAMES}', "FRAME_STRIDE": f"{FRAME_STRIDE}", "AGENT_FEATURE_DIM": f"{AGENT_FEATURE_DIM}",
"MAX_AGENTS": f"{MAX_AGENTS}"},
tags=['baseline'],
)
# print(model)
trainer = Trainer(
max_epochs=EPOCHS,
gradient_clip_val=GRADIENT_CLIP_VAL,
logger=neptune_logger,
checkpoint_callback=checkpoint_callback,
limit_val_batches=LIMIT_VAL_BATCHES,
gpus=1
)
trainer.fit(model)
torch.save(model.state_dict(), f'save/PointNetE:{EPOCHS}LR:{LEARNING_RATE}.pt')