Skip to content

Commit 4bc151f

Browse files
committed
Add type-hints to adaptive/learner/triangulation.py
1 parent 157574f commit 4bc151f

File tree

1 file changed

+96
-48
lines changed

1 file changed

+96
-48
lines changed

adaptive/learner/triangulation.py

+96-48
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
from __future__ import annotations
2+
3+
import collections.abc
4+
import numbers
15
from collections import Counter
2-
from collections.abc import Iterable, Sized
36
from itertools import chain, combinations
47
from math import factorial, sqrt
8+
from typing import Any, Iterable, Iterator, List, Sequence, Tuple, Union
59

610
import scipy.spatial
711
from numpy import abs as np_abs
@@ -13,6 +17,7 @@
1317
dot,
1418
eye,
1519
mean,
20+
ndarray,
1621
ones,
1722
square,
1823
subtract,
@@ -22,8 +27,22 @@
2227
from numpy.linalg import det as ndet
2328
from numpy.linalg import matrix_rank, norm, slogdet, solve
2429

30+
from adaptive.types import Bool
31+
32+
try:
33+
from typing import TypeAlias
34+
except ImportError:
35+
# Remove this when we drop support for Python 3.9
36+
from typing_extensions import TypeAlias
37+
2538

26-
def fast_norm(v):
39+
SimplexPoints: TypeAlias = Union[List[Tuple[float, ...]], ndarray]
40+
Simplex: TypeAlias = Union[Sequence[numbers.Integral], ndarray]
41+
Point: TypeAlias = Union[Tuple[float, ...], ndarray]
42+
Points: TypeAlias = Union[Sequence[Tuple[float, ...]], ndarray]
43+
44+
45+
def fast_norm(v: tuple[float, ...] | ndarray) -> float:
2746
"""Take the vector norm for len 2, 3 vectors.
2847
Defaults to a square root of the dot product for larger vectors.
2948
@@ -41,7 +60,9 @@ def fast_norm(v):
4160
return sqrt(dot(v, v))
4261

4362

44-
def fast_2d_point_in_simplex(point, simplex, eps=1e-8):
63+
def fast_2d_point_in_simplex(
64+
point: Point, simplex: SimplexPoints, eps: float = 1e-8
65+
) -> Bool:
4566
(p0x, p0y), (p1x, p1y), (p2x, p2y) = simplex
4667
px, py = point
4768

@@ -55,7 +76,7 @@ def fast_2d_point_in_simplex(point, simplex, eps=1e-8):
5576
return (t >= -eps) and (s + t <= 1 + eps)
5677

5778

58-
def point_in_simplex(point, simplex, eps=1e-8):
79+
def point_in_simplex(point: Point, simplex: SimplexPoints, eps: float = 1e-8) -> Bool:
5980
if len(point) == 2:
6081
return fast_2d_point_in_simplex(point, simplex, eps)
6182

@@ -66,7 +87,7 @@ def point_in_simplex(point, simplex, eps=1e-8):
6687
return all(alpha > -eps) and sum(alpha) < 1 + eps
6788

6889

69-
def fast_2d_circumcircle(points):
90+
def fast_2d_circumcircle(points: Points) -> tuple[tuple[float, float], float]:
7091
"""Compute the center and radius of the circumscribed circle of a triangle
7192
7293
Parameters
@@ -79,7 +100,7 @@ def fast_2d_circumcircle(points):
79100
tuple
80101
(center point : tuple(float), radius: float)
81102
"""
82-
points = array(points)
103+
points = array(points, dtype=float)
83104
# transform to relative coordinates
84105
pts = points[1:] - points[0]
85106

@@ -102,7 +123,9 @@ def fast_2d_circumcircle(points):
102123
return (x + points[0][0], y + points[0][1]), radius
103124

104125

105-
def fast_3d_circumcircle(points):
126+
def fast_3d_circumcircle(
127+
points: Points,
128+
) -> tuple[tuple[float, float, float], float]:
106129
"""Compute the center and radius of the circumscribed sphere of a simplex.
107130
108131
Parameters
@@ -142,7 +165,7 @@ def fast_3d_circumcircle(points):
142165
return center, radius
143166

144167

145-
def fast_det(matrix):
168+
def fast_det(matrix: ndarray) -> float:
146169
matrix = asarray(matrix, dtype=float)
147170
if matrix.shape == (2, 2):
148171
return matrix[0][0] * matrix[1][1] - matrix[1][0] * matrix[0][1]
@@ -153,7 +176,7 @@ def fast_det(matrix):
153176
return ndet(matrix)
154177

155178

156-
def circumsphere(pts):
179+
def circumsphere(pts: Simplex) -> tuple[tuple[float, ...], float]:
157180
"""Compute the center and radius of a N dimension sphere which touches each point in pts.
158181
159182
Parameters
@@ -201,7 +224,7 @@ def circumsphere(pts):
201224
return tuple(center), radius
202225

203226

204-
def orientation(face, origin):
227+
def orientation(face: tuple | ndarray, origin: tuple | ndarray) -> int:
205228
"""Compute the orientation of the face with respect to a point, origin.
206229
207230
Parameters
@@ -224,14 +247,14 @@ def orientation(face, origin):
224247
sign, logdet = slogdet(vectors - origin)
225248
if logdet < -50: # assume it to be zero when it's close to zero
226249
return 0
227-
return sign
250+
return int(sign)
228251

229252

230-
def is_iterable_and_sized(obj):
231-
return isinstance(obj, Iterable) and isinstance(obj, Sized)
253+
def is_iterable_and_sized(obj: Any) -> bool:
254+
return isinstance(obj, collections.abc.Collection)
232255

233256

234-
def simplex_volume_in_embedding(vertices) -> float:
257+
def simplex_volume_in_embedding(vertices: Sequence[Point]) -> float:
235258
"""Calculate the volume of a simplex in a higher dimensional embedding.
236259
That is: dim > len(vertices) - 1. For example if you would like to know the
237260
surface area of a triangle in a 3d space.
@@ -312,7 +335,7 @@ class Triangulation:
312335
or more simplices in the
313336
"""
314337

315-
def __init__(self, coords):
338+
def __init__(self, coords: Points) -> None:
316339
if not is_iterable_and_sized(coords):
317340
raise TypeError("Please provide a 2-dimensional list of points")
318341
coords = list(coords)
@@ -340,38 +363,40 @@ def __init__(self, coords):
340363
"(the points are linearly dependent)"
341364
)
342365

343-
self.vertices = list(coords)
344-
self.simplices = set()
366+
self.vertices: list[Point] = list(coords)
367+
self.simplices: set[Simplex] = set()
345368
# initialise empty set for each vertex
346-
self.vertex_to_simplices = [set() for _ in coords]
369+
self.vertex_to_simplices: list[set[Simplex]] = [set() for _ in coords]
347370

348371
# find a Delaunay triangulation to start with, then we will throw it
349372
# away and continue with our own algorithm
350373
initial_tri = scipy.spatial.Delaunay(coords)
351374
for simplex in initial_tri.simplices:
352375
self.add_simplex(simplex)
353376

354-
def delete_simplex(self, simplex):
377+
def delete_simplex(self, simplex: Simplex) -> None:
355378
simplex = tuple(sorted(simplex))
356379
self.simplices.remove(simplex)
357380
for vertex in simplex:
358381
self.vertex_to_simplices[vertex].remove(simplex)
359382

360-
def add_simplex(self, simplex):
383+
def add_simplex(self, simplex: Simplex) -> None:
361384
simplex = tuple(sorted(simplex))
362385
self.simplices.add(simplex)
363386
for vertex in simplex:
364387
self.vertex_to_simplices[vertex].add(simplex)
365388

366-
def get_vertices(self, indices):
389+
def get_vertices(self, indices: Iterable[numbers.Integral]) -> list[Point | None]:
367390
return [self.get_vertex(i) for i in indices]
368391

369-
def get_vertex(self, index):
392+
def get_vertex(self, index: numbers.Integral | None) -> Point | None:
370393
if index is None:
371394
return None
372395
return self.vertices[index]
373396

374-
def get_reduced_simplex(self, point, simplex, eps=1e-8) -> list:
397+
def get_reduced_simplex(
398+
self, point: Point, simplex: Simplex, eps: float = 1e-8
399+
) -> list[numbers.Integral]:
375400
"""Check whether vertex lies within a simplex.
376401
377402
Returns
@@ -396,11 +421,13 @@ def get_reduced_simplex(self, point, simplex, eps=1e-8) -> list:
396421

397422
return [simplex[i] for i in result]
398423

399-
def point_in_simplex(self, point, simplex, eps=1e-8):
424+
def point_in_simplex(
425+
self, point: Point, simplex: Simplex, eps: float = 1e-8
426+
) -> Bool:
400427
vertices = self.get_vertices(simplex)
401428
return point_in_simplex(point, vertices, eps)
402429

403-
def locate_point(self, point):
430+
def locate_point(self, point: Point) -> Simplex:
404431
"""Find to which simplex the point belongs.
405432
406433
Return indices of the simplex containing the point.
@@ -412,10 +439,15 @@ def locate_point(self, point):
412439
return ()
413440

414441
@property
415-
def dim(self):
442+
def dim(self) -> int:
416443
return len(self.vertices[0])
417444

418-
def faces(self, dim=None, simplices=None, vertices=None):
445+
def faces(
446+
self,
447+
dim: int | None = None,
448+
simplices: Iterable[Simplex] | None = None,
449+
vertices: Iterable[int] | None = None,
450+
) -> Iterator[tuple[numbers.Integral, ...]]:
419451
"""Iterator over faces of a simplex or vertex sequence."""
420452
if dim is None:
421453
dim = self.dim
@@ -436,11 +468,11 @@ def faces(self, dim=None, simplices=None, vertices=None):
436468
else:
437469
return faces
438470

439-
def containing(self, face):
471+
def containing(self, face: tuple[int, ...]) -> set[Simplex]:
440472
"""Simplices containing a face."""
441473
return set.intersection(*(self.vertex_to_simplices[i] for i in face))
442474

443-
def _extend_hull(self, new_vertex, eps=1e-8):
475+
def _extend_hull(self, new_vertex: Point, eps: float = 1e-8) -> set[Simplex]:
444476
# count multiplicities in order to get all hull faces
445477
multiplicities = Counter(face for face in self.faces())
446478
hull_faces = [face for face, count in multiplicities.items() if count == 1]
@@ -480,7 +512,9 @@ def _extend_hull(self, new_vertex, eps=1e-8):
480512

481513
return new_simplices
482514

483-
def circumscribed_circle(self, simplex, transform):
515+
def circumscribed_circle(
516+
self, simplex: Simplex, transform: ndarray
517+
) -> tuple[tuple[float, ...], float]:
484518
"""Compute the center and radius of the circumscribed circle of a simplex.
485519
486520
Parameters
@@ -496,7 +530,9 @@ def circumscribed_circle(self, simplex, transform):
496530
pts = dot(self.get_vertices(simplex), transform)
497531
return circumsphere(pts)
498532

499-
def point_in_cicumcircle(self, pt_index, simplex, transform):
533+
def point_in_cicumcircle(
534+
self, pt_index: int, simplex: Simplex, transform: ndarray
535+
) -> Bool:
500536
# return self.fast_point_in_circumcircle(pt_index, simplex, transform)
501537
eps = 1e-8
502538

@@ -506,10 +542,15 @@ def point_in_cicumcircle(self, pt_index, simplex, transform):
506542
return norm(center - pt) < (radius * (1 + eps))
507543

508544
@property
509-
def default_transform(self):
545+
def default_transform(self) -> ndarray:
510546
return eye(self.dim)
511547

512-
def bowyer_watson(self, pt_index, containing_simplex=None, transform=None):
548+
def bowyer_watson(
549+
self,
550+
pt_index: int,
551+
containing_simplex: Simplex | None = None,
552+
transform: ndarray | None = None,
553+
) -> tuple[set[Simplex], set[Simplex]]:
513554
"""Modified Bowyer-Watson point adding algorithm.
514555
515556
Create a hole in the triangulation around the new point,
@@ -569,10 +610,10 @@ def bowyer_watson(self, pt_index, containing_simplex=None, transform=None):
569610
new_triangles = self.vertex_to_simplices[pt_index]
570611
return bad_triangles - new_triangles, new_triangles - bad_triangles
571612

572-
def _simplex_is_almost_flat(self, simplex):
613+
def _simplex_is_almost_flat(self, simplex: Simplex) -> Bool:
573614
return self._relative_volume(simplex) < 1e-8
574615

575-
def _relative_volume(self, simplex):
616+
def _relative_volume(self, simplex: Simplex) -> float:
576617
"""Compute the volume of a simplex divided by the average (Manhattan)
577618
distance of its vertices. The advantage of this is that the relative
578619
volume is only dependent on the shape of the simplex and not on the
@@ -583,20 +624,25 @@ def _relative_volume(self, simplex):
583624
average_edge_length = mean(np_abs(vectors))
584625
return self.volume(simplex) / (average_edge_length**self.dim)
585626

586-
def add_point(self, point, simplex=None, transform=None):
627+
def add_point(
628+
self,
629+
point: Point,
630+
simplex: Simplex | None = None,
631+
transform: ndarray | None = None,
632+
) -> tuple[set[Simplex], set[Simplex]]:
587633
"""Add a new vertex and create simplices as appropriate.
588634
589635
Parameters
590636
----------
591637
point : float vector
592638
Coordinates of the point to be added.
593-
transform : N*N matrix of floats
594-
Multiplication matrix to apply to the point (and neighbouring
595-
simplices) when running the Bowyer Watson method.
596639
simplex : tuple of ints, optional
597640
Simplex containing the point. Empty tuple indicates points outside
598641
the hull. If not provided, the algorithm costs O(N), so this should
599642
be used whenever possible.
643+
transform : N*N matrix of floats
644+
Multiplication matrix to apply to the point (and neighbouring
645+
simplices) when running the Bowyer Watson method.
600646
"""
601647
point = tuple(point)
602648
if simplex is None:
@@ -632,16 +678,16 @@ def add_point(self, point, simplex=None, transform=None):
632678
self.vertices.append(point)
633679
return self.bowyer_watson(pt_index, actual_simplex, transform)
634680

635-
def volume(self, simplex):
681+
def volume(self, simplex: Simplex) -> float:
636682
prefactor = factorial(self.dim)
637683
vertices = array(self.get_vertices(simplex))
638684
vectors = vertices[1:] - vertices[0]
639685
return float(abs(fast_det(vectors)) / prefactor)
640686

641-
def volumes(self):
687+
def volumes(self) -> list[float]:
642688
return [self.volume(sim) for sim in self.simplices]
643689

644-
def reference_invariant(self):
690+
def reference_invariant(self) -> bool:
645691
"""vertex_to_simplices and simplices are compatible."""
646692
for vertex in range(len(self.vertices)):
647693
if any(vertex not in tri for tri in self.vertex_to_simplices[vertex]):
@@ -655,26 +701,28 @@ def vertex_invariant(self, vertex):
655701
"""Simplices originating from a vertex don't overlap."""
656702
raise NotImplementedError
657703

658-
def get_neighbors_from_vertices(self, simplex):
704+
def get_neighbors_from_vertices(self, simplex: Simplex) -> set[Simplex]:
659705
return set.union(*[self.vertex_to_simplices[p] for p in simplex])
660706

661-
def get_face_sharing_neighbors(self, neighbors, simplex):
707+
def get_face_sharing_neighbors(
708+
self, neighbors: set[Simplex], simplex: Simplex
709+
) -> set[Simplex]:
662710
"""Keep only the simplices sharing a whole face with simplex."""
663711
return {
664712
simpl for simpl in neighbors if len(set(simpl) & set(simplex)) == self.dim
665713
} # they share a face
666714

667-
def get_simplices_attached_to_points(self, indices):
715+
def get_simplices_attached_to_points(self, indices: Simplex) -> set[Simplex]:
668716
# Get all simplices that share at least a point with the simplex
669717
neighbors = self.get_neighbors_from_vertices(indices)
670718
return self.get_face_sharing_neighbors(neighbors, indices)
671719

672-
def get_opposing_vertices(self, simplex):
720+
def get_opposing_vertices(self, simplex: Simplex) -> tuple[int, ...]:
673721
if simplex not in self.simplices:
674722
raise ValueError("Provided simplex is not part of the triangulation")
675723
neighbors = self.get_simplices_attached_to_points(simplex)
676724

677-
def find_opposing_vertex(vertex):
725+
def find_opposing_vertex(vertex: int):
678726
# find the simplex:
679727
simp = next((x for x in neighbors if vertex not in x), None)
680728
if simp is None:
@@ -687,7 +735,7 @@ def find_opposing_vertex(vertex):
687735
return result
688736

689737
@property
690-
def hull(self):
738+
def hull(self) -> set[numbers.Integral]:
691739
"""Compute hull from triangulation.
692740
693741
Parameters

0 commit comments

Comments
 (0)