Skip to content

Commit 9727138

Browse files
committed
prepared training of dem model from cli
1 parent 0b6a586 commit 9727138

File tree

3 files changed

+64
-90
lines changed

3 files changed

+64
-90
lines changed

dem_autoencoder.ipynb

+25-60
Large diffs are not rendered by default.

dem_model_train.py

+37-23
Original file line numberDiff line numberDiff line change
@@ -10,35 +10,49 @@
1010
pl.seed_everything(9999)
1111
random.seed(9999)
1212

13-
#Load DEM data
14-
data = np.load("data/data.npz")
15-
dem_data = dem_scale(data["dem"])
13+
#Define main function
14+
def dem_training():
1615

17-
train = dem_data[:8000]
18-
valid = dem_data[8000:]
16+
#Load DEM data
17+
data = np.load("data/data.npz")
18+
dem_data = dem_scale(data["dem"])
1919

20-
#Load mask data
21-
random_mask = np.load("data/lakes_random.npz")
22-
mask_data = random_mask["mask"]
20+
train = dem_data[:8000]
21+
valid = dem_data[8000:]
2322

24-
random_mask_idx = random.choices(range(mask_data.shape[0]), k=valid.shape[0])
25-
valid_mask = mask_data[random_mask_idx]
23+
#Load mask data
24+
random_mask = np.load("data/lakes_random.npz")
25+
mask_data = random_mask["mask"]
2626

27-
#Create datasets
28-
train_dataset = DEMTrain(train, mask_data)
29-
valid_dataset = DEMValid(valid, valid_mask)
27+
random_mask_idx = random.choices(range(mask_data.shape[0]), k=valid.shape[0])
28+
valid_mask = mask_data[random_mask_idx]
3029

31-
#Create dataloaders
32-
train_loader = DataLoader(train_dataset, batch_size=32, num_workers=8, shuffle = True)
33-
val_loader = DataLoader(valid_dataset, batch_size=32, num_workers=8, shuffle = False)
30+
#Create datasets
31+
train_dataset = DEMTrain(train, mask_data)
32+
valid_dataset = DEMValid(valid, valid_mask)
3433

35-
#Initiate model
36-
lake_model = AutoEncoder(init_features=8)
34+
#Create dataloaders
35+
train_loader = DataLoader(train_dataset, batch_size=32, num_workers=0, shuffle = True)
36+
val_loader = DataLoader(valid_dataset, batch_size=32, num_workers=0, shuffle = False)
3737

38-
#Initiate trainer
39-
checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_last=True)
38+
#Initiate models of differing complexity
39+
dem_model_8 = AutoEncoder(init_features=8)
40+
dem_model_16 = AutoEncoder(init_features=16)
41+
dem_model_32 = AutoEncoder(init_features=32)
42+
dem_model_64 = AutoEncoder(init_features=64)
4043

41-
trainer = pl.Trainer(gpus=1, max_epochs=1000, callbacks=checkpoint_callback, precision=16)
44+
#Initiate callbacks
45+
checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_last=True)
46+
47+
#Initiate trainer
48+
trainer = pl.Trainer(gpus=1, max_epochs=1000, callbacks=checkpoint_callback)
49+
50+
#Train model
51+
trainer.fit(dem_model_8, train_loader, val_loader)
52+
trainer.fit(dem_model_16, train_loader, val_loader)
53+
trainer.fit(dem_model_32, train_loader, val_loader)
54+
trainer.fit(dem_model_64, train_loader, val_loader)
55+
56+
if __name__ == "__main__":
57+
dem_training()
4258

43-
#Train model
44-
trainer.fit(lake_model, train_loader, val_loader)

model.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#https://github.com/mateuszbuda/brain-segmentation-pytorch
33

44
from collections import OrderedDict
5-
65
import torch
76
import torch.nn as nn
87

@@ -72,23 +71,19 @@ def _block(in_channels, features, name):
7271
in_channels=in_channels,
7372
out_channels=features,
7473
kernel_size=3,
75-
padding=1,
76-
#bias=False,
74+
padding=1
7775
),
7876
),
79-
#(name + "norm1", nn.BatchNorm2d(num_features=features)),
8077
(name + "act1", nn.LeakyReLU()),
8178
(
8279
name + "conv2",
8380
nn.Conv2d(
8481
in_channels=features,
8582
out_channels=features,
8683
kernel_size=3,
87-
padding=1,
88-
#bias=False,
84+
padding=1
8985
),
9086
),
91-
#(name + "norm2", nn.BatchNorm2d(num_features=features)),
9287
(name + "act2", nn.LeakyReLU()),
9388
]
9489
)

0 commit comments

Comments
 (0)