Skip to content

Commit 6852001

Browse files
committed
make legend less complex
1 parent db866fb commit 6852001

File tree

1 file changed

+97
-130
lines changed

1 file changed

+97
-130
lines changed

matplotlib2tikz/legend.py

+97-130
Original file line numberDiff line numberDiff line change
@@ -13,140 +13,28 @@ def draw_legend(data, obj):
1313
texts = []
1414
children_alignment = []
1515
for text in obj.texts:
16-
texts.append("%s" % text.get_text())
17-
children_alignment.append("%s" % text.get_horizontalalignment())
16+
texts.append("{}".format(text.get_text()))
17+
children_alignment.append("{}".format(text.get_horizontalalignment()))
1818

1919
cont = "legend entries={{%s}}" % "},{".join(texts)
2020
data["extra axis options"].add(cont)
2121

2222
# Get the location.
2323
# http://matplotlib.org/api/legend_api.html
24+
loc = obj._loc if obj._loc != 0 else _get_location_from_best(obj)
2425
pad = 0.03
25-
loc = obj._loc
26-
if loc == 0:
27-
# best
28-
# Create a renderer
29-
from matplotlib.backends import backend_agg
30-
31-
renderer = backend_agg.RendererAgg(
32-
width=obj.figure.get_figwidth(),
33-
height=obj.figure.get_figheight(),
34-
dpi=obj.figure.dpi,
35-
)
36-
37-
# Rectangles of the legend and of the axes
38-
# Lower left and upper right points
39-
x0_legend, x1_legend = obj._legend_box.get_window_extent(renderer).get_points()
40-
x0_axes, x1_axes = obj.axes.get_window_extent(renderer).get_points()
41-
dimension_legend = x1_legend - x0_legend
42-
dimension_axes = x1_axes - x0_axes
43-
44-
# To determine the actual position of the legend, check which corner
45-
# (or center) of the legend is closest to the corresponding corner
46-
# (or center) of the axes box.
47-
# 1. Key points of the legend
48-
lower_left_legend = x0_legend
49-
lower_right_legend = numpy.array(
50-
[x1_legend[0], x0_legend[1]], dtype=numpy.float_
51-
)
52-
upper_left_legend = numpy.array(
53-
[x0_legend[0], x1_legend[1]], dtype=numpy.float_
54-
)
55-
upper_right_legend = x1_legend
56-
center_legend = x0_legend + dimension_legend / 2.
57-
center_left_legend = numpy.array(
58-
[x0_legend[0], x0_legend[1] + dimension_legend[1] / 2.], dtype=numpy.float_
59-
)
60-
center_right_legend = numpy.array(
61-
[x1_legend[0], x0_legend[1] + dimension_legend[1] / 2.], dtype=numpy.float_
62-
)
63-
lower_center_legend = numpy.array(
64-
[x0_legend[0] + dimension_legend[0] / 2., x0_legend[1]], dtype=numpy.float_
65-
)
66-
upper_center_legend = numpy.array(
67-
[x0_legend[0] + dimension_legend[0] / 2., x1_legend[1]], dtype=numpy.float_
68-
)
69-
70-
# 2. Key points of the axes
71-
lower_left_axes = x0_axes
72-
lower_right_axes = numpy.array([x1_axes[0], x0_axes[1]], dtype=numpy.float_)
73-
upper_left_axes = numpy.array([x0_axes[0], x1_axes[1]], dtype=numpy.float_)
74-
upper_right_axes = x1_axes
75-
center_axes = x0_axes + dimension_axes / 2.
76-
center_left_axes = numpy.array(
77-
[x0_axes[0], x0_axes[1] + dimension_axes[1] / 2.], dtype=numpy.float_
78-
)
79-
center_right_axes = numpy.array(
80-
[x1_axes[0], x0_axes[1] + dimension_axes[1] / 2.], dtype=numpy.float_
81-
)
82-
lower_center_axes = numpy.array(
83-
[x0_axes[0] + dimension_axes[0] / 2., x0_axes[1]], dtype=numpy.float_
84-
)
85-
upper_center_axes = numpy.array(
86-
[x0_axes[0] + dimension_axes[0] / 2., x1_axes[1]], dtype=numpy.float_
87-
)
88-
89-
# 3. Compute the distances between comparable points.
90-
distances = {
91-
1: upper_right_axes - upper_right_legend, # upper right
92-
2: upper_left_axes - upper_left_legend, # upper left
93-
3: lower_left_axes - lower_left_legend, # lower left
94-
4: lower_right_axes - lower_right_legend, # lower right
95-
# 5:, Not Implemented # right
96-
6: center_left_axes - center_left_legend, # center left
97-
7: center_right_axes - center_right_legend, # center right
98-
8: lower_center_axes - lower_center_legend, # lower center
99-
9: upper_center_axes - upper_center_legend, # upper center
100-
10: center_axes - center_legend, # center
101-
}
102-
for k, v in distances.items():
103-
distances[k] = numpy.linalg.norm(v, ord=2)
104-
105-
# 4. Take the shortest distance between key points as the final
106-
# location
107-
loc = min(distances, key=distances.get)
108-
109-
if loc == 1:
110-
# upper right
111-
position = None
112-
anchor = None
113-
elif loc == 2:
114-
# upper left
115-
position = [pad, 1.0 - pad]
116-
anchor = "north west"
117-
elif loc == 3:
118-
# lower left
119-
position = [pad, pad]
120-
anchor = "south west"
121-
elif loc == 4:
122-
# lower right
123-
position = [1.0 - pad, pad]
124-
anchor = "south east"
125-
elif loc == 5:
126-
# right
127-
position = [1.0 - pad, 0.5]
128-
anchor = "east"
129-
elif loc == 6:
130-
# center left
131-
position = [3 * pad, 0.5]
132-
anchor = "west"
133-
elif loc == 7:
134-
# center right
135-
position = [1.0 - 3 * pad, 0.5]
136-
anchor = "east"
137-
elif loc == 8:
138-
# lower center
139-
position = [0.5, 3 * pad]
140-
anchor = "south"
141-
elif loc == 9:
142-
# upper center
143-
position = [0.5, 1.0 - 3 * pad]
144-
anchor = "north"
145-
else:
146-
assert loc == 10
147-
# center
148-
position = [0.5, 0.5]
149-
anchor = "center"
26+
position, anchor = {
27+
1: (None, None), # upper right
28+
2: ([pad, 1.0 - pad], "north west"), # upper left
29+
3: ([pad, pad], "south west"), # lower left
30+
4: ([1.0 - pad, pad], "south east"), # lower right
31+
5: ([1.0 - pad, 0.5], "east"), # right
32+
6: ([3 * pad, 0.5], "west"), # center left
33+
7: ([1.0 - 3 * pad, 0.5], "east"), # center right
34+
8: ([0.5, 3 * pad], "south"), # lower center
35+
9: ([0.5, 1.0 - 3 * pad], "north"), # upper center
36+
10: ([0.5, 0.5], "center"), # center
37+
}[loc]
15038

15139
# In case of given position via bbox_to_anchor parameter the center
15240
# of legend is changed as follows:
@@ -165,15 +53,15 @@ def draw_legend(data, obj):
16553
edgecolor = obj.get_frame().get_edgecolor()
16654
data, frame_xcolor, _ = mycol.mpl_color2xcolor(data, edgecolor)
16755
if frame_xcolor != "black": # black is default
168-
legend_style.append("draw=%s" % frame_xcolor)
56+
legend_style.append("draw={}".format(frame_xcolor))
16957
else:
17058
legend_style.append("draw=none")
17159

17260
# Get the facecolor of the box
17361
facecolor = obj.get_frame().get_facecolor()
17462
data, fill_xcolor, _ = mycol.mpl_color2xcolor(data, facecolor)
17563
if fill_xcolor != "white": # white is default
176-
legend_style.append("fill=%s" % fill_xcolor)
64+
legend_style.append("fill={}".format(fill_xcolor))
17765

17866
# Get the horizontal alignment
17967
try:
@@ -214,7 +102,86 @@ def draw_legend(data, obj):
214102

215103
# Write styles to data
216104
if legend_style:
217-
style = "legend style={%s}" % ", ".join(legend_style)
105+
style = "legend style={{{}}}".format(", ".join(legend_style))
218106
data["extra axis options"].add(style)
219107

220108
return data
109+
110+
111+
def _get_location_from_best(obj):
112+
# Create a renderer
113+
from matplotlib.backends import backend_agg
114+
115+
renderer = backend_agg.RendererAgg(
116+
width=obj.figure.get_figwidth(),
117+
height=obj.figure.get_figheight(),
118+
dpi=obj.figure.dpi,
119+
)
120+
121+
# Rectangles of the legend and of the axes
122+
# Lower left and upper right points
123+
x0_legend, x1_legend = obj._legend_box.get_window_extent(renderer).get_points()
124+
x0_axes, x1_axes = obj.axes.get_window_extent(renderer).get_points()
125+
dimension_legend = x1_legend - x0_legend
126+
dimension_axes = x1_axes - x0_axes
127+
128+
# To determine the actual position of the legend, check which corner
129+
# (or center) of the legend is closest to the corresponding corner
130+
# (or center) of the axes box.
131+
# 1. Key points of the legend
132+
lower_left_legend = x0_legend
133+
lower_right_legend = numpy.array([x1_legend[0], x0_legend[1]], dtype=numpy.float_)
134+
upper_left_legend = numpy.array([x0_legend[0], x1_legend[1]], dtype=numpy.float_)
135+
upper_right_legend = x1_legend
136+
center_legend = x0_legend + dimension_legend / 2.
137+
center_left_legend = numpy.array(
138+
[x0_legend[0], x0_legend[1] + dimension_legend[1] / 2.], dtype=numpy.float_
139+
)
140+
center_right_legend = numpy.array(
141+
[x1_legend[0], x0_legend[1] + dimension_legend[1] / 2.], dtype=numpy.float_
142+
)
143+
lower_center_legend = numpy.array(
144+
[x0_legend[0] + dimension_legend[0] / 2., x0_legend[1]], dtype=numpy.float_
145+
)
146+
upper_center_legend = numpy.array(
147+
[x0_legend[0] + dimension_legend[0] / 2., x1_legend[1]], dtype=numpy.float_
148+
)
149+
150+
# 2. Key points of the axes
151+
lower_left_axes = x0_axes
152+
lower_right_axes = numpy.array([x1_axes[0], x0_axes[1]], dtype=numpy.float_)
153+
upper_left_axes = numpy.array([x0_axes[0], x1_axes[1]], dtype=numpy.float_)
154+
upper_right_axes = x1_axes
155+
center_axes = x0_axes + dimension_axes / 2.
156+
center_left_axes = numpy.array(
157+
[x0_axes[0], x0_axes[1] + dimension_axes[1] / 2.], dtype=numpy.float_
158+
)
159+
center_right_axes = numpy.array(
160+
[x1_axes[0], x0_axes[1] + dimension_axes[1] / 2.], dtype=numpy.float_
161+
)
162+
lower_center_axes = numpy.array(
163+
[x0_axes[0] + dimension_axes[0] / 2., x0_axes[1]], dtype=numpy.float_
164+
)
165+
upper_center_axes = numpy.array(
166+
[x0_axes[0] + dimension_axes[0] / 2., x1_axes[1]], dtype=numpy.float_
167+
)
168+
169+
# 3. Compute the distances between comparable points.
170+
distances = {
171+
1: upper_right_axes - upper_right_legend, # upper right
172+
2: upper_left_axes - upper_left_legend, # upper left
173+
3: lower_left_axes - lower_left_legend, # lower left
174+
4: lower_right_axes - lower_right_legend, # lower right
175+
# 5:, Not Implemented # right
176+
6: center_left_axes - center_left_legend, # center left
177+
7: center_right_axes - center_right_legend, # center right
178+
8: lower_center_axes - lower_center_legend, # lower center
179+
9: upper_center_axes - upper_center_legend, # upper center
180+
10: center_axes - center_legend, # center
181+
}
182+
for k, v in distances.items():
183+
distances[k] = numpy.linalg.norm(v, ord=2)
184+
185+
# 4. Take the shortest distance between key points as the final
186+
# location
187+
return min(distances, key=distances.get)

0 commit comments

Comments
 (0)