|
10 | 10 | pl.seed_everything(9999)
|
11 | 11 | random.seed(9999)
|
12 | 12 |
|
13 |
| -#Load DEM data |
14 |
| -data = np.load("data/data.npz") |
15 |
| -dem_data = dem_scale(data["dem"]) |
| 13 | +#Define main function |
| 14 | +def dem_training(): |
16 | 15 |
|
17 |
| -train = dem_data[:8000] |
18 |
| -valid = dem_data[8000:] |
| 16 | + #Load DEM data |
| 17 | + data = np.load("data/data.npz") |
| 18 | + dem_data = dem_scale(data["dem"]) |
19 | 19 |
|
20 |
| -#Load mask data |
21 |
| -random_mask = np.load("data/lakes_random.npz") |
22 |
| -mask_data = random_mask["mask"] |
| 20 | + train = dem_data[:8000] |
| 21 | + valid = dem_data[8000:] |
23 | 22 |
|
24 |
| -random_mask_idx = random.choices(range(mask_data.shape[0]), k=valid.shape[0]) |
25 |
| -valid_mask = mask_data[random_mask_idx] |
| 23 | + #Load mask data |
| 24 | + random_mask = np.load("data/lakes_random.npz") |
| 25 | + mask_data = random_mask["mask"] |
26 | 26 |
|
27 |
| -#Create datasets |
28 |
| -train_dataset = DEMTrain(train, mask_data) |
29 |
| -valid_dataset = DEMValid(valid, valid_mask) |
| 27 | + random_mask_idx = random.choices(range(mask_data.shape[0]), k=valid.shape[0]) |
| 28 | + valid_mask = mask_data[random_mask_idx] |
30 | 29 |
|
31 |
| -#Create dataloaders |
32 |
| -train_loader = DataLoader(train_dataset, batch_size=32, num_workers=8, shuffle = True) |
33 |
| -val_loader = DataLoader(valid_dataset, batch_size=32, num_workers=8, shuffle = False) |
| 30 | + #Create datasets |
| 31 | + train_dataset = DEMTrain(train, mask_data) |
| 32 | + valid_dataset = DEMValid(valid, valid_mask) |
34 | 33 |
|
35 |
| -#Initiate model |
36 |
| -lake_model = AutoEncoder(init_features=8) |
| 34 | + #Create dataloaders |
| 35 | + train_loader = DataLoader(train_dataset, batch_size=32, num_workers=0, shuffle = True) |
| 36 | + val_loader = DataLoader(valid_dataset, batch_size=32, num_workers=0, shuffle = False) |
37 | 37 |
|
38 |
| -#Initiate trainer |
39 |
| -checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_last=True) |
| 38 | + #Initiate models of differing complexity |
| 39 | + dem_model_8 = AutoEncoder(init_features=8) |
| 40 | + dem_model_16 = AutoEncoder(init_features=16) |
| 41 | + dem_model_32 = AutoEncoder(init_features=32) |
| 42 | + dem_model_64 = AutoEncoder(init_features=64) |
40 | 43 |
|
41 |
| -trainer = pl.Trainer(gpus=1, max_epochs=1000, callbacks=checkpoint_callback, precision=16) |
| 44 | + #Initiate callbacks |
| 45 | + checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_last=True) |
| 46 | + |
| 47 | + #Initiate trainer |
| 48 | + trainer = pl.Trainer(gpus=1, max_epochs=1000, callbacks=checkpoint_callback) |
| 49 | + |
| 50 | + #Train model |
| 51 | + trainer.fit(dem_model_8, train_loader, val_loader) |
| 52 | + trainer.fit(dem_model_16, train_loader, val_loader) |
| 53 | + trainer.fit(dem_model_32, train_loader, val_loader) |
| 54 | + trainer.fit(dem_model_64, train_loader, val_loader) |
| 55 | + |
| 56 | +if __name__ == "__main__": |
| 57 | + dem_training() |
42 | 58 |
|
43 |
| -#Train model |
44 |
| -trainer.fit(lake_model, train_loader, val_loader) |
|
0 commit comments