Skip to content

Commit 9d153a8

Browse files
authored
fix: PySpark was raising during collect when it contained no rows and a void dtype column (#2032)
1 parent d2d55da commit 9d153a8

File tree

4 files changed

+39
-9
lines changed

4 files changed

+39
-9
lines changed

.github/workflows/downstream_tests.yml

+11-2
Original file line numberDiff line numberDiff line change
@@ -461,18 +461,27 @@ jobs:
461461
- name: install-validoopsie-dev
462462
run: |
463463
cd validoopsie
464+
uv venv
465+
. .venv/bin/activate
464466
uv sync --dev
465467
uv pip install pytest-env
466468
which python
467469
- name: show-deps
468-
run: uv pip freeze
470+
run: |
471+
cd validoopsie
472+
. .venv/bin/activate
473+
uv pip freeze
469474
- name: install-narwhals-dev
470475
run: |
471476
cd validoopsie
477+
. .venv/bin/activate
472478
uv pip uninstall narwhals
473479
uv pip install -e ./..
474480
- name: Run tests
475481
run: |
476482
cd validoopsie
477-
uv run pytest
483+
. .venv/bin/activate
484+
touch tests/__init__.py
485+
touch tests/utils/__init__.py
486+
pytest tests
478487
timeout-minutes: 15

narwhals/_duckdb/expr.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -494,9 +494,7 @@ def is_finite(self: Self) -> Self:
494494

495495
def is_in(self: Self, other: Sequence[Any]) -> Self:
496496
return self._from_call(
497-
lambda _input: lit(False) # noqa: FBT003
498-
if not other
499-
else _input.isin(*[lit(x) for x in other]),
497+
lambda _input: FunctionExpression("contains", lit(other), _input),
500498
"is_in",
501499
expr_kind=self._expr_kind,
502500
)

narwhals/_spark_like/dataframe.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import warnings
34
from typing import TYPE_CHECKING
45
from typing import Any
56
from typing import Literal
@@ -13,6 +14,7 @@
1314
from narwhals.typing import CompliantLazyFrame
1415
from narwhals.utils import Implementation
1516
from narwhals.utils import check_column_exists
17+
from narwhals.utils import find_stacklevel
1618
from narwhals.utils import import_dtypes_module
1719
from narwhals.utils import parse_columns_to_drop
1820
from narwhals.utils import parse_version
@@ -124,13 +126,27 @@ def _collect_to_arrow(self) -> pa.Table:
124126
from narwhals._arrow.utils import narwhals_to_native_dtype
125127

126128
data: dict[str, list[Any]] = {}
127-
schema = []
129+
schema: list[tuple[str, pa.DataType]] = []
128130
current_schema = self.collect_schema()
129131
for key, value in current_schema.items():
130132
data[key] = []
131-
schema.append(
132-
(key, narwhals_to_native_dtype(value, self._version))
133-
)
133+
try:
134+
native_dtype = narwhals_to_native_dtype(value, self._version)
135+
except Exception as exc: # noqa: BLE001
136+
native_spark_dtype = self._native_frame.schema[key].dataType
137+
# If we can't convert the type, just set it to `pa.null`, and warn.
138+
# Avoid the warning if we're starting from PySpark's void type.
139+
# We can avoid the check when we introduce `nw.Null` dtype.
140+
if not isinstance(
141+
native_spark_dtype, self._native_dtypes.NullType
142+
):
143+
warnings.warn(
144+
f"Could not convert dtype {native_spark_dtype} to PyArrow dtype, {exc!r}",
145+
stacklevel=find_stacklevel(),
146+
)
147+
schema.append((key, pa.null()))
148+
else:
149+
schema.append((key, native_dtype))
134150
native_pyarrow_frame = pa.Table.from_pydict(
135151
data, schema=pa.schema(schema)
136152
)

tests/frame/collect_test.py

+7
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,10 @@ def test_collect_with_kwargs(constructor: Constructor) -> None:
105105

106106
expected = {"a": [3], "b": [7]}
107107
assert_equal_data(result, expected)
108+
109+
110+
def test_collect_empty_pyspark(constructor: Constructor) -> None:
111+
df = nw_v1.from_native(constructor({"a": [1, 2, 3]}))
112+
df = df.filter(nw.col("a").is_null()).with_columns(b=nw.lit(None)).lazy()
113+
result = df.collect()
114+
assert_equal_data(result, {"a": [], "b": []})

0 commit comments

Comments
 (0)