Skip to content

Commit 5898816

Browse files
authored
Merge pull request #4790 from FBruzzesi/plotly-with-narwhals
feat: make plotly-express dataframe agnostic via narwhals
2 parents ffb571b + 9f2c55b commit 5898816

38 files changed

+1689
-834
lines changed

CHANGELOG.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ This project adheres to [Semantic Versioning](http://semver.org/).
88

99
### Updated
1010

11-
- Updated plotly.py to use base64 encoding of arrays in plotly JSON to improve performance.
11+
- Updated plotly.py to use base64 encoding of arrays in plotly JSON to improve performance.
1212
- Add `subtitle` attribute to all Plotly Express traces
13+
- Make plotly-express dataframe agnostic via Narwhals [#4790](https://github.com/plotly/plotly.py/pull/4790)
1314

1415
## [5.24.1] - 2024-09-12
1516

packages/python/plotly/_plotly_utils/basevalidators.py

+22-30
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import re
99
import sys
1010
import warnings
11+
import narwhals.stable.v1 as nw
1112

1213
from _plotly_utils.optional_imports import get_module
1314

@@ -72,8 +73,6 @@ def copy_to_readonly_numpy_array(v, kind=None, force_numeric=False):
7273
"""
7374
np = get_module("numpy")
7475

75-
# Don't force pandas to be loaded, we only want to know if it's already loaded
76-
pd = get_module("pandas", should_load=False)
7776
assert np is not None
7877

7978
# ### Process kind ###
@@ -93,34 +92,26 @@ def copy_to_readonly_numpy_array(v, kind=None, force_numeric=False):
9392
"O": "object",
9493
}
9594

96-
# Handle pandas Series and Index objects
97-
if pd and isinstance(v, (pd.Series, pd.Index)):
98-
if v.dtype.kind in numeric_kinds:
99-
# Get the numeric numpy array so we use fast path below
100-
v = v.values
101-
elif v.dtype.kind == "M":
102-
# Convert datetime Series/Index to numpy array of datetimes
103-
if isinstance(v, pd.Series):
104-
with warnings.catch_warnings():
105-
warnings.simplefilter("ignore", FutureWarning)
106-
# Series.dt.to_pydatetime will return Index[object]
107-
# https://github.com/pandas-dev/pandas/pull/52459
108-
v = np.array(v.dt.to_pydatetime())
109-
else:
110-
# DatetimeIndex
111-
v = v.to_pydatetime()
112-
elif pd and isinstance(v, pd.DataFrame) and len(set(v.dtypes)) == 1:
113-
dtype = v.dtypes.tolist()[0]
114-
if dtype.kind in numeric_kinds:
115-
v = v.values
116-
elif dtype.kind == "M":
117-
with warnings.catch_warnings():
118-
warnings.simplefilter("ignore", FutureWarning)
119-
# Series.dt.to_pydatetime will return Index[object]
120-
# https://github.com/pandas-dev/pandas/pull/52459
121-
v = [
122-
np.array(row.dt.to_pydatetime()).tolist() for i, row in v.iterrows()
123-
]
95+
# With `pass_through=True`, the original object will be returned if unable to convert
96+
# to a Narwhals DataFrame or Series.
97+
v = nw.from_native(v, allow_series=True, pass_through=True)
98+
99+
if isinstance(v, nw.Series):
100+
if v.dtype == nw.Datetime and v.dtype.time_zone is not None:
101+
# Remove time zone so that local time is displayed
102+
v = v.dt.replace_time_zone(None).to_numpy()
103+
else:
104+
v = v.to_numpy()
105+
elif isinstance(v, nw.DataFrame):
106+
schema = v.schema
107+
overrides = {}
108+
for key, val in schema.items():
109+
if val == nw.Datetime and val.time_zone is not None:
110+
# Remove time zone so that local time is displayed
111+
overrides[key] = nw.col(key).dt.replace_time_zone(None)
112+
if overrides:
113+
v = v.with_columns(**overrides)
114+
v = v.to_numpy()
124115

125116
if not isinstance(v, np.ndarray):
126117
# v has its own logic on how to convert itself into a numpy array
@@ -193,6 +184,7 @@ def is_homogeneous_array(v):
193184
np
194185
and isinstance(v, np.ndarray)
195186
or (pd and isinstance(v, (pd.Series, pd.Index)))
187+
or (isinstance(v, nw.Series))
196188
):
197189
return True
198190
if is_numpy_convertable(v):

packages/python/plotly/_plotly_utils/tests/validators/test_pandas_series_input.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,13 @@ def color_categorical_pandas(request, pandas_type):
7373
def dates_array(request):
7474
return np.array(
7575
[
76-
datetime(year=2013, month=10, day=10),
77-
datetime(year=2013, month=11, day=10),
78-
datetime(year=2013, month=12, day=10),
79-
datetime(year=2014, month=1, day=10),
80-
datetime(year=2014, month=2, day=10),
81-
]
76+
"2013-10-10",
77+
"2013-11-10",
78+
"2013-12-10",
79+
"2014-01-10",
80+
"2014-02-10",
81+
],
82+
dtype="datetime64[ns]",
8283
)
8384

8485

@@ -183,7 +184,7 @@ def test_data_array_validator_dates_series(
183184
assert isinstance(res, np.ndarray)
184185

185186
# Check dtype
186-
assert res.dtype == "object"
187+
assert res.dtype == "<M8[ns]"
187188

188189
# Check values
189190
np.testing.assert_array_equal(res, dates_array)
@@ -200,7 +201,7 @@ def test_data_array_validator_dates_dataframe(
200201
assert isinstance(res, np.ndarray)
201202

202203
# Check dtype
203-
assert res.dtype == "object"
204+
assert res.dtype == "<M8[ns]"
204205

205206
# Check values
206207
np.testing.assert_array_equal(res, dates_array.reshape(len(dates_array), 1))

packages/python/plotly/optional-requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ ipython
3939

4040
## pandas deps for some matplotlib functionality ##
4141
pandas
42+
narwhals>=1.13.3
4243

4344
## scipy deps for some FigureFactory functions ##
4445
scipy

0 commit comments

Comments
 (0)