Skip to content

Commit e4cc57d

Browse files
committed
Refactor _validate_data_input
1 parent 3d4baa6 commit e4cc57d

File tree

1 file changed

+59
-34
lines changed

1 file changed

+59
-34
lines changed

pygmt/helpers/utils.py

+59-34
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pathlib import Path
1616
from typing import Any, Literal
1717

18+
import numpy as np
1819
import xarray as xr
1920
from pygmt.encodings import charset
2021
from pygmt.exceptions import GMTInvalidInput
@@ -39,10 +40,20 @@
3940
"ISO-8859-15",
4041
"ISO-8859-16",
4142
]
43+
# Type hints for the list of possible data kinds.
44+
Kind = Literal[
45+
"arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors"
46+
]
4247

4348

4449
def _validate_data_input(
45-
data=None, x=None, y=None, z=None, required_z=False, required_data=True, kind=None
50+
data=None,
51+
x=None,
52+
y=None,
53+
z=None,
54+
required_z: bool = False,
55+
required_data: bool = True,
56+
kind: Kind | None = None,
4657
) -> None:
4758
"""
4859
Check if the combination of data/x/y/z is valid.
@@ -76,23 +87,23 @@ def _validate_data_input(
7687
>>> _validate_data_input(data=data, required_z=True, kind="matrix")
7788
Traceback (most recent call last):
7889
...
79-
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
90+
pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given.
8091
>>> _validate_data_input(
8192
... data=pd.DataFrame(data, columns=["x", "y"]),
8293
... required_z=True,
8394
... kind="vectors",
8495
... )
8596
Traceback (most recent call last):
8697
...
87-
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
98+
pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given.
8899
>>> _validate_data_input(
89100
... data=xr.Dataset(pd.DataFrame(data, columns=["x", "y"])),
90101
... required_z=True,
91102
... kind="vectors",
92103
... )
93104
Traceback (most recent call last):
94105
...
95-
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
106+
pygmt.exceptions.GMTInvalidInput: Need at least 3 columns but 2 column(s) are given.
96107
>>> _validate_data_input(data="infile", x=[1, 2, 3])
97108
Traceback (most recent call last):
98109
...
@@ -115,34 +126,52 @@ def _validate_data_input(
115126
GMTInvalidInput
116127
If the data input is not valid.
117128
"""
118-
if data is None: # data is None
119-
if x is None and y is None: # both x and y are None
120-
if required_data: # data is not optional
129+
# Check if too much data is provided.
130+
if data is not None and any(v is not None for v in (x, y, z)):
131+
msg = "Too much data. Use either data or x/y/z."
132+
raise GMTInvalidInput(msg)
133+
134+
# Determine the required number of columns based on the required_z flag.
135+
required_cols = 3 if required_z else 2
136+
137+
# Determine the data kind if not provided.
138+
kind = kind or data_kind(data, required=required_data)
139+
140+
# Check based on the data kind.
141+
match kind:
142+
case "empty": # data is given via a series vectors like x/y/z.
143+
if x is None and y is None:
121144
msg = "No input data provided."
122145
raise GMTInvalidInput(msg)
123-
elif x is None or y is None: # either x or y is None
124-
msg = "Must provide both x and y."
125-
raise GMTInvalidInput(msg)
126-
if required_z and z is None: # both x and y are not None, now check z
127-
msg = "Must provide x, y, and z."
128-
raise GMTInvalidInput(msg)
129-
else: # data is not None
130-
if x is not None or y is not None or z is not None:
131-
msg = "Too much data. Use either data or x/y/z."
132-
raise GMTInvalidInput(msg)
133-
# check if data has the required z column
134-
if required_z:
135-
msg = "data must provide x, y, and z columns."
136-
if kind == "matrix" and data.shape[1] < 3:
146+
if x is None or y is None:
147+
msg = "Must provide both x and y."
148+
raise GMTInvalidInput(msg)
149+
if required_z and z is None:
150+
msg = "Must provide x, y, and z."
151+
raise GMTInvalidInput(msg)
152+
case "matrix": # 2-D numpy.ndarray
153+
if (actual_cols := data.shape[1]) < required_cols:
154+
msg = (
155+
f"Need at least {required_cols} columns but {actual_cols} column(s) "
156+
"are given."
157+
)
158+
raise GMTInvalidInput(msg)
159+
case "vectors":
160+
# The if-else block should match the codes in the virtualfile_in function.
161+
if hasattr(data, "items") and not hasattr(data, "to_frame"):
162+
# Dict, pandas.DataFrame, or xarray.Dataset, but not pd.Series.
163+
_data = [array for _, array in data.items()]
164+
else:
165+
# Python list, tuple, numpy.ndarray, and pandas.Series types
166+
_data = np.atleast_2d(np.asanyarray(data).T)
167+
168+
# Check if the number of columns is sufficient.
169+
if (actual_cols := len(_data)) < required_cols:
170+
msg = (
171+
f"Need at least {required_cols} columns but {actual_cols} "
172+
"column(s) are given."
173+
)
137174
raise GMTInvalidInput(msg)
138-
if kind == "vectors":
139-
if hasattr(data, "shape") and (
140-
(len(data.shape) == 1 and data.shape[0] < 3)
141-
or (len(data.shape) > 1 and data.shape[1] < 3)
142-
): # np.ndarray or pd.DataFrame
143-
raise GMTInvalidInput(msg)
144-
if hasattr(data, "data_vars") and len(data.data_vars) < 3: # xr.Dataset
145-
raise GMTInvalidInput(msg)
146175

147176

148177
def _is_printable_ascii(argstr: str) -> bool:
@@ -261,11 +290,7 @@ def _check_encoding(argstr: str) -> Encoding:
261290
return "ISOLatin1+"
262291

263292

264-
def data_kind(
265-
data: Any, required: bool = True
266-
) -> Literal[
267-
"arg", "empty", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors"
268-
]:
293+
def data_kind(data: Any, required: bool = True) -> Kind:
269294
r"""
270295
Check the kind of data that is provided to a module.
271296

0 commit comments

Comments
 (0)