Skip to content

Commit 01ad964

Browse files
author
Zach Teed
committed
added small 1M paramter model
1 parent c86b3dc commit 01ad964

5 files changed

+18
-27
lines changed

README.md

+8-9
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,13 @@ Pretrained models can be downloaded by running
2424
```Shell
2525
./download_models.sh
2626
```
27-
or downloaded from [google drive](https://drive.google.com/file/d/10-BYgHqRNPGvmNUWr8razjb1xHu55pyA/view?usp=sharing)
27+
or downloaded from [google drive](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT?usp=sharing)
2828

2929
You can demo a trained model on a sequence of frames
3030
```Shell
3131
python demo.py --model=models/raft-things.pth --path=demo-frames
3232
```
3333

34-
## (Optional) Efficent Implementation
35-
You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension
36-
```Shell
37-
cd alt_cuda_corr && python setup.py install && cd ..
38-
```
39-
and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag.Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass.
40-
41-
4234
## Required Data
4335
To evaluate/train RAFT, you will need to download the required datasets.
4436
* [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs)
@@ -83,3 +75,10 @@ If you have a RTX GPU, training can be accelerated using mixed precision. You ca
8375
```Shell
8476
./train_mixed.sh
8577
```
78+
79+
## (Optional) Efficent Implementation
80+
You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension
81+
```Shell
82+
cd alt_cuda_corr && python setup.py install && cd ..
83+
```
84+
and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass.

download_models.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
#!/bin/bash
2-
wget https://www.dropbox.com/s/npt24nvhoojdr0n/models.zip
2+
wget https://www.dropbox.com/s/4j4z58wuv8o0mfz/models.zip
33
unzip models.zip

train.py

+5-13
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def update(self):
4444
VAL_FREQ = 5000
4545

4646

47-
def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW):
47+
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW):
4848
""" Loss function defined over sequence of flow predictions """
4949

5050
n_predictions = len(flow_preds)
@@ -55,7 +55,7 @@ def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW):
5555
valid = (valid >= 0.5) & (mag < max_flow)
5656

5757
for i in range(n_predictions):
58-
i_weight = 0.8**(n_predictions - i - 1)
58+
i_weight = gamma**(n_predictions - i - 1)
5959
i_loss = (flow_preds[i] - flow_gt).abs()
6060
flow_loss += i_weight * (valid[:, None] * i_loss).mean()
6161

@@ -71,16 +71,11 @@ def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW):
7171

7272
return flow_loss, metrics
7373

74-
def show_image(img):
75-
img = img.permute(1,2,0).cpu().numpy()
76-
plt.imshow(img/255.0)
77-
plt.show()
78-
# cv2.imshow('image', img/255.0)
79-
# cv2.waitKey()
8074

8175
def count_parameters(model):
8276
return sum(p.numel() for p in model.parameters() if p.requires_grad)
8377

78+
8479
def fetch_optimizer(args, model):
8580
""" Create the optimizer and learning rate scheduler """
8681
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
@@ -169,17 +164,14 @@ def train(args):
169164
optimizer.zero_grad()
170165
image1, image2, flow, valid = [x.cuda() for x in data_blob]
171166

172-
# show_image(image1[0])
173-
# show_image(image2[0])
174-
175167
if args.add_noise:
176168
stdv = np.random.uniform(0.0, 5.0)
177169
image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
178170
image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0)
179171

180172
flow_predictions = model(image1, image2, iters=args.iters)
181173

182-
loss, metrics = sequence_loss(flow_predictions, flow, valid)
174+
loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma)
183175
scaler.scale(loss).backward()
184176
scaler.unscale_(optimizer)
185177
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
@@ -188,7 +180,6 @@ def train(args):
188180
scheduler.step()
189181
scaler.update()
190182

191-
192183
logger.push(metrics)
193184

194185
if total_steps % VAL_FREQ == VAL_FREQ - 1:
@@ -243,6 +234,7 @@ def train(args):
243234
parser.add_argument('--epsilon', type=float, default=1e-8)
244235
parser.add_argument('--clip', type=float, default=1.0)
245236
parser.add_argument('--dropout', type=float, default=0.0)
237+
parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting')
246238
parser.add_argument('--add_noise', action='store_true')
247239
args = parser.parse_args()
248240

train_mixed.sh

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
mkdir -p checkpoints
33
python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision
44
python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision
5-
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --mixed_precision
6-
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --mixed_precision
5+
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 --mixed_precision
6+
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 --mixed_precision

train_standard.sh

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
mkdir -p checkpoints
33
python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 12 --lr 0.0004 --image_size 368 496 --wdecay 0.0001
44
python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001
5-
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001
6-
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001
5+
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 --gamma=0.85
6+
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85

0 commit comments

Comments
 (0)