Skip to content

Commit

Permalink
Issue #346 some more ProcessArgs porting
Browse files Browse the repository at this point in the history
for less boilerplate code and better error messages
  • Loading branch information
soxofaan committed Jan 21, 2025
1 parent a99837e commit 27506db
Showing 1 changed file with 23 additions and 28 deletions.
51 changes: 23 additions & 28 deletions openeo_driver/ProcessGraphDeserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,22 +936,18 @@ def save_result(args: Dict, env: EvalEnv) -> SaveResult: # TODO: return type no

@process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/save_ml_model.json"))
@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/save_ml_model.json"))
def save_ml_model(args: dict, env: EvalEnv) -> MlModelResult:
data: DriverMlModel = extract_arg(args, "data", process_id="save_ml_model")
if not isinstance(data, DriverMlModel):
raise ProcessParameterInvalidException(
parameter="data", process="save_ml_model", reason=f"Invalid data type {type(data)!r} expected raster-cube."
)
options = args.get("options", {})
def save_ml_model(args: ProcessArgs, env: EvalEnv) -> MlModelResult:
data = args.get_required("data", expected_type=DriverMlModel)
options = args.get_optional("options", default={}, expected_type=dict)
return MlModelResult(ml_model=data, options=options)


@process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/load_ml_model.json"))
@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/load_ml_model.json"))
def load_ml_model(args: dict, env: EvalEnv) -> DriverMlModel:
def load_ml_model(args: ProcessArgs, env: EvalEnv) -> DriverMlModel:
if env.get(ENV_DRY_RUN_TRACER):
return DriverMlModel()
job_id = extract_arg(args, "id")
job_id = args.get_required("id", expected_type=str)
return env.backend_implementation.load_ml_model(job_id)


Expand Down Expand Up @@ -1430,17 +1426,17 @@ def filter_temporal(args: dict, env: EvalEnv) -> DriverDataCube:
extent = _extract_temporal_extent(args, field="extent", process_id="filter_temporal")
return cube.filter_temporal(start=extent[0], end=extent[1])


@process_registry_100.add_function(spec=read_spec("openeo-processes/1.x/proposals/filter_labels.json"))
@process_registry_2xx.add_function(spec=read_spec("openeo-processes/2.x/proposals/filter_labels.json"))
def filter_labels(args: dict, env: EvalEnv) -> DriverDataCube:
cube = extract_arg(args, 'data')
if not isinstance(cube, DriverDataCube):
raise ProcessParameterInvalidException(
parameter="data", process="filter_labels",
reason=f"Invalid data type {type(cube)!r} expected cube."
)
def filter_labels(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:
cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube)
# TODO: validation that condition is a process graph construct
condition = args.get_required("condition", expected_type=dict)
dimension = args.get_required("dimension", expected_type=str)
context = args.get_optional("context", default=None)
return cube.filter_labels(condition=condition, dimension=dimension, context=context, env=env)

return cube.filter_labels(condition=extract_arg(args,"condition"),dimension=extract_arg(args,"dimension"),context=args.get("context",None),env=env)

def _extract_bbox_extent(args: dict, field="extent", process_id="filter_bbox", handle_geojson=False) -> dict:
extent = extract_arg(args, name=field, process_id=process_id)
Expand Down Expand Up @@ -1505,13 +1501,11 @@ def filter_spatial(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:


@process
def filter_bands(args: Dict, env: EvalEnv) -> Union[DriverDataCube, DriverVectorCube]:
cube: Union[DriverDataCube, DriverVectorCube] = extract_arg(args, "data")
if not isinstance(cube, DriverDataCube) and not isinstance(cube, DriverVectorCube):
raise ProcessParameterInvalidException(
parameter="data", process="filter_bands", reason=f"Invalid data type {type(cube)!r} expected raster-cube."
)
bands = extract_arg(args, "bands", process_id="filter_bands")
def filter_bands(args: ProcessArgs, env: EvalEnv) -> Union[DriverDataCube, DriverVectorCube]:
cube: Union[DriverDataCube, DriverVectorCube] = args.get_required(
"data", expected_type=(DriverDataCube, DriverVectorCube)
)
bands = args.get_required("bands", expected_type=list)
return cube.filter_bands(bands=bands)


Expand Down Expand Up @@ -2325,10 +2319,11 @@ def load_result(args: ProcessArgs, env: EvalEnv) -> DriverDataCube:

@process_registry_100.add_function(spec=read_spec("openeo-processes/1.x/proposals/inspect.json"))
@process_registry_2xx.add_function(spec=read_spec("openeo-processes/2.x/proposals/inspect.json"))
def inspect(args: dict, env: EvalEnv):
data = extract_arg(args, "data")
message = args.get("message", "")
level = args.get("level", "info")
def inspect(args: ProcessArgs, env: EvalEnv):
data = args.get_required("data")
message = args.get_optional("message", default="")
code = args.get_optional("code", default="User")
level = args.get_optional("level", default="info")
if message:
_log.log(level=logging.getLevelName(level.upper()), msg=message)
data_message = str(data)
Expand Down

0 comments on commit 27506db

Please sign in to comment.