Skip to content

Improve Left/Right/Inner Join #223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 137 additions & 6 deletions dask_sql/physical/rel/logical/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from typing import List, Tuple

import dask.dataframe as dd

# Need pd.NA
import pandas as pd
from dask.base import tokenize
from dask.highlevelgraph import HighLevelGraph

Expand Down Expand Up @@ -92,12 +95,32 @@ def convert(
# 4. dask can only merge on the same column names.
# We therefore create new columns on purpose, which have a distinct name.
assert len(lhs_on) == len(rhs_on)
# Add two columns (1,2,...) to keep track of observations in left and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes as well as the index-reset further down are only used when there is a filter condition. I do not expect this to be the default. Therefore I think it would be nice if we do not touch the "normal" use case and do not introduce another performance drawback for the "normal" user. I think a simple if filter_condition and some comment should be enough, or what do you think?

# right tables. They must be at the end of the columns since
# _join_on_columns needs the relative order of columns (lhs_on and rhs_on)
# Only dask-supported functions are used (assign and cumsum) so that a
# compute() is not triggered.
df_lhs_renamed = df_lhs_renamed.assign(left_idx=1)
df_lhs_renamed = df_lhs_renamed.assign(
left_idx=df_lhs_renamed["left_idx"].cumsum()
)
df_rhs_renamed = df_rhs_renamed.assign(right_idx=1)
df_rhs_renamed = df_rhs_renamed.assign(
right_idx=df_rhs_renamed["right_idx"].cumsum()
)

if lhs_on:
# 5. Now we can finally merge on these columns
# The resulting dataframe will contain all (renamed) columns from the lhs and rhs
# plus the added columns
# Need the indicator for left/right join
df = self._join_on_columns(
df_lhs_renamed, df_rhs_renamed, lhs_on, rhs_on, join_type,
df_lhs_renamed,
df_rhs_renamed,
lhs_on,
rhs_on,
join_type,
indicator=True,
)
else:
# 5. We are in the complex join case
Expand Down Expand Up @@ -148,10 +171,28 @@ def merge_single_partitions(lhs_partition, rhs_partition):
ResourceWarning,
)

# Add _merge to be consistent with the case lhs_on=True
df["_merge"] = "both"
df["_merge"] = df["_merge"].astype("category")
# Put newly added columns to the end
df = df[
df.columns.drop("left_idx").insert(
df.columns.get_loc("right_idx") - 1, "left_idx"
)
]

# Completely reset index to uniquely identify each row since there
# could be duplicates. (Yeah. It may be better to inform users that
# index will break. After all, it is expected to be broken since the
# number of rows changes.
df = df.assign(uniqid=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I did not understand why this is needed, as you are grouping and joining later anyways). Probably I am just too stupid to see it, but maybe it is wise adding more documentation on why you do it.
(also, I am not 100% sure if this does not trigger a calculation, as Dask needs to know about the divisions. But I did not check)

df = df.assign(uniqid=df["uniqid"].cumsum()).set_index("uniqid")

# 6. So the next step is to make sure
# we have the correct column order (and to remove the temporary join columns)
correct_column_order = list(df_lhs_renamed.columns) + list(
df_rhs_renamed.columns
# Need to exclude temporary columns left_idx and right_idx
correct_column_order = list(df_lhs_renamed.columns.drop("left_idx")) + list(
df_rhs_renamed.columns.drop("right_idx")
)
cc = ColumnContainer(df.columns).limit_to(correct_column_order)

Expand All @@ -177,8 +218,91 @@ def merge_single_partitions(lhs_partition, rhs_partition):
for rex in filter_condition
],
)
logger.debug(f"Additionally applying filter {filter_condition}")
df = filter_or_scalar(df, filter_condition)
# Three cases to deal with inequality conditions (left join as an example):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add more documentation here? From the PR context I know that we are dealing with complex join conditions, which consist of an equality and an in-equality join (but there could be more than just inequalities), but someone just reading the code will not know that. Just describe the setting you are dealing with and the actual problem with the naive implementation before going into the details.

# Case 1 [eq_unmatched] (Not matched by equality):
# Left-only from equality join (_merge=='left_only')
# => Keep all
# Case 2 [ineq_unmatched] (Not matched by inequality):
# For unique left_idx, there are no True in filter_condition
# => Set values from right/left table to missing (NaN or NaT)
# => Keep 1 copy and drop duplicates over left_idx (there could
# be duplicates now due to equality match).
# Case 3 (Matched by inequality):
# For unique left_idx, there are 1 or more True in filter_condition
# => Keep obs with True in filter_condition
# This has to be added to df since partition will break the groupby
df["filter_condition"] = filter_condition
if join_type in ["left", "right"]:
# ----- Case 1 (Not matched by equality)
if join_type == "left":
# Flag obs unmatched in equality join
df["eq_unmatched"] = df["_merge"] == "left_only"
idx_varname = "left_idx"
other_varpre = "rhs_"
else:
# Flag obs unmatched in equality join
df["eq_unmatched"] = df["_merge"] == "right_only"
idx_varname = "right_idx"
other_varpre = "lhs_"

# ----- Case 2 (Not matched by inequality)

# Set NA (pd.NA)
# Flag obs not matched by inequality
df = df.merge(
(df.groupby(idx_varname)["filter_condition"].agg("sum") < 1)
.rename("ineq_unmatched")
.to_frame(),
left_on=idx_varname,
right_index=True,
how="left",
)
# Assign pd.NA
for v in df.columns[df.columns.str.startswith(other_varpre)]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just tried running the tests and the same code in jupyter notebook , and got the following error TypeError: float() argument must be a string or a number, not 'NAType' , if you are not getting this error then it is probably due to environment mismatch error I guess, Or any other guess regarding this Error?
maybe due to setting pd.NA ? what do you think @flcong?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which test triggers the error? It seems to be the conversion float(pd.NA)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me, first query (left join) was failing in test_join_lricomplex
I will try to provide more context by tomorrow :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. BTW, I use Python 3.8.10 and Pandas 1.3.2 to run the tests.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good catch @flcong, The Above Error was raised on pandas version 1.2.4 and working fine in pandas==1.3.2

df[v] = df[v].mask(
df["ineq_unmatched"] & (~df["eq_unmatched"]), pd.NA
)

# Drop duplicates
# Flag the first obs for each unique left_idx
# (or right_idx for right join) in order to remove duplicates
df = df.merge(
df[[idx_varname]]
.drop_duplicates()
.assign(first_elem=True)
.drop(columns=[idx_varname]),
left_index=True,
right_index=True,
how="left",
)
df["first_elem"] = df["first_elem"].fillna(False)

# ----- The full condition to keep observations
filter_condition_all = (
df["filter_condition"]
| df["eq_unmatched"]
| (df["ineq_unmatched"] & df["first_elem"])
)
# Drop added temporary columns
df = df.drop(
columns=[
"left_idx",
"right_idx",
"_merge",
"filter_condition",
"eq_unmatched",
"ineq_unmatched",
"first_elem",
]
)
elif join_type == "inner":
filter_condition_all = filter_condition
# TODO: Full Join

logger.debug(f"Additionally applying filter {filter_condition_all}")
df = filter_or_scalar(df, filter_condition_all)
# Reset index (maybe notify users that dask-sql may break index)
df = df.reset_index(drop=True)
dc = DataContainer(df, cc)

dc = self.fix_dtype_to_row_type(dc, rel.getRowType())
Expand All @@ -191,6 +315,7 @@ def _join_on_columns(
lhs_on: List[str],
rhs_on: List[str],
join_type: str,
indicator: bool = False,
) -> dd.DataFrame:
lhs_columns_to_add = {
f"common_{i}": df_lhs_renamed.iloc[:, index]
Expand Down Expand Up @@ -222,7 +347,13 @@ def _join_on_columns(
df_rhs_with_tmp = df_rhs_renamed.assign(**rhs_columns_to_add)
added_columns = list(lhs_columns_to_add.keys())

df = dd.merge(df_lhs_with_tmp, df_rhs_with_tmp, on=added_columns, how=join_type)
df = dd.merge(
df_lhs_with_tmp,
df_rhs_with_tmp,
on=added_columns,
how=join_type,
indicator=indicator,
)

return df

Expand Down
118 changes: 114 additions & 4 deletions tests/integration/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,88 @@ def datetime_table():
)


@pytest.fixture
def user_table_lk():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love those tests, they are super cool because they seem like coming from a real use-case, which is absolute brilliant.

However, can we also have a very simple one with just like 3-4 lines and two columns (e.g. the one I used in my comments)? This makes debugging much easier than skimming though multiple lines which (because the columns are so wide) even span a lot of space in the editor. I can also take care of this if you want!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I realize that. I simplified the new tests.

# Link table identified by id and date range (startdate and enddate)
# Used for query with both equality and inequality conditions
out = pd.DataFrame(
[
[0, 0, 2, pd.NA, 110, "a1", 1.1, pd.Timestamp("2001-01-01")],
[0, 4, 6, pd.NA, 111, "a2", 1.2, pd.Timestamp("2001-02-01")],
[1, 2, 5, pd.NA, 112, "a3", np.nan, pd.Timestamp("2001-03-01")],
[1, 4, 6, 13, 113, "a4", np.nan, pd.Timestamp("2001-04-01")],
[3, 1, 2, 14, 114, "a5", np.nan, pd.NaT],
[3, 2, 3, 15, 115, "a6", 1.6, pd.NaT],
],
columns=[
"id",
"startdate",
"enddate",
"lk_nullint",
"lk_int",
"lk_str",
"lk_float",
"lk_date",
],
)
out["lk_nullint"] = out["lk_nullint"].astype("Int32")
out["lk_str"] = out["lk_str"].astype("string")
return out


@pytest.fixture
def user_table_lk2(user_table_lk):
# Link table identified by only date range (startdate and enddate)
# Used for query with inequality conditions
return user_table_lk.set_index("id").loc[1].reset_index(drop=True)


@pytest.fixture
def user_table_ts():
# A table of time-series data identified by dates
out = pd.DataFrame(
[
[3, pd.NA, 221, "b1", 2.1, pd.Timestamp("2002-01-01")],
[4, 22, 222, "b2", np.nan, pd.Timestamp("2002-02-01")],
[7, 23, 223, "b3", 2.3, pd.NaT],
],
columns=["dates", "ts_nullint", "ts_int", "ts_str", "ts_float", "ts_date"],
)
out["ts_nullint"] = out["ts_nullint"].astype("Int32")
out["ts_str"] = out["ts_str"].astype("string")
return out


@pytest.fixture
def user_table_pn():
# A panel table identified by id and dates
out = pd.DataFrame(
[
[0, 1, pd.NA, 331, "c1", 3.1, pd.Timestamp("2003-01-01")],
[0, 2, pd.NA, 332, "c2", 3.2, pd.Timestamp("2003-02-01")],
[0, 3, pd.NA, 333, "c3", 3.3, pd.Timestamp("2003-03-01")],
[1, 3, pd.NA, 334, "c4", np.nan, pd.Timestamp("2003-04-01")],
[1, 4, 35, 335, "c5", np.nan, pd.Timestamp("2003-05-01")],
[2, 1, 36, 336, "c6", np.nan, pd.Timestamp("2003-06-01")],
[2, 3, 37, 337, "c7", np.nan, pd.NaT],
[3, 2, 38, 338, "c8", 3.8, pd.NaT],
[3, 2, 39, 339, "c9", 3.9, pd.NaT],
],
columns=[
"ids",
"dates",
"pn_nullint",
"pn_int",
"pn_str",
"pn_float",
"pn_date",
],
)
out["pn_nullint"] = out["pn_nullint"].astype("Int32")
out["pn_str"] = out["pn_str"].astype("string")
return out


@pytest.fixture()
def c(
df_simple,
Expand All @@ -97,6 +179,10 @@ def c(
user_table_nan,
string_table,
datetime_table,
user_table_lk,
user_table_lk2,
user_table_ts,
user_table_pn,
):
dfs = {
"df_simple": df_simple,
Expand All @@ -108,6 +194,10 @@ def c(
"user_table_nan": user_table_nan,
"string_table": string_table,
"datetime_table": datetime_table,
"user_table_lk": user_table_lk,
"user_table_lk2": user_table_lk2,
"user_table_ts": user_table_ts,
"user_table_pn": user_table_pn,
}

# Lazy import, otherwise the pytest framework has problems
Expand All @@ -134,7 +224,9 @@ def temporary_data_file():


@pytest.fixture()
def assert_query_gives_same_result(engine):
def assert_query_gives_same_result(
engine, user_table_lk, user_table_lk2, user_table_ts, user_table_pn,
):
np.random.seed(42)

df1 = dd.from_pandas(
Expand Down Expand Up @@ -191,12 +283,22 @@ def assert_query_gives_same_result(engine):
c.create_table("df1", df1)
c.create_table("df2", df2)
c.create_table("df3", df3)
c.create_table("user_table_ts", user_table_ts)
c.create_table("user_table_pn", user_table_pn)
c.create_table("user_table_lk", user_table_lk)
c.create_table("user_table_lk2", user_table_lk2)

df1.compute().to_sql("df1", engine, index=False, if_exists="replace")
df2.compute().to_sql("df2", engine, index=False, if_exists="replace")
df3.compute().to_sql("df3", engine, index=False, if_exists="replace")

def _assert_query_gives_same_result(query, sort_columns=None, **kwargs):
user_table_ts.to_sql("user_table_ts", engine, index=False, if_exists="replace")
user_table_pn.to_sql("user_table_pn", engine, index=False, if_exists="replace")
user_table_lk.to_sql("user_table_lk", engine, index=False, if_exists="replace")
user_table_lk2.to_sql("user_table_lk2", engine, index=False, if_exists="replace")

def _assert_query_gives_same_result(
query, sort_columns=None, force_dtype=None, check_dtype=False, **kwargs,
):
sql_result = pd.read_sql_query(query, engine)
dask_result = c.sql(query).compute()

Expand All @@ -211,7 +313,15 @@ def _assert_query_gives_same_result(query, sort_columns=None, **kwargs):
sql_result = sql_result.reset_index(drop=True)
dask_result = dask_result.reset_index(drop=True)

assert_frame_equal(sql_result, dask_result, check_dtype=False, **kwargs)
# Change dtypes
if force_dtype == "sql":
for col, dtype in sql_result.dtypes.iteritems():
dask_result[col] = dask_result[col].astype(dtype)
elif force_dtype == "dask":
for col, dtype in dask_result.dtypes.iteritems():
sql_result[col] = sql_result[col].astype(dtype)

assert_frame_equal(sql_result, dask_result, check_dtype=check_dtype, **kwargs)

return _assert_query_gives_same_result

Expand Down
Loading