Skip to content

Commit b7c30cf

Browse files
committed
contrast_ratio() in pymatviz/utils/plotting.py to calculate color contrast according to WCAG 2.0
fix text color contrast for dark elements in pmv.structure_(2|3)d_plotly this required adjusting the `pick_max_contrast_color` function to use a lower contrast threshold (2.0 instead of 2.5) to ensure dark-colored elements like Nickel and Barium get white text labels for better visibility also fix the color comments in the ALLOY color scheme to accurately reflect the actual RGB values minor cleanup in structure_viz/plotly.py
1 parent c8655e7 commit b7c30cf

File tree

6 files changed

+240
-115
lines changed

6 files changed

+240
-115
lines changed

assets/scripts/structure_viz/structure_3d_plotly.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
title = "CoCrFeNiMn High-Entropy Alloy"
5959
fig.layout.title = title
6060
fig.show()
61-
pmv.io.save_and_compress_svg(fig, "hea-structure-3d-plotly")
61+
# pmv.io.save_and_compress_svg(fig, "hea-structure-3d-plotly")
6262

6363

6464
# %% Li-ion battery cathode material with Li vacancies: Li0.8CoO2
@@ -78,4 +78,4 @@
7878
title = "Li0.8CoO2 with Li Vacancies"
7979
fig.layout.title = title
8080
fig.show()
81-
pmv.io.save_and_compress_svg(fig, "lco-structure-3d-plotly")
81+
# pmv.io.save_and_compress_svg(fig, "lco-structure-3d-plotly")

pymatviz/colors.py

+37-37
Original file line numberDiff line numberDiff line change
@@ -273,47 +273,47 @@
273273
# overriding metal colors.
274274
ELEM_COLORS_ALLOY_256: dict[str, Rgb256ColorType] = ELEM_COLORS_VESTA_256 | {
275275
# Alkali metals - bright purples
276-
"Li": (0, 53, 0), # Bright purple
277-
"Na": (0, 41, 255), # Deep purple
278-
"K": (0, 255, 0), # Royal purple
279-
"Rb": (0, 255, 255), # Dark purple
280-
"Cs": (255, 0, 0), # Deep violet
276+
"Li": (0, 53, 0), # Dark green
277+
"Na": (0, 41, 255), # Deep blue
278+
"K": (0, 255, 0), # Bright green
279+
"Rb": (0, 255, 255), # Cyan
280+
"Cs": (255, 0, 0), # Bright red
281281
# Alkaline earth metals - yellows/oranges
282-
"Be": (255, 0, 255), # Golden yellow
283-
"Mg": (255, 255, 0), # Dark orange
284-
"Ca": (255, 255, 255), # Bright orange
285-
"Sr": (38, 154, 0), # Red-orange
286-
"Ba": (0, 150, 255), # Pure red
282+
"Be": (255, 0, 255), # Magenta
283+
"Mg": (255, 255, 0), # Yellow
284+
"Ca": (255, 255, 255), # White
285+
"Sr": (38, 154, 0), # Green
286+
"Ba": (0, 150, 255), # Blue
287287
# Transition metals - maximizing contrast
288-
"Sc": (207, 26, 128), # Light gray (from JMOL)
289-
"Ti": (216, 219, 127), # Purple (changed from blue for more contrast with Zr)
290-
"V": (255, 150, 0), # Pink
291-
"Cr": (197, 163, 255), # Bright green
292-
"Mn": (0, 46, 133), # Magenta
293-
"Fe": (0, 151, 134), # Bright orange (changed from JMOL)
294-
"Co": (0, 255, 121), # Deep blue
295-
"Ni": (99, 0, 62), # Orange (changed from green for contrast with Zr)
296-
"Cu": (129, 0, 255), # Brown (changed from JMOL)
297-
"Zn": (168, 74, 0), # Light blue
298-
"Zr": (108, 96, 208), # Cyan (kept)
299-
"Nb": (134, 228, 15), # Purple (new)
288+
"Sc": (207, 26, 128), # Pink
289+
"Ti": (216, 219, 127), # Light yellow-green
290+
"V": (255, 150, 0), # Orange
291+
"Cr": (197, 163, 255), # Light purple
292+
"Mn": (0, 46, 133), # Dark blue
293+
"Fe": (0, 151, 134), # Teal
294+
"Co": (0, 255, 121), # Bright green
295+
"Ni": (99, 0, 62), # Dark red/burgundy
296+
"Cu": (129, 0, 255), # Purple
297+
"Zn": (168, 74, 0), # Brown
298+
"Zr": (108, 96, 208), # Medium blue-purple
299+
"Nb": (134, 228, 15), # Lime green
300300
# Post-transition metals - earth tones
301-
"Al": (102, 211, 188), # Gray (from JMOL)
302-
"Ga": (255, 121, 143), # Rose
303-
"In": (131, 143, 93), # Dusty rose
304-
"Sn": (197, 163, 255), # Dark orange (changed from blue-gray for contrast with Zr)
305-
"Tl": (0, 46, 133), # Terra cotta
306-
"Pb": (0, 151, 134), # Dark gray
307-
"Bi": (0, 255, 121), # Purple
301+
"Al": (102, 211, 188), # Light teal
302+
"Ga": (255, 121, 143), # Pink
303+
"In": (131, 143, 93), # Olive green
304+
"Sn": (197, 163, 255), # Light purple
305+
"Tl": (0, 46, 133), # Dark blue
306+
"Pb": (0, 151, 134), # Teal
307+
"Bi": (0, 255, 121), # Bright green
308308
# Noble metals - preserving traditional colors
309-
"Ru": (99, 0, 62), # Teal
310-
"Rh": (129, 0, 255), # Hot pink
311-
"Pd": (168, 74, 0), # Blue (from JMOL)
312-
"Ag": (108, 96, 208), # Silver (from JMOL)
313-
"Os": (134, 228, 15), # Blue (from JMOL)
314-
"Ir": (102, 211, 188), # Dark blue (from JMOL)
315-
"Pt": (255, 121, 143), # Light gray (from JMOL)
316-
"Au": (131, 143, 93), # Gold (from JMOL)
309+
"Ru": (99, 0, 62), # Dark red/burgundy
310+
"Rh": (129, 0, 255), # Purple
311+
"Pd": (168, 74, 0), # Brown
312+
"Ag": (108, 96, 208), # Medium blue-purple
313+
"Os": (134, 228, 15), # Lime green
314+
"Ir": (102, 211, 188), # Light teal
315+
"Pt": (255, 121, 143), # Pink
316+
"Au": (131, 143, 93), # Olive green
317317
}
318318

319319
ELEM_COLORS_ALLOY: dict[str, RgbColorType] = {

pymatviz/structure_viz/helpers.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
from pymatviz.colors import ELEM_COLORS_ALLOY, ELEM_COLORS_JMOL, ELEM_COLORS_VESTA
1616
from pymatviz.enums import ElemColorScheme, Key, SiteCoords
17-
from pymatviz.utils import df_ptable, pick_max_contrast_color
17+
from pymatviz.utils import df_ptable
18+
from pymatviz.utils.plotting import pick_max_contrast_color
1819

1920

2021
if TYPE_CHECKING:
@@ -305,12 +306,15 @@ def draw_site(
305306
txt = generate_site_label(site_labels, site_idx, majority_species)
306307

307308
marker = dict(
308-
size=site_radius * atom_size * (0.8 if is_image else 1),
309+
size=site_radius * atom_size,
309310
color=color,
310-
opacity=0.5 if is_image else 1,
311+
opacity=0.8 if is_image else 1,
312+
line=dict(width=1, color="gray"),
311313
)
312314
marker.update(site_kwargs)
313315

316+
# Calculate text color based on background color for maximum contrast
317+
text_color = pick_max_contrast_color(color)
314318
scatter_kwargs = dict(
315319
x=[coords[0]],
316320
y=[coords[1]],
@@ -319,7 +323,7 @@ def draw_site(
319323
text=txt,
320324
textposition="middle center",
321325
textfont=dict(
322-
color=pick_max_contrast_color(color),
326+
color=text_color,
323327
size=np.clip(atom_size * site_radius * (0.8 if is_image else 1), 10, 18),
324328
),
325329
hoverinfo="text" if hover_text else None,

pymatviz/structure_viz/plotly.py

+6-40
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,6 @@ def structure_2d_plotly(
204204

205205
# Plot atoms and vectors
206206
if show_sites:
207-
site_kwargs = dict(line=dict(width=0.3, color="gray"))
208-
if isinstance(show_sites, dict):
209-
site_kwargs |= show_sites
210-
211207
for site_idx, (site, coords) in enumerate(
212208
zip(struct_i, rotated_coords, strict=False)
213209
):
@@ -221,7 +217,7 @@ def structure_2d_plotly(
221217
_atomic_radii,
222218
atom_size,
223219
scale,
224-
site_kwargs,
220+
{} if show_sites is True else show_sites,
225221
is_3d=False,
226222
row=row,
227223
col=col,
@@ -251,17 +247,6 @@ def structure_2d_plotly(
251247

252248
# Add image sites
253249
if show_image_sites:
254-
image_site_kwargs = dict(
255-
size=_atomic_radii[site.species.elements[0].symbol]
256-
* scale
257-
* atom_size
258-
* 0.8,
259-
color=_elem_colors.get(site.species.elements[0].symbol, "gray"),
260-
opacity=0.5,
261-
)
262-
if isinstance(show_image_sites, dict):
263-
image_site_kwargs |= show_image_sites
264-
265250
image_atoms = get_image_sites(site, struct_i.lattice)
266251
if len(image_atoms) > 0:
267252
rotated_image_atoms = np.dot(image_atoms, rotation_matrix)
@@ -277,7 +262,7 @@ def structure_2d_plotly(
277262
_atomic_radii,
278263
atom_size,
279264
scale,
280-
image_site_kwargs,
265+
{} if show_image_sites is True else show_image_sites,
281266
is_image=True,
282267
is_3d=False,
283268
row=row,
@@ -289,9 +274,7 @@ def structure_2d_plotly(
289274
draw_unit_cell(
290275
fig,
291276
struct_i,
292-
unit_cell_kwargs=show_unit_cell
293-
if isinstance(show_unit_cell, dict)
294-
else {},
277+
unit_cell_kwargs={} if show_unit_cell is True else show_unit_cell,
295278
is_3d=False,
296279
row=row,
297280
col=col,
@@ -484,10 +467,6 @@ def structure_3d_plotly(
484467

485468
# Plot atoms and vectors
486469
if show_sites:
487-
site_kwargs = dict(line=dict(width=0.3, color="gray"))
488-
if isinstance(show_sites, dict):
489-
site_kwargs |= show_sites
490-
491470
for site_idx, site in enumerate(struct_i):
492471
draw_site(
493472
fig,
@@ -499,7 +478,7 @@ def structure_3d_plotly(
499478
_atomic_radii,
500479
atom_size,
501480
scale,
502-
site_kwargs,
481+
{} if show_sites is True else show_sites,
503482
is_3d=True,
504483
scene=f"scene{idx}",
505484
name=f"site{site_idx}",
@@ -527,17 +506,6 @@ def structure_3d_plotly(
527506

528507
# Add image sites
529508
if show_image_sites:
530-
image_site_kwargs = dict(
531-
size=_atomic_radii[site.species.elements[0].symbol]
532-
* scale
533-
* atom_size
534-
* 0.8,
535-
color=_elem_colors.get(site.species.elements[0].symbol, "gray"),
536-
opacity=0.5,
537-
)
538-
if isinstance(show_image_sites, dict):
539-
image_site_kwargs |= show_image_sites
540-
541509
image_atoms = get_image_sites(site, struct_i.lattice)
542510
if len(image_atoms) > 0:
543511
for image_coords in image_atoms:
@@ -551,7 +519,7 @@ def structure_3d_plotly(
551519
_atomic_radii,
552520
atom_size,
553521
scale,
554-
image_site_kwargs,
522+
{} if show_image_sites is True else show_image_sites,
555523
is_image=True,
556524
is_3d=True,
557525
scene=f"scene{idx}",
@@ -562,9 +530,7 @@ def structure_3d_plotly(
562530
draw_unit_cell(
563531
fig,
564532
struct_i,
565-
unit_cell_kwargs=show_unit_cell
566-
if isinstance(show_unit_cell, dict)
567-
else {},
533+
unit_cell_kwargs={} if show_unit_cell is True else show_unit_cell,
568534
is_3d=True,
569535
scene=f"scene{idx}",
570536
)

pymatviz/utils/plotting.py

+61-13
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,8 @@ def luminance(color: ColorType) -> float:
203203
"""Compute the relative luminance of a color using the WCAG 2.0 formula.
204204
205205
Args:
206-
color (ColorType): RGB color tuple with values in [0, 1] or a color string
207-
that can be converted to RGB.
206+
color (ColorType): RGB color tuple with values in [0, 1] or [0, 255], or a color
207+
string that can be converted to RGB.
208208
209209
Returns:
210210
float: Relative luminance of the color in range [0, 1].
@@ -213,36 +213,84 @@ def luminance(color: ColorType) -> float:
213213
r, g, b, *_a = map(float, color.strip("rgb()").split(","))
214214
if r > 1 or g > 1 or b > 1:
215215
r, g, b = r / 255, g / 255, b / 255
216+
elif isinstance(color, tuple) and len(color) >= 3:
217+
# Check if any value is > 1, indicating 0-255 range
218+
if any(c > 1 for c in color[:3]):
219+
r, g, b = color[0] / 255, color[1] / 255, color[2] / 255
220+
else:
221+
r, g, b = color[:3]
216222
else:
217223
# raises ValueError if color invalid
218224
r, g, b, *_a = matplotlib.colors.to_rgba(color)
219225

226+
def _convert_rgb_to_linear(rgb: float) -> float:
227+
"""Convert an RGB value to linear RGB (remove gamma correction)."""
228+
return rgb / 12.92 if rgb <= 0.03928 else ((rgb + 0.055) / 1.055) ** 2.4
229+
230+
# Convert RGB to linear RGB (remove gamma correction)
231+
r, g, b = map(_convert_rgb_to_linear, (r, g, b))
232+
220233
# Calculate relative luminance using WCAG 2.0 coefficients
221234
return 0.2126 * r + 0.7152 * g + 0.0722 * b
222235

223236

237+
def contrast_ratio(color1: ColorType, color2: ColorType) -> float:
238+
"""Calculate the contrast ratio between two colors according to WCAG 2.0.
239+
240+
Args:
241+
color1 (ColorType): First color (RGB tuple with values in [0, 1] or [0, 255],
242+
or a color string that can be converted to RGB).
243+
color2 (ColorType): Second color (RGB tuple with values in [0, 1] or [0, 255],
244+
or a color string that can be converted to RGB).
245+
246+
Returns:
247+
float: Contrast ratio between the two colors, ranging from 1:1 to 21:1.
248+
"""
249+
lum1 = luminance(color1)
250+
lum2 = luminance(color2)
251+
252+
# Ensure lighter color is first for the formula
253+
lighter = max(lum1, lum2)
254+
darker = min(lum1, lum2)
255+
256+
# Calculate contrast ratio: (L1 + 0.05) / (L2 + 0.05)
257+
return (lighter + 0.05) / (darker + 0.05)
258+
259+
224260
def pick_max_contrast_color(
225261
bg_color: ColorType,
226-
luminance_threshold: float = 0.3, # Threshold for light/dark color distinction
227262
colors: tuple[ColorType, ColorType] = ("white", "black"),
263+
min_contrast_ratio: float = 2.0, # Lower threshold makes dark colors get white text
228264
) -> ColorType:
229-
"""Choose dark or light text color for a given background color based on WCAG 2.0.
265+
"""Choose text color for a given background color based on WCAG 2.0 contrast ratio.
266+
267+
This function calculates the contrast ratio between the background color and each
268+
of the provided text colors, then returns the color with the highest contrast ratio.
269+
If the contrast ratio with white is above the minimum contrast ratio, white will be
270+
chosen even if black has a slightly higher contrast ratio. This ensures that darker
271+
colors always get white text, which is often more readable in 3D visualizations.
230272
231273
Args:
232274
bg_color (ColorType): Background color.
233-
luminance_threshold (float, optional): Luminance threshold for choosing text
234-
color. Defaults to 0.5 to distinguish between light and dark colors.
235-
colors (tuple[ColorType, ColorType], optional): One light and one dark text
236-
color to choose from in that order. Defaults to ("white", "black").
275+
colors (tuple[ColorType, ColorType], optional): Text colors to choose
276+
from. Defaults to ("white", "black").
277+
min_contrast_ratio (float, optional): Minimum contrast ratio to prefer white
278+
over black text. Defaults to 2.0 (lower than WCAG AA standard to ensure
279+
dark colors get white text).
237280
238281
Returns:
239-
ColorType: The color that provides better contrast, usually "black" or "white".
282+
ColorType: item in `colors` that provides the best contrast with bg_color.
240283
"""
241-
# Calculate luminance of the background color
242-
bg_luminance = luminance(bg_color)
284+
# Calculate contrast ratios for each potential text color
285+
contrast_ratios = [contrast_ratio(bg_color, color) for color in colors]
286+
287+
# If the contrast ratio with white is above the minimum contrast ratio,
288+
# prefer white text even if black has a slightly higher contrast ratio
289+
if contrast_ratios[0] >= min_contrast_ratio:
290+
return colors[0]
243291

244-
# Use black text on light colors (luminance > threshold)
245-
return colors[1] if bg_luminance > luminance_threshold else colors[0]
292+
# Otherwise, return the color with the highest contrast ratio
293+
return colors[contrast_ratios.index(max(contrast_ratios))]
246294

247295

248296
def pretty_label(key: str, backend: Backend) -> str:

0 commit comments

Comments
 (0)