Skip to content

Commit 9b162f5

Browse files
committed
Add a few tests and enfore np.float64 if no dtype is given
1 parent cd672c1 commit 9b162f5

2 files changed

Lines changed: 51 additions & 1 deletion

File tree

src/optimagic/visualization/plotting_utilities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def _ensure_array_from_plotly_data(data: Any) -> np.ndarray:
355355
return _decode_base64_data(data["bdata"], dtype=data["dtype"])
356356
elif isinstance(data, collections.abc.Sequence):
357357
try:
358-
return np.array(data)
358+
return np.array(data, dtype=np.float64)
359359
except Exception:
360360
pass
361361
raise ValueError("Failed to convert input to numpy array.")
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import base64
2+
3+
import numpy as np
4+
import pytest
5+
from numpy.testing import assert_array_equal
6+
7+
from optimagic.visualization.plotting_utilities import (
8+
_decode_base64_data,
9+
_ensure_array_from_plotly_data,
10+
)
11+
12+
13+
def test_decode_base64_data():
14+
expected = np.arange(10, dtype=float)
15+
encoded = base64.b64encode(expected.tobytes()).decode("ascii")
16+
got = _decode_base64_data(encoded, dtype="float")
17+
assert_array_equal(expected, got)
18+
19+
20+
def test_ensure_array_from_plotly_data_case_array():
21+
expected = np.arange(10, dtype=float)
22+
got = _ensure_array_from_plotly_data(expected)
23+
assert_array_equal(expected, got)
24+
25+
26+
def test_ensure_array_from_plotly_data_case_list():
27+
expected = np.arange(10, dtype=float)
28+
got = _ensure_array_from_plotly_data(expected.tolist())
29+
assert_array_equal(expected, got)
30+
31+
32+
def test_ensure_array_from_plotly_data_case_base64():
33+
expected = np.arange(10, dtype=float)
34+
encoded = base64.b64encode(expected.tobytes()).decode("ascii")
35+
got = _ensure_array_from_plotly_data({"bdata": encoded, "dtype": "float"})
36+
assert_array_equal(expected, got)
37+
38+
39+
@pytest.mark.parametrize(
40+
"invalid_input",
41+
[
42+
None,
43+
"not a valid input",
44+
1234,
45+
[{"a": 1}, {"b": 2}],
46+
],
47+
)
48+
def test_ensure_array_from_plotly_data_case_invalid(invalid_input):
49+
with pytest.raises(ValueError, match="Failed to convert input to numpy array."):
50+
_ensure_array_from_plotly_data(invalid_input)

0 commit comments

Comments
 (0)