diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index f5b2c26d684..cf6d557f4a6 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -1836,23 +1836,25 @@ def virtualfile_in( # noqa: PLR0912 ) ncols = 3 + # Specify either data or x/y/z. + if data is not None and any(v is not None for v in (x, y, z)): + msg = "Too much data. Use either data or x/y/z." + raise GMTInvalidInput(msg) + + # Determine the kind of data. kind = data_kind(data, required=required_data) - _validate_data_input( - data=data, - x=x, - y=y, - z=z, - ncols=ncols, - required_data=required_data, - kind=kind, - ) + # Check if the kind of data is valid. if check_kind: valid_kinds = ("file", "arg") if required_data is False else ("file",) - if check_kind == "raster": - valid_kinds += ("grid", "image") - elif check_kind == "vector": - valid_kinds += ("empty", "matrix", "vectors", "geojson") + match check_kind: + case "raster": + valid_kinds += ("grid", "image") + case "vector": + valid_kinds += ("empty", "matrix", "vectors", "geojson") + case _: + msg = f"Invalid value for check_kind: '{check_kind}'." + raise GMTInvalidInput(msg) if kind not in valid_kinds: msg = f"Unrecognized data type for {check_kind}: {type(data)}." raise GMTInvalidInput(msg) @@ -1886,6 +1888,7 @@ def virtualfile_in( # noqa: PLR0912 _data = [x, y] if z is not None: _data.append(z) + # TODO(PyGMT>=0.20.0): Remove the deprecated parameter 'extra_arrays'. if extra_arrays: msg = ( "The parameter 'extra_arrays' will be removed in v0.20.0. " @@ -1911,6 +1914,9 @@ def virtualfile_in( # noqa: PLR0912 _virtualfile_from = self.virtualfile_from_vectors _data = data.T + # Check if _data to be passed to the virtualfile_from_ function is valid. + _validate_data_input(data=_data, kind=kind, ncols=ncols) + # Finally create the virtualfile from the data, to be passed into GMT file_context = _virtualfile_from(_data) return file_context diff --git a/pygmt/helpers/utils.py b/pygmt/helpers/utils.py index 5c4fa66c614..002c3274638 100644 --- a/pygmt/helpers/utils.py +++ b/pygmt/helpers/utils.py @@ -40,36 +40,42 @@ "ISO-8859-15", "ISO-8859-16", ] +# Type hints for the list of possible data kinds. +Kind = Literal[ + "arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors" +] -def _validate_data_input( # noqa: PLR0912 - data=None, x=None, y=None, z=None, ncols=2, required_data=True, kind=None -) -> None: +def _validate_data_input(data: Any, kind: Kind, ncols=2) -> None: """ - Check if the combination of data/x/y/z is valid. + Check if the data to be passed to the virtualfile_from_ functions is valid. Examples -------- - >>> _validate_data_input(data="infile") - >>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6]) - >>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], z=[7, 8, 9]) - >>> _validate_data_input(data=None, required_data=False) - >>> _validate_data_input() + The "empty" kind means the data is given via a series of vectors like x/y/z. + + >>> _validate_data_input(data=[[1, 2, 3], [4, 5, 6]], kind="empty") + >>> _validate_data_input(data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]], kind="empty") + >>> _validate_data_input(data=[None, [4, 5, 6]], kind="empty") Traceback (most recent call last): ... - pygmt.exceptions.GMTInvalidInput: No input data provided. - >>> _validate_data_input(x=[1, 2, 3]) + pygmt.exceptions.GMTInvalidInput: Must provide both x and y. + >>> _validate_data_input(data=[[1, 2, 3], None], kind="empty") Traceback (most recent call last): ... pygmt.exceptions.GMTInvalidInput: Must provide both x and y. - >>> _validate_data_input(y=[4, 5, 6]) + >>> _validate_data_input(data=[None, None], kind="empty") Traceback (most recent call last): ... pygmt.exceptions.GMTInvalidInput: Must provide both x and y. + >>> _validate_data_input(data=[[1, 2, 3], [4, 5, 6]], kind="empty", ncols=3) >>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], ncols=3) Traceback (most recent call last): ... pygmt.exceptions.GMTInvalidInput: Must provide x, y, and z. + + The "matrix" kind means the data is given via a 2-D numpy.ndarray. + >>> import numpy as np >>> import pandas as pd >>> import xarray as xr @@ -77,7 +83,11 @@ def _validate_data_input( # noqa: PLR0912 >>> _validate_data_input(data=data, ncols=3, kind="matrix") Traceback (most recent call last): ... - pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns. + pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given. + + The "vectors" kind means the original data is either dictionary, list, tuple, + pandas.DataFrame, pandas.Series, xarray.Dataset, or xarray.DataArray. + >>> _validate_data_input( ... data=pd.DataFrame(data, columns=["x", "y"]), ... ncols=3, @@ -85,7 +95,7 @@ def _validate_data_input( # noqa: PLR0912 ... ) Traceback (most recent call last): ... - pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns. + pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given. >>> _validate_data_input( ... data=xr.Dataset(pd.DataFrame(data, columns=["x", "y"])), ... ncols=3, @@ -93,49 +103,32 @@ def _validate_data_input( # noqa: PLR0912 ... ) Traceback (most recent call last): ... - pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns. - >>> _validate_data_input(data="infile", x=[1, 2, 3]) - Traceback (most recent call last): - ... - pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z. - >>> _validate_data_input(data="infile", y=[4, 5, 6]) - Traceback (most recent call last): - ... - pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z. - >>> _validate_data_input(data="infile", x=[1, 2, 3], y=[4, 5, 6]) - Traceback (most recent call last): - ... - pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z. - >>> _validate_data_input(data="infile", z=[7, 8, 9]) - Traceback (most recent call last): - ... - pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z. + pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given. Raises ------ GMTInvalidInput If the data input is not valid. """ - required_z = ncols >= 3 - if data is None: # data is None - if x is None and y is None: # both x and y are None - if required_data: # data is not optional - msg = "No input data provided." + match kind: + case "empty": # data = [x, y], [x, y, z], [x, y, z, ...] + if len(data) < 2 or any(v is None for v in data[:2]): + msg = "Must provide both x and y." raise GMTInvalidInput(msg) - elif x is None or y is None: # either x or y is None - msg = "Must provide both x and y." - raise GMTInvalidInput(msg) - if required_z and z is None: # both x and y are not None, now check z - msg = "Must provide x, y, and z." - raise GMTInvalidInput(msg) - else: # data is not None - if x is not None or y is not None or z is not None: - msg = "Too much data. Use either data or x/y/z." - raise GMTInvalidInput(msg) - # check if data has the required z column - if required_z: - msg = "data must provide x, y, and z columns." - if kind == "matrix" and data.shape[1] < 3: + if ncols >= 3 and (len(data) < 3 or data[:3] is None): + msg = "Must provide x, y, and z." + raise GMTInvalidInput(msg) + case "matrix": # 2-D numpy.ndarray + if (actual_cols := data.shape[1]) < ncols: + msg = f"Need at least {ncols} columns but {actual_cols} column(s) are given." + raise GMTInvalidInput(msg) + case "vectors": + # "vectors" means the original data is either dictionary, list, tuple, + # pandas.DataFrame, pandas.Series, xarray.Dataset, or xarray.DataArray. + # The original data is converted to a list of vectors or a 2-D numpy.ndarray + # in the virtualfile_in function. + if (actual_cols := len(data)) < ncols: + msg = f"Need at least {ncols} columns but {actual_cols} column(s) are given." raise GMTInvalidInput(msg) if kind == "vectors": if hasattr(data, "shape") and ( @@ -145,15 +138,15 @@ def _validate_data_input( # noqa: PLR0912 raise GMTInvalidInput(msg) if hasattr(data, "data_vars") and len(data.data_vars) < 3: # xr.Dataset raise GMTInvalidInput(msg) - if kind == "vectors" and isinstance(data, dict): - # Iterator over the up-to-3 first elements. - arrays = list(islice(data.values(), 3)) - if len(arrays) < 2 or any(v is None for v in arrays[:2]): # Check x/y - msg = "Must provide x and y." - raise GMTInvalidInput(msg) - if required_z and (len(arrays) < 3 or arrays[2] is None): # Check z - msg = "Must provide x, y, and z." - raise GMTInvalidInput(msg) + if kind == "vectors" and isinstance(data, dict): + # Iterator over the up-to-3 first elements. + arrays = list(islice(data.values(), 3)) + if len(arrays) < 2 or any(v is None for v in arrays[:2]): # Check x/y + msg = "Must provide x and y." + raise GMTInvalidInput(msg) + if required_z and (len(arrays) < 3 or arrays[2] is None): # Check z + msg = "Must provide x, y, and z." + raise GMTInvalidInput(msg) def _is_printable_ascii(argstr: str) -> bool: @@ -272,11 +265,7 @@ def _check_encoding(argstr: str) -> Encoding: return "ISOLatin1+" -def data_kind( - data: Any, required: bool = True -) -> Literal[ - "arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors" -]: +def data_kind(data: Any, required: bool = True) -> Kind: r""" Check the kind of data that is provided to a module.