diff --git a/sklbench/report/implementation.py b/sklbench/report/implementation.py index 28fa2bb0..3d419048 100644 --- a/sklbench/report/implementation.py +++ b/sklbench/report/implementation.py @@ -16,8 +16,10 @@ import argparse import json +from functools import reduce from typing import Dict, List +import numpy as np import openpyxl as xl import pandas as pd from openpyxl.formatting.rule import ColorScaleRule @@ -165,7 +167,12 @@ def select_comparison(i, j, diffs_selection): df = input_df.set_index(index_columns) unique_indices = df.index.unique() splitted_dfs = split_df_by_columns(input_df, diff_columns) - splitted_dfs = {key: df.set_index(index_columns) for key, df in splitted_dfs.items()} + common_cols = reduce(np.intersect1d, [df.columns for df in splitted_dfs.values()]) + df_specific_cols = np.setdiff1d(index_columns, common_cols) + splitted_dfs = { + key: df.assign(**{col: None for col in df_specific_cols}).set_index(index_columns) + for key, df in splitted_dfs.items() + } # drop results with duplicated indices (keep first entry only) for key, splitted_df in splitted_dfs.items(): @@ -181,6 +188,9 @@ def select_comparison(i, j, diffs_selection): # compared values for i, (key_ith, df_ith) in enumerate(splitted_dfs.items()): for j, (key_jth, df_jth) in enumerate(splitted_dfs.items()): + common_indexes = np.intersect1d(df_ith.index, df_jth.index) + df_ith = df_ith.loc[common_indexes] + df_jth = df_jth.loc[common_indexes] if select_comparison(i, j, diffs_selection): comparison_name = f"{key_jth} vs {key_ith}" for column in df_ith.columns: @@ -196,7 +206,9 @@ def select_comparison(i, j, diffs_selection): df[f"{comparison_name}\n{column} is equal"] = ( df_ith[column] == df_jth[column] ) - df = df.reset_index() + if len(df_specific_cols): + df.index = df.index.droplevel(list(df_specific_cols)) + df = df.dropna(axis=0, how="all", ignore_index=False).reset_index() # move to multi-index df = df[reorder_columns(list(df.columns))] df.columns = [