Skip to content

Commit f95ce27

Browse files
committed
Simplify/robustify segment-point distance calculation.
The version in poly_editor is relatively simple because it only supports array inputs and doesn't vectorize over any input. The version in proj3d is private so the API can be changed, but it needs (currently) to at least support non-array inputs and to vectorize over `p`. - Rename the parameters to make the difference between the "segment ends" (`s0, s1`) and the "point" (`p`) parameters clearer. - Switch `p` to support (N, ndim) inputs instead of (ndim, N) (consistently with most other APIs); adjust test_lines_dists accordingly. - Use vectorized ops everywhere, which also caught the fact that previously, entries beyond the third in (what was) `p1, p2` would be silently ignored (because access was via `p1[0]`, `p1[1]`, `p2[0]`, `p2[1]`). Instead now the vectorized version naturally extends to any number of dimensions. Adjust format_coord and test_lines_dists_nowarning accordingly. - Also support vectorizing over `s0`, `s1`, if they have the same length as `p` (this comes basically for free).
1 parent e9d1f9c commit f95ce27

File tree

4 files changed

+37
-58
lines changed

4 files changed

+37
-58
lines changed

examples/event_handling/poly_editor.py

+10-23
Original file line numberDiff line numberDiff line change
@@ -14,37 +14,24 @@
1414
You can copy and paste individual parts, or download the entire example
1515
using the link at the bottom of the page.
1616
"""
17+
1718
import numpy as np
1819
from matplotlib.lines import Line2D
1920
from matplotlib.artist import Artist
2021

2122

22-
def dist(x, y):
23-
"""
24-
Return the distance between two points.
25-
"""
26-
d = x - y
27-
return np.sqrt(np.dot(d, d))
28-
29-
3023
def dist_point_to_segment(p, s0, s1):
3124
"""
32-
Get the distance of a point to a segment.
33-
*p*, *s0*, *s1* are *xy* sequences
34-
This algorithm from
35-
http://www.geomalgorithms.com/algorithms.html
25+
Get the distance from the point *p* to the segment (*s0*, *s1*), where
26+
*p*, *s0*, *s1* are ``[x, y]`` arrays.
3627
"""
37-
v = s1 - s0
38-
w = p - s0
39-
c1 = np.dot(w, v)
40-
if c1 <= 0:
41-
return dist(p, s0)
42-
c2 = np.dot(v, v)
43-
if c2 <= c1:
44-
return dist(p, s1)
45-
b = c1 / c2
46-
pb = s0 + b * v
47-
return dist(p, pb)
28+
s01 = s1 - s0
29+
s0p = p - s0
30+
if (s01 == 0).all():
31+
return np.hypot(*s0p)
32+
# Project onto segment, without going past segment ends.
33+
p1 = s0 + np.clip((s0p @ s01) / (s01 @ s01), 0, 1) * s01
34+
return np.hypot(*(p - p1))
4835

4936

5037
class PolygonInteractor:

lib/mpl_toolkits/mplot3d/axes3d.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1066,7 +1066,7 @@ def format_coord(self, xd, yd):
10661066
# nearest edge
10671067
p0, p1 = min(self._tunit_edges(),
10681068
key=lambda edge: proj3d._line2d_seg_dist(
1069-
edge[0], edge[1], (xd, yd)))
1069+
(xd, yd), edge[0][:2], edge[1][:2]))
10701070

10711071
# scale the z value to match
10721072
x0, y0, z0 = p0

lib/mpl_toolkits/mplot3d/proj3d.py

+19-23
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,27 @@
66
import numpy.linalg as linalg
77

88

9-
def _line2d_seg_dist(p1, p2, p0):
9+
def _line2d_seg_dist(p, s0, s1):
1010
"""
11-
Return the distance(s) from line defined by p1 - p2 to point(s) p0.
11+
Return the distance(s) from point(s) *p* to segment(s) (*s0*, *s1*).
1212
13-
p0[0] = x(s)
14-
p0[1] = y(s)
15-
16-
intersection point p = p1 + u*(p2-p1)
17-
and intersection point lies within segment if u is between 0 and 1.
18-
19-
If p1 and p2 are identical, the distance between them and p0 is returned.
20-
"""
21-
22-
x01 = np.asarray(p0[0]) - p1[0]
23-
y01 = np.asarray(p0[1]) - p1[1]
24-
if np.all(p1[0:2] == p2[0:2]):
25-
return np.hypot(x01, y01)
26-
27-
x21 = p2[0] - p1[0]
28-
y21 = p2[1] - p1[1]
29-
u = (x01*x21 + y01*y21) / (x21**2 + y21**2)
30-
u = np.clip(u, 0, 1)
31-
d = np.hypot(x01 - u*x21, y01 - u*y21)
32-
33-
return d
13+
Parameters
14+
----------
15+
p : (ndim,) or (N, ndim) array-like
16+
The points from which the distances are computed.
17+
s0, s1 : (ndim,) or (N, ndim) array-like
18+
The xy(z...) coordinates of the segment endpoints.
19+
"""
20+
s0 = np.asarray(s0)
21+
s01 = s1 - s0 # shape (ndim,) or (N, ndim)
22+
s0p = p - s0 # shape (ndim,) or (N, ndim)
23+
l2 = s01 @ s01 # squared segment length
24+
# Avoid div. by zero for degenerate segments (for them, s01 = (0, 0, ...)
25+
# so the value of l2 doesn't matter; this just replaces 0/0 by 0/1).
26+
l2 = np.where(l2, l2, 1)
27+
# Project onto segment, without going past segment ends.
28+
p1 = s0 + np.multiply.outer(np.clip(s0p @ s01 / l2, 0, 1), s01)
29+
return ((p - p1) ** 2).sum(axis=-1) ** (1/2)
3430

3531

3632
def world_transformation(xmin, xmax,

lib/mpl_toolkits/mplot3d/tests/test_axes3d.py

+7-11
Original file line numberDiff line numberDiff line change
@@ -1120,8 +1120,8 @@ def test_lines_dists():
11201120
ys = (100, 150, 30, 200)
11211121
ax.scatter(xs, ys)
11221122

1123-
dist0 = proj3d._line2d_seg_dist(p0, p1, (xs[0], ys[0]))
1124-
dist = proj3d._line2d_seg_dist(p0, p1, np.array((xs, ys)))
1123+
dist0 = proj3d._line2d_seg_dist((xs[0], ys[0]), p0, p1)
1124+
dist = proj3d._line2d_seg_dist(np.array((xs, ys)).T, p0, p1)
11251125
assert dist0 == dist[0]
11261126

11271127
for x, y, d in zip(xs, ys, dist):
@@ -1133,15 +1133,11 @@ def test_lines_dists():
11331133

11341134

11351135
def test_lines_dists_nowarning():
1136-
# Smoke test to see that no RuntimeWarning is emitted when two first
1137-
# arguments are the same, see GH#22624
1138-
p0 = (10, 30, 50)
1139-
p1 = (10, 30, 20)
1140-
p2 = (20, 150)
1141-
proj3d._line2d_seg_dist(p0, p0, p2)
1142-
proj3d._line2d_seg_dist(p0, p1, p2)
1143-
p0 = np.array(p0)
1144-
proj3d._line2d_seg_dist(p0, p0, p2)
1136+
# No RuntimeWarning must be emitted for degenerate segments, see GH#22624.
1137+
s0 = (10, 30, 50)
1138+
p = (20, 150, 180)
1139+
proj3d._line2d_seg_dist(p, s0, s0)
1140+
proj3d._line2d_seg_dist(np.array(p), s0, s0)
11451141

11461142

11471143
def test_autoscale():

0 commit comments

Comments
 (0)