Skip to content

Commit 3a26b96

Browse files
committed
feat: infer datetime format for pyarrow
1 parent 3f186da commit 3a26b96

File tree

3 files changed

+131
-15
lines changed

3 files changed

+131
-15
lines changed

narwhals/_arrow/series.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from narwhals._arrow.utils import floordiv_compat
1313
from narwhals._arrow.utils import narwhals_to_native_dtype
1414
from narwhals._arrow.utils import native_to_narwhals_dtype
15+
from narwhals._arrow.utils import parse_datetime_format
1516
from narwhals._arrow.utils import validate_column_comparand
1617
from narwhals.utils import Implementation
1718
from narwhals.utils import generate_unique_token
@@ -1062,8 +1063,7 @@ def to_datetime(self: Self, format: str | None) -> ArrowSeries: # noqa: A002
10621063
import pyarrow.compute as pc # ignore-banned-import()
10631064

10641065
if format is None:
1065-
msg = "`format` is required for pyarrow backend."
1066-
raise ValueError(msg)
1066+
format = parse_datetime_format(self._arrow_series._native_series)
10671067

10681068
return self._arrow_series._from_native_series(
10691069
pc.strptime(self._arrow_series._native_series, format=format, unit="us")

narwhals/_arrow/utils.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,88 @@ def convert_str_slice_to_int_slice(
326326
stop = columns.index(str_slice.stop) + 1 if str_slice.stop is not None else None
327327
step = str_slice.step
328328
return (start, stop, step)
329+
330+
331+
# Regex for date, time, separator and timezone components
332+
DATE_RE = r"^(?P<date>\d{1,4}[-/.]\d{1,2}[-/.]\d{1,4})?"
333+
SEP_RE = r"(?P<sep>\s|T)?"
334+
TIME_RE = r"(?P<time>\d{2}:\d{2}:\d{2})?" # \s*(?P<period>[AP]M)?)?
335+
TZ_RE = r"(?P<tz>Z|[+-]\d{2}:?\d{2})?$" # Matches 'Z', '+02:00', '+0200', '+02', etc.
336+
FULL_RE = DATE_RE + SEP_RE + TIME_RE + TZ_RE
337+
338+
# Separate regexes for different date formats
339+
YMD_RE = r"^(?P<year>(?:[12][0-9])?[0-9]{2})(?P<sep1>[-/.])(?P<month>0[1-9]|1[0-2])(?P<sep2>[-/.])(?P<day>0[1-9]|[12][0-9]|3[01])$"
340+
DMY_RE = r"^(?P<day>0[1-9]|[12][0-9]|3[01])(?P<sep1>[-/.])(?P<month>0[1-9]|1[0-2])(?P<sep2>[-/.])(?P<year>(?:[12][0-9])?[0-9]{2})$"
341+
MDY_RE = r"^(?P<month>0[1-9]|1[0-2])(?P<sep1>[-/.])(?P<day>0[1-9]|[12][0-9]|3[01])(?P<sep2>[-/.])(?P<year>(?:[12][0-9])?[0-9]{2})$"
342+
343+
DATE_FORMATS = (
344+
(YMD_RE, "%Y-%m-%d"),
345+
(DMY_RE, "%d-%m-%Y"),
346+
(MDY_RE, "%m-%d-%Y"),
347+
)
348+
349+
350+
def parse_datetime_format(arr: pa.StringArray) -> str:
351+
"""Try to infer datetime format from StringArray."""
352+
import pyarrow as pa # ignore-banned-import
353+
import pyarrow.compute as pc # ignore-banned-import
354+
355+
matches = pa.concat_arrays( # converts from ChunkedArray to StructArray
356+
pc.extract_regex(arr, pattern=FULL_RE).chunks
357+
)
358+
359+
if not pc.all(matches.is_valid()).as_py():
360+
msg = (
361+
"Unable to infer datetime format, provided format is not supported. "
362+
"Please report a bug to https://github.com/narwhals-dev/narwhals/issues"
363+
)
364+
raise NotImplementedError(msg)
365+
366+
dates = matches.field("date")
367+
separators = matches.field("sep")
368+
times = matches.field("time")
369+
tz = matches.field("tz")
370+
371+
# separators and time zones must be unique
372+
if pc.count(pc.unique(separators)).as_py() > 1:
373+
msg = "Found multiple separator values while inferring datetime format."
374+
raise ValueError(msg)
375+
376+
if pc.count(pc.unique(tz)).as_py() > 1:
377+
msg = "Found multiple timezone values while inferring datetime format."
378+
raise ValueError(msg)
379+
380+
date_value = _parse_date_format(dates)
381+
time_value = _parse_time_format(times)
382+
383+
sep_value = separators[0].as_py()
384+
tz_value = "%z" if tz[0].as_py() else ""
385+
386+
return f"{date_value}{sep_value}{time_value}{tz_value}"
387+
388+
389+
def _parse_date_format(arr: pa.Array) -> str:
390+
import pyarrow.compute as pc # ignore-banned-import
391+
392+
for date_rgx, date_fmt in DATE_FORMATS:
393+
matches = pc.extract_regex(arr, pattern=date_rgx)
394+
if (
395+
pc.all(matches.is_valid()).as_py()
396+
and pc.count(pc.unique(sep1 := matches.field("sep1"))).as_py() == 1
397+
and pc.count(pc.unique(sep2 := matches.field("sep2"))).as_py() == 1
398+
and (date_sep_value := sep1[0].as_py()) == sep2[0].as_py()
399+
):
400+
return date_fmt.replace("-", date_sep_value)
401+
402+
msg = (
403+
"Unable to infer datetime format. "
404+
"Please report a bug to https://github.com/narwhals-dev/narwhals/issues"
405+
)
406+
raise ValueError(msg)
407+
408+
409+
def _parse_time_format(arr: pa.Array) -> str:
410+
import pyarrow.compute as pc # ignore-banned-import
411+
412+
matches = pc.extract_regex(arr, pattern=TIME_RE)
413+
return "%H:%M:%S" if pc.all(matches.is_valid()).as_py() else ""

tests/expr_and_series/str/to_datetime_test.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
14
from typing import Any
25

6+
import pyarrow as pa
37
import pytest
48

59
import narwhals.stable.v1 as nw
6-
from tests.utils import Constructor
10+
from narwhals._arrow.utils import parse_datetime_format
11+
12+
if TYPE_CHECKING:
13+
from tests.utils import Constructor
714

815
data = {"a": ["2020-01-01T12:34:56"]}
916

@@ -38,12 +45,7 @@ def test_to_datetime_series(constructor_eager: Any) -> None:
3845
assert str(result) == expected
3946

4047

41-
def test_to_datetime_infer_fmt(
42-
request: pytest.FixtureRequest, constructor: Constructor
43-
) -> None:
44-
if "pyarrow_table" in str(constructor):
45-
request.applymarker(pytest.mark.xfail)
46-
48+
def test_to_datetime_infer_fmt(constructor: Constructor) -> None:
4749
if "cudf" in str(constructor): # pragma: no cover
4850
expected = "2020-01-01T12:34:56.000000000"
4951
else:
@@ -59,12 +61,7 @@ def test_to_datetime_infer_fmt(
5961
assert str(result) == expected
6062

6163

62-
def test_to_datetime_series_infer_fmt(
63-
request: pytest.FixtureRequest, constructor_eager: Any
64-
) -> None:
65-
if "pyarrow_table" in str(constructor_eager):
66-
request.applymarker(pytest.mark.xfail)
67-
64+
def test_to_datetime_series_infer_fmt(constructor_eager: Any) -> None:
6865
if "cudf" in str(constructor_eager): # pragma: no cover
6966
expected = "2020-01-01T12:34:56.000000000"
7067
else:
@@ -74,3 +71,37 @@ def test_to_datetime_series_infer_fmt(
7471
nw.from_native(constructor_eager(data), eager_only=True)["a"].str.to_datetime()
7572
).item(0)
7673
assert str(result) == expected
74+
75+
76+
@pytest.mark.parametrize("data", [["2024-01-01", "abc"], ["2024-01-01", None]])
77+
def test_pyarrow_infer_datetime_raise_invalid(data: list[str | None]) -> None:
78+
with pytest.raises(
79+
NotImplementedError,
80+
match="Unable to infer datetime format, provided format is not supported.",
81+
):
82+
parse_datetime_format(pa.chunked_array([data]))
83+
84+
85+
@pytest.mark.parametrize(
86+
("data", "duplicate"),
87+
[
88+
(["2024-01-01T00:00:00", "2024-01-01 01:00:00"], "separator"),
89+
(["2024-01-01 00:00:00+01:00", "2024-01-01 01:00:00+02:00"], "timezone"),
90+
],
91+
)
92+
def test_pyarrow_infer_datetime_raise_not_unique(
93+
data: list[str | None], duplicate: str
94+
) -> None:
95+
with pytest.raises(
96+
ValueError,
97+
match=f"Found multiple {duplicate} values while inferring datetime format.",
98+
):
99+
parse_datetime_format(pa.chunked_array([data]))
100+
101+
102+
@pytest.mark.parametrize("data", [["2024-01-01", "2024-12-01", "02-02-2024"]])
103+
def test_pyarrow_infer_datetime_raise_inconsistent_date_fmt(
104+
data: list[str | None],
105+
) -> None:
106+
with pytest.raises(ValueError, match="Unable to infer datetime format. "):
107+
parse_datetime_format(pa.chunked_array([data]))

0 commit comments

Comments
 (0)