diff --git a/dask_sql/physical/rel/logical/join.py b/dask_sql/physical/rel/logical/join.py index b0c079f3c..57037b782 100644 --- a/dask_sql/physical/rel/logical/join.py +++ b/dask_sql/physical/rel/logical/join.py @@ -116,7 +116,10 @@ def merge_single_partitions(lhs_partition, rhs_partition): # which is definitely not possible (java dependency, JVM start...) lhs_partition = lhs_partition.assign(common=1) rhs_partition = rhs_partition.assign(common=1) - merged_data = lhs_partition.merge(rhs_partition, on=["common"]) + # Need to drop "common" here, otherwise metadata mismatches + merged_data = lhs_partition.merge(rhs_partition, on=["common"]).drop( + columns=["common"] + ) return merged_data @@ -179,7 +182,20 @@ def merge_single_partitions(lhs_partition, rhs_partition): ) logger.debug(f"Additionally applying filter {filter_condition}") df = filter_or_scalar(df, filter_condition) + # make sure we recover any lost rows in case of left, right or outer joins + if join_type in ["left", "outer"]: + df = df.merge( + df_lhs_renamed, on=list(df_lhs_renamed.columns), how="right" + ) + elif join_type in ["right", "outer"]: + df = df.merge( + df_rhs_renamed, on=list(df_rhs_renamed.columns), how="right" + ) dc = DataContainer(df, cc) + # Caveat: columns of int may be casted to float if NaN is introduced + # for unmatched rows. Since we don't know which column would be casted + # without triggering compute(), we have to either leave it alone, or + # forcibly cast all int to nullable int. dc = self.fix_dtype_to_row_type(dc, rel.getRowType()) return dc diff --git a/tests/integration/fixtures.py b/tests/integration/fixtures.py index 4566d3690..fc1aefd66 100644 --- a/tests/integration/fixtures.py +++ b/tests/integration/fixtures.py @@ -86,6 +86,50 @@ def datetime_table(): ) +@pytest.fixture +def user_table_lk(): + # Link table identified by id and startdate + # Used for query with both equality and inequality conditions + out = pd.DataFrame( + [[0, 5, 11, 111], [1, 2, pd.NA, 112], [1, 4, 13, 113], [3, 1, 14, 114],], + columns=["id", "startdate", "lk_nullint", "lk_int"], + ) + out["lk_nullint"] = out["lk_nullint"].astype("Int32") + return out + + +@pytest.fixture +def user_table_lk2(user_table_lk): + # Link table identified by startdate only + # Used for query with inequality conditions + out = pd.DataFrame( + [[2, pd.NA, 112], [4, 13, 113],], columns=["startdate", "lk_nullint", "lk_int"], + ) + out["lk_nullint"] = out["lk_nullint"].astype("Int32") + return out + + +@pytest.fixture +def user_table_ts(): + # A table of time-series data identified by dates + out = pd.DataFrame( + [[1, 21], [3, pd.NA], [7, 23],], columns=["dates", "ts_nullint"], + ) + out["ts_nullint"] = out["ts_nullint"].astype("Int32") + return out + + +@pytest.fixture +def user_table_pn(): + # A panel table identified by id and dates + out = pd.DataFrame( + [[0, 1, pd.NA], [1, 5, 32], [2, 1, 33],], + columns=["ids", "dates", "pn_nullint"], + ) + out["pn_nullint"] = out["pn_nullint"].astype("Int32") + return out + + @pytest.fixture() def c( df_simple, @@ -97,6 +141,10 @@ def c( user_table_nan, string_table, datetime_table, + user_table_lk, + user_table_lk2, + user_table_ts, + user_table_pn, ): dfs = { "df_simple": df_simple, @@ -108,6 +156,10 @@ def c( "user_table_nan": user_table_nan, "string_table": string_table, "datetime_table": datetime_table, + "user_table_lk": user_table_lk, + "user_table_lk2": user_table_lk2, + "user_table_ts": user_table_ts, + "user_table_pn": user_table_pn, } # Lazy import, otherwise the pytest framework has problems @@ -134,7 +186,9 @@ def temporary_data_file(): @pytest.fixture() -def assert_query_gives_same_result(engine): +def assert_query_gives_same_result( + engine, user_table_lk, user_table_lk2, user_table_ts, user_table_pn, +): np.random.seed(42) df1 = dd.from_pandas( @@ -191,12 +245,22 @@ def assert_query_gives_same_result(engine): c.create_table("df1", df1) c.create_table("df2", df2) c.create_table("df3", df3) + c.create_table("user_table_ts", user_table_ts) + c.create_table("user_table_pn", user_table_pn) + c.create_table("user_table_lk", user_table_lk) + c.create_table("user_table_lk2", user_table_lk2) df1.compute().to_sql("df1", engine, index=False, if_exists="replace") df2.compute().to_sql("df2", engine, index=False, if_exists="replace") df3.compute().to_sql("df3", engine, index=False, if_exists="replace") - - def _assert_query_gives_same_result(query, sort_columns=None, **kwargs): + user_table_ts.to_sql("user_table_ts", engine, index=False, if_exists="replace") + user_table_pn.to_sql("user_table_pn", engine, index=False, if_exists="replace") + user_table_lk.to_sql("user_table_lk", engine, index=False, if_exists="replace") + user_table_lk2.to_sql("user_table_lk2", engine, index=False, if_exists="replace") + + def _assert_query_gives_same_result( + query, sort_columns=None, force_dtype=None, check_dtype=False, **kwargs, + ): sql_result = pd.read_sql_query(query, engine) dask_result = c.sql(query).compute() @@ -211,7 +275,15 @@ def _assert_query_gives_same_result(query, sort_columns=None, **kwargs): sql_result = sql_result.reset_index(drop=True) dask_result = dask_result.reset_index(drop=True) - assert_frame_equal(sql_result, dask_result, check_dtype=False, **kwargs) + # Change dtypes + if force_dtype == "sql": + for col, dtype in sql_result.dtypes.iteritems(): + dask_result[col] = dask_result[col].astype(dtype) + elif force_dtype == "dask": + for col, dtype in dask_result.dtypes.iteritems(): + sql_result[col] = sql_result[col].astype(dtype) + + assert_frame_equal(sql_result, dask_result, check_dtype=check_dtype, **kwargs) return _assert_query_gives_same_result diff --git a/tests/integration/test_join.py b/tests/integration/test_join.py index 6437cde0f..e5b142ce1 100644 --- a/tests/integration/test_join.py +++ b/tests/integration/test_join.py @@ -184,3 +184,149 @@ def test_join_literal(c): df_expected = pd.DataFrame({"user_id": [], "b": [], "user_id0": [], "c": []}) assert_frame_equal(df.reset_index(), df_expected.reset_index(), check_dtype=False) + + +def test_join_lricomplex(c): + # ---------- Panel data (equality and inequality conditions) + + # Correct answer + dfcorrpn = pd.DataFrame( + [ + [0, 1, pd.NA, pd.NA, pd.NA, pd.NA], + [1, 5, 32, 2, pd.NA, 112], + [1, 5, 32, 4, 13, 113], + [2, 1, 33, pd.NA, pd.NA, pd.NA], + ], + columns=["ids", "dates", "pn_nullint", "startdate", "lk_nullint", "lk_int",], + ) + change_types = { + "pn_nullint": "Int32", + "lk_nullint": "Int32", + "startdate": "Int64", + "lk_int": "Int64", + } + for k, v in change_types.items(): + dfcorrpn[k] = dfcorrpn[k].astype(v) + + # Left Join + querypnl = """ + select a.*, b.startdate, b.lk_nullint, b.lk_int + from user_table_pn a left join user_table_lk b + on a.ids=b.id and b.startdate<=a.dates + """ + dftestpnl = ( + c.sql(querypnl) + .compute() + .sort_values(["ids", "dates", "startdate"]) + .reset_index(drop=True) + ) + assert_frame_equal(dftestpnl, dfcorrpn, check_dtype=False) + + # Right Join + querypnr = """ + select b.*, a.startdate, a.lk_nullint, a.lk_int + from user_table_lk a right join user_table_pn b + on b.ids=a.id and a.startdate<=b.dates + """ + dftestpnr = ( + c.sql(querypnr) + .compute() + .sort_values(["ids", "dates", "startdate"]) + .reset_index(drop=True) + ) + assert_frame_equal(dftestpnr, dfcorrpn, check_dtype=False) + + # Inner Join + querypni = """ + select a.*, b.startdate, b.lk_nullint, b.lk_int + from user_table_pn a inner join user_table_lk b + on a.ids=b.id and b.startdate<=a.dates + """ + dftestpni = ( + c.sql(querypni) + .compute() + .sort_values(["ids", "dates", "startdate"]) + .reset_index(drop=True) + ) + assert_frame_equal( + dftestpni, + dfcorrpn.dropna(subset=["startdate"]) + .assign( + startdate=lambda x: x["startdate"].astype("int64"), + lk_int=lambda x: x["lk_int"].astype("int64"), + ) + .reset_index(drop=True), + check_dtype=False, + ) + + # ---------- Time-series data (inequality condition only) + + # # Correct answer + dfcorrts = pd.DataFrame( + [ + [1, 21, pd.NA, pd.NA, pd.NA], + [3, pd.NA, 2, pd.NA, 112], + [7, 23, 2, pd.NA, 112], + [7, 23, 4, 13, 113], + ], + columns=["dates", "ts_nullint", "startdate", "lk_nullint", "lk_int",], + ) + change_types = { + "ts_nullint": "Int32", + "lk_nullint": "Int32", + "startdate": "Int64", + "lk_int": "Int64", + } + for k, v in change_types.items(): + dfcorrts[k] = dfcorrts[k].astype(v) + + # Left Join + querytsl = """ + select a.*, b.startdate, b.lk_nullint, b.lk_int + from user_table_ts a left join user_table_lk2 b + on b.startdate<=a.dates + """ + dftesttsl = ( + c.sql(querytsl) + .compute() + .sort_values(["dates", "startdate"]) + .reset_index(drop=True) + ) + assert_frame_equal(dftesttsl, dfcorrts, check_dtype=False) + + # Right Join + querytsr = """ + select b.*, a.startdate, a.lk_nullint, a.lk_int + from user_table_lk2 a right join user_table_ts b + on a.startdate<=b.dates + """ + dftesttsr = ( + c.sql(querytsr) + .compute() + .sort_values(["dates", "startdate"]) + .reset_index(drop=True) + ) + assert_frame_equal(dftesttsr, dfcorrts, check_dtype=False) + + # Inner Join + querytsi = """ + select a.*, b.startdate, b.lk_nullint, b.lk_int + from user_table_ts a inner join user_table_lk2 b + on b.startdate<=a.dates + """ + dftesttsi = ( + c.sql(querytsi) + .compute() + .sort_values(["dates", "startdate"]) + .reset_index(drop=True) + ) + assert_frame_equal( + dftesttsi, + dfcorrts.dropna(subset=["startdate"]) + .assign( + startdate=lambda x: x["startdate"].astype("int64"), + lk_int=lambda x: x["lk_int"].astype("int64"), + ) + .reset_index(drop=True), + check_dtype=False, + ) diff --git a/tests/integration/test_postgres.py b/tests/integration/test_postgres.py index f1614d5ad..cd3689c09 100644 --- a/tests/integration/test_postgres.py +++ b/tests/integration/test_postgres.py @@ -10,12 +10,14 @@ def engine(): client = docker.from_env() network = client.networks.create("dask-sql", driver="bridge") + # For local test, you may need to add ports={"5432/tcp": "5432"} to expose port postgres = client.containers.run( "postgres:latest", detach=True, remove=True, network="dask-sql", environment={"POSTGRES_HOST_AUTH_METHOD": "trust"}, + # ports={"5432/tcp": "5432"}, ) try: @@ -32,6 +34,8 @@ def engine(): # get the address and create the connection postgres.reload() address = postgres.attrs["NetworkSettings"]["Networks"]["dask-sql"]["IPAddress"] + # For local test, you may need to assign address = "localhost" + # address = "localhost" port = 5432 engine = sqlalchemy.create_engine( @@ -126,6 +130,86 @@ def test_join(assert_query_gives_same_result): ) +def test_join_lricomplex( + assert_query_gives_same_result, + engine, + user_table_ts, + user_table_pn, + user_table_lk, + user_table_lk2, + c, +): + # ---------- Panel data + # Left Join + assert_query_gives_same_result( + """ + select a.*, b.startdate, b.lk_nullint, b.lk_int + from user_table_pn a left join user_table_lk b + on a.ids=b.id and b.startdate<=a.dates + """, + ["ids", "dates", "startdate"], + force_dtype="dask", + check_dtype=True, + ) + # Right Join + assert_query_gives_same_result( + """ + select b.*, a.startdate, a.lk_nullint, a.lk_int + from user_table_lk a right join user_table_pn b + on b.ids=a.id and a.startdate<=b.dates + """, + ["ids", "dates", "startdate"], + force_dtype="dask", + check_dtype=True, + ) + # Inner Join + assert_query_gives_same_result( + """ + select a.*, b.startdate, b.lk_nullint, b.lk_int + from user_table_pn a inner join user_table_lk b + on a.ids=b.id and b.startdate<=a.dates + """, + ["ids", "dates", "startdate"], + force_dtype="dask", + check_dtype=True, + ) + + # ---------- Time-series data + # Left Join + assert_query_gives_same_result( + """ + select a.*, b.startdate, b.lk_nullint, b.lk_int + from user_table_ts a left join user_table_lk2 b + on b.startdate<=a.dates + """, + ["dates", "startdate"], + force_dtype="dask", + check_dtype=True, + ) + # Right Join + assert_query_gives_same_result( + """ + select b.*, a.startdate, a.lk_nullint, a.lk_int + from user_table_lk2 a right join user_table_ts b + on a.startdate<=b.dates + """, + ["dates", "startdate"], + force_dtype="dask", + check_dtype=True, + ) + # Inner Join + assert_query_gives_same_result( + """ + select a.*, b.startdate, b.lk_nullint, b.lk_int + from user_table_ts a inner join user_table_lk2 b + on b.startdate<=a.dates + """, + ["dates", "startdate"], + force_dtype="dask", + check_dtype=True, + ) + + def test_sort(assert_query_gives_same_result): assert_query_gives_same_result( """ diff --git a/tests/integration/test_show.py b/tests/integration/test_show.py index 2165699ca..76749ef53 100644 --- a/tests/integration/test_show.py +++ b/tests/integration/test_show.py @@ -35,6 +35,10 @@ def test_tables(c): "user_table_nan", "string_table", "datetime_table", + "user_table_lk", + "user_table_lk2", + "user_table_ts", + "user_table_pn", ] } )