23
23
from pointblank ._typing import AbsoluteTolBounds
24
24
25
25
26
+ def _safe_modify_datetime_compare_val (data_frame : Any , column : str , compare_val : Any ) -> Any :
27
+ """
28
+ Safely modify datetime comparison values for LazyFrame compatibility.
29
+
30
+ This function handles the case where we can't directly slice LazyFrames
31
+ to get column dtypes for datetime conversion.
32
+ """
33
+ try :
34
+ # First try to get column dtype from schema for LazyFrames
35
+ column_dtype = None
36
+
37
+ if hasattr (data_frame , "collect_schema" ):
38
+ schema = data_frame .collect_schema ()
39
+ column_dtype = schema .get (column )
40
+ elif hasattr (data_frame , "schema" ):
41
+ schema = data_frame .schema
42
+ column_dtype = schema .get (column )
43
+
44
+ # If we got a dtype from schema, use it
45
+ if column_dtype is not None :
46
+ # Create a mock column object for _modify_datetime_compare_val
47
+ class MockColumn :
48
+ def __init__ (self , dtype ):
49
+ self .dtype = dtype
50
+
51
+ mock_column = MockColumn (column_dtype )
52
+ return _modify_datetime_compare_val (tgt_column = mock_column , compare_val = compare_val )
53
+
54
+ # Fallback: try collecting a small sample if possible
55
+ try :
56
+ sample = data_frame .head (1 ).collect ()
57
+ if hasattr (sample , "dtypes" ) and column in sample .columns :
58
+ # For pandas-like dtypes
59
+ column_dtype = sample .dtypes [column ] if hasattr (sample , "dtypes" ) else None
60
+ if column_dtype :
61
+
62
+ class MockColumn :
63
+ def __init__ (self , dtype ):
64
+ self .dtype = dtype
65
+
66
+ mock_column = MockColumn (column_dtype )
67
+ return _modify_datetime_compare_val (
68
+ tgt_column = mock_column , compare_val = compare_val
69
+ )
70
+ except Exception :
71
+ pass
72
+
73
+ # Final fallback: try direct access (for eager DataFrames)
74
+ try :
75
+ if hasattr (data_frame , "dtypes" ) and column in data_frame .columns :
76
+ column_dtype = data_frame .dtypes [column ]
77
+
78
+ class MockColumn :
79
+ def __init__ (self , dtype ):
80
+ self .dtype = dtype
81
+
82
+ mock_column = MockColumn (column_dtype )
83
+ return _modify_datetime_compare_val (tgt_column = mock_column , compare_val = compare_val )
84
+ except Exception :
85
+ pass
86
+
87
+ except Exception :
88
+ pass
89
+
90
+ # If all else fails, return the original compare_val
91
+ return compare_val
92
+
93
+
26
94
@dataclass
27
95
class Interrogator :
28
96
"""
@@ -136,9 +204,7 @@ def gt(self) -> FrameT | Any:
136
204
137
205
compare_expr = _get_compare_expr_nw (compare = self .compare )
138
206
139
- compare_expr = _modify_datetime_compare_val (
140
- tgt_column = self .x [self .column ], compare_val = compare_expr
141
- )
207
+ compare_expr = _safe_modify_datetime_compare_val (self .x , self .column , compare_expr )
142
208
143
209
return (
144
210
self .x .with_columns (
@@ -211,9 +277,7 @@ def lt(self) -> FrameT | Any:
211
277
212
278
compare_expr = _get_compare_expr_nw (compare = self .compare )
213
279
214
- compare_expr = _modify_datetime_compare_val (
215
- tgt_column = self .x [self .column ], compare_val = compare_expr
216
- )
280
+ compare_expr = _safe_modify_datetime_compare_val (self .x , self .column , compare_expr )
217
281
218
282
return (
219
283
self .x .with_columns (
@@ -329,9 +393,7 @@ def eq(self) -> FrameT | Any:
329
393
else :
330
394
compare_expr = _get_compare_expr_nw (compare = self .compare )
331
395
332
- compare_expr = _modify_datetime_compare_val (
333
- tgt_column = self .x [self .column ], compare_val = compare_expr
334
- )
396
+ compare_expr = _safe_modify_datetime_compare_val (self .x , self .column , compare_expr )
335
397
336
398
tbl = self .x .with_columns (
337
399
pb_is_good_1 = nw .col (self .column ).is_null () & self .na_pass ,
@@ -421,9 +483,7 @@ def ne(self) -> FrameT | Any:
421
483
).to_native ()
422
484
423
485
else :
424
- compare_expr = _modify_datetime_compare_val (
425
- tgt_column = self .x [self .column ], compare_val = self .compare
426
- )
486
+ compare_expr = _safe_modify_datetime_compare_val (self .x , self .column , self .compare )
427
487
428
488
return self .x .with_columns (
429
489
pb_is_good_ = nw .col (self .column ) != nw .lit (compare_expr ),
@@ -544,9 +604,7 @@ def ne(self) -> FrameT | Any:
544
604
if ref_col_has_null_vals :
545
605
# Create individual cases for Pandas and Polars
546
606
547
- compare_expr = _modify_datetime_compare_val (
548
- tgt_column = self .x [self .column ], compare_val = self .compare
549
- )
607
+ compare_expr = _safe_modify_datetime_compare_val (self .x , self .column , self .compare )
550
608
551
609
if is_pandas_dataframe (self .x .to_native ()):
552
610
tbl = self .x .with_columns (
@@ -584,6 +642,25 @@ def ne(self) -> FrameT | Any:
584
642
585
643
return tbl
586
644
645
+ else :
646
+ # Generic case for other DataFrame types (PySpark, etc.)
647
+ # Use similar logic to Polars but handle potential differences
648
+ tbl = self .x .with_columns (
649
+ pb_is_good_1 = nw .col (self .column ).is_null (), # val is Null in Column
650
+ pb_is_good_2 = nw .lit (self .na_pass ), # Pass if any Null in val or compare
651
+ )
652
+
653
+ tbl = tbl .with_columns (pb_is_good_3 = nw .col (self .column ) != nw .lit (compare_expr ))
654
+
655
+ tbl = tbl .with_columns (
656
+ pb_is_good_ = (
657
+ (nw .col ("pb_is_good_1" ) & nw .col ("pb_is_good_2" ))
658
+ | (nw .col ("pb_is_good_3" ) & ~ nw .col ("pb_is_good_1" ))
659
+ )
660
+ )
661
+
662
+ return tbl .drop ("pb_is_good_1" , "pb_is_good_2" , "pb_is_good_3" ).to_native ()
663
+
587
664
def ge (self ) -> FrameT | Any :
588
665
# Ibis backends ---------------------------------------------
589
666
@@ -629,9 +706,7 @@ def ge(self) -> FrameT | Any:
629
706
630
707
compare_expr = _get_compare_expr_nw (compare = self .compare )
631
708
632
- compare_expr = _modify_datetime_compare_val (
633
- tgt_column = self .x [self .column ], compare_val = compare_expr
634
- )
709
+ compare_expr = _safe_modify_datetime_compare_val (self .x , self .column , compare_expr )
635
710
636
711
tbl = (
637
712
self .x .with_columns (
@@ -702,9 +777,7 @@ def le(self) -> FrameT | Any:
702
777
703
778
compare_expr = _get_compare_expr_nw (compare = self .compare )
704
779
705
- compare_expr = _modify_datetime_compare_val (
706
- tgt_column = self .x [self .column ], compare_val = compare_expr
707
- )
780
+ compare_expr = _safe_modify_datetime_compare_val (self .x , self .column , compare_expr )
708
781
709
782
return (
710
783
self .x .with_columns (
@@ -834,10 +907,8 @@ def between(self) -> FrameT | Any:
834
907
low_val = _get_compare_expr_nw (compare = self .low )
835
908
high_val = _get_compare_expr_nw (compare = self .high )
836
909
837
- low_val = _modify_datetime_compare_val (tgt_column = self .x [self .column ], compare_val = low_val )
838
- high_val = _modify_datetime_compare_val (
839
- tgt_column = self .x [self .column ], compare_val = high_val
840
- )
910
+ low_val = _safe_modify_datetime_compare_val (self .x , self .column , low_val )
911
+ high_val = _safe_modify_datetime_compare_val (self .x , self .column , high_val )
841
912
842
913
tbl = self .x .with_columns (
843
914
pb_is_good_1 = nw .col (self .column ).is_null (), # val is Null in Column
@@ -1026,10 +1097,8 @@ def outside(self) -> FrameT | Any:
1026
1097
low_val = _get_compare_expr_nw (compare = self .low )
1027
1098
high_val = _get_compare_expr_nw (compare = self .high )
1028
1099
1029
- low_val = _modify_datetime_compare_val (tgt_column = self .x [self .column ], compare_val = low_val )
1030
- high_val = _modify_datetime_compare_val (
1031
- tgt_column = self .x [self .column ], compare_val = high_val
1032
- )
1100
+ low_val = _safe_modify_datetime_compare_val (self .x , self .column , low_val )
1101
+ high_val = _safe_modify_datetime_compare_val (self .x , self .column , high_val )
1033
1102
1034
1103
tbl = self .x .with_columns (
1035
1104
pb_is_good_1 = nw .col (self .column ).is_null (), # val is Null in Column
@@ -1209,14 +1278,15 @@ def rows_distinct(self) -> FrameT | Any:
1209
1278
else :
1210
1279
columns_subset = self .columns_subset
1211
1280
1212
- # Create a subset of the table with only the columns of interest
1213
- subset_tbl = tbl .select (columns_subset )
1281
+ # Create a count of duplicates using group_by approach like Ibis backend
1282
+ # Group by the columns of interest and count occurrences
1283
+ count_tbl = tbl .group_by (columns_subset ).agg (nw .len ().alias ("pb_count_" ))
1214
1284
1215
- # Check for duplicates in the subset table, creating a series of booleans
1216
- pb_is_good_series = subset_tbl . is_duplicated ( )
1285
+ # Join back to original table to get count for each row
1286
+ tbl = tbl . join ( count_tbl , on = columns_subset , how = "left" )
1217
1287
1218
- # Add the series to the input table
1219
- tbl = tbl .with_columns (pb_is_good_ = ~ pb_is_good_series )
1288
+ # Passing rows will have the value `1` (no duplicates, so True), otherwise False applies
1289
+ tbl = tbl .with_columns (pb_is_good_ = nw . col ( "pb_count_" ) == 1 ). drop ( "pb_count_" )
1220
1290
1221
1291
return tbl .to_native ()
1222
1292
@@ -2088,6 +2158,8 @@ def get_test_results(self):
2088
2158
return self ._get_pandas_results ()
2089
2159
elif "duckdb" in self .tbl_type or "ibis" in self .tbl_type :
2090
2160
return self ._get_ibis_results ()
2161
+ elif "pyspark" in self .tbl_type :
2162
+ return self ._get_pyspark_results ()
2091
2163
else : # pragma: no cover
2092
2164
raise NotImplementedError (f"Support for { self .tbl_type } is not yet implemented" )
2093
2165
@@ -2247,6 +2319,53 @@ def _get_ibis_results(self):
2247
2319
results_tbl = self .data_tbl .mutate (pb_is_good_ = ibis .literal (True ))
2248
2320
return results_tbl
2249
2321
2322
+ def _get_pyspark_results (self ):
2323
+ """Process expressions for PySpark DataFrames."""
2324
+ from pyspark .sql import functions as F
2325
+
2326
+ pyspark_columns = []
2327
+
2328
+ for expr_fn in self .expressions :
2329
+ try :
2330
+ # First try direct evaluation with PySpark DataFrame
2331
+ expr_result = expr_fn (self .data_tbl )
2332
+
2333
+ # Check if it's a PySpark Column
2334
+ if hasattr (expr_result , "_jc" ): # PySpark Column has _jc attribute
2335
+ pyspark_columns .append (expr_result )
2336
+ else :
2337
+ raise TypeError (
2338
+ f"Expression returned { type (expr_result )} , expected PySpark Column"
2339
+ )
2340
+
2341
+ except Exception as e :
2342
+ try :
2343
+ # Try as a ColumnExpression (for pb.expr_col style)
2344
+ col_expr = expr_fn (None )
2345
+
2346
+ if hasattr (col_expr , "to_pyspark_expr" ):
2347
+ # Convert to PySpark expression
2348
+ pyspark_expr = col_expr .to_pyspark_expr (self .data_tbl )
2349
+ pyspark_columns .append (pyspark_expr )
2350
+ else :
2351
+ raise TypeError (f"Cannot convert { type (col_expr )} to PySpark Column" )
2352
+ except Exception as nested_e :
2353
+ print (f"Error evaluating PySpark expression: { e } -> { nested_e } " )
2354
+
2355
+ # Combine results with AND logic
2356
+ if pyspark_columns :
2357
+ final_result = pyspark_columns [0 ]
2358
+ for col in pyspark_columns [1 :]:
2359
+ final_result = final_result & col
2360
+
2361
+ # Create results table with boolean column
2362
+ results_tbl = self .data_tbl .withColumn ("pb_is_good_" , final_result )
2363
+ return results_tbl
2364
+
2365
+ # Default case
2366
+ results_tbl = self .data_tbl .withColumn ("pb_is_good_" , F .lit (True ))
2367
+ return results_tbl
2368
+
2250
2369
2251
2370
class SpeciallyValidation :
2252
2371
def __init__ (self , data_tbl , expression , threshold , tbl_type ):
@@ -2359,13 +2478,22 @@ class NumberOfTestUnits:
2359
2478
column : str
2360
2479
2361
2480
def get_test_units (self , tbl_type : str ) -> int :
2362
- if tbl_type == "pandas" or tbl_type == "polars" :
2481
+ if (
2482
+ tbl_type == "pandas"
2483
+ or tbl_type == "polars"
2484
+ or tbl_type == "pyspark"
2485
+ or tbl_type == "local"
2486
+ ):
2363
2487
# Convert the DataFrame to a format that narwhals can work with and:
2364
2488
# - check if the column exists
2365
2489
dfn = _column_test_prep (
2366
2490
df = self .df , column = self .column , allowed_types = None , check_exists = False
2367
2491
)
2368
2492
2493
+ # Handle LazyFrames which don't have len()
2494
+ if hasattr (dfn , "collect" ):
2495
+ dfn = dfn .collect ()
2496
+
2369
2497
return len (dfn )
2370
2498
2371
2499
if tbl_type in IBIS_BACKENDS :
@@ -2383,7 +2511,22 @@ def _get_compare_expr_nw(compare: Any) -> Any:
2383
2511
2384
2512
2385
2513
def _column_has_null_values (table : FrameT , column : str ) -> bool :
2386
- null_count = (table .select (column ).null_count ())[column ][0 ]
2514
+ try :
2515
+ # Try the standard null_count() method
2516
+ null_count = (table .select (column ).null_count ())[column ][0 ]
2517
+ except AttributeError :
2518
+ # For LazyFrames, collect first then get null count
2519
+ try :
2520
+ collected = table .select (column ).collect ()
2521
+ null_count = (collected .null_count ())[column ][0 ]
2522
+ except Exception :
2523
+ # Fallback: check if any values are null
2524
+ try :
2525
+ result = table .select (nw .col (column ).is_null ().sum ().alias ("null_count" )).collect ()
2526
+ null_count = result ["null_count" ][0 ]
2527
+ except Exception :
2528
+ # Last resort: return False (assume no nulls)
2529
+ return False
2387
2530
2388
2531
if null_count is None or null_count == 0 :
2389
2532
return False
@@ -2414,7 +2557,7 @@ def _check_nulls_across_columns_nw(table, columns_subset):
2414
2557
2415
2558
# Build the expression by combining each column's `is_null()` with OR operations
2416
2559
null_expr = functools .reduce (
2417
- lambda acc , col : acc | table [ col ] .is_null () if acc is not None else table [ col ] .is_null (),
2560
+ lambda acc , col : acc | nw . col ( col ) .is_null () if acc is not None else nw . col ( col ) .is_null (),
2418
2561
column_names ,
2419
2562
None ,
2420
2563
)
0 commit comments