Skip to content

Commit a7f2646

Browse files
authored
Merge pull request #256 from posit-dev/feat-validate-spark-df-without-ibis
feat: validate Spark DFs without using Ibis
2 parents b15ebad + 18bb562 commit a7f2646

File tree

14 files changed

+1423
-106
lines changed

14 files changed

+1423
-106
lines changed

.github/workflows/ci-tests.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,21 @@ jobs:
2424
uses: actions/setup-python@v5
2525
with:
2626
python-version: ${{ matrix.python-version }}
27+
- name: Set up Java for PySpark
28+
uses: actions/setup-java@v4
29+
with:
30+
distribution: "temurin"
31+
java-version: "11"
2732
- name: Install uv
2833
uses: astral-sh/setup-uv@v5
2934
- name: pytest unit tests
3035
run: |
3136
make test
37+
env:
38+
# Optimize PySpark for CI environment
39+
PYSPARK_DRIVER_MEMORY: 1g
40+
PYSPARK_EXECUTOR_MEMORY: 1g
41+
SPARK_LOCAL_IP: 127.0.0.1
3242
- name: Upload coverage reports to Codecov
3343
uses: codecov/codecov-action@v5
3444
with:

pointblank/_constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@
118118
"mysql",
119119
"parquet",
120120
"postgres",
121-
"pyspark",
122121
"snowflake",
123122
"sqlite",
124123
]

pointblank/_interrogation.py

Lines changed: 181 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,74 @@
2323
from pointblank._typing import AbsoluteTolBounds
2424

2525

26+
def _safe_modify_datetime_compare_val(data_frame: Any, column: str, compare_val: Any) -> Any:
27+
"""
28+
Safely modify datetime comparison values for LazyFrame compatibility.
29+
30+
This function handles the case where we can't directly slice LazyFrames
31+
to get column dtypes for datetime conversion.
32+
"""
33+
try:
34+
# First try to get column dtype from schema for LazyFrames
35+
column_dtype = None
36+
37+
if hasattr(data_frame, "collect_schema"):
38+
schema = data_frame.collect_schema()
39+
column_dtype = schema.get(column)
40+
elif hasattr(data_frame, "schema"):
41+
schema = data_frame.schema
42+
column_dtype = schema.get(column)
43+
44+
# If we got a dtype from schema, use it
45+
if column_dtype is not None:
46+
# Create a mock column object for _modify_datetime_compare_val
47+
class MockColumn:
48+
def __init__(self, dtype):
49+
self.dtype = dtype
50+
51+
mock_column = MockColumn(column_dtype)
52+
return _modify_datetime_compare_val(tgt_column=mock_column, compare_val=compare_val)
53+
54+
# Fallback: try collecting a small sample if possible
55+
try:
56+
sample = data_frame.head(1).collect()
57+
if hasattr(sample, "dtypes") and column in sample.columns:
58+
# For pandas-like dtypes
59+
column_dtype = sample.dtypes[column] if hasattr(sample, "dtypes") else None
60+
if column_dtype:
61+
62+
class MockColumn:
63+
def __init__(self, dtype):
64+
self.dtype = dtype
65+
66+
mock_column = MockColumn(column_dtype)
67+
return _modify_datetime_compare_val(
68+
tgt_column=mock_column, compare_val=compare_val
69+
)
70+
except Exception:
71+
pass
72+
73+
# Final fallback: try direct access (for eager DataFrames)
74+
try:
75+
if hasattr(data_frame, "dtypes") and column in data_frame.columns:
76+
column_dtype = data_frame.dtypes[column]
77+
78+
class MockColumn:
79+
def __init__(self, dtype):
80+
self.dtype = dtype
81+
82+
mock_column = MockColumn(column_dtype)
83+
return _modify_datetime_compare_val(tgt_column=mock_column, compare_val=compare_val)
84+
except Exception:
85+
pass
86+
87+
except Exception:
88+
pass
89+
90+
# If all else fails, return the original compare_val
91+
return compare_val
92+
93+
2694
@dataclass
2795
class Interrogator:
2896
"""
@@ -136,9 +204,7 @@ def gt(self) -> FrameT | Any:
136204

137205
compare_expr = _get_compare_expr_nw(compare=self.compare)
138206

139-
compare_expr = _modify_datetime_compare_val(
140-
tgt_column=self.x[self.column], compare_val=compare_expr
141-
)
207+
compare_expr = _safe_modify_datetime_compare_val(self.x, self.column, compare_expr)
142208

143209
return (
144210
self.x.with_columns(
@@ -211,9 +277,7 @@ def lt(self) -> FrameT | Any:
211277

212278
compare_expr = _get_compare_expr_nw(compare=self.compare)
213279

214-
compare_expr = _modify_datetime_compare_val(
215-
tgt_column=self.x[self.column], compare_val=compare_expr
216-
)
280+
compare_expr = _safe_modify_datetime_compare_val(self.x, self.column, compare_expr)
217281

218282
return (
219283
self.x.with_columns(
@@ -329,9 +393,7 @@ def eq(self) -> FrameT | Any:
329393
else:
330394
compare_expr = _get_compare_expr_nw(compare=self.compare)
331395

332-
compare_expr = _modify_datetime_compare_val(
333-
tgt_column=self.x[self.column], compare_val=compare_expr
334-
)
396+
compare_expr = _safe_modify_datetime_compare_val(self.x, self.column, compare_expr)
335397

336398
tbl = self.x.with_columns(
337399
pb_is_good_1=nw.col(self.column).is_null() & self.na_pass,
@@ -421,9 +483,7 @@ def ne(self) -> FrameT | Any:
421483
).to_native()
422484

423485
else:
424-
compare_expr = _modify_datetime_compare_val(
425-
tgt_column=self.x[self.column], compare_val=self.compare
426-
)
486+
compare_expr = _safe_modify_datetime_compare_val(self.x, self.column, self.compare)
427487

428488
return self.x.with_columns(
429489
pb_is_good_=nw.col(self.column) != nw.lit(compare_expr),
@@ -544,9 +604,7 @@ def ne(self) -> FrameT | Any:
544604
if ref_col_has_null_vals:
545605
# Create individual cases for Pandas and Polars
546606

547-
compare_expr = _modify_datetime_compare_val(
548-
tgt_column=self.x[self.column], compare_val=self.compare
549-
)
607+
compare_expr = _safe_modify_datetime_compare_val(self.x, self.column, self.compare)
550608

551609
if is_pandas_dataframe(self.x.to_native()):
552610
tbl = self.x.with_columns(
@@ -584,6 +642,25 @@ def ne(self) -> FrameT | Any:
584642

585643
return tbl
586644

645+
else:
646+
# Generic case for other DataFrame types (PySpark, etc.)
647+
# Use similar logic to Polars but handle potential differences
648+
tbl = self.x.with_columns(
649+
pb_is_good_1=nw.col(self.column).is_null(), # val is Null in Column
650+
pb_is_good_2=nw.lit(self.na_pass), # Pass if any Null in val or compare
651+
)
652+
653+
tbl = tbl.with_columns(pb_is_good_3=nw.col(self.column) != nw.lit(compare_expr))
654+
655+
tbl = tbl.with_columns(
656+
pb_is_good_=(
657+
(nw.col("pb_is_good_1") & nw.col("pb_is_good_2"))
658+
| (nw.col("pb_is_good_3") & ~nw.col("pb_is_good_1"))
659+
)
660+
)
661+
662+
return tbl.drop("pb_is_good_1", "pb_is_good_2", "pb_is_good_3").to_native()
663+
587664
def ge(self) -> FrameT | Any:
588665
# Ibis backends ---------------------------------------------
589666

@@ -629,9 +706,7 @@ def ge(self) -> FrameT | Any:
629706

630707
compare_expr = _get_compare_expr_nw(compare=self.compare)
631708

632-
compare_expr = _modify_datetime_compare_val(
633-
tgt_column=self.x[self.column], compare_val=compare_expr
634-
)
709+
compare_expr = _safe_modify_datetime_compare_val(self.x, self.column, compare_expr)
635710

636711
tbl = (
637712
self.x.with_columns(
@@ -702,9 +777,7 @@ def le(self) -> FrameT | Any:
702777

703778
compare_expr = _get_compare_expr_nw(compare=self.compare)
704779

705-
compare_expr = _modify_datetime_compare_val(
706-
tgt_column=self.x[self.column], compare_val=compare_expr
707-
)
780+
compare_expr = _safe_modify_datetime_compare_val(self.x, self.column, compare_expr)
708781

709782
return (
710783
self.x.with_columns(
@@ -834,10 +907,8 @@ def between(self) -> FrameT | Any:
834907
low_val = _get_compare_expr_nw(compare=self.low)
835908
high_val = _get_compare_expr_nw(compare=self.high)
836909

837-
low_val = _modify_datetime_compare_val(tgt_column=self.x[self.column], compare_val=low_val)
838-
high_val = _modify_datetime_compare_val(
839-
tgt_column=self.x[self.column], compare_val=high_val
840-
)
910+
low_val = _safe_modify_datetime_compare_val(self.x, self.column, low_val)
911+
high_val = _safe_modify_datetime_compare_val(self.x, self.column, high_val)
841912

842913
tbl = self.x.with_columns(
843914
pb_is_good_1=nw.col(self.column).is_null(), # val is Null in Column
@@ -1026,10 +1097,8 @@ def outside(self) -> FrameT | Any:
10261097
low_val = _get_compare_expr_nw(compare=self.low)
10271098
high_val = _get_compare_expr_nw(compare=self.high)
10281099

1029-
low_val = _modify_datetime_compare_val(tgt_column=self.x[self.column], compare_val=low_val)
1030-
high_val = _modify_datetime_compare_val(
1031-
tgt_column=self.x[self.column], compare_val=high_val
1032-
)
1100+
low_val = _safe_modify_datetime_compare_val(self.x, self.column, low_val)
1101+
high_val = _safe_modify_datetime_compare_val(self.x, self.column, high_val)
10331102

10341103
tbl = self.x.with_columns(
10351104
pb_is_good_1=nw.col(self.column).is_null(), # val is Null in Column
@@ -1209,14 +1278,15 @@ def rows_distinct(self) -> FrameT | Any:
12091278
else:
12101279
columns_subset = self.columns_subset
12111280

1212-
# Create a subset of the table with only the columns of interest
1213-
subset_tbl = tbl.select(columns_subset)
1281+
# Create a count of duplicates using group_by approach like Ibis backend
1282+
# Group by the columns of interest and count occurrences
1283+
count_tbl = tbl.group_by(columns_subset).agg(nw.len().alias("pb_count_"))
12141284

1215-
# Check for duplicates in the subset table, creating a series of booleans
1216-
pb_is_good_series = subset_tbl.is_duplicated()
1285+
# Join back to original table to get count for each row
1286+
tbl = tbl.join(count_tbl, on=columns_subset, how="left")
12171287

1218-
# Add the series to the input table
1219-
tbl = tbl.with_columns(pb_is_good_=~pb_is_good_series)
1288+
# Passing rows will have the value `1` (no duplicates, so True), otherwise False applies
1289+
tbl = tbl.with_columns(pb_is_good_=nw.col("pb_count_") == 1).drop("pb_count_")
12201290

12211291
return tbl.to_native()
12221292

@@ -2088,6 +2158,8 @@ def get_test_results(self):
20882158
return self._get_pandas_results()
20892159
elif "duckdb" in self.tbl_type or "ibis" in self.tbl_type:
20902160
return self._get_ibis_results()
2161+
elif "pyspark" in self.tbl_type:
2162+
return self._get_pyspark_results()
20912163
else: # pragma: no cover
20922164
raise NotImplementedError(f"Support for {self.tbl_type} is not yet implemented")
20932165

@@ -2247,6 +2319,53 @@ def _get_ibis_results(self):
22472319
results_tbl = self.data_tbl.mutate(pb_is_good_=ibis.literal(True))
22482320
return results_tbl
22492321

2322+
def _get_pyspark_results(self):
2323+
"""Process expressions for PySpark DataFrames."""
2324+
from pyspark.sql import functions as F
2325+
2326+
pyspark_columns = []
2327+
2328+
for expr_fn in self.expressions:
2329+
try:
2330+
# First try direct evaluation with PySpark DataFrame
2331+
expr_result = expr_fn(self.data_tbl)
2332+
2333+
# Check if it's a PySpark Column
2334+
if hasattr(expr_result, "_jc"): # PySpark Column has _jc attribute
2335+
pyspark_columns.append(expr_result)
2336+
else:
2337+
raise TypeError(
2338+
f"Expression returned {type(expr_result)}, expected PySpark Column"
2339+
)
2340+
2341+
except Exception as e:
2342+
try:
2343+
# Try as a ColumnExpression (for pb.expr_col style)
2344+
col_expr = expr_fn(None)
2345+
2346+
if hasattr(col_expr, "to_pyspark_expr"):
2347+
# Convert to PySpark expression
2348+
pyspark_expr = col_expr.to_pyspark_expr(self.data_tbl)
2349+
pyspark_columns.append(pyspark_expr)
2350+
else:
2351+
raise TypeError(f"Cannot convert {type(col_expr)} to PySpark Column")
2352+
except Exception as nested_e:
2353+
print(f"Error evaluating PySpark expression: {e} -> {nested_e}")
2354+
2355+
# Combine results with AND logic
2356+
if pyspark_columns:
2357+
final_result = pyspark_columns[0]
2358+
for col in pyspark_columns[1:]:
2359+
final_result = final_result & col
2360+
2361+
# Create results table with boolean column
2362+
results_tbl = self.data_tbl.withColumn("pb_is_good_", final_result)
2363+
return results_tbl
2364+
2365+
# Default case
2366+
results_tbl = self.data_tbl.withColumn("pb_is_good_", F.lit(True))
2367+
return results_tbl
2368+
22502369

22512370
class SpeciallyValidation:
22522371
def __init__(self, data_tbl, expression, threshold, tbl_type):
@@ -2359,13 +2478,22 @@ class NumberOfTestUnits:
23592478
column: str
23602479

23612480
def get_test_units(self, tbl_type: str) -> int:
2362-
if tbl_type == "pandas" or tbl_type == "polars":
2481+
if (
2482+
tbl_type == "pandas"
2483+
or tbl_type == "polars"
2484+
or tbl_type == "pyspark"
2485+
or tbl_type == "local"
2486+
):
23632487
# Convert the DataFrame to a format that narwhals can work with and:
23642488
# - check if the column exists
23652489
dfn = _column_test_prep(
23662490
df=self.df, column=self.column, allowed_types=None, check_exists=False
23672491
)
23682492

2493+
# Handle LazyFrames which don't have len()
2494+
if hasattr(dfn, "collect"):
2495+
dfn = dfn.collect()
2496+
23692497
return len(dfn)
23702498

23712499
if tbl_type in IBIS_BACKENDS:
@@ -2383,7 +2511,22 @@ def _get_compare_expr_nw(compare: Any) -> Any:
23832511

23842512

23852513
def _column_has_null_values(table: FrameT, column: str) -> bool:
2386-
null_count = (table.select(column).null_count())[column][0]
2514+
try:
2515+
# Try the standard null_count() method
2516+
null_count = (table.select(column).null_count())[column][0]
2517+
except AttributeError:
2518+
# For LazyFrames, collect first then get null count
2519+
try:
2520+
collected = table.select(column).collect()
2521+
null_count = (collected.null_count())[column][0]
2522+
except Exception:
2523+
# Fallback: check if any values are null
2524+
try:
2525+
result = table.select(nw.col(column).is_null().sum().alias("null_count")).collect()
2526+
null_count = result["null_count"][0]
2527+
except Exception:
2528+
# Last resort: return False (assume no nulls)
2529+
return False
23872530

23882531
if null_count is None or null_count == 0:
23892532
return False
@@ -2414,7 +2557,7 @@ def _check_nulls_across_columns_nw(table, columns_subset):
24142557

24152558
# Build the expression by combining each column's `is_null()` with OR operations
24162559
null_expr = functools.reduce(
2417-
lambda acc, col: acc | table[col].is_null() if acc is not None else table[col].is_null(),
2560+
lambda acc, col: acc | nw.col(col).is_null() if acc is not None else nw.col(col).is_null(),
24182561
column_names,
24192562
None,
24202563
)

pointblank/_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,13 @@ def _get_tbl_type(data: FrameT | Any) -> str:
6666
except Exception as e:
6767
raise TypeError("The `data` object is not a DataFrame or Ibis Table.") from e
6868

69-
# Detect through regex if the table is a polars or pandas DataFrame
69+
# Detect through regex if the table is a polars, pandas, or Spark DataFrame
7070
if re.search(r"polars", df_ns_str, re.IGNORECASE):
7171
return "polars"
7272
elif re.search(r"pandas", df_ns_str, re.IGNORECASE):
7373
return "pandas"
74+
elif re.search(r"pyspark", df_ns_str, re.IGNORECASE):
75+
return "pyspark"
7476

7577
# If ibis is present, then get the table's backend name
7678
ibis_present = _is_lib_present(lib_name="ibis")
@@ -164,7 +166,7 @@ def _check_any_df_lib(method_used: str) -> None:
164166
def _is_value_a_df(value: Any) -> bool:
165167
try:
166168
ns = nw.get_native_namespace(value)
167-
if "polars" in str(ns) or "pandas" in str(ns):
169+
if "polars" in str(ns) or "pandas" in str(ns) or "pyspark" in str(ns):
168170
return True
169171
else: # pragma: no cover
170172
return False

0 commit comments

Comments
 (0)