From 776cce8aa7de895ed8e25b7dd1174100f8bea2db Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Tue, 19 Dec 2023 13:15:52 +1300 Subject: [PATCH] Rename embeddings file to include MGRS code and store GeoTIFF source_url (#86) * :card_file_box: Store source_url of GeoTIFF to GeoParquet file Passing the URL or path of the GeoTIFF file through the datapipe, and into the model's prediction loop. The geopandas.GeoDataFrame now has an extra 'source_url' string column, and this is saved to the GeoParquet file too. * :truck: Save one GeoParquet file for each unique MGRS tile For each MGRS code (e.g. 12ABC), save a GeoParquet file with a name formatted like `{MGRS:5}_v{VERSION:2}.gpq`, e.g. 12ABC_v01.gpq. Have updated the unit test to check that rows with different MGRS codes are saved to different files. * :zap: Save GeoParquet file with ZSTD compression Using ZStandard compression instead of Parquet's default Snappy compression. Should result is slightly smaller filesizes, and slightly faster data transfer and compression (especially over the network). Also changed an assert statement to an if-then-raise instead. * :recycle: Predict with multiple workers and gather results to save Speed up embedding generation by enabling multiple workers to fetch and load mini-batches of GeoTIFF files independently, and run the prediction. The prediction or generated embeddings from each worker (a geopandas.GeoDataFrame) is then concatenated together row-wise, before getting passed to the GeoParquet output script. This is done via LightningModule's `on_predict_epoch_end` hook. Also documented these new processing steps in the docstring. --- src/datamodule.py | 9 +++- src/model_vit.py | 88 ++++++++++++++++++++++++++++++------ src/tests/test_datamodule.py | 5 ++ src/tests/test_model.py | 25 +++++++--- 4 files changed, 106 insertions(+), 21 deletions(-) diff --git a/src/datamodule.py b/src/datamodule.py index 50a7b49c..e2360fba 100644 --- a/src/datamodule.py +++ b/src/datamodule.py @@ -189,6 +189,7 @@ def _array_to_torch(filepath: str) -> dict[str, torch.Tensor | str]: - bbox: torch.Tensor - spatial bounding box as (xmin, ymin, xmax, ymax) - epsg: torch.Tensor - coordinate reference system as an EPSG code - date: str - the date the image was acquired in YYYY-MM-DD format + - source_url: str - the URL or path to the source GeoTIFF file """ # GeoTIFF - Rasterio with rasterio.open(fp=filepath) as dataset: @@ -205,7 +206,13 @@ def _array_to_torch(filepath: str) -> dict[str, torch.Tensor | str]: # Get date date: str = dataset.tags()["date"] # YYYY-MM-DD format - return {"image": tensor, "bbox": bbox, "epsg": epsg, "date": date} + return { + "image": tensor, # shape (13, 512, 512) + "bbox": bbox, # bounds [xmin, ymin, xmax, ymax] + "epsg": epsg, # e.g. 32632 + "date": date, # e.g. 2020-12-31 + "source_url": filepath, # e.g. s3://.../claytile_12ABC_20201231_v0_0200.tif + } class GeoTIFFDataPipeModule(L.LightningDataModule): diff --git a/src/model_vit.py b/src/model_vit.py index a9e96775..ae9f727e 100644 --- a/src/model_vit.py +++ b/src/model_vit.py @@ -5,10 +5,12 @@ https://github.com/Lightning-AI/deep-learning-project-template """ import os +import re import geopandas as gpd import lightning as L import numpy as np +import pandas as pd import pyarrow as pa import shapely import torch @@ -175,7 +177,7 @@ def predict_step( Logic for the neural network's prediction loop. Takes batches of image inputs, generate the embeddings, and store them - in a GeoParquet file with spatiotemporal metadata. + in a geopandas.GeoDataFrame with spatiotemporal metadata. Steps: 1. Image inputs are passed through the encoder model to produce raw @@ -200,20 +202,22 @@ def predict_step( 3. Embeddings are joined with spatiotemporal metadata (date and bounding box polygon) in a geopandas.GeoDataFrame table. The coordinates of the bounding box are in an OGC:CRS84 projection (i.e. - longitude/latitude). - 4. The geodataframe table is saved out to a GeoParquet file. - - | date | embeddings | geometry | - |------------|----------------------|--------------| - | 2021-01-01 | [0.1, 0.4, ... x768] | POLYGON(...) | ---> *.gpq - | 2021-06-30 | [0.2, 0.5, ... x768] | POLYGON(...) | - | 2021-12-31 | [0.3, 0.6, ... x768] | POLYGON(...) | + longitude/latitude). The table is as follows: + + | source_url | date | embeddings | geometry | + |--------------|------------|----------------------|--------------| + | s3://1A.tif | 2021-01-01 | [0.1, 0.4, ... x768] | POLYGON(...) | + | s3://2B.tif | 2021-06-30 | [0.2, 0.5, ... x768] | POLYGON(...) | + | s3://3C.tif | 2021-12-31 | [0.3, 0.6, ... x768] | POLYGON(...) | """ # Get image, bounding box, EPSG code, and date inputs x: torch.Tensor = batch["image"] # image of shape (1, 13, 256, 256) # BCHW bboxes: np.ndarray = batch["bbox"].cpu().__array__() # bounding boxes epsgs: torch.Tensor = batch["epsg"] # coordinate reference systems as EPSG code dates: list[str] = batch["date"] # dates, e.g. ['2022-12-12', '2022-12-12'] + source_urls: list[str] = batch[ # URLs, e.g. ['s3://1.tif', 's3://2.tif'] + "source_url" + ] # Forward encoder self.vit.config.mask_ratio = 0 # disable masking @@ -244,7 +248,8 @@ def predict_step( gdf = gpd.GeoDataFrame( data={ - "date": gpd.pd.to_datetime(arg=dates, format="%Y-%m-%d").astype( + "source_url": pd.Series(data=source_urls, dtype="string[pyarrow]"), + "date": pd.to_datetime(arg=dates, format="%Y-%m-%d").astype( dtype="date32[day][pyarrow]" ), "embeddings": pa.FixedShapeTensorArray.from_numpy_ndarray( @@ -261,12 +266,67 @@ def predict_step( ) gdf = gdf.to_crs(crs="OGC:CRS84") # reproject from UTM to lonlat coordinates - # Save embeddings in GeoParquet format + return gdf + + def on_predict_epoch_end(self) -> gpd.GeoDataFrame: + """ + Logic to gather all the results from one epoch in a prediction loop. + + This is where we combine all the vector embeddings generated from each + mini-batch prediction (stored in geopandas.GeoDataFrame tables) in a + row-wise manner. The embeddings are then output to GeoParquet files + according to their MGRS code. + + Steps: + 1. Concatenate all geopandas.GeoDataFrame objects row-wise + 2. Find all unique MGRS names in the big GeoDataFrame. + 3. Split the GeoDataFrame by MGRS name, and output each MGRS-subsetted + table to a GeoParquet file. + ________________ ________________ ________________ + | | | | | | + | GeoDataFrame 1 | | GeoDataFrame 2 | | GeoDataFrame 3 | + | for MGRS 12ABC | | for MGRS 23DEF | | for MGRS 45GHI | + |________________| |________________| |________________| + | | | + v v v + 12ABC_v01.gpq 23DEF_v01.gpq 45GHI_v01.gpq + """ + # Combine list of geopandas.GeoDataFrame objects + results: list[gpd.GeoDataFrame] = self.trainer.predict_loop.predictions + if results: + gdf: gpd.GeoDataFrame = pd.concat( + objs=results, axis="index", ignore_index=True + ) + else: + print( + "No embeddings generated, " + f"possibly no GeoTIFF files in {self.trainer.datamodule.data_path}" + ) + return + + # Save embeddings in GeoParquet format, one file for each MGRS code outfolder: str = f"{self.trainer.default_root_dir}/data/embeddings" os.makedirs(name=outfolder, exist_ok=True) - outpath = f"{outfolder}/embeddings_{batch_idx}.gpq" - gdf.to_parquet(path=outpath, schema_version="1.0.0") - print(f"Saved embeddings of shape {tuple(embeddings_mean.shape)} to {outpath}") + + # Find unique MGRS names (e.g. '12ABC'), e.g. + # from 's3://.../.../claytile_12ABC_20201231_v02_0001.tif', get 12ABC + mgrs_codes = gdf.source_url.str.split("/").str[-1].str.split("_").str[1] + unique_mgrs_codes = mgrs_codes.unique() + for mgrs_code in unique_mgrs_codes: + if re.match(pattern=r"(\d{2}[A-Z]{3})", string=mgrs_code) is None: + raise ValueError( + "MGRS code should have 2 numbers and 3 letters (e.g. 12ABC), " + f"but got {mgrs_code} instead" + ) + + # Output to a GeoParquet filename like {MGRS:5}_v{VERSION:2}.gpq + outpath = f"{outfolder}/{mgrs_code}_v01.gpq" + _gdf: gpd.GeoDataFrame = gdf.loc[mgrs_codes == mgrs_code] + _gdf.to_parquet(path=outpath, schema_version="1.0.0", compression="ZSTD") + print( + f"Saved {len(_gdf)} rows of embeddings of " + f"shape {gdf.embeddings.iloc[0].shape} to {outpath}" + ) return gdf diff --git a/src/tests/test_datamodule.py b/src/tests/test_datamodule.py index 84b3575f..0977ac19 100644 --- a/src/tests/test_datamodule.py +++ b/src/tests/test_datamodule.py @@ -65,6 +65,7 @@ def test_geotiffdatapipemodule(geotiff_folder, stage, dataloader): bbox = batch["bbox"] epsg = batch["epsg"] date = batch["date"] + source_url = batch["source_url"] assert image.shape == torch.Size([2, 3, 256, 256]) assert image.dtype == torch.float16 @@ -80,6 +81,10 @@ def test_geotiffdatapipemodule(geotiff_folder, stage, dataloader): actual=epsg, expected=torch.tensor(data=[32646, 32646], dtype=torch.int32) ) assert date == ["2022-12-31", "2023-12-31"] + assert source_url == [ + f"{geotiff_folder}/claytile-12ABC-2022-12-31-01-1.tif", + f"{geotiff_folder}/claytile-12ABC-2023-12-31-01-2.tif", + ] def test_geotiffdatapipemodule_list_from_s3_bucket(monkeypatch): diff --git a/src/tests/test_model.py b/src/tests/test_model.py index 42350855..ee17ecfe 100644 --- a/src/tests/test_model.py +++ b/src/tests/test_model.py @@ -27,15 +27,21 @@ def fixture_datapipe() -> torchdata.datapipes.iter.IterDataPipe: datapipe = torchdata.datapipes.iter.IterableWrapper( iterable=[ { - "image": torch.randn(2, 13, 512, 512).to(dtype=torch.float16), + "image": torch.randn(3, 13, 512, 512).to(dtype=torch.float16), "bbox": torch.tensor( data=[ [499975.0, 3397465.0, 502535.0, 3400025.0], [530695.0, 3397465.0, 533255.0, 3400025.0], + [561415.0, 3397465.0, 563975.0, 3400025.0], ] ), - "date": ["2020-01-01", "2020-12-31"], - "epsg": torch.tensor(data=[32646, 32646]), + "date": ["2020-01-01", "2020-12-31", "2020-12-31"], + "epsg": torch.tensor(data=[32760, 32760, 32760]), + "source_url": [ + "s3://claytile_60HTE_1.tif", + "s3://claytile_60GUV_2.tif", + "s3://claytile_60GUV_3.tif", + ], }, ] ) @@ -67,11 +73,18 @@ def test_model_vit(datapipe): # Prediction trainer.predict(model=model, dataloaders=dataloader) - assert os.path.exists(path := f"{tmpdirname}/data/embeddings/embeddings_0.gpq") + assert ( + len(os.listdir(path=f"{tmpdirname}/data/embeddings")) == 2 # noqa: PLR2004 + ) + assert os.path.exists(path := f"{tmpdirname}/data/embeddings/60HTE_v01.gpq") + assert os.path.exists(path := f"{tmpdirname}/data/embeddings/60GUV_v01.gpq") geodataframe: gpd.GeoDataFrame = gpd.read_parquet(path=path) - assert geodataframe.shape == (2, 3) - assert all(geodataframe.columns == ["date", "embeddings", "geometry"]) + assert geodataframe.shape == (2, 4) # 2 rows, 4 columns + assert all( + geodataframe.columns == ["source_url", "date", "embeddings", "geometry"] + ) + assert geodataframe.source_url.dtype == "string" assert geodataframe.date.dtype == "date32[day][pyarrow]" assert geodataframe.embeddings.dtype == "object" assert geodataframe.geometry.dtype == gpd.array.GeometryDtype()