diff --git a/src/model_clay.py b/src/model_clay.py index 7b9dc134..ea2b9397 100644 --- a/src/model_clay.py +++ b/src/model_clay.py @@ -792,6 +792,7 @@ def __init__( # noqa: PLR0913 b2=0.95, embeddings_level: Literal["mean", "patch", "group"] = "mean", band_groups=None, + save_loss: Literal["False", "Input", "Patch"] = "Patch", ): super().__init__() self.save_hyperparameters(logger=True) @@ -925,6 +926,25 @@ def predict_step( assert embeddings_output.shape == torch.Size(expected_size) + # Calculate the loss for the current input if save_loss is "Input" or "Patch" + if self.hparams.save_loss in ["Input", "Patch"]: + loss = self(batch) + else: + loss = None + + # Calculate the loss for each patch if save_loss is "Patch" + if self.hparams.save_loss == "Patch": + patch_losses = self.model.per_pixel_loss( + batch["pixels"], + self.model.decoder( + outputs_encoder[0], outputs_encoder[1], outputs_encoder[2] + ), + outputs_encoder[3], + ) + patch_losses = patch_losses.detach().cpu().numpy() + else: + patch_losses = None + # Create table to store the embeddings with spatiotemporal metadata unique_epsg_codes = set(int(epsg) for epsg in epsgs) if len(unique_epsg_codes) == 1: # check that there's only 1 unique EPSG @@ -934,16 +954,25 @@ def predict_step( f"More than 1 EPSG code detected: {unique_epsg_codes}" ) + data = { + "source_url": pd.Series(data=source_urls, dtype="string[pyarrow]"), + "date": pd.to_datetime(arg=dates, format="%Y-%m-%d").astype( + dtype="date32[day][pyarrow]" + ), + "embeddings": pa.FixedShapeTensorArray.from_numpy_ndarray( + np.ascontiguousarray(embeddings_output.cpu().detach().__array__()) + ), + } + + if loss is not None: + data["loss"] = loss.item() + if patch_losses is not None: + data["patch_losses"] = pa.FixedSizeListArray.from_arrays( + np.ascontiguousarray(patch_losses), 256 + ) + gdf = gpd.GeoDataFrame( - data={ - "source_url": pd.Series(data=source_urls, dtype="string[pyarrow]"), - "date": pd.to_datetime(arg=dates, format="%Y-%m-%d").astype( - dtype="date32[day][pyarrow]" - ), - "embeddings": pa.FixedShapeTensorArray.from_numpy_ndarray( - np.ascontiguousarray(embeddings_output.cpu().detach().__array__()) - ), - }, + data=data, geometry=shapely.box( xmin=bboxes[:, 0], ymin=bboxes[:, 1],