Skip to content

Commit 99d39bd

Browse files
authored
Merge pull request matplotlib#24970 from greglucas/cmap-cast
FIX: Handle uint8 indices properly for colormap lookups
2 parents 235b01f + 9229e5a commit 99d39bd

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

lib/matplotlib/colors.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -715,16 +715,17 @@ def __call__(self, X, alpha=None, bytes=False):
715715
if not xa.dtype.isnative:
716716
xa = xa.byteswap().newbyteorder() # Native byteorder is faster.
717717
if xa.dtype.kind == "f":
718-
with np.errstate(invalid="ignore"):
719-
xa *= self.N
720-
# Negative values are out of range, but astype(int) would
721-
# truncate them towards zero.
722-
xa[xa < 0] = -1
723-
# xa == 1 (== N after multiplication) is not out of range.
724-
xa[xa == self.N] = self.N - 1
725-
# Avoid converting large positive values to negative integers.
726-
np.clip(xa, -1, self.N, out=xa)
727-
xa = xa.astype(int)
718+
xa *= self.N
719+
# Negative values are out of range, but astype(int) would
720+
# truncate them towards zero.
721+
xa[xa < 0] = -1
722+
# xa == 1 (== N after multiplication) is not out of range.
723+
xa[xa == self.N] = self.N - 1
724+
# Avoid converting large positive values to negative integers.
725+
np.clip(xa, -1, self.N, out=xa)
726+
with np.errstate(invalid="ignore"):
727+
# We need this cast for unsigned ints as well as floats
728+
xa = xa.astype(int)
728729
# Set the over-range indices before the under-range;
729730
# otherwise the under-range values get converted to over-range.
730731
xa[xa > self.N - 1] = self._i_over

lib/matplotlib/tests/test_colors.py

+7
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ def test_create_lookup_table(N, result):
3030
assert_array_almost_equal(mcolors._create_lookup_table(N, data), result)
3131

3232

33+
@pytest.mark.parametrize("dtype", [np.uint8, int, np.float16, float])
34+
def test_index_dtype(dtype):
35+
# We use subtraction in the indexing, so need to verify that uint8 works
36+
cm = mpl.colormaps["viridis"]
37+
assert_array_equal(cm(dtype(0)), cm(0))
38+
39+
3340
def test_resampled():
3441
"""
3542
GitHub issue #6025 pointed to incorrect ListedColormap.resampled;

0 commit comments

Comments
 (0)