Skip to content

Commit 598d3ff

Browse files
committed
Ensure missing val table works for all input types
1 parent 778ff82 commit 598d3ff

File tree

1 file changed

+108
-16
lines changed

1 file changed

+108
-16
lines changed

pointblank/validate.py

Lines changed: 108 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,12 @@ def missing_vals_tbl(data: FrameT | Any) -> GT:
824824
# Get the number of rows in the table
825825
n_rows = get_row_count(data)
826826

827+
# Define the number of cut points for the missing values table
828+
n_cut_points = 11
829+
830+
# Get the cut points for the table preview
831+
cut_points = _get_cut_points(n_rows=n_rows, n_cuts=n_cut_points)
832+
827833
# Determine if the table is a DataFrame or an Ibis table
828834
tbl_type = _get_tbl_type(data=data)
829835
ibis_tbl = "ibis.expr.types.relations.Table" in str(type(data))
@@ -850,18 +856,9 @@ def missing_vals_tbl(data: FrameT | Any) -> GT:
850856
# the proportion of missing values in each 'sector' in each column
851857
if ibis_tbl:
852858

853-
# Get the row count for the table
854-
ibis_rows = data.count()
855-
n_rows = ibis_rows.to_polars() if df_lib_name_gt == "polars" else int(ibis_rows.to_pandas())
856-
857859
# Get the column names from the table
858860
col_names = list(data.columns)
859861

860-
n_cut_points = 11
861-
862-
# Get the cut points for the table preview
863-
cut_points = _get_cut_points(n_rows=n_rows, n_cuts=n_cut_points)
864-
865862
# Iterate over the cut points and get the proportion of missing values in each 'sector'
866863
# for each column
867864
missing_vals = {
@@ -877,24 +874,97 @@ def missing_vals_tbl(data: FrameT | Any) -> GT:
877874
for col in data.columns
878875
}
879876

880-
# Get a dictionary of counts of missing values in each column
881-
missing_val_counts = {col: data[col].isnull().sum().to_polars() for col in data.columns}
882-
877+
# Pivot the `missing_vals` dictionary to create a table with the missing value proportions
883878
missing_vals = {
884879
"columns": list(missing_vals.keys()),
885880
**{
886881
str(i + 1): [missing_vals[col][i] for col in missing_vals.keys()] for i in range(10)
887882
},
888883
}
889884

890-
# From `missing_vals`, create a DataFrame with the missing value proportions
885+
# Get a dictionary of counts of missing values in each column
886+
missing_val_counts = {col: data[col].isnull().sum().to_polars() for col in data.columns}
887+
888+
if pl_pb_tbl:
889+
890+
# Get the column names from the table
891+
col_names = list(data.columns)
892+
893+
# Iterate over the cut points and get the proportion of missing values in each 'sector'
894+
# for each column
895+
if "polars" in tbl_type:
896+
897+
# Polars case
898+
missing_vals = {
899+
col: [
900+
(
901+
data[(cut_points[i] - 1) : cut_points[i]][col].is_null().sum()
902+
/ (cut_points[i] - (cut_points[i - 1] if i > 0 else 0))
903+
if cut_points[i] > (cut_points[i - 1] if i > 0 else 0)
904+
else 0
905+
)
906+
for i in range(len(cut_points))
907+
]
908+
for col in data.columns
909+
}
910+
911+
# Pivot the `missing_vals` dictionary to create a table with the missing
912+
# value proportions
913+
missing_vals = {
914+
"columns": list(missing_vals.keys()),
915+
**{
916+
str(i + 1): [missing_vals[col][i] for col in missing_vals.keys()]
917+
for i in range(10)
918+
},
919+
}
920+
921+
# Get a dictionary of counts of missing values in each column
922+
missing_val_counts = {col: data[col].is_null().sum() for col in data.columns}
923+
924+
if "pandas" in tbl_type:
925+
926+
# Pandas case (for this case, if final values are zero then use pd.NA)
927+
missing_vals = {
928+
col: [
929+
(
930+
data[(cut_points[i] - 1) : cut_points[i]][col].isnull().sum()
931+
/ (cut_points[i] - (cut_points[i - 1] if i > 0 else 0))
932+
if cut_points[i] > (cut_points[i - 1] if i > 0 else 0)
933+
else 0
934+
)
935+
for i in range(len(cut_points))
936+
]
937+
for col in data.columns
938+
}
939+
940+
# Pivot the `missing_vals` dictionary to create a table with the missing
941+
# value proportions
942+
missing_vals = {
943+
"columns": list(missing_vals.keys()),
944+
**{
945+
str(i + 1): [missing_vals[col][i] for col in missing_vals.keys()]
946+
for i in range(10)
947+
},
948+
}
949+
950+
# Get a dictionary of counts of missing values in each column
951+
missing_val_counts = {col: data[col].isnull().sum() for col in data.columns}
952+
953+
# From `missing_vals`, create da DataFrame with the missing value proportions
891954
if df_lib_name_gt == "polars":
892955

893956
import polars as pl
894957

895958
# Create a Polars DataFrame from the `missing_vals` dictionary
896959
missing_vals_df = pl.DataFrame(missing_vals)
897960

961+
else:
962+
963+
import pandas as pd
964+
965+
# Create a Pandas DataFrame from the `missing_vals` dictionary
966+
missing_vals_df = pd.DataFrame(missing_vals)
967+
898968
# Get a count of total missing values
899969
n_missing_total = sum(missing_val_counts.values())
900970

@@ -924,8 +994,6 @@ def missing_vals_tbl(data: FrameT | Any) -> GT:
924994
"</div>"
925995
)
926996

927-
import polars.selectors as cs
928-
929997
missing_vals_tbl = (
930998
GT(missing_vals_df)
931999
.tab_header(title=html(combined_title), subtitle=html(combined_subtitle))
@@ -975,9 +1043,33 @@ def missing_vals_tbl(data: FrameT | Any) -> GT:
9751043
locations=loc.column_labels(),
9761044
)
9771045
.fmt(fns=lambda x: "", columns=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"])
978-
.tab_style(style=style.fill(color="lightblue"), locations=loc.body(mask=cs.numeric().eq(0)))
9791046
)
9801047

1048+
#
1049+
# Highlight sectors of the table where there are no missing values
1050+
#
1051+
1052+
if df_lib_name_gt == "polars":
1053+
1054+
import polars.selectors as cs
1055+
1056+
missing_vals_tbl = missing_vals_tbl.tab_style(
1057+
style=style.fill(color="lightblue"), locations=loc.body(mask=cs.numeric().eq(0))
1058+
)
1059+
1060+
if df_lib_name_gt == "pandas":
1061+
1062+
# For every column in the DataFrame, determine the indices of the rows where the value is 0
1063+
# and use tab_style to fill the cell with a light blue color
1064+
for col in missing_vals_df.columns:
1065+
1066+
row_indices = list(missing_vals_df[missing_vals_df[col] == 0].index)
1067+
1068+
missing_vals_tbl = missing_vals_tbl.tab_style(
1069+
style=style.fill(color="lightblue"),
1070+
locations=loc.body(columns=col, rows=row_indices),
1071+
)
1072+
9811073
return missing_vals_tbl
9821074

9831075

0 commit comments

Comments
 (0)