Skip to content

Commit

Permalink
Rename embeddings file to include MGRS code and store GeoTIFF source_…
Browse files Browse the repository at this point in the history
…url (#86)

* 🗃️ Store source_url of GeoTIFF to GeoParquet file

Passing the URL or path of the GeoTIFF file through the datapipe, and into the model's prediction loop. The geopandas.GeoDataFrame now has an extra 'source_url' string column, and this is saved to the GeoParquet file too.

* 🚚 Save one GeoParquet file for each unique MGRS tile

For each MGRS code (e.g. 12ABC), save a GeoParquet file with a name formatted like `{MGRS:5}_v{VERSION:2}.gpq`, e.g. 12ABC_v01.gpq. Have updated the unit test to check that rows with different MGRS codes are saved to different files.

* ⚡ Save GeoParquet file with ZSTD compression

Using ZStandard compression instead of Parquet's default Snappy compression. Should result is slightly smaller filesizes, and slightly faster data transfer and compression (especially over the network). Also changed an assert statement to an if-then-raise instead.

* ♻️ Predict with multiple workers and gather results to save

Speed up embedding generation by enabling multiple workers to fetch and load mini-batches of GeoTIFF files independently, and run the prediction. The prediction or generated embeddings from each worker (a geopandas.GeoDataFrame) is then concatenated together row-wise, before getting passed to the GeoParquet output script. This is done via LightningModule's `on_predict_epoch_end` hook. Also documented these new processing steps in the docstring.
  • Loading branch information
weiji14 authored Dec 19, 2023
1 parent a9cfeff commit 776cce8
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 21 deletions.
9 changes: 8 additions & 1 deletion src/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def _array_to_torch(filepath: str) -> dict[str, torch.Tensor | str]:
- bbox: torch.Tensor - spatial bounding box as (xmin, ymin, xmax, ymax)
- epsg: torch.Tensor - coordinate reference system as an EPSG code
- date: str - the date the image was acquired in YYYY-MM-DD format
- source_url: str - the URL or path to the source GeoTIFF file
"""
# GeoTIFF - Rasterio
with rasterio.open(fp=filepath) as dataset:
Expand All @@ -205,7 +206,13 @@ def _array_to_torch(filepath: str) -> dict[str, torch.Tensor | str]:
# Get date
date: str = dataset.tags()["date"] # YYYY-MM-DD format

return {"image": tensor, "bbox": bbox, "epsg": epsg, "date": date}
return {
"image": tensor, # shape (13, 512, 512)
"bbox": bbox, # bounds [xmin, ymin, xmax, ymax]
"epsg": epsg, # e.g. 32632
"date": date, # e.g. 2020-12-31
"source_url": filepath, # e.g. s3://.../claytile_12ABC_20201231_v0_0200.tif
}


class GeoTIFFDataPipeModule(L.LightningDataModule):
Expand Down
88 changes: 74 additions & 14 deletions src/model_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
https://github.com/Lightning-AI/deep-learning-project-template
"""
import os
import re

import geopandas as gpd
import lightning as L
import numpy as np
import pandas as pd
import pyarrow as pa
import shapely
import torch
Expand Down Expand Up @@ -175,7 +177,7 @@ def predict_step(
Logic for the neural network's prediction loop.
Takes batches of image inputs, generate the embeddings, and store them
in a GeoParquet file with spatiotemporal metadata.
in a geopandas.GeoDataFrame with spatiotemporal metadata.
Steps:
1. Image inputs are passed through the encoder model to produce raw
Expand All @@ -200,20 +202,22 @@ def predict_step(
3. Embeddings are joined with spatiotemporal metadata (date and
bounding box polygon) in a geopandas.GeoDataFrame table. The
coordinates of the bounding box are in an OGC:CRS84 projection (i.e.
longitude/latitude).
4. The geodataframe table is saved out to a GeoParquet file.
| date | embeddings | geometry |
|------------|----------------------|--------------|
| 2021-01-01 | [0.1, 0.4, ... x768] | POLYGON(...) | ---> *.gpq
| 2021-06-30 | [0.2, 0.5, ... x768] | POLYGON(...) |
| 2021-12-31 | [0.3, 0.6, ... x768] | POLYGON(...) |
longitude/latitude). The table is as follows:
| source_url | date | embeddings | geometry |
|--------------|------------|----------------------|--------------|
| s3://1A.tif | 2021-01-01 | [0.1, 0.4, ... x768] | POLYGON(...) |
| s3://2B.tif | 2021-06-30 | [0.2, 0.5, ... x768] | POLYGON(...) |
| s3://3C.tif | 2021-12-31 | [0.3, 0.6, ... x768] | POLYGON(...) |
"""
# Get image, bounding box, EPSG code, and date inputs
x: torch.Tensor = batch["image"] # image of shape (1, 13, 256, 256) # BCHW
bboxes: np.ndarray = batch["bbox"].cpu().__array__() # bounding boxes
epsgs: torch.Tensor = batch["epsg"] # coordinate reference systems as EPSG code
dates: list[str] = batch["date"] # dates, e.g. ['2022-12-12', '2022-12-12']
source_urls: list[str] = batch[ # URLs, e.g. ['s3://1.tif', 's3://2.tif']
"source_url"
]

# Forward encoder
self.vit.config.mask_ratio = 0 # disable masking
Expand Down Expand Up @@ -244,7 +248,8 @@ def predict_step(

gdf = gpd.GeoDataFrame(
data={
"date": gpd.pd.to_datetime(arg=dates, format="%Y-%m-%d").astype(
"source_url": pd.Series(data=source_urls, dtype="string[pyarrow]"),
"date": pd.to_datetime(arg=dates, format="%Y-%m-%d").astype(
dtype="date32[day][pyarrow]"
),
"embeddings": pa.FixedShapeTensorArray.from_numpy_ndarray(
Expand All @@ -261,12 +266,67 @@ def predict_step(
)
gdf = gdf.to_crs(crs="OGC:CRS84") # reproject from UTM to lonlat coordinates

# Save embeddings in GeoParquet format
return gdf

def on_predict_epoch_end(self) -> gpd.GeoDataFrame:
"""
Logic to gather all the results from one epoch in a prediction loop.
This is where we combine all the vector embeddings generated from each
mini-batch prediction (stored in geopandas.GeoDataFrame tables) in a
row-wise manner. The embeddings are then output to GeoParquet files
according to their MGRS code.
Steps:
1. Concatenate all geopandas.GeoDataFrame objects row-wise
2. Find all unique MGRS names in the big GeoDataFrame.
3. Split the GeoDataFrame by MGRS name, and output each MGRS-subsetted
table to a GeoParquet file.
________________ ________________ ________________
| | | | | |
| GeoDataFrame 1 | | GeoDataFrame 2 | | GeoDataFrame 3 |
| for MGRS 12ABC | | for MGRS 23DEF | | for MGRS 45GHI |
|________________| |________________| |________________|
| | |
v v v
12ABC_v01.gpq 23DEF_v01.gpq 45GHI_v01.gpq
"""
# Combine list of geopandas.GeoDataFrame objects
results: list[gpd.GeoDataFrame] = self.trainer.predict_loop.predictions
if results:
gdf: gpd.GeoDataFrame = pd.concat(
objs=results, axis="index", ignore_index=True
)
else:
print(
"No embeddings generated, "
f"possibly no GeoTIFF files in {self.trainer.datamodule.data_path}"
)
return

# Save embeddings in GeoParquet format, one file for each MGRS code
outfolder: str = f"{self.trainer.default_root_dir}/data/embeddings"
os.makedirs(name=outfolder, exist_ok=True)
outpath = f"{outfolder}/embeddings_{batch_idx}.gpq"
gdf.to_parquet(path=outpath, schema_version="1.0.0")
print(f"Saved embeddings of shape {tuple(embeddings_mean.shape)} to {outpath}")

# Find unique MGRS names (e.g. '12ABC'), e.g.
# from 's3://.../.../claytile_12ABC_20201231_v02_0001.tif', get 12ABC
mgrs_codes = gdf.source_url.str.split("/").str[-1].str.split("_").str[1]
unique_mgrs_codes = mgrs_codes.unique()
for mgrs_code in unique_mgrs_codes:
if re.match(pattern=r"(\d{2}[A-Z]{3})", string=mgrs_code) is None:
raise ValueError(
"MGRS code should have 2 numbers and 3 letters (e.g. 12ABC), "
f"but got {mgrs_code} instead"
)

# Output to a GeoParquet filename like {MGRS:5}_v{VERSION:2}.gpq
outpath = f"{outfolder}/{mgrs_code}_v01.gpq"
_gdf: gpd.GeoDataFrame = gdf.loc[mgrs_codes == mgrs_code]
_gdf.to_parquet(path=outpath, schema_version="1.0.0", compression="ZSTD")
print(
f"Saved {len(_gdf)} rows of embeddings of "
f"shape {gdf.embeddings.iloc[0].shape} to {outpath}"
)

return gdf

Expand Down
5 changes: 5 additions & 0 deletions src/tests/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def test_geotiffdatapipemodule(geotiff_folder, stage, dataloader):
bbox = batch["bbox"]
epsg = batch["epsg"]
date = batch["date"]
source_url = batch["source_url"]

assert image.shape == torch.Size([2, 3, 256, 256])
assert image.dtype == torch.float16
Expand All @@ -80,6 +81,10 @@ def test_geotiffdatapipemodule(geotiff_folder, stage, dataloader):
actual=epsg, expected=torch.tensor(data=[32646, 32646], dtype=torch.int32)
)
assert date == ["2022-12-31", "2023-12-31"]
assert source_url == [
f"{geotiff_folder}/claytile-12ABC-2022-12-31-01-1.tif",
f"{geotiff_folder}/claytile-12ABC-2023-12-31-01-2.tif",
]


def test_geotiffdatapipemodule_list_from_s3_bucket(monkeypatch):
Expand Down
25 changes: 19 additions & 6 deletions src/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,21 @@ def fixture_datapipe() -> torchdata.datapipes.iter.IterDataPipe:
datapipe = torchdata.datapipes.iter.IterableWrapper(
iterable=[
{
"image": torch.randn(2, 13, 512, 512).to(dtype=torch.float16),
"image": torch.randn(3, 13, 512, 512).to(dtype=torch.float16),
"bbox": torch.tensor(
data=[
[499975.0, 3397465.0, 502535.0, 3400025.0],
[530695.0, 3397465.0, 533255.0, 3400025.0],
[561415.0, 3397465.0, 563975.0, 3400025.0],
]
),
"date": ["2020-01-01", "2020-12-31"],
"epsg": torch.tensor(data=[32646, 32646]),
"date": ["2020-01-01", "2020-12-31", "2020-12-31"],
"epsg": torch.tensor(data=[32760, 32760, 32760]),
"source_url": [
"s3://claytile_60HTE_1.tif",
"s3://claytile_60GUV_2.tif",
"s3://claytile_60GUV_3.tif",
],
},
]
)
Expand Down Expand Up @@ -67,11 +73,18 @@ def test_model_vit(datapipe):

# Prediction
trainer.predict(model=model, dataloaders=dataloader)
assert os.path.exists(path := f"{tmpdirname}/data/embeddings/embeddings_0.gpq")
assert (
len(os.listdir(path=f"{tmpdirname}/data/embeddings")) == 2 # noqa: PLR2004
)
assert os.path.exists(path := f"{tmpdirname}/data/embeddings/60HTE_v01.gpq")
assert os.path.exists(path := f"{tmpdirname}/data/embeddings/60GUV_v01.gpq")
geodataframe: gpd.GeoDataFrame = gpd.read_parquet(path=path)

assert geodataframe.shape == (2, 3)
assert all(geodataframe.columns == ["date", "embeddings", "geometry"])
assert geodataframe.shape == (2, 4) # 2 rows, 4 columns
assert all(
geodataframe.columns == ["source_url", "date", "embeddings", "geometry"]
)
assert geodataframe.source_url.dtype == "string"
assert geodataframe.date.dtype == "date32[day][pyarrow]"
assert geodataframe.embeddings.dtype == "object"
assert geodataframe.geometry.dtype == gpd.array.GeometryDtype()
Expand Down

0 comments on commit 776cce8

Please sign in to comment.