Skip to content

Commit 3e7967c

Browse files
authored
Use x and y parameters for Image trace in imshow (for RGB or binary_string=True) (#2761)
* take x and y into account when using Image trace * x and y parameters are now used for Image trace in imshow * raise ValueError when x and y don't have numerical dtype for Image trace * better error message * black
1 parent 9c9b98e commit 3e7967c

File tree

2 files changed

+86
-11
lines changed

2 files changed

+86
-11
lines changed

Diff for: packages/python/plotly/plotly/express/_imshow.py

+54-11
Original file line numberDiff line numberDiff line change
@@ -204,23 +204,19 @@ def imshow(
204204
args = locals()
205205
apply_default_cascade(args)
206206
labels = labels.copy()
207+
img_is_xarray = False
207208
# ----- Define x and y, set labels if img is an xarray -------------------
208209
if xarray_imported and isinstance(img, xarray.DataArray):
209-
if binary_string:
210-
raise ValueError(
211-
"It is not possible to use binary image strings for xarrays."
212-
"Please pass your data as a numpy array instead using"
213-
"`img.values`"
214-
)
210+
img_is_xarray = True
215211
y_label, x_label = img.dims[0], img.dims[1]
216212
# np.datetime64 is not handled correctly by go.Heatmap
217213
for ax in [x_label, y_label]:
218214
if np.issubdtype(img.coords[ax].dtype, np.datetime64):
219215
img.coords[ax] = img.coords[ax].astype(str)
220216
if x is None:
221-
x = img.coords[x_label]
217+
x = img.coords[x_label].values
222218
if y is None:
223-
y = img.coords[y_label]
219+
y = img.coords[y_label].values
224220
if aspect is None:
225221
aspect = "auto"
226222
if labels.get("x", None) is None:
@@ -330,6 +326,42 @@ def imshow(
330326
_vectorize_zvalue(zmin, mode="min"),
331327
_vectorize_zvalue(zmax, mode="max"),
332328
)
329+
x0, y0, dx, dy = (None,) * 4
330+
error_msg_xarray = (
331+
"Non-numerical coordinates were passed with xarray `img`, but "
332+
"the Image trace cannot handle it. Please use `binary_string=False` "
333+
"for 2D data or pass instead the numpy array `img.values` to `px.imshow`."
334+
)
335+
if x is not None:
336+
x = np.asanyarray(x)
337+
if np.issubdtype(x.dtype, np.number):
338+
x0 = x[0]
339+
dx = x[1] - x[0]
340+
else:
341+
error_msg = (
342+
error_msg_xarray
343+
if img_is_xarray
344+
else (
345+
"Only numerical values are accepted for the `x` parameter "
346+
"when an Image trace is used."
347+
)
348+
)
349+
raise ValueError(error_msg)
350+
if y is not None:
351+
y = np.asanyarray(y)
352+
if np.issubdtype(y.dtype, np.number):
353+
y0 = y[0]
354+
dy = y[1] - y[0]
355+
else:
356+
error_msg = (
357+
error_msg_xarray
358+
if img_is_xarray
359+
else (
360+
"Only numerical values are accepted for the `y` parameter "
361+
"when an Image trace is used."
362+
)
363+
)
364+
raise ValueError(error_msg)
333365
if binary_string:
334366
if zmin is None and zmax is None: # no rescaling, faster
335367
img_rescaled = img
@@ -355,13 +387,24 @@ def imshow(
355387
compression=binary_compression_level,
356388
ext=binary_format,
357389
)
358-
trace = go.Image(source=img_str)
390+
trace = go.Image(source=img_str, x0=x0, y0=y0, dx=dx, dy=dy)
359391
else:
360392
colormodel = "rgb" if img.shape[-1] == 3 else "rgba256"
361-
trace = go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel)
393+
trace = go.Image(
394+
z=img,
395+
zmin=zmin,
396+
zmax=zmax,
397+
colormodel=colormodel,
398+
x0=x0,
399+
y0=y0,
400+
dx=dx,
401+
dy=dy,
402+
)
362403
layout = {}
363-
if origin == "lower":
404+
if origin == "lower" or (dy is not None and dy < 0):
364405
layout["yaxis"] = dict(autorange=True)
406+
if dx is not None and dx < 0:
407+
layout["xaxis"] = dict(autorange="reversed")
365408
else:
366409
raise ValueError(
367410
"px.imshow only accepts 2D single-channel, RGB or RGBA images. "

Diff for: packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py

+32
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from PIL import Image
66
from io import BytesIO
77
import base64
8+
import datetime
89
from plotly.express.imshow_utils import rescale_intensity
910

1011
img_rgb = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255]]], dtype=np.uint8)
@@ -204,6 +205,37 @@ def test_imshow_labels_and_ranges():
204205
with pytest.raises(ValueError):
205206
fig = px.imshow([[1, 2], [3, 4], [5, 6]], x=["a"])
206207

208+
img = np.ones((2, 2), dtype=np.uint8)
209+
fig = px.imshow(img, x=["a", "b"])
210+
assert fig.data[0].x == ("a", "b")
211+
212+
with pytest.raises(ValueError):
213+
img = np.ones((2, 2, 3), dtype=np.uint8)
214+
fig = px.imshow(img, x=["a", "b"])
215+
216+
img = np.ones((2, 2), dtype=np.uint8)
217+
base = datetime.datetime(2000, 1, 1)
218+
fig = px.imshow(img, x=[base, base + datetime.timedelta(hours=1)])
219+
assert fig.data[0].x == (
220+
datetime.datetime(2000, 1, 1, 0, 0),
221+
datetime.datetime(2000, 1, 1, 1, 0),
222+
)
223+
224+
with pytest.raises(ValueError):
225+
img = np.ones((2, 2, 3), dtype=np.uint8)
226+
base = datetime.datetime(2000, 1, 1)
227+
fig = px.imshow(img, x=[base, base + datetime.timedelta(hours=1)])
228+
229+
230+
def test_imshow_ranges_image_trace():
231+
fig = px.imshow(img_rgb, x=[1, 11, 21])
232+
assert fig.data[0].dx == 10
233+
assert fig.data[0].x0 == 1
234+
fig = px.imshow(img_rgb, x=[21, 11, 1])
235+
assert fig.data[0].dx == -10
236+
assert fig.data[0].x0 == 21
237+
assert fig.layout.xaxis.autorange == "reversed"
238+
207239

208240
def test_imshow_dataframe():
209241
df = px.data.medals_wide(indexed=False)

0 commit comments

Comments
 (0)