Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) ->
def _join_inner(
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
) -> dd.DataFrame:
return self.native.merge(
return self.native.dropna(subset=left_on, how="any").merge(
other.native,
left_on=left_on,
right_on=right_on,
Expand All @@ -311,7 +311,7 @@ def _join_left(
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
) -> dd.DataFrame:
result_native = self.native.merge(
other.native,
other.native.dropna(subset=right_on, how="any"),
how="left",
left_on=left_on,
right_on=right_on,
Expand All @@ -329,18 +329,33 @@ def _join_full(
) -> dd.DataFrame:
# dask does not retain keys post-join
# we must append the suffix to each key before-hand

self_native = self.native
right_on_mapper = _remap_full_join_keys(left_on, right_on, suffix)
other_native = other.native.rename(columns=right_on_mapper)
check_column_names_are_unique(other_native.columns)
right_suffixed = list(right_on_mapper.values())
return self.native.merge(
other_native,

left_null_mask = self_native[list(left_on)].isna().any(axis=1)
right_null_mask = other_native[right_suffixed].isna().any(axis=1)

# We need to add suffix to `other` columns overlapping in `self` if not in keys
to_rename = set(other.columns).intersection(self.columns).difference(right_on)
right_null_rows = other_native[right_null_mask].rename(
columns={col: f"{col}{suffix}" for col in to_rename}
)

join_result = self_native[~left_null_mask].merge(
other_native[~right_null_mask],
left_on=left_on,
right_on=right_suffixed,
how="outer",
suffixes=("", suffix),
)
return dd.concat(
[join_result, self_native[left_null_mask], right_null_rows],
axis=0,
join="outer",
)

def _join_cross(self, other: Self, *, suffix: str) -> dd.DataFrame:
key_token = generate_temporary_column_name(
Expand All @@ -366,7 +381,7 @@ def _join_semi(
columns_to_select=list(right_on),
columns_mapping=dict(zip(right_on, left_on)),
)
return self.native.merge(
return self.native.dropna(subset=left_on, how="any").merge(
other_native, how="inner", left_on=left_on, right_on=left_on
)

Expand All @@ -382,7 +397,7 @@ def _join_anti(
columns_mapping=dict(zip(right_on, left_on)),
)
df = self.native.merge(
other_native,
other_native.dropna(subset=left_on, how="any"),
how="left",
indicator=indicator_token, # pyright: ignore[reportArgumentType]
left_on=left_on,
Expand Down
33 changes: 26 additions & 7 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ def group_by(
def _join_inner(
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
) -> pd.DataFrame:
return self.native.merge(
return self.native.dropna(subset=left_on, how="any").merge(
other.native,
left_on=left_on,
right_on=right_on,
Expand All @@ -615,7 +615,7 @@ def _join_left(
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
) -> pd.DataFrame:
result_native = self.native.merge(
other.native,
other.native.dropna(subset=right_on, how="any"),
how="left",
left_on=left_on,
right_on=right_on,
Expand All @@ -635,18 +635,34 @@ def _join_full(
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
) -> pd.DataFrame:
# Pandas coalesces keys in full joins unless there's no collision
ns = self.__narwhals_namespace__()
self_native = self.native
right_on_mapper = _remap_full_join_keys(left_on, right_on, suffix)
other_native = other.native.rename(columns=right_on_mapper)
check_column_names_are_unique(other_native.columns)
right_suffixed = list(right_on_mapper.values())
return self.native.merge(
other_native,

left_null_mask = self_native[list(left_on)].isna().any(axis=1)
right_null_mask = other_native[right_suffixed].isna().any(axis=1)

# We need to add suffix to `other` columns overlapping in `self` if not in keys
to_rename = set(other.columns).intersection(self.columns).difference(right_on)
right_null_rows = other_native[right_null_mask].rename(
columns={col: f"{col}{suffix}" for col in to_rename}
)

join_result = self_native[~left_null_mask].merge(
other_native[~right_null_mask],
left_on=left_on,
right_on=right_suffixed,
how="outer",
suffixes=("", suffix),
)

return ns._concat_diagonal(
[join_result, self_native[left_null_mask], right_null_rows]
)

def _join_cross(self, other: Self, *, suffix: str) -> pd.DataFrame:
implementation = self._implementation
backend_version = self._backend_version
Expand Down Expand Up @@ -677,7 +693,7 @@ def _join_semi(
columns_to_select=list(right_on),
columns_mapping=dict(zip(right_on, left_on)),
)
return self.native.merge(
return self.native.dropna(subset=left_on, how="any").merge(
other_native, how="inner", left_on=left_on, right_on=left_on
)

Expand All @@ -688,7 +704,10 @@ def _join_anti(

if implementation.is_cudf():
return self.native.merge(
other.native, how="leftanti", left_on=left_on, right_on=right_on
other.native.dropna(subset=left_on, how="any"),
how="leftanti",
left_on=left_on,
right_on=right_on,
)

indicator_token = generate_temporary_column_name(
Expand All @@ -701,7 +720,7 @@ def _join_anti(
columns_mapping=dict(zip(right_on, left_on)),
)
result_native = self.native.merge(
other_native,
other_native.dropna(subset=left_on, how="any"),
# TODO(FBruzzesi): See https://github.com/modin-project/modin/issues/7384
how="left" if implementation.is_pandas() else "outer",
indicator=indicator_token,
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,8 @@ filterwarnings = [
# Warning raised when calling PandasLikeNamespace.from_arrow with old pyarrow
"ignore:.*is_sparse is deprecated and will be removed in a future version.*:DeprecationWarning:pyarrow",

'ignore:.*invalid value encountered in cast:RuntimeWarning:pandas',
"ignore:.*invalid value encountered in cast:RuntimeWarning:pandas",
"ignore:.*The behavior of DataFrame concatenation with empty or all-NA entries is deprecated:FutureWarning",
]
xfail_strict = true
markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]
Expand Down
97 changes: 97 additions & 0 deletions tests/frame/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,3 +826,100 @@ def test_join_same_laziness(constructor: Constructor) -> None:
other = nw.from_native(frame_pl)
with pytest.raises(TypeError, match=msg):
frame.join(other, on="id") # type: ignore[arg-type]


# fmt: off
@pytest.mark.parametrize(
("how", "expected"),
[
("inner", {"a": [1], "b": [1], "x": [1], "y": [1.2]}),
(
"left",
{
"a": [1, 1, None, None],
"b": [1, None, 5, None],
"x": [1, 2, 3, 4],
"y": [1.2, None, None, None],
},
),
(
"full",
{
"a": [1, 1, None, None, None, None, None],
"b": [1, None, 5, None, None, None, None],
"x": [1, 2, 3, 4, None, None, None],
"a_right": [1, None, None, None, 1, None, None],
"b_right": [1, None, None, None, None, 5, None],
"y": [1.2, None, None, None, 3.4, 5.6, 7.8],
},
),
(
"cross",
{
"a": [1, 1, 1, 1, 1, 1, 1, 1, None, None, None, None, None, None, None, None],
"b": [ 1, 1, 1, 1, None, None, None, None, 5, 5, 5, 5, None, None, None, None],
"x": [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4],
"a_right": [ 1, 1, None, None, 1, 1, None, None, 1, 1, None, None, 1, 1, None, None],
"b_right": [ 1, None, 5, None, 1, None, 5, None, 1, None, 5, None, 1, None, 5, None],
"y": [ 1.2, 3.4, 5.6, 7.8, 1.2, 3.4, 5.6, 7.8, 1.2, 3.4, 5.6, 7.8, 1.2, 3.4, 5.6, 7.8],
},

),
("semi", {"a": [1], "b": [1], "x": [1]}),
("anti", {"a": [1, None, None], "b": [None, 5, None], "x": [2, 3, 4]}),
],
)
def test_join_on_null_values(
constructor: Constructor, how: JoinStrategy, expected: dict[str, list[Any]]
) -> None:
if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 1, 4) and how=="cross":
pytest.skip()
# See https://github.com/narwhals-dev/narwhals/issues/3307
keys = {"a": [1, 1, None, None], "b": [1, None, 5, None]}
data_left = {**keys, "x": [1, 2, 3, 4]}
data_right = {**keys, "y": [1.2, 3.4, 5.6, 7.8]}

df_left = from_native_lazy(constructor(data_left))
df_right = from_native_lazy(constructor(data_right))

on = None if how == "cross" else list(keys)
sort_by = ["a", "x", "y"] if how in {"cross", "full"} else ["a", "x"]
result = df_left.join(df_right, on=on, how=how).sort(sort_by, nulls_last=True)
assert_equal_data(result, expected)
# fmt: on


@pytest.mark.filterwarnings(
"ignore:.*Merging dataframes with merge column data type mismatches:UserWarning:dask"
)
def test_full_join_with_overlapping_non_key_columns_and_nulls(
constructor: Constructor,
) -> None:
data_left = {
"id": [1, 2, 3],
"shared_col": ["a", "b", "c"], # Overlapping, not a join key
"left_only": [10, 20, 30],
}
data_right = {
"id": [2, 3, None], # Has null in join key
"shared_col": ["x", "y", "z"], # Overlapping, not a join key
"right_only": [100, 200, 300],
}

df_left = from_native_lazy(constructor(data_left))
df_right = from_native_lazy(constructor(data_right))

result = df_left.join(df_right, on="id", how="full", suffix="_r").sort(
"id", nulls_last=True
)

expected = {
"id": [1, 2, 3, None],
"shared_col": ["a", "b", "c", None],
"left_only": [10, 20, 30, None],
"id_r": [None, 2, 3, None],
"shared_col_r": [None, "x", "y", "z"],
"right_only": [None, 100, 200, 300],
}

assert_equal_data(result, expected)
Loading