Skip to content

Commit

Permalink
Merge pull request #49 from posit-dev/fix-preview-refactor
Browse files Browse the repository at this point in the history
fix: refactor step report for column values based validation
  • Loading branch information
rich-iannone authored Feb 5, 2025
2 parents 78514b6 + 70cc6b4 commit 4166afa
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 48 deletions.
136 changes: 98 additions & 38 deletions pointblank/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,44 @@ def preview(
if incl_header is None:
incl_header = global_config.preview_incl_header

return _generate_display_table(
data=data,
columns_subset=columns_subset,
n_head=n_head,
n_tail=n_tail,
limit=limit,
show_row_numbers=show_row_numbers,
max_col_width=max_col_width,
incl_header=incl_header,
mark_missing_values=True,
)


def _generate_display_table(
data: FrameT | Any,
columns_subset: str | list[str] | Column | None = None,
n_head: int = 5,
n_tail: int = 5,
limit: int | None = 50,
show_row_numbers: bool = True,
max_col_width: int | None = 250,
incl_header: bool = None,
mark_missing_values: bool = True,
row_number_list: list[int] | None = None,
) -> GT:

# Make a copy of the data to avoid modifying the original
data = copy.deepcopy(data)

# Does the data table already have a leading row number column?
if "_row_num_" in data.columns:
if data.columns[0] == "_row_num_":
has_leading_row_num_col = True
else:
has_leading_row_num_col = False
else:
has_leading_row_num_col = False

# Check that the n_head and n_tail aren't greater than the limit
if n_head + n_tail > limit:
raise ValueError(f"The sum of `n_head=` and `n_tail=` cannot exceed the limit ({limit}).")
Expand Down Expand Up @@ -482,15 +517,19 @@ def preview(
if n_head + n_tail > n_rows:
full_dataset = True
data_subset = data
row_number_list = range(1, n_rows + 1)

if row_number_list is None:
row_number_list = range(1, n_rows + 1)
else:
# Get the first and last n rows of the table
data_head = data.head(n=n_head)
row_numbers_head = range(1, n_head + 1)
data_tail = data[(n_rows - n_tail) : n_rows]
row_numbers_tail = range(n_rows - n_tail + 1, n_rows + 1)
data_subset = data_head.union(data_tail)
row_number_list = list(row_numbers_head) + list(row_numbers_tail)

if row_number_list is None:
row_number_list = list(row_numbers_head) + list(row_numbers_tail)

# Convert either to Polars or Pandas depending on the available library
if df_lib_name_gt == "polars":
Expand All @@ -514,13 +553,17 @@ def preview(
# If n_head + n_tail is greater than the row count, display the entire table
if n_head + n_tail >= n_rows:
full_dataset = True
row_number_list = range(1, n_rows + 1)

if row_number_list is None:
row_number_list = range(1, n_rows + 1)

else:
data = pl.concat([data.head(n=n_head), data.tail(n=n_tail)])

row_number_list = list(range(1, n_head + 1)) + list(
range(n_rows - n_tail + 1, n_rows + 1)
)
if row_number_list is None:
row_number_list = list(range(1, n_head + 1)) + list(
range(n_rows - n_tail + 1, n_rows + 1)
)

if tbl_type == "pandas":

Expand Down Expand Up @@ -591,7 +634,7 @@ def preview(
for i in range(len(col_dtype_dict.keys()))
]

# Set the column width to the col_widths list
# Set the column width to the `col_widths`` list
col_width_dict = {k: v for k, v in zip(col_names, col_widths)}

# For each of the values in the dictionary, prepend the column name to the data type
Expand All @@ -605,19 +648,28 @@ def preview(
for k, v in col_dtype_dict_short.items()
}

if has_leading_row_num_col:
# Remove the first entry col_width_dict and col_dtype_labels_dict dictionaries
col_width_dict.pop("_row_num_")
col_dtype_labels_dict.pop("_row_num_")

# Prepend a column that contains the row numbers if `show_row_numbers=True`
if show_row_numbers:
if show_row_numbers or has_leading_row_num_col:

if df_lib_name_gt == "polars":
if has_leading_row_num_col:
row_number_list = data["_row_num_"].to_list()

import polars as pl
else:
if df_lib_name_gt == "polars":

import polars as pl

row_number_series = pl.Series("_row_num_", row_number_list)
data = data.insert_column(0, row_number_series)
row_number_series = pl.Series("_row_num_", row_number_list)
data = data.insert_column(0, row_number_series)

if df_lib_name_gt == "pandas":
if df_lib_name_gt == "pandas":

data.insert(0, "_row_num_", row_number_list)
data.insert(0, "_row_num_", row_number_list)

# Get the highest number in the `row_number_list` and calculate a width that will
# safely fit a number of that magnitude
Expand All @@ -626,27 +678,24 @@ def preview(

# Update the col_width_dict to include the row number column
col_width_dict = {"_row_num_": f"{max_row_num_width}px"} | col_width_dict
# Update the col_dtype_labels_dict to include the row number column (use empty string)

# Update the `col_dtype_labels_dict` to include the row number column (use empty string)
col_dtype_labels_dict = {"_row_num_": ""} | col_dtype_labels_dict

# Create the label, table type, and thresholds HTML fragments
table_type_html = _create_table_type_html(
tbl_type=tbl_type, tbl_name=None, font_size="10px"
)
# Create the label, table type, and thresholds HTML fragments
table_type_html = _create_table_type_html(tbl_type=tbl_type, tbl_name=None, font_size="10px")

tbl_dims_html = _create_table_dims_html(
columns=len(col_names), rows=n_rows, font_size="10px"
)
tbl_dims_html = _create_table_dims_html(columns=len(col_names), rows=n_rows, font_size="10px")

# Compose the subtitle HTML fragment
combined_subtitle = (
"<div>"
'<div style="padding-top: 0; padding-bottom: 7px;">'
f"{table_type_html}"
f"{tbl_dims_html}"
"</div>"
"</div>"
)
# Compose the subtitle HTML fragment
combined_subtitle = (
"<div>"
'<div style="padding-top: 0; padding-bottom: 7px;">'
f"{table_type_html}"
f"{tbl_dims_html}"
"</div>"
"</div>"
)

gt_tbl = (
GT(data=data, id="pb_preview_tbl")
Expand Down Expand Up @@ -690,7 +739,7 @@ def preview(
gt_tbl = gt_tbl.tab_header(title=html(combined_subtitle))
gt_tbl = gt_tbl.tab_options(heading_subtitle_font_size="12px")

if none_values:
if none_values and mark_missing_values:
for column, none_index in none_values:
gt_tbl = gt_tbl.tab_style(
style=[style.text(color="#B22222"), style.fill(color="#FFC1C159")],
Expand Down Expand Up @@ -4335,12 +4384,17 @@ def interrogate(
and tbl_type not in IBIS_BACKENDS
):

# Add row numbers to the results table
validation_extract_nw = (
nw.from_native(results_tbl)
.with_row_index(name="_row_num_")
.filter(nw.col("pb_is_good_") == False) # noqa
.drop("pb_is_good_")
)

# Add 1 to the row numbers to make them 1-indexed
validation_extract_nw = validation_extract_nw.with_columns(nw.col("_row_num_") + 1)

# Apply any sampling or limiting to the number of rows to extract
if get_first_n is not None:
validation_extract_nw = validation_extract_nw.head(get_first_n)
Expand Down Expand Up @@ -6454,9 +6508,6 @@ def get_step_report(self, i: int) -> GT:
if not active:
return "This validation step is inactive."

# Get the extracted data for the step
extract = self.get_data_extracts(i=i, frame=True)

# Create a table with a sample of ten rows, highlighting the column of interest
tbl_preview = preview(data=self.data, n_head=5, n_tail=5, limit=10, incl_header=False)

Expand All @@ -6466,6 +6517,9 @@ def get_step_report(self, i: int) -> GT:

if assertion_type in ROW_BASED_VALIDATION_TYPES:

# Get the extracted data for the step
extract = self.get_data_extracts(i=i, frame=True)

step_report = _step_report_row_based(
assertion_type=assertion_type,
i=i,
Expand Down Expand Up @@ -7148,13 +7202,19 @@ def _step_report_row_based(
)

else:

# Create a preview of the extracted data
extract_preview = preview(
data=extract, n_head=1000, n_tail=1000, limit=2000, incl_header=False
extract_tbl = _generate_display_table(
data=extract,
n_head=1000,
n_tail=1000,
limit=2000,
incl_header=False,
mark_missing_values=False,
)

step_report = (
extract_preview.tab_header(
extract_tbl.tab_header(
title=f"Report for Validation Step {i}",
subtitle=html(
"<div>"
Expand Down
Loading

0 comments on commit 4166afa

Please sign in to comment.