-
Notifications
You must be signed in to change notification settings - Fork 71
Improve Left/Right/Inner Join #223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
d54b2a6
43edc68
c199fe7
29a201d
a7c1e30
0fc00db
e32023c
39980fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,9 @@ | |
from typing import List, Tuple | ||
|
||
import dask.dataframe as dd | ||
|
||
# Need pd.NA | ||
import pandas as pd | ||
from dask.base import tokenize | ||
from dask.highlevelgraph import HighLevelGraph | ||
|
||
|
@@ -92,12 +95,32 @@ def convert( | |
# 4. dask can only merge on the same column names. | ||
# We therefore create new columns on purpose, which have a distinct name. | ||
assert len(lhs_on) == len(rhs_on) | ||
# Add two columns (1,2,...) to keep track of observations in left and | ||
# right tables. They must be at the end of the columns since | ||
# _join_on_columns needs the relative order of columns (lhs_on and rhs_on) | ||
# Only dask-supported functions are used (assign and cumsum) so that a | ||
# compute() is not triggered. | ||
df_lhs_renamed = df_lhs_renamed.assign(left_idx=1) | ||
df_lhs_renamed = df_lhs_renamed.assign( | ||
left_idx=df_lhs_renamed["left_idx"].cumsum() | ||
) | ||
df_rhs_renamed = df_rhs_renamed.assign(right_idx=1) | ||
df_rhs_renamed = df_rhs_renamed.assign( | ||
right_idx=df_rhs_renamed["right_idx"].cumsum() | ||
) | ||
|
||
if lhs_on: | ||
# 5. Now we can finally merge on these columns | ||
# The resulting dataframe will contain all (renamed) columns from the lhs and rhs | ||
# plus the added columns | ||
# Need the indicator for left/right join | ||
df = self._join_on_columns( | ||
df_lhs_renamed, df_rhs_renamed, lhs_on, rhs_on, join_type, | ||
df_lhs_renamed, | ||
df_rhs_renamed, | ||
lhs_on, | ||
rhs_on, | ||
join_type, | ||
indicator=True, | ||
) | ||
else: | ||
# 5. We are in the complex join case | ||
|
@@ -148,10 +171,28 @@ def merge_single_partitions(lhs_partition, rhs_partition): | |
ResourceWarning, | ||
) | ||
|
||
# Add _merge to be consistent with the case lhs_on=True | ||
df["_merge"] = "both" | ||
df["_merge"] = df["_merge"].astype("category") | ||
# Put newly added columns to the end | ||
df = df[ | ||
df.columns.drop("left_idx").insert( | ||
df.columns.get_loc("right_idx") - 1, "left_idx" | ||
) | ||
] | ||
|
||
# Completely reset index to uniquely identify each row since there | ||
# could be duplicates. (Yeah. It may be better to inform users that | ||
# index will break. After all, it is expected to be broken since the | ||
# number of rows changes. | ||
df = df.assign(uniqid=1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I did not understand why this is needed, as you are grouping and joining later anyways). Probably I am just too stupid to see it, but maybe it is wise adding more documentation on why you do it. |
||
df = df.assign(uniqid=df["uniqid"].cumsum()).set_index("uniqid") | ||
|
||
# 6. So the next step is to make sure | ||
# we have the correct column order (and to remove the temporary join columns) | ||
correct_column_order = list(df_lhs_renamed.columns) + list( | ||
df_rhs_renamed.columns | ||
# Need to exclude temporary columns left_idx and right_idx | ||
correct_column_order = list(df_lhs_renamed.columns.drop("left_idx")) + list( | ||
df_rhs_renamed.columns.drop("right_idx") | ||
) | ||
cc = ColumnContainer(df.columns).limit_to(correct_column_order) | ||
|
||
|
@@ -177,8 +218,91 @@ def merge_single_partitions(lhs_partition, rhs_partition): | |
for rex in filter_condition | ||
], | ||
) | ||
logger.debug(f"Additionally applying filter {filter_condition}") | ||
df = filter_or_scalar(df, filter_condition) | ||
# Three cases to deal with inequality conditions (left join as an example): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add more documentation here? From the PR context I know that we are dealing with complex join conditions, which consist of an equality and an in-equality join (but there could be more than just inequalities), but someone just reading the code will not know that. Just describe the setting you are dealing with and the actual problem with the naive implementation before going into the details. |
||
# Case 1 [eq_unmatched] (Not matched by equality): | ||
# Left-only from equality join (_merge=='left_only') | ||
# => Keep all | ||
# Case 2 [ineq_unmatched] (Not matched by inequality): | ||
# For unique left_idx, there are no True in filter_condition | ||
# => Set values from right/left table to missing (NaN or NaT) | ||
# => Keep 1 copy and drop duplicates over left_idx (there could | ||
# be duplicates now due to equality match). | ||
# Case 3 (Matched by inequality): | ||
# For unique left_idx, there are 1 or more True in filter_condition | ||
# => Keep obs with True in filter_condition | ||
# This has to be added to df since partition will break the groupby | ||
df["filter_condition"] = filter_condition | ||
if join_type in ["left", "right"]: | ||
# ----- Case 1 (Not matched by equality) | ||
if join_type == "left": | ||
# Flag obs unmatched in equality join | ||
df["eq_unmatched"] = df["_merge"] == "left_only" | ||
idx_varname = "left_idx" | ||
other_varpre = "rhs_" | ||
else: | ||
# Flag obs unmatched in equality join | ||
df["eq_unmatched"] = df["_merge"] == "right_only" | ||
idx_varname = "right_idx" | ||
other_varpre = "lhs_" | ||
|
||
# ----- Case 2 (Not matched by inequality) | ||
|
||
# Set NA (pd.NA) | ||
# Flag obs not matched by inequality | ||
df = df.merge( | ||
(df.groupby(idx_varname)["filter_condition"].agg("sum") < 1) | ||
.rename("ineq_unmatched") | ||
.to_frame(), | ||
left_on=idx_varname, | ||
right_index=True, | ||
how="left", | ||
) | ||
# Assign pd.NA | ||
for v in df.columns[df.columns.str.startswith(other_varpre)]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just tried running the tests and the same code in jupyter notebook , and got the following error There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Which test triggers the error? It seems to be the conversion There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For me, first query (left join) was failing in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK. BTW, I use Python 3.8.10 and Pandas 1.3.2 to run the tests. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good catch @flcong, The Above Error was raised on pandas version |
||
df[v] = df[v].mask( | ||
df["ineq_unmatched"] & (~df["eq_unmatched"]), pd.NA | ||
) | ||
|
||
# Drop duplicates | ||
# Flag the first obs for each unique left_idx | ||
# (or right_idx for right join) in order to remove duplicates | ||
df = df.merge( | ||
df[[idx_varname]] | ||
.drop_duplicates() | ||
.assign(first_elem=True) | ||
.drop(columns=[idx_varname]), | ||
left_index=True, | ||
right_index=True, | ||
how="left", | ||
) | ||
df["first_elem"] = df["first_elem"].fillna(False) | ||
|
||
# ----- The full condition to keep observations | ||
filter_condition_all = ( | ||
df["filter_condition"] | ||
| df["eq_unmatched"] | ||
| (df["ineq_unmatched"] & df["first_elem"]) | ||
) | ||
# Drop added temporary columns | ||
df = df.drop( | ||
columns=[ | ||
"left_idx", | ||
"right_idx", | ||
"_merge", | ||
"filter_condition", | ||
"eq_unmatched", | ||
"ineq_unmatched", | ||
"first_elem", | ||
] | ||
) | ||
elif join_type == "inner": | ||
filter_condition_all = filter_condition | ||
# TODO: Full Join | ||
|
||
logger.debug(f"Additionally applying filter {filter_condition_all}") | ||
df = filter_or_scalar(df, filter_condition_all) | ||
# Reset index (maybe notify users that dask-sql may break index) | ||
df = df.reset_index(drop=True) | ||
dc = DataContainer(df, cc) | ||
|
||
dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) | ||
|
@@ -191,6 +315,7 @@ def _join_on_columns( | |
lhs_on: List[str], | ||
rhs_on: List[str], | ||
join_type: str, | ||
indicator: bool = False, | ||
) -> dd.DataFrame: | ||
lhs_columns_to_add = { | ||
f"common_{i}": df_lhs_renamed.iloc[:, index] | ||
|
@@ -222,7 +347,13 @@ def _join_on_columns( | |
df_rhs_with_tmp = df_rhs_renamed.assign(**rhs_columns_to_add) | ||
added_columns = list(lhs_columns_to_add.keys()) | ||
|
||
df = dd.merge(df_lhs_with_tmp, df_rhs_with_tmp, on=added_columns, how=join_type) | ||
df = dd.merge( | ||
df_lhs_with_tmp, | ||
df_rhs_with_tmp, | ||
on=added_columns, | ||
how=join_type, | ||
indicator=indicator, | ||
) | ||
|
||
return df | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -86,6 +86,88 @@ def datetime_table(): | |
) | ||
|
||
|
||
@pytest.fixture | ||
def user_table_lk(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I love those tests, they are super cool because they seem like coming from a real use-case, which is absolute brilliant. However, can we also have a very simple one with just like 3-4 lines and two columns (e.g. the one I used in my comments)? This makes debugging much easier than skimming though multiple lines which (because the columns are so wide) even span a lot of space in the editor. I can also take care of this if you want! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah. I realize that. I simplified the new tests. |
||
# Link table identified by id and date range (startdate and enddate) | ||
# Used for query with both equality and inequality conditions | ||
out = pd.DataFrame( | ||
[ | ||
[0, 0, 2, pd.NA, 110, "a1", 1.1, pd.Timestamp("2001-01-01")], | ||
[0, 4, 6, pd.NA, 111, "a2", 1.2, pd.Timestamp("2001-02-01")], | ||
[1, 2, 5, pd.NA, 112, "a3", np.nan, pd.Timestamp("2001-03-01")], | ||
[1, 4, 6, 13, 113, "a4", np.nan, pd.Timestamp("2001-04-01")], | ||
[3, 1, 2, 14, 114, "a5", np.nan, pd.NaT], | ||
[3, 2, 3, 15, 115, "a6", 1.6, pd.NaT], | ||
], | ||
columns=[ | ||
"id", | ||
"startdate", | ||
"enddate", | ||
"lk_nullint", | ||
"lk_int", | ||
"lk_str", | ||
"lk_float", | ||
"lk_date", | ||
], | ||
) | ||
out["lk_nullint"] = out["lk_nullint"].astype("Int32") | ||
out["lk_str"] = out["lk_str"].astype("string") | ||
return out | ||
|
||
|
||
@pytest.fixture | ||
def user_table_lk2(user_table_lk): | ||
# Link table identified by only date range (startdate and enddate) | ||
# Used for query with inequality conditions | ||
return user_table_lk.set_index("id").loc[1].reset_index(drop=True) | ||
|
||
|
||
@pytest.fixture | ||
def user_table_ts(): | ||
# A table of time-series data identified by dates | ||
out = pd.DataFrame( | ||
[ | ||
[3, pd.NA, 221, "b1", 2.1, pd.Timestamp("2002-01-01")], | ||
[4, 22, 222, "b2", np.nan, pd.Timestamp("2002-02-01")], | ||
[7, 23, 223, "b3", 2.3, pd.NaT], | ||
], | ||
columns=["dates", "ts_nullint", "ts_int", "ts_str", "ts_float", "ts_date"], | ||
) | ||
out["ts_nullint"] = out["ts_nullint"].astype("Int32") | ||
out["ts_str"] = out["ts_str"].astype("string") | ||
return out | ||
|
||
|
||
@pytest.fixture | ||
def user_table_pn(): | ||
# A panel table identified by id and dates | ||
out = pd.DataFrame( | ||
[ | ||
[0, 1, pd.NA, 331, "c1", 3.1, pd.Timestamp("2003-01-01")], | ||
[0, 2, pd.NA, 332, "c2", 3.2, pd.Timestamp("2003-02-01")], | ||
[0, 3, pd.NA, 333, "c3", 3.3, pd.Timestamp("2003-03-01")], | ||
[1, 3, pd.NA, 334, "c4", np.nan, pd.Timestamp("2003-04-01")], | ||
[1, 4, 35, 335, "c5", np.nan, pd.Timestamp("2003-05-01")], | ||
[2, 1, 36, 336, "c6", np.nan, pd.Timestamp("2003-06-01")], | ||
[2, 3, 37, 337, "c7", np.nan, pd.NaT], | ||
[3, 2, 38, 338, "c8", 3.8, pd.NaT], | ||
[3, 2, 39, 339, "c9", 3.9, pd.NaT], | ||
], | ||
columns=[ | ||
"ids", | ||
"dates", | ||
"pn_nullint", | ||
"pn_int", | ||
"pn_str", | ||
"pn_float", | ||
"pn_date", | ||
], | ||
) | ||
out["pn_nullint"] = out["pn_nullint"].astype("Int32") | ||
out["pn_str"] = out["pn_str"].astype("string") | ||
return out | ||
|
||
|
||
@pytest.fixture() | ||
def c( | ||
df_simple, | ||
|
@@ -97,6 +179,10 @@ def c( | |
user_table_nan, | ||
string_table, | ||
datetime_table, | ||
user_table_lk, | ||
user_table_lk2, | ||
user_table_ts, | ||
user_table_pn, | ||
): | ||
dfs = { | ||
"df_simple": df_simple, | ||
|
@@ -108,6 +194,10 @@ def c( | |
"user_table_nan": user_table_nan, | ||
"string_table": string_table, | ||
"datetime_table": datetime_table, | ||
"user_table_lk": user_table_lk, | ||
"user_table_lk2": user_table_lk2, | ||
"user_table_ts": user_table_ts, | ||
"user_table_pn": user_table_pn, | ||
} | ||
|
||
# Lazy import, otherwise the pytest framework has problems | ||
|
@@ -134,7 +224,9 @@ def temporary_data_file(): | |
|
||
|
||
@pytest.fixture() | ||
def assert_query_gives_same_result(engine): | ||
def assert_query_gives_same_result( | ||
engine, user_table_lk, user_table_lk2, user_table_ts, user_table_pn, | ||
): | ||
np.random.seed(42) | ||
|
||
df1 = dd.from_pandas( | ||
|
@@ -191,12 +283,22 @@ def assert_query_gives_same_result(engine): | |
c.create_table("df1", df1) | ||
c.create_table("df2", df2) | ||
c.create_table("df3", df3) | ||
c.create_table("user_table_ts", user_table_ts) | ||
c.create_table("user_table_pn", user_table_pn) | ||
c.create_table("user_table_lk", user_table_lk) | ||
c.create_table("user_table_lk2", user_table_lk2) | ||
|
||
df1.compute().to_sql("df1", engine, index=False, if_exists="replace") | ||
df2.compute().to_sql("df2", engine, index=False, if_exists="replace") | ||
df3.compute().to_sql("df3", engine, index=False, if_exists="replace") | ||
|
||
def _assert_query_gives_same_result(query, sort_columns=None, **kwargs): | ||
user_table_ts.to_sql("user_table_ts", engine, index=False, if_exists="replace") | ||
user_table_pn.to_sql("user_table_pn", engine, index=False, if_exists="replace") | ||
user_table_lk.to_sql("user_table_lk", engine, index=False, if_exists="replace") | ||
user_table_lk2.to_sql("user_table_lk2", engine, index=False, if_exists="replace") | ||
|
||
def _assert_query_gives_same_result( | ||
query, sort_columns=None, force_dtype=None, check_dtype=False, **kwargs, | ||
): | ||
sql_result = pd.read_sql_query(query, engine) | ||
dask_result = c.sql(query).compute() | ||
|
||
|
@@ -211,7 +313,15 @@ def _assert_query_gives_same_result(query, sort_columns=None, **kwargs): | |
sql_result = sql_result.reset_index(drop=True) | ||
dask_result = dask_result.reset_index(drop=True) | ||
|
||
assert_frame_equal(sql_result, dask_result, check_dtype=False, **kwargs) | ||
# Change dtypes | ||
if force_dtype == "sql": | ||
for col, dtype in sql_result.dtypes.iteritems(): | ||
dask_result[col] = dask_result[col].astype(dtype) | ||
elif force_dtype == "dask": | ||
for col, dtype in dask_result.dtypes.iteritems(): | ||
sql_result[col] = sql_result[col].astype(dtype) | ||
|
||
assert_frame_equal(sql_result, dask_result, check_dtype=check_dtype, **kwargs) | ||
|
||
return _assert_query_gives_same_result | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes as well as the index-reset further down are only used when there is a filter condition. I do not expect this to be the default. Therefore I think it would be nice if we do not touch the "normal" use case and do not introduce another performance drawback for the "normal" user. I think a simple
if filter_condition
and some comment should be enough, or what do you think?