Skip to content

Commit 15b293e

Browse files
authored
Support mask= argument in LocBody (#566)
* A rough idea for supporting the `mask` argument in `GT.tab_style()` * Ensure the `mask=` argument is used exclusively without specifying `columns` or `rows` * Add tests for `resolve_mask_i()` * Update test name * Replace ambiguous variable name `masks` with `cellpos_data` * Update the `resolve_mask_i()` logic based on team feedback * Replace `assert` with `raise ValueError()` * Update test cases for `GT.tab_style()` * Add additional test cases for `resolve_mask_i()` * Add docstring for `mask=` in `LocBody` * Use `df.height` to get the number of rows in a DataFrame * Rename `resolve_mask_i` to `resolve_mask` * Apply code review suggestions for the `mask=` implementation in `LocBody`
1 parent ef7d2ea commit 15b293e

File tree

2 files changed

+146
-8
lines changed

2 files changed

+146
-8
lines changed

great_tables/_locations.py

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,14 @@ class LocBody(Loc):
502502
rows
503503
The rows to target. Can either be a single row name or a series of row names provided in a
504504
list.
505+
mask
506+
The cells to target. If the underlying wrapped DataFrame is a Polars DataFrame,
507+
you can pass a Polars expression for cell-based selection. This argument must be used
508+
exclusively and cannot be combined with the `columns=` or `rows=` arguments.
509+
510+
:::{.callout-warning}
511+
`mask=` is still experimental.
512+
:::
505513
506514
Returns
507515
-------
@@ -539,6 +547,7 @@ class LocBody(Loc):
539547

540548
columns: SelectExpr = None
541549
rows: RowSelectExpr = None
550+
mask: PlExpr | None = None
542551

543552

544553
@dataclass
@@ -823,6 +832,52 @@ def resolve_rows_i(
823832
)
824833

825834

835+
def resolve_mask(
836+
data: GTData | list[str],
837+
expr: PlExpr,
838+
excl_stub: bool = True,
839+
excl_group: bool = True,
840+
) -> list[tuple[int, int, str]]:
841+
"""Return data for creating `CellPos`, based on expr"""
842+
if not isinstance(expr, PlExpr):
843+
raise ValueError("Only Polars expressions can be passed to the `mask` argument.")
844+
845+
frame: PlDataFrame = data._tbl_data
846+
frame_cols = frame.columns
847+
848+
stub_var = data._boxhead.vars_from_type(ColInfoTypeEnum.stub)
849+
group_var = data._boxhead.vars_from_type(ColInfoTypeEnum.row_group)
850+
cols_excl = [*(stub_var if excl_stub else []), *(group_var if excl_group else [])]
851+
852+
# `df.select()` raises `ColumnNotFoundError` if columns are missing from the original DataFrame.
853+
masked = frame.select(expr).drop(cols_excl, strict=False)
854+
855+
# Validate that `masked.columns` exist in the `frame_cols`
856+
missing = set(masked.columns) - set(frame_cols)
857+
if missing:
858+
raise ValueError(
859+
"The `mask` expression produces extra columns, with names not in the original DataFrame."
860+
f"\n\nExtra columns: {missing}"
861+
)
862+
863+
# Validate that row lengths are equal
864+
if masked.height != frame.height:
865+
raise ValueError(
866+
"The DataFrame length after applying `mask` differs from the original."
867+
"\n\n* Original length: {frame.height}"
868+
"\n* Mask length: {masked.height}"
869+
)
870+
871+
cellpos_data: list[tuple[int, int, str]] = [] # column, row, colname for `CellPos`
872+
col_idx_map = {colname: frame_cols.index(colname) for colname in frame_cols}
873+
for row_idx, row_dict in enumerate(masked.iter_rows(named=True)):
874+
for colname, value in row_dict.items():
875+
if value: # select only when `value` is True
876+
col_idx = col_idx_map[colname]
877+
cellpos_data.append((col_idx, row_idx, colname))
878+
return cellpos_data
879+
880+
826881
# Resolve generic ======================================================================
827882

828883

@@ -868,15 +923,22 @@ def _(loc: LocStub, data: GTData) -> set[int]:
868923

869924
@resolve.register
870925
def _(loc: LocBody, data: GTData) -> list[CellPos]:
871-
cols = resolve_cols_i(data=data, expr=loc.columns)
872-
rows = resolve_rows_i(data=data, expr=loc.rows)
873-
874-
# TODO: dplyr arranges by `Var1`, and does distinct (since you can tidyselect the same
875-
# thing multiple times
876-
cell_pos = [
877-
CellPos(col[1], row[1], colname=col[0]) for col, row in itertools.product(cols, rows)
878-
]
926+
if (loc.columns is not None or loc.rows is not None) and loc.mask is not None:
927+
raise ValueError(
928+
"Cannot specify the `mask` argument along with `columns` or `rows` in `loc.body()`."
929+
)
879930

931+
if loc.mask is None:
932+
rows = resolve_rows_i(data=data, expr=loc.rows)
933+
cols = resolve_cols_i(data=data, expr=loc.columns)
934+
# TODO: dplyr arranges by `Var1`, and does distinct (since you can tidyselect the same
935+
# thing multiple times
936+
cell_pos = [
937+
CellPos(col[1], row[1], colname=col[0]) for col, row in itertools.product(cols, rows)
938+
]
939+
else:
940+
cellpos_data = resolve_mask(data=data, expr=loc.mask)
941+
cell_pos = [CellPos(*cellpos) for cellpos in cellpos_data]
880942
return cell_pos
881943

882944

tests/test_tab_create_modify.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,19 @@
55
from great_tables._locations import LocBody
66
from great_tables._styles import CellStyleFill
77
from great_tables._tab_create_modify import tab_style
8+
from polars import selectors as cs
89

910

1011
@pytest.fixture
1112
def gt():
1213
return GT(pd.DataFrame({"x": [1, 2], "y": [4, 5]}))
1314

1415

16+
@pytest.fixture
17+
def gt2():
18+
return GT(pl.DataFrame({"x": [1, 2], "y": [4, 5]}))
19+
20+
1521
def test_tab_style(gt: GT):
1622
style = CellStyleFill(color="blue")
1723
new_gt = tab_style(gt, style, LocBody(["x"], [0]))
@@ -71,3 +77,73 @@ def test_tab_style_font_from_column():
7177

7278
assert rendered_html.find('<td style="font-family: Helvetica;" class="gt_row gt_right">1</td>')
7379
assert rendered_html.find('<td style="font-family: Courier;" class="gt_row gt_right">2</td>')
80+
81+
82+
def test_tab_style_loc_body_mask(gt2: GT):
83+
style = CellStyleFill(color="blue")
84+
new_gt = tab_style(gt2, style, LocBody(mask=cs.numeric().gt(1.5)))
85+
86+
assert len(gt2._styles) == 0
87+
assert len(new_gt._styles) == 3
88+
89+
xy_0y, xy_1x, xy_1y = new_gt._styles
90+
91+
assert xy_0y.styles[0] is style
92+
assert xy_1x.styles[0] is style
93+
assert xy_1y.styles[0] is style
94+
95+
assert xy_0y.rownum == 0
96+
assert xy_0y.colname == "y"
97+
98+
assert xy_1x.rownum == 1
99+
assert xy_1x.colname == "x"
100+
101+
assert xy_1y.rownum == 1
102+
assert xy_1y.colname == "y"
103+
104+
105+
def test_tab_style_loc_body_raises(gt2: GT):
106+
style = CellStyleFill(color="blue")
107+
mask = cs.numeric().gt(1.5)
108+
err_msg = "Cannot specify the `mask` argument along with `columns` or `rows` in `loc.body()`."
109+
110+
with pytest.raises(ValueError) as exc_info:
111+
tab_style(gt2, style, LocBody(columns=["x"], mask=mask))
112+
assert err_msg in exc_info.value.args[0]
113+
114+
with pytest.raises(ValueError) as exc_info:
115+
tab_style(gt2, style, LocBody(rows=[0], mask=mask))
116+
117+
assert err_msg in exc_info.value.args[0]
118+
119+
120+
def test_tab_style_loc_body_mask_not_polars_expression_raises(gt2: GT):
121+
style = CellStyleFill(color="blue")
122+
mask = "fake expression"
123+
err_msg = "Only Polars expressions can be passed to the `mask` argument."
124+
125+
with pytest.raises(ValueError) as exc_info:
126+
tab_style(gt2, style, LocBody(mask=mask))
127+
assert err_msg in exc_info.value.args[0]
128+
129+
130+
def test_tab_style_loc_body_mask_columns_not_inside_raises(gt2: GT):
131+
style = CellStyleFill(color="blue")
132+
mask = pl.len()
133+
err_msg = (
134+
"The `mask` expression produces extra columns, with names not in the original DataFrame."
135+
)
136+
137+
with pytest.raises(ValueError) as exc_info:
138+
tab_style(gt2, style, LocBody(mask=mask))
139+
assert err_msg in exc_info.value.args[0]
140+
141+
142+
def test_tab_style_loc_body_mask_rows_not_equal_raises(gt2: GT):
143+
style = CellStyleFill(color="blue")
144+
mask = pl.len().alias("x")
145+
err_msg = "The DataFrame length after applying `mask` differs from the original."
146+
147+
with pytest.raises(ValueError) as exc_info:
148+
tab_style(gt2, style, LocBody(mask=mask))
149+
assert err_msg in exc_info.value.args[0]

0 commit comments

Comments
 (0)