-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathadd_arrow.py
58 lines (52 loc) · 1.85 KB
/
add_arrow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from matplotlib import pyplot as plt
import numpy as np
def add_arrow(lines, position=None, direction='forward', size=15, color='black'):
"""
% (C) Nick Holschuh - Amherst College -- 2022 ([email protected])
%
% This function adds an arrow to a line. Adapted from:
% https://stackoverflow.com/questions/34017866/arrow-on-a-line-plot
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% The inputs are:
% line: Line2D object or list of lines
% position: x-position of the arrow. If None, mean of xdata is taken
% direction: 'left' or 'right'
% size: size of the arrow in fontsize points
% color: if None, line color is taken.
%
%%%%%%%%%%%%%%%
% The outputs are:
%
% N/A
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
"""
if isinstance(lines,type([])) == 0:
lines = [lines]
for line in lines:
if color is None:
color = line.get_color()
xdata = line.get_xdata()
ydata = line.get_ydata()
if position is None:
position = xdata.mean()
# find closest index
if direction == 'right':
start_ind = np.argmin(np.absolute(xdata - position))
end_ind = start_ind + 1
elif direction == 'forward':
start_ind = 0;
end_ind = start_ind+1
else:
start_ind = np.argmin(np.absolute(xdata - position))
end_ind = start_ind - 1
#print(start_ind,end_ind)
#print(xdata)
#print(ydata)
line.axes.annotate('',
xytext=(xdata[start_ind], ydata[start_ind]),
xy=(xdata[end_ind], ydata[end_ind]),
arrowprops=dict(arrowstyle="->", color=color),
size=size
)