Skip to content

Commit

Permalink
Issue #346 Some more ProcessArgs porting
Browse files Browse the repository at this point in the history
for less boilerplate code and better/earlier error messages
  • Loading branch information
soxofaan committed Dec 17, 2024
1 parent cff6869 commit f5e17b8
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 116 deletions.
187 changes: 73 additions & 114 deletions openeo_driver/ProcessGraphDeserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import time
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union, Sequence
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union, Sequence, Optional

import geopandas as gpd
import numpy as np
Expand Down Expand Up @@ -140,7 +140,7 @@ def wrapped(args: dict, env: EvalEnv):

# Type hint alias for a "process function":
# a Python function that implements some openEO process (as used in `apply_process`)
ProcessFunction = Callable[[dict, EvalEnv], Any]
ProcessFunction = Callable[[Union[dict, ProcessArgs], EvalEnv], Any]


def process(f: ProcessFunction) -> ProcessFunction:
Expand Down Expand Up @@ -750,14 +750,15 @@ def load_collection(args: dict, env: EvalEnv) -> DriverDataCube:
.param(name='options', description="options specific to the file format", schema={"type": "object"})
.returns(description="the data as a data cube", schema={})
)
def load_disk_data(args: Dict, env: EvalEnv) -> DriverDataCube:
def load_disk_data(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
"""
Deprecated, use load_uploaded_files or load_stac
"""
_log.warning("Deprecated: usage of load_disk_data")
kwargs = dict(
glob_pattern=extract_arg(args, 'glob_pattern'),
format=extract_arg(args, 'format'),
options=args.get('options', {}),
glob_pattern=args.get_required("glob_pattern", expected_type=str),
format=args.get_required("format", expected_type=str),
options=args.get_optional("options", default={}, expected_type=dict),
)
dry_run_tracer: DryRunDataTracer = env.get(ENV_DRY_RUN_TRACER)
if dry_run_tracer:
Expand Down Expand Up @@ -916,22 +917,18 @@ def save_result(args: Dict, env: EvalEnv) -> SaveResult: # TODO: return type no

@process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/save_ml_model.json"))
@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/save_ml_model.json"))
def save_ml_model(args: dict, env: EvalEnv) -> MlModelResult:
data: DriverMlModel = extract_arg(args, "data", process_id="save_ml_model")
if not isinstance(data, DriverMlModel):
raise ProcessParameterInvalidException(
parameter="data", process="save_ml_model", reason=f"Invalid data type {type(data)!r} expected raster-cube."
)
options = args.get("options", {})
def save_ml_model(args: ProcessArgs, env: EvalEnv) -> MlModelResult:
data = args.get_required("data", expected_type=DriverMlModel)
options = args.get_optional("options", default={}, expected_type=dict)
return MlModelResult(ml_model=data, options=options)


@process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/load_ml_model.json"))
@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/load_ml_model.json"))
def load_ml_model(args: dict, env: EvalEnv) -> DriverMlModel:
def load_ml_model(args: ProcessArgs, env: EvalEnv) -> DriverMlModel:
if env.get(ENV_DRY_RUN_TRACER):
return DriverMlModel()
job_id = extract_arg(args, "id")
job_id = args.get_required("id", expected_type=str)
return env.backend_implementation.load_ml_model(job_id)


Expand Down Expand Up @@ -1138,19 +1135,19 @@ def get_validated_parameter(args, param_name, default_value, expected_type, min_

@process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/predict_random_forest.json"))
@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/predict_random_forest.json"))
def predict_random_forest(args: dict, env: EvalEnv):
def predict_random_forest(args: ProcessArgs, env: EvalEnv):
raise NotImplementedError


@process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/predict_catboost.json"))
@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/predict_catboost.json"))
def predict_catboost(args: dict, env: EvalEnv):
def predict_catboost(args: ProcessArgs, env: EvalEnv):
raise NotImplementedError


@process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/predict_probabilities.json"))
@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/predict_probabilities.json"))
def predict_probabilities(args: dict, env: EvalEnv):
def predict_probabilities(args: ProcessArgs, env: EvalEnv):
raise NotImplementedError


Expand All @@ -1165,51 +1162,34 @@ def add_dimension(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:


@process
def drop_dimension(args: dict, env: EvalEnv) -> DriverDataCube:
data_cube = extract_arg(args, 'data')
if not isinstance(data_cube, DriverDataCube):
raise ProcessParameterInvalidException(
parameter="data", process="drop_dimension",
reason=f"Invalid data type {type(data_cube)!r} expected raster-cube."
)
return data_cube.drop_dimension(name=extract_arg(args, 'name'))
def drop_dimension(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
name: str = args.get_required("name", expected_type=str)
return cube.drop_dimension(name=name)


@process
def dimension_labels(args: dict, env: EvalEnv) -> DriverDataCube:
data_cube = extract_arg(args, 'data')
if not isinstance(data_cube, DriverDataCube):
raise ProcessParameterInvalidException(
parameter="data", process="dimension_labels",
reason=f"Invalid data type {type(data_cube)!r} expected raster-cube."
)
return data_cube.dimension_labels(dimension=extract_arg(args, 'dimension'))
def dimension_labels(args: ProcessArgs, env: EvalEnv) -> List[str]:
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
dimension: str = args.get_required("dimension", expected_type=str)
return cube.dimension_labels(dimension=dimension)


@process
def rename_dimension(args: dict, env: EvalEnv) -> DriverDataCube:
data_cube = extract_arg(args, 'data')
if not isinstance(data_cube, DriverDataCube):
raise ProcessParameterInvalidException(
parameter="data", process="rename_dimension",
reason=f"Invalid data type {type(data_cube)!r} expected raster-cube."
)
return data_cube.rename_dimension(source=extract_arg(args, 'source'),target=extract_arg(args, 'target'))
def rename_dimension(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
source: str = args.get_required("source", expected_type=str)
target: str = args.get_required("target", expected_type=str)
return cube.rename_dimension(source=source, target=target)


@process
def rename_labels(args: dict, env: EvalEnv) -> DriverDataCube:
data_cube = extract_arg(args, 'data')
if not isinstance(data_cube, DriverDataCube):
raise ProcessParameterInvalidException(
parameter="data", process="rename_labels",
reason=f"Invalid data type {type(data_cube)!r} expected raster-cube."
)
return data_cube.rename_labels(
dimension=extract_arg(args, 'dimension'),
target=extract_arg(args, 'target'),
source=args.get('source',[])
)
def rename_labels(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
dimension: str = args.get_required("dimension", expected_type=str)
target: List[str] = args.get_required("target", expected_type=list)
source: Optional[str] = args.get_optional("source", default=None)
return cube.rename_labels(dimension=dimension, target=target, source=source)


@process
Expand Down Expand Up @@ -1355,14 +1335,10 @@ def aggregate_spatial(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:


@process
def mask(args: dict, env: EvalEnv) -> DriverDataCube:
cube = extract_arg(args, 'data')
if not isinstance(cube, DriverDataCube):
raise ProcessParameterInvalidException(
parameter="data", process="mask", reason=f"Invalid data type {type(cube)!r} expected raster-cube."
)
mask = extract_arg(args, 'mask')
replacement = args.get('replacement', None)
def mask(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
mask = args.get_required("mask")
replacement = args.get_optional("replacement", default=None)
return cube.mask(mask=mask, replacement=replacement)


Expand Down Expand Up @@ -1394,7 +1370,10 @@ def mask_polygon(args: dict, env: EvalEnv) -> DriverDataCube:
return image_collection


def _extract_temporal_extent(args: dict, field="extent", process_id="filter_temporal") -> Tuple[str, str]:
def _extract_temporal_extent(
args: Union[dict, ProcessArgs], field="extent", process_id="filter_temporal"
) -> Tuple[str, str]:
# TODO: make this a ProcessArgs method?
extent = extract_arg(args, name=field, process_id=process_id)
if len(extent) != 2:
raise ProcessParameterInvalidException(
Expand All @@ -1419,29 +1398,27 @@ def _extract_temporal_extent(args: dict, field="extent", process_id="filter_temp


@process
def filter_temporal(args: dict, env: EvalEnv) -> DriverDataCube:
cube = extract_arg(args, 'data')
if not isinstance(cube, DriverDataCube):
raise ProcessParameterInvalidException(
parameter="data", process="filter_temporal",
reason=f"Invalid data type {type(cube)!r} expected raster-cube."
)
def filter_temporal(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
extent = _extract_temporal_extent(args, field="extent", process_id="filter_temporal")
return cube.filter_temporal(start=extent[0], end=extent[1])


@process_registry_100.add_function(spec=read_spec("openeo-processes/1.x/proposals/filter_labels.json"))
@process_registry_2xx.add_function(spec=read_spec("openeo-processes/2.x/proposals/filter_labels.json"))
def filter_labels(args: dict, env: EvalEnv) -> DriverDataCube:
cube = extract_arg(args, 'data')
if not isinstance(cube, DriverDataCube):
raise ProcessParameterInvalidException(
parameter="data", process="filter_labels",
reason=f"Invalid data type {type(cube)!r} expected cube."
)
def filter_labels(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
# TODO: validation that condition is a process graph construct
condition = args.get_required("condition", expected_type=dict)
dimension = args.get_required("dimension", expected_type=str)
context = args.get_optional("context", default=None)
return cube.filter_labels(condition=condition, dimension=dimension, context=context, env=env)

return cube.filter_labels(condition=extract_arg(args,"condition"),dimension=extract_arg(args,"dimension"),context=args.get("context",None),env=env)

def _extract_bbox_extent(args: dict, field="extent", process_id="filter_bbox", handle_geojson=False) -> dict:
def _extract_bbox_extent(
args: Union[dict, ProcessArgs], field="extent", process_id="filter_bbox", handle_geojson=False
) -> dict:
# TODO: make this a ProcessArgs method?
extent = extract_arg(args, name=field, process_id=process_id)
if handle_geojson and extent.get("type") in [
"Polygon",
Expand All @@ -1466,24 +1443,16 @@ def _extract_bbox_extent(args: dict, field="extent", process_id="filter_bbox", h


@process
def filter_bbox(args: Dict, env: EvalEnv) -> DriverDataCube:
cube = extract_arg(args, 'data')
if not isinstance(cube, DriverDataCube):
raise ProcessParameterInvalidException(
parameter="data", process="filter_bbox", reason=f"Invalid data type {type(cube)!r} expected raster-cube."
)
def filter_bbox(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
spatial_extent = _extract_bbox_extent(args, "extent", process_id="filter_bbox")
return cube.filter_bbox(**spatial_extent)


@process
def filter_spatial(args: Dict, env: EvalEnv) -> DriverDataCube:
cube = extract_arg(args, 'data')
geometries = extract_arg(args, 'geometries')
if not isinstance(cube, DriverDataCube):
raise ProcessParameterInvalidException(
parameter="data", process="filter_spatial", reason=f"Invalid data type {type(cube)!r} expected raster-cube."
)
def filter_spatial(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
geometries = args.get_required("geometries")

if isinstance(geometries, dict):
if "type" in geometries and geometries["type"] != "GeometryCollection":
Expand Down Expand Up @@ -1512,32 +1481,22 @@ def filter_spatial(args: Dict, env: EvalEnv) -> DriverDataCube:


@process
def filter_bands(args: Dict, env: EvalEnv) -> Union[DriverDataCube, DriverVectorCube]:
cube: Union[DriverDataCube, DriverVectorCube] = extract_arg(args, "data")
if not isinstance(cube, DriverDataCube) and not isinstance(cube, DriverVectorCube):
raise ProcessParameterInvalidException(
parameter="data", process="filter_bands", reason=f"Invalid data type {type(cube)!r} expected raster-cube."
)
bands = extract_arg(args, "bands", process_id="filter_bands")
def filter_bands(args: ProcessArgs, env: EvalEnv) -> Union[DriverDataCube, DriverVectorCube]:
cube: Union[DriverDataCube, DriverVectorCube] = args.get_required(
"data", expected_type=(DriverDataCube, DriverVectorCube)
)
bands = args.get_required("bands", expected_type=list)
return cube.filter_bands(bands=bands)


@process
def apply_kernel(args: Dict, env: EvalEnv) -> DriverDataCube:
image_collection = extract_arg(args, 'data')
kernel = np.asarray(extract_arg(args, 'kernel'))
factor = args.get('factor', 1.0)
border = args.get('border', 0)
if not isinstance(image_collection, DriverDataCube):
raise ProcessParameterInvalidException(
parameter="data", process="apply_kernel",
reason=f"Invalid data type {type(image_collection)!r} expected raster-cube."
)
if border == "0":
# R-client sends `0` border as a string
border = 0
replace_invalid = args.get('replace_invalid', 0)
return image_collection.apply_kernel(kernel=kernel, factor=factor, border=border, replace_invalid=replace_invalid)
def apply_kernel(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
kernel = np.asarray(args.get_required("kernel", expected_type=list))
factor = args.get_optional("factor", default=1.0, expected_type=(int, float))
border = args.get_optional("border", default=0, expected_type=int)
replace_invalid = args.get_optional("replace_invalid", default=0, expected_type=(int, float))
return cube.apply_kernel(kernel=kernel, factor=factor, border=border, replace_invalid=replace_invalid)


@process
Expand Down
6 changes: 4 additions & 2 deletions openeo_driver/datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,15 @@ def filter_spatial(self, geometries) -> 'DriverDataCube':
def filter_bands(self, bands) -> 'DriverDataCube':
self._not_implemented()

def filter_labels(self, condition: dict,dimensin: str, context: Optional[dict] = None, env: EvalEnv = None ) -> 'DriverDataCube':
def filter_labels(
self, condition: dict, dimension: str, context: Optional[dict] = None, env: EvalEnv = None
) -> "DriverDataCube":
self._not_implemented()

def apply(self, process: dict, *, context: Optional[dict] = None, env: EvalEnv) -> "DriverDataCube":
self._not_implemented()

def apply_kernel(self, kernel: list, factor=1, border=0, replace_invalid=0) -> 'DriverDataCube':
def apply_kernel(self, kernel: numpy.ndarray, factor=1, border=0, replace_invalid=0) -> "DriverDataCube":
self._not_implemented()

def apply_neighborhood(
Expand Down
3 changes: 3 additions & 0 deletions openeo_driver/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Union

from openeo_driver.datacube import DriverDataCube
from openeo_driver.errors import (
OpenEOApiException,
ProcessParameterInvalidException,
Expand Down Expand Up @@ -325,6 +326,8 @@ def _check_value(
):
if expected_type:
if not isinstance(value, expected_type):
if expected_type is DriverDataCube:
expected_type = "raster cube"
raise ProcessParameterInvalidException(
parameter=name, process=self.process_id, reason=f"Expected {expected_type} but got {type(value)}."
)
Expand Down

0 comments on commit f5e17b8

Please sign in to comment.