Skip to content

Commit 0b6a586

Browse files
committed
Started separating content in notebook into seperate files.
1 parent 358f945 commit 0b6a586

8 files changed

+345
-1800
lines changed

data_classes.py

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import torch
2+
from torch.utils.data import Dataset
3+
import random
4+
from helpers import mask_morph_trans, dem_aug, mask_aug, dem_scale
5+
6+
#Classes (train and validation) for loading DEM data for unsupervised training
7+
8+
class DEMTrain(Dataset):
9+
10+
def __init__(self, array, masks, dem_transform = dem_aug, mask_transform = mask_aug):
11+
12+
self.array = array
13+
self.masks = masks
14+
self.n_masks = masks.shape[0]
15+
16+
self.dem_transform = dem_transform
17+
18+
self.mask_transform = mask_transform
19+
20+
def __getitem__(self, idx):
21+
22+
target = self.array[idx]
23+
target_transformed = self.dem_transform(image=target)
24+
target_trans = target_transformed["image"]
25+
target_tensor = torch.from_numpy(target_trans).unsqueeze(0)
26+
27+
mask = self.masks[random.choice(range(self.n_masks))]
28+
mask_transformed = self.mask_transform(image=mask)
29+
mask_trans = mask_transformed["image"]
30+
mask_trans_morph = mask_morph_trans(mask_trans, p=0.25)
31+
mask_tensor = torch.from_numpy(mask_trans_morph).unsqueeze(0)
32+
33+
input_tensor = target_tensor*(1 - mask_tensor)
34+
35+
return input_tensor, target_tensor, mask_tensor
36+
37+
def __len__(self):
38+
return self.array.shape[0]
39+
40+
class DEMValid(Dataset):
41+
42+
def __init__(self, array, masks):
43+
44+
self.array = array
45+
self.masks = masks
46+
47+
def __getitem__(self, idx):
48+
49+
target = self.array[idx]
50+
target_tensor = torch.from_numpy(target).unsqueeze(0)
51+
52+
mask = self.masks[idx]
53+
mask_tensor = torch.from_numpy(mask).unsqueeze(0)
54+
55+
input_tensor = target_tensor*(1 - mask_tensor)
56+
57+
return input_tensor, target_tensor, mask_tensor
58+
59+
def __len__(self):
60+
return self.array.shape[0]
61+
62+
#Class for loading lake data from dicts
63+
64+
class Lakes(Dataset):
65+
66+
def __init__(self, lakes_list, transform = None):
67+
68+
self.lakes_list = lakes_list
69+
self.transform = transform
70+
71+
def __getitem__(self, idx):
72+
73+
item = self.lakes_list[idx]
74+
lake = item["lake"]
75+
mask = item["mask"]
76+
77+
lake = dem_scale(lake)
78+
79+
if self.transform is not None:
80+
arrays_trans = self.transform(image = lake, mask = mask)
81+
lake = arrays_trans["image"]
82+
mask = arrays_trans["mask"]
83+
mask = mask_morph_trans(mask)
84+
85+
target_tensor = torch.from_numpy(lake).unsqueeze(0)
86+
mask_tensor = torch.from_numpy(mask).unsqueeze(0)
87+
88+
input_tensor = target_tensor * (1-mask_tensor)
89+
90+
return input_tensor, target_tensor, mask_tensor
91+
92+
def __len__(self):
93+
return len(self.lakes_list)

dem_autoencoder.ipynb

+38-1,800
Large diffs are not rendered by default.

dem_model_train.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import numpy as np
2+
import random
3+
from torch.utils.data import DataLoader
4+
import pytorch_lightning as pl
5+
from pytorch_lightning.callbacks import ModelCheckpoint
6+
from lightning_module import AutoEncoder
7+
from helpers import dem_scale
8+
from data_classes import DEMTrain, DEMValid
9+
10+
pl.seed_everything(9999)
11+
random.seed(9999)
12+
13+
#Load DEM data
14+
data = np.load("data/data.npz")
15+
dem_data = dem_scale(data["dem"])
16+
17+
train = dem_data[:8000]
18+
valid = dem_data[8000:]
19+
20+
#Load mask data
21+
random_mask = np.load("data/lakes_random.npz")
22+
mask_data = random_mask["mask"]
23+
24+
random_mask_idx = random.choices(range(mask_data.shape[0]), k=valid.shape[0])
25+
valid_mask = mask_data[random_mask_idx]
26+
27+
#Create datasets
28+
train_dataset = DEMTrain(train, mask_data)
29+
valid_dataset = DEMValid(valid, valid_mask)
30+
31+
#Create dataloaders
32+
train_loader = DataLoader(train_dataset, batch_size=32, num_workers=8, shuffle = True)
33+
val_loader = DataLoader(valid_dataset, batch_size=32, num_workers=8, shuffle = False)
34+
35+
#Initiate model
36+
lake_model = AutoEncoder(init_features=8)
37+
38+
#Initiate trainer
39+
checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_last=True)
40+
41+
trainer = pl.Trainer(gpus=1, max_epochs=1000, callbacks=checkpoint_callback, precision=16)
42+
43+
#Train model
44+
trainer.fit(lake_model, train_loader, val_loader)

helpers.py

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import random
2+
import numpy as np
3+
import cv2
4+
import albumentations as A
5+
6+
#Helper functions
7+
8+
#Scale data range to -1 to 1
9+
def dem_scale(dem, min_val=-25, max_val=175):
10+
zero_one = (dem - min_val)/(max_val - min_val)
11+
minus_one = (zero_one*2)-1
12+
13+
return(minus_one)
14+
15+
#Scale range back to original (from -1 to 1)
16+
def dem_inv_scale(dem_scale, min_val=-25, max_val=175):
17+
zero_one = (dem_scale + 1)/2
18+
orig_scale = zero_one*(max_val - min_val) + min_val
19+
20+
return(orig_scale)
21+
22+
#Mask augmentation using erotion and dilation
23+
def mask_morph_trans(mask, p=0.5, min_iters = 1, max_iters = 10):
24+
25+
rand = random.uniform(0, 1)
26+
27+
if rand > p:
28+
return(mask)
29+
30+
kernel = np.ones((3,3),np.uint8)
31+
morph_op = random.choice([cv2.erode, cv2.dilate])
32+
iters = random.randint(min_iters, max_iters)
33+
34+
mask_copy = mask.copy()
35+
mask_morph = morph_op(mask_copy, kernel, iterations=iters)
36+
37+
if mask_morph.sum()/mask.sum() < 0.1:
38+
return(mask)
39+
40+
return(mask_morph)
41+
42+
#Augmentations
43+
dem_aug = A.Compose([
44+
A.RandomRotate90(p=0.25),
45+
A.Flip(p=0.25),
46+
A.RandomResizedCrop(p=0.25, height=256, width=256, scale=(0.5, 1), interpolation=cv2.INTER_LINEAR)
47+
])
48+
49+
mask_aug = A.Compose([
50+
A.ShiftScaleRotate(p=0.25, scale_limit=0.2, shift_limit=0.2,
51+
interpolation=cv2.INTER_NEAREST, border_mode=cv2.BORDER_CONSTANT),
52+
A.RandomRotate90(p=0.25),
53+
A.Flip(p=0.25)
54+
])
55+
56+
lake_aug = A.Compose([
57+
A.ShiftScaleRotate(p=0.25, border_mode=cv2.BORDER_CONSTANT, interpolation=cv2.INTER_LINEAR),
58+
A.RandomRotate90(p=0.25),
59+
A.Flip(p=0.25),
60+
A.GaussNoise(p=0.25, var_limit=(0, 1e-4)),
61+
A.GaussianBlur(p=0.25)
62+
])

lakes_random_shapes_list.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import cv2
2+
import os
3+
import pickle
4+
import numpy as np
5+
6+
#Load random lake masks created in "lakes_random_shapes.R" script and save as numpy array
7+
8+
#Load lake masks and convert to array
9+
mask_random_dir = "data/lakes_random/"
10+
mask_random_paths = os.listdir(mask_random_dir)
11+
12+
mask_random_list = []
13+
14+
for i in mask_random_paths:
15+
lake_path = mask_random_dir+i
16+
lake_array = cv2.imread(lake_path, cv2.IMREAD_UNCHANGED)
17+
lake_array_float = lake_array.astype("float32") / 255.0
18+
mask_random_list.append(lake_array_float)
19+
20+
mask_random_np = np.array(mask_random_list)
21+
22+
#Save file
23+
file_name = "data/lakes_random"
24+
25+
np.savez(file_name, mask = mask_random_np)

lightning_module.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import pytorch_lightning as pl
2+
from loss import MAELossHole
3+
import torch
4+
from model import UNet
5+
6+
class AutoEncoder(pl.LightningModule):
7+
def __init__(self, init_features = 8, lr = 1e-4):
8+
super().__init__()
9+
10+
self.save_hyperparameters()
11+
12+
self.lr = lr
13+
14+
self.unet = UNet(in_channels=1, out_channels=1, mask_channels=1, init_features=init_features)
15+
16+
self.loss = MAELossHole()
17+
18+
def forward(self, x_in, x_mask):
19+
x_hat = self.unet(x_in, x_mask)
20+
return x_hat
21+
22+
def configure_optimizers(self):
23+
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
24+
return optimizer
25+
26+
def training_step(self, train_batch, batch_idx):
27+
x_in, x_obs, x_mask = train_batch
28+
x_hat = self.unet(x_in, x_mask)
29+
loss = self.loss(x_hat, x_obs, x_mask)
30+
self.log('train_loss', loss, on_epoch=True, prog_bar=True)
31+
return {'loss': loss}
32+
33+
def validation_step(self, val_batch, batch_idx):
34+
x_in, x_obs, x_mask = val_batch
35+
x_hat = self.unet(x_in, x_mask)
36+
loss = self.loss(x_hat, x_obs, x_mask)
37+
self.log('val_loss', loss, on_epoch=True, prog_bar=True)

loss.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import torch.nn as nn
2+
3+
#Loss functions
4+
5+
class MSELossWeighted(nn.Module):
6+
def __init__(self, w_hole = 10, w_nonhole = 1):
7+
super().__init__()
8+
self.l2 = nn.MSELoss()
9+
self.w_hole = w_hole
10+
self.w_nonhole = w_nonhole
11+
12+
def forward(self, hat, obs, mask):
13+
14+
l2_hole = self.l2(hat[mask == 1], obs[mask == 1])
15+
l2_nonhole = self.l2(hat[mask == 0], obs[mask == 0])
16+
17+
l2_total = (l2_hole*self.w_hole) + (l2_nonhole*self.w_nonhole)
18+
19+
return l2_total
20+
21+
class MAELossWeighted(nn.Module):
22+
def __init__(self, w_hole = 10, w_nonhole = 1):
23+
super().__init__()
24+
self.l1 = nn.L1Loss()
25+
self.w_hole = w_hole
26+
self.w_nonhole = w_nonhole
27+
28+
def forward(self, hat, obs, mask):
29+
30+
l1_hole = self.l1(hat[mask == 1], obs[mask == 1])
31+
l1_nonhole = self.l1(hat[mask == 0], obs[mask == 0])
32+
33+
l1_total = (l1_hole*self.w_hole) + (l1_nonhole*self.w_nonhole)
34+
35+
return l1_total
36+
37+
class MAELossHole(nn.Module):
38+
def __init__(self):
39+
super().__init__()
40+
self.l1 = nn.L1Loss()
41+
42+
def forward(self, hat, obs, mask):
43+
44+
l1_hole = self.l1(hat[mask == 1], obs[mask == 1])
45+
46+
return l1_hole
File renamed without changes.

0 commit comments

Comments
 (0)