Skip to content

Commit 9511318

Browse files
committed
Initial upload
1 parent d2b5bf5 commit 9511318

20 files changed

+1725
-1
lines changed

README.md

+51-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,52 @@
11
# Estimating Canopy Height at Scale [ICML2024]
2-
# [UNDER CONSTRUCTION, CODE COMING SOON]
2+
3+
[Jan Pauls](https://www.wi.uni-muenster.de/de/institut/dasc/personen/jan-pauls), [Max Zimmer](https://maxzimmer.org), [Una M. Kelly](https://www.wi.uni-muenster.de/de/institut/dasc/personen/una-kelly), [Martin Schwartz](https://www.researchgate.net/profile/Martin-Schwartz-6), [Sassan Saatchi](https://science.jpl.nasa.gov/people/Saatchi/), [Philippe Ciais](https://www.lsce.ipsl.fr/Phocea/Pisp/index.php?nom=philippe.ciais), [Sebastian Pokutta](https://pokutta.com), [Martin Brandt](https://www.researchgate.net/profile/Martin-Brandt-2), [Fabian Gieseke](https://www.wi.uni-muenster.de/department/dasc/people/fabian-gieseke)
4+
5+
6+
[[`Paper`](http://arxiv.org/abs/2406.01076)] [`Google Earth Engine viewer`](https://worldwidemap.projects.earthengine.app/view/canopy-height-2020)] [[`BibTeX`](#citing-the-paper)]
7+
8+
![Global canopy height map](figures/global_canopy_height.png)
9+
10+
We propose a framework for **global-scale canopy height estimation** based on satellite data. Our model leverages advanced data preprocessing techniques, resorts to a novel loss function designed to counter geolocation inaccuracies inherent in the ground-truth height measurements, and employs data from the Shuttle Radar Topography Mission to effectively filter out erroneous labels in mountainous regions, enhancing the reliability of our predictions in those areas. A comparison between predictions and ground-truth labels yields an MAE / RMSE of 2.43 / 4.73 (meters) overall and 4.45 / 6.72 (meters) for trees taller than five meters, which depicts a substantial improvement compared to existing global-scale maps. The resulting height map as well as the underlying framework will facilitate and enhance ecological analyses at a global scale, including, but not limited to, large-scale forest and biomass monitoring.
11+
12+
![Global canopy height map](figures/pipeline.png)
13+
14+
A comparison between our map and two other existing global height maps (Lang et al., Potapov et al.), as well as a regional map for France reveals that the visual quality improved a lot. It closely matches the one from regional maps, albeit some regions with remaining quality differences (e.g. column 8)
15+
16+
![Global and regional comparison](figures/global_and_regional_comparison.png)
17+
18+
## Interactive Google Earth Engine viewer
19+
We uploaded our produced canopy height map to Google Earth Engine and created a [GEE app](https://worldwidemap.projects.earthengine.app/view/canopy-height-2020) that allows users to visualize our map globally and compare it to other existing products. If you want to build your own app or download/use our map in another way, you can access the map under the following asset_id:
20+
21+
```
22+
var canopy_height_2020 = ee.ImageCollection('projects/worldwidemap/assets/canopyheight2020')
23+
24+
# To display on the map, create the mosaic:
25+
var canopy_height_2020 = ee.ImageCollection('projects/worldwidemap/assets/canopyheight2020').mosaic()
26+
```
27+
28+
## Acknowledgements
29+
30+
This paper is part of the project *AI4Forest*, which is funded by the
31+
German Aerospace Agency
32+
([DLR](https://github.com/AI4Forest/Global-Canopy-Height-Map)), the
33+
german federal ministry for education and research
34+
([BMBF](https://www.bmbf.de/bmbf/en/home/home_node.html)) and the french
35+
national research agency ([anr](https://anr.fr/en/)). Further,
36+
calculations (or parts of them) for this publication were performed on
37+
the HPC cluster PALMA II of the University of Münster, subsidised by the
38+
DFG (INST 211/667-1).
39+
40+
## Citing the paper
41+
42+
If you use our map in your research, please cite using the following BibTex:
43+
44+
```
45+
@article{pauls2024estimating,
46+
title={Estimating Canopy Height at Scale},
47+
author={Jan Pauls and Max Zimmer and Una M. Kelly and Martin Schwartz and Sassan Saatchi and Philippe Ciais and Sebastian Pokutta and Martin Brandt and Fabian Gieseke},
48+
year={2024},
49+
eprint={2406.01076},
50+
archivePrefix={arXiv}
51+
}
52+
```
1.14 MB
Loading

figures/global_canopy_height.png

915 KB
Loading

figures/pipeline.png

413 KB
Loading
+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import torch
2+
from torchvision import transforms
3+
from torch.utils.data import DataLoader
4+
import os
5+
import numpy as np
6+
from tqdm.auto import tqdm
7+
import sys
8+
# Assuming PreprocessedSatelliteDataset is defined in your project
9+
from config import PreprocessedSatelliteDataset
10+
from runner import Runner
11+
12+
def update_extremes(values, extremes, num_extremes, largest=True):
13+
"""
14+
Update the list of extreme values (either largest or smallest) based on the new batch.
15+
"""
16+
combined = torch.cat((extremes, values))
17+
sorted_values, _ = torch.sort(combined, descending=largest)
18+
return sorted_values[:num_extremes]
19+
20+
def compute_percentiles(dataset_name, split, percentiles, num_workers_default=4):
21+
# Set up dataset and DataLoader
22+
rootPath = Runner.get_dataset_root(dataset_name=dataset_name)
23+
dataframe = os.path.join(rootPath, f'{split}.csv')
24+
25+
train_transforms = transforms.Compose([
26+
transforms.ToTensor(),
27+
])
28+
29+
dataset = PreprocessedSatelliteDataset(data_path=rootPath, dataframe=dataframe,
30+
image_transforms=train_transforms,
31+
use_weighted_sampler=None, use_memmap=True)
32+
total_data_points = len(dataset)
33+
num_channels = 14
34+
35+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
36+
num_workers = num_workers_default * torch.cuda.device_count()
37+
dataloader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=num_workers, pin_memory=torch.cuda.is_available())
38+
39+
# Initialize percentile tracking
40+
extremes = {channel: {p: torch.tensor([]).to(device) for p in percentiles} for channel in range(num_channels)}
41+
42+
# Process each batch
43+
with torch.no_grad():
44+
for data, _ in tqdm(dataloader):
45+
data = data.to(device=device, non_blocking=True)
46+
# Switch the channel dimension to the first dimension, currently its at dim 1
47+
data = data.permute(1, 0, 2, 3)
48+
# Flatten the data
49+
data = data.flatten(start_dim=1)
50+
51+
for channel in range(num_channels):
52+
channel_data = data[channel, :]
53+
54+
for percentile in percentiles:
55+
if percentile < 50:
56+
num_extremes = int(total_data_points * percentile / 100)
57+
largest = False
58+
else:
59+
num_extremes = int(total_data_points * (100 - percentile) / 100) # E.g. if percentile == 95, we look at the 5 percentile from the other side
60+
largest = True
61+
current_extremes = extremes[channel][percentile]
62+
new_extremes = update_extremes(values=channel_data, extremes=current_extremes, num_extremes=num_extremes, largest=largest)
63+
extremes[channel][percentile] = new_extremes
64+
65+
# Compute final percentile values
66+
percentile_values = {channel: {} for channel in range(num_channels)}
67+
for channel in range(num_channels):
68+
for percentile in percentiles:
69+
if percentile > 50:
70+
percentile_values[channel][percentile] = extremes[channel][percentile].min().item()
71+
else:
72+
percentile_values[channel][percentile] = extremes[channel][percentile].max().item()
73+
74+
# Save results
75+
dump_path = os.path.join(os.getcwd(), f'{dataset_name}_{split}_percentiles.txt')
76+
with open(dump_path, 'w') as f:
77+
for percentile in percentiles:
78+
percentile_values_for_all_channels = tuple(percentile_values[channel][percentile] for channel in percentile_values)
79+
f.write(f'{percentile}: {percentile_values_for_all_channels},\n')
80+
81+
82+
return percentile_values
83+
84+
# Usage example
85+
percentiles = [1, 2, 5, 95, 98, 99]
86+
dataset_name = 'ai4forest_camera'
87+
split = 'train'
88+
percentile_values = compute_percentiles(dataset_name, split, percentiles)

scripts/compute_dataset_statistics.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import torch
2+
from torchvision import transforms
3+
from torch.utils.data import DataLoader
4+
from PIL import Image
5+
import os
6+
import numpy as np
7+
8+
from config import PreprocessedSatelliteDataset
9+
from runner import Runner
10+
11+
from tqdm.auto import tqdm
12+
13+
def compute_mean_std(dataset, split):
14+
rootPath = Runner.get_dataset_root(dataset_name=dataset)
15+
if split == 'train':
16+
dataframe = os.path.join(rootPath, 'train.csv')
17+
elif split == 'val':
18+
dataframe = os.path.join(rootPath, 'val.csv')
19+
else:
20+
raise ValueError("Invalid split value. Expected 'train' or 'val'.")
21+
# Convert to tensor (this changes the order of the channels)
22+
train_transforms = transforms.Compose([
23+
transforms.ToTensor(),
24+
])
25+
dataset = PreprocessedSatelliteDataset(data_path=rootPath, dataframe=dataframe, image_transforms=train_transforms,
26+
use_weighted_sampler=None, use_memmap=True)
27+
28+
29+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30+
num_workers_default = 4
31+
num_workers = num_workers_default * torch.cuda.device_count()
32+
dataloader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=num_workers, pin_memory=torch.cuda.is_available())
33+
mean = 0.
34+
std = 0.
35+
nb_samples = 0.
36+
with torch.no_grad():
37+
for data in tqdm(dataloader):
38+
data, _ = data
39+
data = data.to(device=device, non_blocking=True)
40+
batch_samples = data.size(0)
41+
data = data.view(batch_samples, data.size(1), -1)
42+
mean += data.mean(2).sum(0)
43+
std += data.std(2).sum(0)
44+
nb_samples += batch_samples
45+
46+
mean /= nb_samples
47+
std /= nb_samples
48+
return mean, std
49+
50+
# Load the dataset
51+
dataset = 'ai4forest_camera'
52+
split = 'train'
53+
54+
55+
# Compute and print the mean and std
56+
mean, std = compute_mean_std(dataset=dataset, split=split)
57+
print(f'Mean: {mean}')
58+
print(f'Std: {std}')
59+
60+
# Dump the mean and std to a file in the current working directory
61+
dump_path = os.path.join(os.getcwd(), f'{dataset}_{split}_mean_std.txt')
62+
with open(dump_path, 'w') as f:
63+
f.write(f'Mean: {mean}\n')
64+
f.write(f'Std: {std}\n')

training/.DS_Store

6 KB
Binary file not shown.

training/config.py

+170
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
2+
import glob
3+
import os
4+
import torch
5+
import numpy as np
6+
import pandas as pd
7+
import pdb
8+
from torch.utils.data.dataloader import default_collate
9+
import sys
10+
11+
means = {
12+
'ai4forest_camera': (10782.3223, 3304.7444, 1999.6086, 7276.4209, 1186.4460, 1884.6165,
13+
2645.6113, 3128.2588, 3806.2808, 4134.6855, 4113.4883, 4259.1885,
14+
4683.5879, 3838.2222), # Not the true values, change for your dataset
15+
}
16+
17+
stds = {
18+
'ai4forest_camera': (907.7484, 472.1412, 423.8558, 1086.0916, 175.0936, 226.6303,
19+
299.4834, 313.0911, 388.1186, 434.4579, 455.7314, 455.0303,
20+
388.5127, 374.1260), # Not the true values, change for your dataset
21+
}
22+
23+
percentiles = {
24+
'ai4forest_camera': {
25+
1: (-7542.0, -8126.0, -16659.0, -14187.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0),
26+
2: (-6834.0, -7255.0, -14468.0, -13537.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0),
27+
5: (-5694.0, -5963.0, -12383.0, -12601.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0),
28+
95: (24995.0, 24556.0, 22124.0, 20120.0, 15016.0, 15116.0, 15212.0, 15181.0, 14946.0, 14406.0, 14660.0, 13810.0, 12082.0, 13041.0),
29+
98: (25969.0, 26078.0, 23632.0, 21934.0, 15648.0, 15608.0, 15487.0, 15449.0, 15296.0, 15155.0, 15264.0, 14943.0, 13171.0, 14064.0),
30+
99: (27044.0, 27349.0, 24868.0, 23266.0, 15970.0, 15680.0, 15548.0, 15494.0, 15432.0, 15368.0, 15385.0, 15219.0, 13590.0, 14657.0),
31+
} # Not the true values, change for your dataset
32+
}
33+
34+
class FixValDataset(Dataset):
35+
"""
36+
Dataset class to load the fixval dataset.
37+
"""
38+
def __init__(self, data_path, dataframe, image_transforms=None):
39+
self.data_path = data_path
40+
self.df = pd.read_csv(dataframe, index_col=False)
41+
self.files = list(self.df["paths"].apply(lambda x: os.path.join(data_path, x)))
42+
self.image_transforms = image_transforms
43+
44+
def __len__(self):
45+
return len(self.files)
46+
47+
def __getitem__(self, index):
48+
file = self.files[index].replace(r"'", "")
49+
fileName = file[file.rfind('data_')+5: file.rfind('.npz')]
50+
data = np.load(file)
51+
52+
image = data["data"].astype(np.float32)
53+
# Move the channel axis to the last position (required for torchvision transforms)
54+
image = np.moveaxis(image, 0, -1)
55+
if self.image_transforms:
56+
image = self.image_transforms(image)
57+
58+
return image, fileName
59+
60+
class PreprocessedSatelliteDataset(Dataset):
61+
"""
62+
Dataset class for preprocessed satellite imagery.
63+
"""
64+
65+
def __init__(self, data_path, dataframe=None, image_transforms=None, label_transforms=None, joint_transforms=None, use_weighted_sampler=False,
66+
use_weighting_quantile=None, use_memmap=False, remove_corrupt=True, load_labels=True, patch_size=512):
67+
self.use_memmap = use_memmap
68+
self.patch_size = patch_size
69+
self.load_labels = load_labels # If False, we only load the images and not the labels
70+
df = pd.read_csv(dataframe)
71+
72+
if remove_corrupt:
73+
old_len = len(df)
74+
#df = df[df["missing_s2_flag"] == False] # Use only the rows that are not corrupt, i.e. those where df["missing_s2_flag"] == False
75+
76+
# Use only the rows that are not corrupt, i.e. those where df["has_corrupt_s2_channel_flag"] == False
77+
df = df[df["has_corrupt_s2_channel_flag"] == False]
78+
sys.stdout.write(f"Removed {old_len - len(df)} corrupt rows.\n")
79+
80+
self.files = list(df["paths"].apply(lambda x: os.path.join(data_path, x)))
81+
82+
if use_weighted_sampler not in [False, None]:
83+
assert use_weighted_sampler in ['g5', 'g10', 'g15', 'g20', 'g25', 'g30']
84+
weighting_quantile = use_weighting_quantile
85+
assert weighting_quantile in [None, 'None'] or int(weighting_quantile) == weighting_quantile, "weighting_quantile must be an integer."
86+
if weighting_quantile in [None, 'None']:
87+
self.weights = (df[use_weighted_sampler] / df["totals"]).values.clip(0., 1.)
88+
else:
89+
# We do not clip between 0 and 1, but rather between the weighting_quantile and 1.
90+
weighting_quantile = float(weighting_quantile)
91+
self.weights = (df[use_weighted_sampler] / df["totals"]).values
92+
93+
# Compute the quantiles, ignoring nan values and zero values
94+
tmp_weights = self.weights.copy()
95+
tmp_weights[np.isnan(tmp_weights)] = 0.
96+
tmp_weights = tmp_weights[tmp_weights > 0.]
97+
98+
quantile_min = np.nanquantile(tmp_weights, weighting_quantile / 100)
99+
sys.stdout.write(f"Computed weighting {weighting_quantile}-quantile-lower bound: {quantile_min}.\n")
100+
101+
# Clip the weights
102+
self.weights = self.weights.clip(quantile_min, 1.0)
103+
104+
# Set the nan values to 0.
105+
self.weights[np.isnan(self.weights)] = 0.
106+
107+
else:
108+
self.weights = None
109+
self.image_transforms, self.label_transforms, self.joint_transforms = image_transforms, label_transforms, joint_transforms
110+
111+
def __len__(self):
112+
return len(self.files)
113+
114+
def __getitem__(self, index):
115+
if self.use_memmap:
116+
item = self.getitem_memmap(index)
117+
else:
118+
item = self.getitem_classic(index)
119+
120+
return item
121+
122+
def getitem_memmap(self, index):
123+
file = self.files[index]
124+
with np.load(file, mmap_mode='r') as npz_file:
125+
image = npz_file['data'].astype(np.float32)
126+
# Move the channel axis to the last position (required for torchvision transforms)
127+
image = np.moveaxis(image, 0, -1)
128+
if self.image_transforms:
129+
image = self.image_transforms(image)
130+
if self.load_labels:
131+
label = npz_file['labels'].astype(np.float32)
132+
133+
# Process label
134+
label = label[:3] # Everything after index/granule 3 is irrelevant
135+
label = label / 100 # Convert from cm to m
136+
label = np.moveaxis(label, 0, -1)
137+
138+
if self.label_transforms:
139+
label = self.label_transforms(label)
140+
if self.joint_transforms:
141+
image, label = self.joint_transforms(image, label)
142+
return image, label
143+
144+
return image
145+
146+
def getitem_classic(self, index):
147+
file = self.files[index]
148+
data = np.load(file)
149+
150+
image = data["data"].astype(np.float32)
151+
# Move the channel axis to the last position (required for torchvision transforms)
152+
image = np.moveaxis(image, 0, -1)[:self.patch_size,:self.patch_size]
153+
if self.image_transforms:
154+
image = self.image_transforms(image)
155+
if self.load_labels:
156+
label = data["labels"].astype(np.float32)
157+
158+
# Process label
159+
label = label[:3] # Everything after index 3 is irrelevant
160+
label = label[:,:self.patch_size, :self.patch_size]
161+
label = label / 100 # Convert from cm to m
162+
label = np.moveaxis(label, 0, -1)
163+
164+
if self.label_transforms:
165+
label = self.label_transforms(label)
166+
if self.joint_transforms:
167+
image, label = self.joint_transforms(image, label)
168+
return image, label
169+
170+
return image

0 commit comments

Comments
 (0)