Skip to content

Commit de2bcb5

Browse files
authored
Merge pull request #4784 from plotly/b64-before-render
Convert base64 in BaseFigure.to_dict instead of validate_coerce
2 parents 6364d4e + aabfa6e commit de2bcb5

File tree

4 files changed

+112
-90
lines changed

4 files changed

+112
-90
lines changed

packages/python/plotly/_plotly_utils/basevalidators.py

Lines changed: 1 addition & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -51,83 +51,6 @@ def to_scalar_or_list(v):
5151
return v
5252

5353

54-
plotlyjsShortTypes = {
55-
"int8": "i1",
56-
"uint8": "u1",
57-
"int16": "i2",
58-
"uint16": "u2",
59-
"int32": "i4",
60-
"uint32": "u4",
61-
"float32": "f4",
62-
"float64": "f8",
63-
}
64-
65-
int8min = -128
66-
int8max = 127
67-
int16min = -32768
68-
int16max = 32767
69-
int32min = -2147483648
70-
int32max = 2147483647
71-
72-
uint8max = 255
73-
uint16max = 65535
74-
uint32max = 4294967295
75-
76-
77-
def to_typed_array_spec(v):
78-
"""
79-
Convert numpy array to plotly.js typed array spec
80-
If not possible return the original value
81-
"""
82-
v = copy_to_readonly_numpy_array(v)
83-
84-
np = get_module("numpy", should_load=False)
85-
if not isinstance(v, np.ndarray):
86-
return v
87-
88-
dtype = str(v.dtype)
89-
90-
# convert default Big Ints until we could support them in plotly.js
91-
if dtype == "int64":
92-
max = v.max()
93-
min = v.min()
94-
if max <= int8max and min >= int8min:
95-
v = v.astype("int8")
96-
elif max <= int16max and min >= int16min:
97-
v = v.astype("int16")
98-
elif max <= int32max and min >= int32min:
99-
v = v.astype("int32")
100-
else:
101-
return v
102-
103-
elif dtype == "uint64":
104-
max = v.max()
105-
min = v.min()
106-
if max <= uint8max and min >= 0:
107-
v = v.astype("uint8")
108-
elif max <= uint16max and min >= 0:
109-
v = v.astype("uint16")
110-
elif max <= uint32max and min >= 0:
111-
v = v.astype("uint32")
112-
else:
113-
return v
114-
115-
dtype = str(v.dtype)
116-
117-
if dtype in plotlyjsShortTypes:
118-
arrObj = {
119-
"dtype": plotlyjsShortTypes[dtype],
120-
"bdata": base64.b64encode(v).decode("ascii"),
121-
}
122-
123-
if v.ndim > 1:
124-
arrObj["shape"] = str(v.shape)[1:-1]
125-
126-
return arrObj
127-
128-
return v
129-
130-
13154
def copy_to_readonly_numpy_array(v, kind=None, force_numeric=False):
13255
"""
13356
Convert an array-like value into a read-only numpy array
@@ -292,15 +215,6 @@ def is_typed_array_spec(v):
292215
return isinstance(v, dict) and "bdata" in v and "dtype" in v
293216

294217

295-
def has_skipped_key(all_parent_keys):
296-
"""
297-
Return whether any keys in the parent hierarchy are in the list of keys that
298-
are skipped for conversion to the typed array spec
299-
"""
300-
skipped_keys = ["geojson", "layer", "range"]
301-
return any(skipped_key in all_parent_keys for skipped_key in skipped_keys)
302-
303-
304218
def is_none_or_typed_array_spec(v):
305219
return v is None or is_typed_array_spec(v)
306220

@@ -500,10 +414,8 @@ def description(self):
500414
def validate_coerce(self, v):
501415
if is_none_or_typed_array_spec(v):
502416
pass
503-
elif has_skipped_key(self.parent_name):
504-
v = to_scalar_or_list(v)
505417
elif is_homogeneous_array(v):
506-
v = to_typed_array_spec(v)
418+
v = copy_to_readonly_numpy_array(v)
507419
elif is_simple_array(v):
508420
v = to_scalar_or_list(v)
509421
else:

packages/python/plotly/_plotly_utils/utils.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,116 @@
1+
import base64
12
import decimal
23
import json as _json
34
import sys
45
import re
56
from functools import reduce
67

78
from _plotly_utils.optional_imports import get_module
8-
from _plotly_utils.basevalidators import ImageUriValidator
9+
from _plotly_utils.basevalidators import (
10+
ImageUriValidator,
11+
copy_to_readonly_numpy_array,
12+
is_homogeneous_array,
13+
)
14+
15+
16+
int8min = -128
17+
int8max = 127
18+
int16min = -32768
19+
int16max = 32767
20+
int32min = -2147483648
21+
int32max = 2147483647
22+
23+
uint8max = 255
24+
uint16max = 65535
25+
uint32max = 4294967295
26+
27+
plotlyjsShortTypes = {
28+
"int8": "i1",
29+
"uint8": "u1",
30+
"int16": "i2",
31+
"uint16": "u2",
32+
"int32": "i4",
33+
"uint32": "u4",
34+
"float32": "f4",
35+
"float64": "f8",
36+
}
37+
38+
39+
def to_typed_array_spec(v):
40+
"""
41+
Convert numpy array to plotly.js typed array spec
42+
If not possible return the original value
43+
"""
44+
v = copy_to_readonly_numpy_array(v)
45+
46+
np = get_module("numpy", should_load=False)
47+
if not isinstance(v, np.ndarray):
48+
return v
49+
50+
dtype = str(v.dtype)
51+
52+
# convert default Big Ints until we could support them in plotly.js
53+
if dtype == "int64":
54+
max = v.max()
55+
min = v.min()
56+
if max <= int8max and min >= int8min:
57+
v = v.astype("int8")
58+
elif max <= int16max and min >= int16min:
59+
v = v.astype("int16")
60+
elif max <= int32max and min >= int32min:
61+
v = v.astype("int32")
62+
else:
63+
return v
64+
65+
elif dtype == "uint64":
66+
max = v.max()
67+
min = v.min()
68+
if max <= uint8max and min >= 0:
69+
v = v.astype("uint8")
70+
elif max <= uint16max and min >= 0:
71+
v = v.astype("uint16")
72+
elif max <= uint32max and min >= 0:
73+
v = v.astype("uint32")
74+
else:
75+
return v
76+
77+
dtype = str(v.dtype)
78+
79+
if dtype in plotlyjsShortTypes:
80+
arrObj = {
81+
"dtype": plotlyjsShortTypes[dtype],
82+
"bdata": base64.b64encode(v).decode("ascii"),
83+
}
84+
85+
if v.ndim > 1:
86+
arrObj["shape"] = str(v.shape)[1:-1]
87+
88+
return arrObj
89+
90+
return v
91+
92+
93+
def is_skipped_key(key):
94+
"""
95+
Return whether any keys in the parent hierarchy are in the list of keys that
96+
are skipped for conversion to the typed array spec
97+
"""
98+
skipped_keys = ["geojson", "layer", "range"]
99+
return any(skipped_key in key for skipped_key in skipped_keys)
100+
101+
102+
def convert_to_base64(obj):
103+
if isinstance(obj, dict):
104+
for key, value in obj.items():
105+
if is_skipped_key(key):
106+
continue
107+
elif is_homogeneous_array(value):
108+
obj[key] = to_typed_array_spec(value)
109+
else:
110+
convert_to_base64(value)
111+
elif isinstance(obj, list) or isinstance(obj, tuple):
112+
for i, value in enumerate(obj):
113+
convert_to_base64(value)
9114

10115

11116
def cumsum(x):

packages/python/plotly/plotly/basedatatypes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
display_string_positions,
1616
chomp_empty_strings,
1717
find_closest_string,
18+
convert_to_base64,
1819
)
1920
from _plotly_utils.exceptions import PlotlyKeyError
2021
from .optional_imports import get_module
@@ -3310,6 +3311,9 @@ def to_dict(self):
33103311
if frames:
33113312
res["frames"] = frames
33123313

3314+
# Add base64 conversion before sending to the front-end
3315+
convert_to_base64(res)
3316+
33133317
return res
33143318

33153319
def to_plotly_json(self):

packages/python/plotly/plotly/io/_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def validate_coerce_fig_to_dict(fig, validate):
2424
typ=type(fig), v=fig
2525
)
2626
)
27+
2728
return fig_dict
2829

2930

0 commit comments

Comments
 (0)