diff --git a/src/datamodule.py b/src/datamodule.py index 50a7b49c..e2360fba 100644 --- a/src/datamodule.py +++ b/src/datamodule.py @@ -189,6 +189,7 @@ def _array_to_torch(filepath: str) -> dict[str, torch.Tensor | str]: - bbox: torch.Tensor - spatial bounding box as (xmin, ymin, xmax, ymax) - epsg: torch.Tensor - coordinate reference system as an EPSG code - date: str - the date the image was acquired in YYYY-MM-DD format + - source_url: str - the URL or path to the source GeoTIFF file """ # GeoTIFF - Rasterio with rasterio.open(fp=filepath) as dataset: @@ -205,7 +206,13 @@ def _array_to_torch(filepath: str) -> dict[str, torch.Tensor | str]: # Get date date: str = dataset.tags()["date"] # YYYY-MM-DD format - return {"image": tensor, "bbox": bbox, "epsg": epsg, "date": date} + return { + "image": tensor, # shape (13, 512, 512) + "bbox": bbox, # bounds [xmin, ymin, xmax, ymax] + "epsg": epsg, # e.g. 32632 + "date": date, # e.g. 2020-12-31 + "source_url": filepath, # e.g. s3://.../claytile_12ABC_20201231_v0_0200.tif + } class GeoTIFFDataPipeModule(L.LightningDataModule): diff --git a/src/model_vit.py b/src/model_vit.py index a9e96775..ae9f727e 100644 --- a/src/model_vit.py +++ b/src/model_vit.py @@ -5,10 +5,12 @@ https://github.com/Lightning-AI/deep-learning-project-template """ import os +import re import geopandas as gpd import lightning as L import numpy as np +import pandas as pd import pyarrow as pa import shapely import torch @@ -175,7 +177,7 @@ def predict_step( Logic for the neural network's prediction loop. Takes batches of image inputs, generate the embeddings, and store them - in a GeoParquet file with spatiotemporal metadata. + in a geopandas.GeoDataFrame with spatiotemporal metadata. Steps: 1. Image inputs are passed through the encoder model to produce raw @@ -200,20 +202,22 @@ def predict_step( 3. Embeddings are joined with spatiotemporal metadata (date and bounding box polygon) in a geopandas.GeoDataFrame table. The coordinates of the bounding box are in an OGC:CRS84 projection (i.e. - longitude/latitude). - 4. The geodataframe table is saved out to a GeoParquet file. - - | date | embeddings | geometry | - |------------|----------------------|--------------| - | 2021-01-01 | [0.1, 0.4, ... x768] | POLYGON(...) | ---> *.gpq - | 2021-06-30 | [0.2, 0.5, ... x768] | POLYGON(...) | - | 2021-12-31 | [0.3, 0.6, ... x768] | POLYGON(...) | + longitude/latitude). The table is as follows: + + | source_url | date | embeddings | geometry | + |--------------|------------|----------------------|--------------| + | s3://1A.tif | 2021-01-01 | [0.1, 0.4, ... x768] | POLYGON(...) | + | s3://2B.tif | 2021-06-30 | [0.2, 0.5, ... x768] | POLYGON(...) | + | s3://3C.tif | 2021-12-31 | [0.3, 0.6, ... x768] | POLYGON(...) | """ # Get image, bounding box, EPSG code, and date inputs x: torch.Tensor = batch["image"] # image of shape (1, 13, 256, 256) # BCHW bboxes: np.ndarray = batch["bbox"].cpu().__array__() # bounding boxes epsgs: torch.Tensor = batch["epsg"] # coordinate reference systems as EPSG code dates: list[str] = batch["date"] # dates, e.g. ['2022-12-12', '2022-12-12'] + source_urls: list[str] = batch[ # URLs, e.g. ['s3://1.tif', 's3://2.tif'] + "source_url" + ] # Forward encoder self.vit.config.mask_ratio = 0 # disable masking @@ -244,7 +248,8 @@ def predict_step( gdf = gpd.GeoDataFrame( data={ - "date": gpd.pd.to_datetime(arg=dates, format="%Y-%m-%d").astype( + "source_url": pd.Series(data=source_urls, dtype="string[pyarrow]"), + "date": pd.to_datetime(arg=dates, format="%Y-%m-%d").astype( dtype="date32[day][pyarrow]" ), "embeddings": pa.FixedShapeTensorArray.from_numpy_ndarray( @@ -261,12 +266,67 @@ def predict_step( ) gdf = gdf.to_crs(crs="OGC:CRS84") # reproject from UTM to lonlat coordinates - # Save embeddings in GeoParquet format + return gdf + + def on_predict_epoch_end(self) -> gpd.GeoDataFrame: + """ + Logic to gather all the results from one epoch in a prediction loop. + + This is where we combine all the vector embeddings generated from each + mini-batch prediction (stored in geopandas.GeoDataFrame tables) in a + row-wise manner. The embeddings are then output to GeoParquet files + according to their MGRS code. + + Steps: + 1. Concatenate all geopandas.GeoDataFrame objects row-wise + 2. Find all unique MGRS names in the big GeoDataFrame. + 3. Split the GeoDataFrame by MGRS name, and output each MGRS-subsetted + table to a GeoParquet file. + ________________ ________________ ________________ + | | | | | | + | GeoDataFrame 1 | | GeoDataFrame 2 | | GeoDataFrame 3 | + | for MGRS 12ABC | | for MGRS 23DEF | | for MGRS 45GHI | + |________________| |________________| |________________| + | | | + v v v + 12ABC_v01.gpq 23DEF_v01.gpq 45GHI_v01.gpq + """ + # Combine list of geopandas.GeoDataFrame objects + results: list[gpd.GeoDataFrame] = self.trainer.predict_loop.predictions + if results: + gdf: gpd.GeoDataFrame = pd.concat( + objs=results, axis="index", ignore_index=True + ) + else: + print( + "No embeddings generated, " + f"possibly no GeoTIFF files in {self.trainer.datamodule.data_path}" + ) + return + + # Save embeddings in GeoParquet format, one file for each MGRS code outfolder: str = f"{self.trainer.default_root_dir}/data/embeddings" os.makedirs(name=outfolder, exist_ok=True) - outpath = f"{outfolder}/embeddings_{batch_idx}.gpq" - gdf.to_parquet(path=outpath, schema_version="1.0.0") - print(f"Saved embeddings of shape {tuple(embeddings_mean.shape)} to {outpath}") + + # Find unique MGRS names (e.g. '12ABC'), e.g. + # from 's3://.../.../claytile_12ABC_20201231_v02_0001.tif', get 12ABC + mgrs_codes = gdf.source_url.str.split("/").str[-1].str.split("_").str[1] + unique_mgrs_codes = mgrs_codes.unique() + for mgrs_code in unique_mgrs_codes: + if re.match(pattern=r"(\d{2}[A-Z]{3})", string=mgrs_code) is None: + raise ValueError( + "MGRS code should have 2 numbers and 3 letters (e.g. 12ABC), " + f"but got {mgrs_code} instead" + ) + + # Output to a GeoParquet filename like {MGRS:5}_v{VERSION:2}.gpq + outpath = f"{outfolder}/{mgrs_code}_v01.gpq" + _gdf: gpd.GeoDataFrame = gdf.loc[mgrs_codes == mgrs_code] + _gdf.to_parquet(path=outpath, schema_version="1.0.0", compression="ZSTD") + print( + f"Saved {len(_gdf)} rows of embeddings of " + f"shape {gdf.embeddings.iloc[0].shape} to {outpath}" + ) return gdf diff --git a/src/tests/test_datamodule.py b/src/tests/test_datamodule.py index 84b3575f..0977ac19 100644 --- a/src/tests/test_datamodule.py +++ b/src/tests/test_datamodule.py @@ -65,6 +65,7 @@ def test_geotiffdatapipemodule(geotiff_folder, stage, dataloader): bbox = batch["bbox"] epsg = batch["epsg"] date = batch["date"] + source_url = batch["source_url"] assert image.shape == torch.Size([2, 3, 256, 256]) assert image.dtype == torch.float16 @@ -80,6 +81,10 @@ def test_geotiffdatapipemodule(geotiff_folder, stage, dataloader): actual=epsg, expected=torch.tensor(data=[32646, 32646], dtype=torch.int32) ) assert date == ["2022-12-31", "2023-12-31"] + assert source_url == [ + f"{geotiff_folder}/claytile-12ABC-2022-12-31-01-1.tif", + f"{geotiff_folder}/claytile-12ABC-2023-12-31-01-2.tif", + ] def test_geotiffdatapipemodule_list_from_s3_bucket(monkeypatch): diff --git a/src/tests/test_model.py b/src/tests/test_model.py index 42350855..ee17ecfe 100644 --- a/src/tests/test_model.py +++ b/src/tests/test_model.py @@ -27,15 +27,21 @@ def fixture_datapipe() -> torchdata.datapipes.iter.IterDataPipe: datapipe = torchdata.datapipes.iter.IterableWrapper( iterable=[ { - "image": torch.randn(2, 13, 512, 512).to(dtype=torch.float16), + "image": torch.randn(3, 13, 512, 512).to(dtype=torch.float16), "bbox": torch.tensor( data=[ [499975.0, 3397465.0, 502535.0, 3400025.0], [530695.0, 3397465.0, 533255.0, 3400025.0], + [561415.0, 3397465.0, 563975.0, 3400025.0], ] ), - "date": ["2020-01-01", "2020-12-31"], - "epsg": torch.tensor(data=[32646, 32646]), + "date": ["2020-01-01", "2020-12-31", "2020-12-31"], + "epsg": torch.tensor(data=[32760, 32760, 32760]), + "source_url": [ + "s3://claytile_60HTE_1.tif", + "s3://claytile_60GUV_2.tif", + "s3://claytile_60GUV_3.tif", + ], }, ] ) @@ -67,11 +73,18 @@ def test_model_vit(datapipe): # Prediction trainer.predict(model=model, dataloaders=dataloader) - assert os.path.exists(path := f"{tmpdirname}/data/embeddings/embeddings_0.gpq") + assert ( + len(os.listdir(path=f"{tmpdirname}/data/embeddings")) == 2 # noqa: PLR2004 + ) + assert os.path.exists(path := f"{tmpdirname}/data/embeddings/60HTE_v01.gpq") + assert os.path.exists(path := f"{tmpdirname}/data/embeddings/60GUV_v01.gpq") geodataframe: gpd.GeoDataFrame = gpd.read_parquet(path=path) - assert geodataframe.shape == (2, 3) - assert all(geodataframe.columns == ["date", "embeddings", "geometry"]) + assert geodataframe.shape == (2, 4) # 2 rows, 4 columns + assert all( + geodataframe.columns == ["source_url", "date", "embeddings", "geometry"] + ) + assert geodataframe.source_url.dtype == "string" assert geodataframe.date.dtype == "date32[day][pyarrow]" assert geodataframe.embeddings.dtype == "object" assert geodataframe.geometry.dtype == gpd.array.GeometryDtype()