@@ -502,6 +502,14 @@ class LocBody(Loc):
502
502
rows
503
503
The rows to target. Can either be a single row name or a series of row names provided in a
504
504
list.
505
+ mask
506
+ The cells to target. If the underlying wrapped DataFrame is a Polars DataFrame,
507
+ you can pass a Polars expression for cell-based selection. This argument must be used
508
+ exclusively and cannot be combined with the `columns=` or `rows=` arguments.
509
+
510
+ :::{.callout-warning}
511
+ `mask=` is still experimental.
512
+ :::
505
513
506
514
Returns
507
515
-------
@@ -539,6 +547,7 @@ class LocBody(Loc):
539
547
540
548
columns : SelectExpr = None
541
549
rows : RowSelectExpr = None
550
+ mask : PlExpr | None = None
542
551
543
552
544
553
@dataclass
@@ -823,6 +832,52 @@ def resolve_rows_i(
823
832
)
824
833
825
834
835
+ def resolve_mask (
836
+ data : GTData | list [str ],
837
+ expr : PlExpr ,
838
+ excl_stub : bool = True ,
839
+ excl_group : bool = True ,
840
+ ) -> list [tuple [int , int , str ]]:
841
+ """Return data for creating `CellPos`, based on expr"""
842
+ if not isinstance (expr , PlExpr ):
843
+ raise ValueError ("Only Polars expressions can be passed to the `mask` argument." )
844
+
845
+ frame : PlDataFrame = data ._tbl_data
846
+ frame_cols = frame .columns
847
+
848
+ stub_var = data ._boxhead .vars_from_type (ColInfoTypeEnum .stub )
849
+ group_var = data ._boxhead .vars_from_type (ColInfoTypeEnum .row_group )
850
+ cols_excl = [* (stub_var if excl_stub else []), * (group_var if excl_group else [])]
851
+
852
+ # `df.select()` raises `ColumnNotFoundError` if columns are missing from the original DataFrame.
853
+ masked = frame .select (expr ).drop (cols_excl , strict = False )
854
+
855
+ # Validate that `masked.columns` exist in the `frame_cols`
856
+ missing = set (masked .columns ) - set (frame_cols )
857
+ if missing :
858
+ raise ValueError (
859
+ "The `mask` expression produces extra columns, with names not in the original DataFrame."
860
+ f"\n \n Extra columns: { missing } "
861
+ )
862
+
863
+ # Validate that row lengths are equal
864
+ if masked .height != frame .height :
865
+ raise ValueError (
866
+ "The DataFrame length after applying `mask` differs from the original."
867
+ "\n \n * Original length: {frame.height}"
868
+ "\n * Mask length: {masked.height}"
869
+ )
870
+
871
+ cellpos_data : list [tuple [int , int , str ]] = [] # column, row, colname for `CellPos`
872
+ col_idx_map = {colname : frame_cols .index (colname ) for colname in frame_cols }
873
+ for row_idx , row_dict in enumerate (masked .iter_rows (named = True )):
874
+ for colname , value in row_dict .items ():
875
+ if value : # select only when `value` is True
876
+ col_idx = col_idx_map [colname ]
877
+ cellpos_data .append ((col_idx , row_idx , colname ))
878
+ return cellpos_data
879
+
880
+
826
881
# Resolve generic ======================================================================
827
882
828
883
@@ -868,15 +923,22 @@ def _(loc: LocStub, data: GTData) -> set[int]:
868
923
869
924
@resolve .register
870
925
def _ (loc : LocBody , data : GTData ) -> list [CellPos ]:
871
- cols = resolve_cols_i (data = data , expr = loc .columns )
872
- rows = resolve_rows_i (data = data , expr = loc .rows )
873
-
874
- # TODO: dplyr arranges by `Var1`, and does distinct (since you can tidyselect the same
875
- # thing multiple times
876
- cell_pos = [
877
- CellPos (col [1 ], row [1 ], colname = col [0 ]) for col , row in itertools .product (cols , rows )
878
- ]
926
+ if (loc .columns is not None or loc .rows is not None ) and loc .mask is not None :
927
+ raise ValueError (
928
+ "Cannot specify the `mask` argument along with `columns` or `rows` in `loc.body()`."
929
+ )
879
930
931
+ if loc .mask is None :
932
+ rows = resolve_rows_i (data = data , expr = loc .rows )
933
+ cols = resolve_cols_i (data = data , expr = loc .columns )
934
+ # TODO: dplyr arranges by `Var1`, and does distinct (since you can tidyselect the same
935
+ # thing multiple times
936
+ cell_pos = [
937
+ CellPos (col [1 ], row [1 ], colname = col [0 ]) for col , row in itertools .product (cols , rows )
938
+ ]
939
+ else :
940
+ cellpos_data = resolve_mask (data = data , expr = loc .mask )
941
+ cell_pos = [CellPos (* cellpos ) for cellpos in cellpos_data ]
880
942
return cell_pos
881
943
882
944
0 commit comments