Skip to content

Commit 4166afa

Browse files
authored
Merge pull request #49 from posit-dev/fix-preview-refactor
fix: refactor step report for column values based validation
2 parents 78514b6 + 70cc6b4 commit 4166afa

File tree

3 files changed

+108
-48
lines changed

3 files changed

+108
-48
lines changed

pointblank/validate.py

Lines changed: 98 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -410,9 +410,44 @@ def preview(
410410
if incl_header is None:
411411
incl_header = global_config.preview_incl_header
412412

413+
return _generate_display_table(
414+
data=data,
415+
columns_subset=columns_subset,
416+
n_head=n_head,
417+
n_tail=n_tail,
418+
limit=limit,
419+
show_row_numbers=show_row_numbers,
420+
max_col_width=max_col_width,
421+
incl_header=incl_header,
422+
mark_missing_values=True,
423+
)
424+
425+
426+
def _generate_display_table(
427+
data: FrameT | Any,
428+
columns_subset: str | list[str] | Column | None = None,
429+
n_head: int = 5,
430+
n_tail: int = 5,
431+
limit: int | None = 50,
432+
show_row_numbers: bool = True,
433+
max_col_width: int | None = 250,
434+
incl_header: bool = None,
435+
mark_missing_values: bool = True,
436+
row_number_list: list[int] | None = None,
437+
) -> GT:
438+
413439
# Make a copy of the data to avoid modifying the original
414440
data = copy.deepcopy(data)
415441

442+
# Does the data table already have a leading row number column?
443+
if "_row_num_" in data.columns:
444+
if data.columns[0] == "_row_num_":
445+
has_leading_row_num_col = True
446+
else:
447+
has_leading_row_num_col = False
448+
else:
449+
has_leading_row_num_col = False
450+
416451
# Check that the n_head and n_tail aren't greater than the limit
417452
if n_head + n_tail > limit:
418453
raise ValueError(f"The sum of `n_head=` and `n_tail=` cannot exceed the limit ({limit}).")
@@ -482,15 +517,19 @@ def preview(
482517
if n_head + n_tail > n_rows:
483518
full_dataset = True
484519
data_subset = data
485-
row_number_list = range(1, n_rows + 1)
520+
521+
if row_number_list is None:
522+
row_number_list = range(1, n_rows + 1)
486523
else:
487524
# Get the first and last n rows of the table
488525
data_head = data.head(n=n_head)
489526
row_numbers_head = range(1, n_head + 1)
490527
data_tail = data[(n_rows - n_tail) : n_rows]
491528
row_numbers_tail = range(n_rows - n_tail + 1, n_rows + 1)
492529
data_subset = data_head.union(data_tail)
493-
row_number_list = list(row_numbers_head) + list(row_numbers_tail)
530+
531+
if row_number_list is None:
532+
row_number_list = list(row_numbers_head) + list(row_numbers_tail)
494533

495534
# Convert either to Polars or Pandas depending on the available library
496535
if df_lib_name_gt == "polars":
@@ -514,13 +553,17 @@ def preview(
514553
# If n_head + n_tail is greater than the row count, display the entire table
515554
if n_head + n_tail >= n_rows:
516555
full_dataset = True
517-
row_number_list = range(1, n_rows + 1)
556+
557+
if row_number_list is None:
558+
row_number_list = range(1, n_rows + 1)
559+
518560
else:
519561
data = pl.concat([data.head(n=n_head), data.tail(n=n_tail)])
520562

521-
row_number_list = list(range(1, n_head + 1)) + list(
522-
range(n_rows - n_tail + 1, n_rows + 1)
523-
)
563+
if row_number_list is None:
564+
row_number_list = list(range(1, n_head + 1)) + list(
565+
range(n_rows - n_tail + 1, n_rows + 1)
566+
)
524567

525568
if tbl_type == "pandas":
526569

@@ -591,7 +634,7 @@ def preview(
591634
for i in range(len(col_dtype_dict.keys()))
592635
]
593636

594-
# Set the column width to the col_widths list
637+
# Set the column width to the `col_widths`` list
595638
col_width_dict = {k: v for k, v in zip(col_names, col_widths)}
596639

597640
# For each of the values in the dictionary, prepend the column name to the data type
@@ -605,19 +648,28 @@ def preview(
605648
for k, v in col_dtype_dict_short.items()
606649
}
607650

651+
if has_leading_row_num_col:
652+
# Remove the first entry col_width_dict and col_dtype_labels_dict dictionaries
653+
col_width_dict.pop("_row_num_")
654+
col_dtype_labels_dict.pop("_row_num_")
655+
608656
# Prepend a column that contains the row numbers if `show_row_numbers=True`
609-
if show_row_numbers:
657+
if show_row_numbers or has_leading_row_num_col:
610658

611-
if df_lib_name_gt == "polars":
659+
if has_leading_row_num_col:
660+
row_number_list = data["_row_num_"].to_list()
612661

613-
import polars as pl
662+
else:
663+
if df_lib_name_gt == "polars":
664+
665+
import polars as pl
614666

615-
row_number_series = pl.Series("_row_num_", row_number_list)
616-
data = data.insert_column(0, row_number_series)
667+
row_number_series = pl.Series("_row_num_", row_number_list)
668+
data = data.insert_column(0, row_number_series)
617669

618-
if df_lib_name_gt == "pandas":
670+
if df_lib_name_gt == "pandas":
619671

620-
data.insert(0, "_row_num_", row_number_list)
672+
data.insert(0, "_row_num_", row_number_list)
621673

622674
# Get the highest number in the `row_number_list` and calculate a width that will
623675
# safely fit a number of that magnitude
@@ -626,27 +678,24 @@ def preview(
626678

627679
# Update the col_width_dict to include the row number column
628680
col_width_dict = {"_row_num_": f"{max_row_num_width}px"} | col_width_dict
629-
# Update the col_dtype_labels_dict to include the row number column (use empty string)
681+
682+
# Update the `col_dtype_labels_dict` to include the row number column (use empty string)
630683
col_dtype_labels_dict = {"_row_num_": ""} | col_dtype_labels_dict
631684

632-
# Create the label, table type, and thresholds HTML fragments
633-
table_type_html = _create_table_type_html(
634-
tbl_type=tbl_type, tbl_name=None, font_size="10px"
635-
)
685+
# Create the label, table type, and thresholds HTML fragments
686+
table_type_html = _create_table_type_html(tbl_type=tbl_type, tbl_name=None, font_size="10px")
636687

637-
tbl_dims_html = _create_table_dims_html(
638-
columns=len(col_names), rows=n_rows, font_size="10px"
639-
)
688+
tbl_dims_html = _create_table_dims_html(columns=len(col_names), rows=n_rows, font_size="10px")
640689

641-
# Compose the subtitle HTML fragment
642-
combined_subtitle = (
643-
"<div>"
644-
'<div style="padding-top: 0; padding-bottom: 7px;">'
645-
f"{table_type_html}"
646-
f"{tbl_dims_html}"
647-
"</div>"
648-
"</div>"
649-
)
690+
# Compose the subtitle HTML fragment
691+
combined_subtitle = (
692+
"<div>"
693+
'<div style="padding-top: 0; padding-bottom: 7px;">'
694+
f"{table_type_html}"
695+
f"{tbl_dims_html}"
696+
"</div>"
697+
"</div>"
698+
)
650699

651700
gt_tbl = (
652701
GT(data=data, id="pb_preview_tbl")
@@ -690,7 +739,7 @@ def preview(
690739
gt_tbl = gt_tbl.tab_header(title=html(combined_subtitle))
691740
gt_tbl = gt_tbl.tab_options(heading_subtitle_font_size="12px")
692741

693-
if none_values:
742+
if none_values and mark_missing_values:
694743
for column, none_index in none_values:
695744
gt_tbl = gt_tbl.tab_style(
696745
style=[style.text(color="#B22222"), style.fill(color="#FFC1C159")],
@@ -4335,12 +4384,17 @@ def interrogate(
43354384
and tbl_type not in IBIS_BACKENDS
43364385
):
43374386

4387+
# Add row numbers to the results table
43384388
validation_extract_nw = (
43394389
nw.from_native(results_tbl)
4390+
.with_row_index(name="_row_num_")
43404391
.filter(nw.col("pb_is_good_") == False) # noqa
43414392
.drop("pb_is_good_")
43424393
)
43434394

4395+
# Add 1 to the row numbers to make them 1-indexed
4396+
validation_extract_nw = validation_extract_nw.with_columns(nw.col("_row_num_") + 1)
4397+
43444398
# Apply any sampling or limiting to the number of rows to extract
43454399
if get_first_n is not None:
43464400
validation_extract_nw = validation_extract_nw.head(get_first_n)
@@ -6454,9 +6508,6 @@ def get_step_report(self, i: int) -> GT:
64546508
if not active:
64556509
return "This validation step is inactive."
64566510

6457-
# Get the extracted data for the step
6458-
extract = self.get_data_extracts(i=i, frame=True)
6459-
64606511
# Create a table with a sample of ten rows, highlighting the column of interest
64616512
tbl_preview = preview(data=self.data, n_head=5, n_tail=5, limit=10, incl_header=False)
64626513

@@ -6466,6 +6517,9 @@ def get_step_report(self, i: int) -> GT:
64666517

64676518
if assertion_type in ROW_BASED_VALIDATION_TYPES:
64686519

6520+
# Get the extracted data for the step
6521+
extract = self.get_data_extracts(i=i, frame=True)
6522+
64696523
step_report = _step_report_row_based(
64706524
assertion_type=assertion_type,
64716525
i=i,
@@ -7148,13 +7202,19 @@ def _step_report_row_based(
71487202
)
71497203

71507204
else:
7205+
71517206
# Create a preview of the extracted data
7152-
extract_preview = preview(
7153-
data=extract, n_head=1000, n_tail=1000, limit=2000, incl_header=False
7207+
extract_tbl = _generate_display_table(
7208+
data=extract,
7209+
n_head=1000,
7210+
n_tail=1000,
7211+
limit=2000,
7212+
incl_header=False,
7213+
mark_missing_values=False,
71547214
)
71557215

71567216
step_report = (
7157-
extract_preview.tab_header(
7217+
extract_tbl.tab_header(
71587218
title=f"Report for Validation Step {i}",
71597219
subtitle=html(
71607220
"<div>"

0 commit comments

Comments
 (0)