diff --git a/adaptive/learner/learnerND.py b/adaptive/learner/learnerND.py index 014692f12..e0d843f44 100644 --- a/adaptive/learner/learnerND.py +++ b/adaptive/learner/learnerND.py @@ -6,14 +6,18 @@ from collections import OrderedDict from collections.abc import Iterable from copy import deepcopy +from typing import Any, Callable, Sequence import numpy as np import scipy.spatial from scipy import interpolate +from scipy.spatial.qhull import ConvexHull from sortedcontainers import SortedKeyList from adaptive.learner.base_learner import BaseLearner, uses_nth_neighbors from adaptive.learner.triangulation import ( + Point, + Simplex, Triangulation, circumsphere, fast_det, @@ -21,6 +25,7 @@ simplex_volume_in_embedding, ) from adaptive.notebook_integration import ensure_holoviews, ensure_plotly +from adaptive.types import Bool from adaptive.utils import ( assign_defaults, cache_latest, @@ -37,13 +42,13 @@ with_pandas = False -def to_list(inp): +def to_list(inp: float) -> list[float]: if isinstance(inp, Iterable): return list(inp) return [inp] -def volume(simplex, ys=None): +def volume(simplex: Simplex, ys: None = None) -> float: # Notice the parameter ys is there so you can use this volume method as # as loss function matrix = np.subtract(simplex[:-1], simplex[-1], dtype=float) @@ -54,14 +59,14 @@ def volume(simplex, ys=None): return vol -def orientation(simplex): +def orientation(simplex: np.ndarray): matrix = np.subtract(simplex[:-1], simplex[-1]) # See https://www.jstor.org/stable/2315353 sign, _logdet = np.linalg.slogdet(matrix) return sign -def uniform_loss(simplex, values, value_scale): +def uniform_loss(simplex: np.ndarray, values: np.ndarray, value_scale: float) -> float: """ Uniform loss. @@ -81,7 +86,7 @@ def uniform_loss(simplex, values, value_scale): return volume(simplex) -def std_loss(simplex, values, value_scale): +def std_loss(simplex: Simplex, values: np.ndarray, value_scale: float) -> np.ndarray: """ Computes the loss of the simplex based on the standard deviation. @@ -107,7 +112,7 @@ def std_loss(simplex, values, value_scale): return r.flat * np.power(vol, 1.0 / dim) + vol -def default_loss(simplex, values, value_scale): +def default_loss(simplex: np.ndarray, values: np.ndarray, value_scale: float) -> float: """ Computes the average of the volumes of the simplex. @@ -132,7 +137,13 @@ def default_loss(simplex, values, value_scale): @uses_nth_neighbors(1) -def triangle_loss(simplex, values, value_scale, neighbors, neighbor_values): +def triangle_loss( + simplex: np.ndarray, + values: np.ndarray, + value_scale: float, + neighbors: list[None | np.ndarray] | list[None] | list[np.ndarray], + neighbor_values: list[None | float] | list[None] | list[float], +) -> int | float: """ Computes the average of the volumes of the simplex combined with each neighbouring point. @@ -169,7 +180,7 @@ def triangle_loss(simplex, values, value_scale, neighbors, neighbor_values): ) -def curvature_loss_function(exploration=0.05): +def curvature_loss_function(exploration: float = 0.05) -> Callable: # XXX: add doc-string! @uses_nth_neighbors(1) def curvature_loss(simplex, values, value_scale, neighbors, neighbor_values): @@ -206,7 +217,9 @@ def curvature_loss(simplex, values, value_scale, neighbors, neighbor_values): return curvature_loss -def choose_point_in_simplex(simplex, transform=None): +def choose_point_in_simplex( + simplex: Simplex, transform: np.ndarray | None = None +) -> np.ndarray: """Choose a new point in inside a simplex. Pick the center of the simplex if the shape is nice (that is, the @@ -247,7 +260,7 @@ def choose_point_in_simplex(simplex, transform=None): return point -def _simplex_evaluation_priority(key): +def _simplex_evaluation_priority(key: Any) -> Any: # We round the loss to 8 digits such that losses # are equal up to numerical precision will be considered # to be equal. This is needed because we want the learner @@ -307,7 +320,12 @@ class LearnerND(BaseLearner): children based on volume. """ - def __init__(self, func, bounds, loss_per_simplex=None): + def __init__( + self, + function: Callable, + bounds: Sequence[tuple[float, float]] | ConvexHull, + loss_per_simplex: Callable | None = None, + ) -> None: self._vdim = None self.loss_per_simplex = loss_per_simplex or default_loss @@ -339,14 +357,16 @@ def __init__(self, func, bounds, loss_per_simplex=None): self.ndim = len(self._bbox) - self.function = func - self._tri = None - self._losses = dict() + self.function = function # type: ignore + self._tri: Triangulation | None = None + self._losses: dict[Simplex, float] = dict() - self._pending_to_simplex = dict() # vertex → simplex + self._pending_to_simplex: dict[Point, Simplex] = dict() # vertex → simplex # triangulation of the pending points inside a specific simplex - self._subtriangulations = dict() # simplex → triangulation + self._subtriangulations: dict[ + Simplex, Triangulation + ] = dict() # simplex → triangulation # scale to unit hypercube # for the input @@ -381,12 +401,12 @@ def new(self) -> LearnerND: return LearnerND(self.function, self.bounds, self.loss_per_simplex) @property - def npoints(self): + def npoints(self) -> int: """Number of evaluated points.""" return len(self.data) @property - def vdim(self): + def vdim(self) -> int: """Length of the output of ``learner.function``. If the output is unsized (when it's a scalar) then `vdim = 1`. @@ -489,17 +509,17 @@ def load_dataframe( ) @property - def bounds_are_done(self): + def bounds_are_done(self) -> bool: return all(p in self.data for p in self._bounds_points) - def _ip(self): + def _ip(self) -> interpolate.LinearNDInterpolator: """A `scipy.interpolate.LinearNDInterpolator` instance containing the learner's data.""" # XXX: take our own triangulation into account when generating the _ip return interpolate.LinearNDInterpolator(self.points, self.values) @property - def tri(self): + def tri(self) -> Triangulation | None: """An `adaptive.learner.triangulation.Triangulation` instance with all the points of the learner.""" if self._tri is not None: @@ -517,16 +537,16 @@ def tri(self): return self._tri @property - def values(self): + def values(self) -> np.ndarray: """Get the values from `data` as a numpy array.""" return np.array(list(self.data.values()), dtype=float) @property - def points(self): + def points(self) -> np.ndarray: """Get the points from `data` as a numpy array.""" return np.array(list(self.data.keys()), dtype=float) - def tell(self, point, value): + def tell(self, point: tuple[float, ...], value: float | np.ndarray) -> None: point = tuple(point) if point in self.data: @@ -545,16 +565,17 @@ def tell(self, point, value): self._update_range(value) if tri is not None: simplex = self._pending_to_simplex.get(point) + assert self.tri is not None if simplex is not None and not self._simplex_exists(simplex): simplex = None to_delete, to_add = tri.add_point(point, simplex, transform=self._transform) self._update_losses(to_delete, to_add) - def _simplex_exists(self, simplex): + def _simplex_exists(self, simplex: Simplex) -> bool: simplex = tuple(sorted(simplex)) return simplex in self.tri.simplices - def inside_bounds(self, point): + def inside_bounds(self, point: tuple[float, ...]) -> Bool: """Check whether a point is inside the bounds.""" if self._interior is not None: return self._interior.find_simplex(point, tol=1e-8) >= 0 @@ -564,7 +585,7 @@ def inside_bounds(self, point): (mn - eps) <= p <= (mx + eps) for p, (mn, mx) in zip(point, self._bbox) ) - def tell_pending(self, point, *, simplex=None): + def tell_pending(self, point: tuple[float, ...], *, simplex=None) -> None: point = tuple(point) if not self.inside_bounds(point): return @@ -591,7 +612,9 @@ def tell_pending(self, point, *, simplex=None): continue self._update_subsimplex_losses(simpl, to_add) - def _try_adding_pending_point_to_simplex(self, point, simplex): + def _try_adding_pending_point_to_simplex( + self, point: Point, simplex: Simplex + ) -> Any: # try to insert it if not self.tri.point_in_simplex(point, simplex): return None, None @@ -603,7 +626,9 @@ def _try_adding_pending_point_to_simplex(self, point, simplex): self._pending_to_simplex[point] = simplex return self._subtriangulations[simplex].add_point(point) - def _update_subsimplex_losses(self, simplex, new_subsimplices): + def _update_subsimplex_losses( + self, simplex: Simplex, new_subsimplices: set[Simplex] + ) -> None: loss = self._losses[simplex] loss_density = loss / self.tri.volume(simplex) @@ -612,11 +637,11 @@ def _update_subsimplex_losses(self, simplex, new_subsimplices): subloss = subtriangulation.volume(subsimplex) * loss_density self._simplex_queue.add((subloss, simplex, subsimplex)) - def _ask_and_tell_pending(self, n=1): + def _ask_and_tell_pending(self, n: int = 1) -> Any: xs, losses = zip(*(self._ask() for _ in range(n))) return list(xs), list(losses) - def ask(self, n, tell_pending=True): + def ask(self, n: int, tell_pending: bool = True) -> Any: """Chose points for learners.""" if not tell_pending: with restore(self): @@ -624,7 +649,9 @@ def ask(self, n, tell_pending=True): else: return self._ask_and_tell_pending(n) - def _ask_bound_point(self): + def _ask_bound_point( + self, + ) -> tuple[Point, float]: # get the next bound point that is still available new_point = next( p @@ -634,7 +661,9 @@ def _ask_bound_point(self): self.tell_pending(new_point) return new_point, np.inf - def _ask_point_without_known_simplices(self): + def _ask_point_without_known_simplices( + self, + ) -> tuple[Point, float]: assert not self._bounds_available # pick a random point inside the bounds # XXX: change this into picking a point based on volume loss @@ -649,7 +678,7 @@ def _ask_point_without_known_simplices(self): self.tell_pending(p) return p, np.inf - def _pop_highest_existing_simplex(self): + def _pop_highest_existing_simplex(self) -> Any: # find the simplex with the highest loss, we do need to check that the # simplex hasn't been deleted yet while len(self._simplex_queue): @@ -675,7 +704,9 @@ def _pop_highest_existing_simplex(self): " be a simplex available if LearnerND.tri() is not None." ) - def _ask_best_point(self): + def _ask_best_point( + self, + ) -> tuple[Point, float]: assert self.tri is not None loss, simplex, subsimplex = self._pop_highest_existing_simplex() @@ -696,13 +727,15 @@ def _ask_best_point(self): return point_new, loss @property - def _bounds_available(self): + def _bounds_available(self) -> bool: return any( (p not in self.pending_points and p not in self.data) for p in self._bounds_points ) - def _ask(self): + def _ask( + self, + ) -> tuple[Point, float]: if self._bounds_available: return self._ask_bound_point() # O(1) @@ -714,7 +747,7 @@ def _ask(self): return self._ask_best_point() # O(log N) - def _compute_loss(self, simplex): + def _compute_loss(self, simplex: Simplex) -> float: # get the loss vertices = self.tri.get_vertices(simplex) values = [self.data[tuple(v)] for v in vertices] @@ -753,7 +786,7 @@ def _compute_loss(self, simplex): ) ) - def _update_losses(self, to_delete: set, to_add: set): + def _update_losses(self, to_delete: set[Simplex], to_add: set[Simplex]) -> None: # XXX: add the points outside the triangulation to this as well pending_points_unbound = set() @@ -799,7 +832,7 @@ def _update_losses(self, to_delete: set, to_add: set): simplex, self._subtriangulations[simplex].simplices ) - def _recompute_all_losses(self): + def _recompute_all_losses(self) -> None: """Recompute all losses and pending losses.""" # amortized O(N) complexity if self.tri is None: @@ -823,11 +856,11 @@ def _recompute_all_losses(self): ) @property - def _scale(self): + def _scale(self) -> float: # get the output scale return self._max_value - self._min_value - def _update_range(self, new_output): + def _update_range(self, new_output: list[int] | float | np.ndarray) -> bool: if self._min_value is None or self._max_value is None: # this is the first point, nothing to do, just set the range self._min_value = np.min(new_output) @@ -863,12 +896,12 @@ def _update_range(self, new_output): return False @cache_latest - def loss(self, real=True): + def loss(self, real: bool = True) -> float: # XXX: compute pending loss if real == False losses = self._losses if self.tri is not None else dict() return max(losses.values()) if losses else float("inf") - def remove_unfinished(self): + def remove_unfinished(self) -> None: # XXX: implement this method self.pending_points = set() self._subtriangulations = dict() @@ -878,7 +911,7 @@ def remove_unfinished(self): # Plotting related stuff # ########################## - def plot(self, n=None, tri_alpha=0): + def plot(self, n: int | None = None, tri_alpha: float = 0): """Plot the function we want to learn, only works in 2D. Parameters @@ -939,7 +972,7 @@ def plot(self, n=None, tri_alpha=0): return im.opts(style=im_opts) * tris.opts(style=tri_opts, **no_hover) - def plot_slice(self, cut_mapping, n=None): + def plot_slice(self, cut_mapping: dict[int, float], n: int | None = None): """Plot a 1D or 2D interpolated slice of a N-dimensional function. Parameters @@ -1009,7 +1042,7 @@ def plot_slice(self, cut_mapping, n=None): else: raise ValueError("Only 1 or 2-dimensional plots can be generated.") - def plot_3D(self, with_triangulation=False, return_fig=False): + def plot_3D(self, with_triangulation: bool = False, return_fig: bool = False): """Plot the learner's data in 3D using plotly. Does *not* work with the @@ -1094,7 +1127,7 @@ def plot_3D(self, with_triangulation=False, return_fig=False): return fig if return_fig else plotly.offline.iplot(fig) - def _get_iso(self, level=0.0, which="surface"): + def _get_iso(self, level: float = 0.0, which: str = "surface"): if which == "surface": if self.ndim != 3 or self.vdim != 1: raise Exception( @@ -1165,7 +1198,9 @@ def _get_vertex_index(a, b): return vertices, faces_or_lines - def plot_isoline(self, level=0.0, n=None, tri_alpha=0): + def plot_isoline( + self, level: float = 0.0, n: int | None = None, tri_alpha: float = 0 + ): """Plot the isoline at a specific level, only works in 2D. Parameters @@ -1205,7 +1240,7 @@ def plot_isoline(self, level=0.0, n=None, tri_alpha=0): contour = contour.opts(style=contour_opts) return plot * contour - def plot_isosurface(self, level=0.0, hull_opacity=0.2): + def plot_isosurface(self, level: float = 0.0, hull_opacity: float = 0.2): """Plots a linearly interpolated isosurface. This is the 3D analog of an isoline. Does *not* work with the @@ -1243,7 +1278,7 @@ def plot_isosurface(self, level=0.0, hull_opacity=0.2): hull_mesh = self._get_hull_mesh(opacity=hull_opacity) return plotly.offline.iplot([isosurface, hull_mesh]) - def _get_hull_mesh(self, opacity=0.2): + def _get_hull_mesh(self, opacity: float = 0.2): plotly = ensure_plotly() hull = scipy.spatial.ConvexHull(self._bounds_points) @@ -1282,9 +1317,9 @@ def _get_plane_color(simplex): lighting=lighting, ) - def _get_data(self): + def _get_data(self) -> dict[str, Any]: return deepcopy(self.__dict__) - def _set_data(self, state): + def _set_data(self, state: dict[str, Any]) -> None: for k, v in state.items(): setattr(self, k, v) diff --git a/adaptive/learner/triangulation.py b/adaptive/learner/triangulation.py index 4eb5952d5..da24ec3af 100644 --- a/adaptive/learner/triangulation.py +++ b/adaptive/learner/triangulation.py @@ -1,7 +1,11 @@ +from __future__ import annotations + +import collections.abc +import numbers from collections import Counter -from collections.abc import Iterable, Sized from itertools import chain, combinations from math import factorial, sqrt +from typing import Any, Iterable, Iterator, List, Sequence, Tuple, Union import scipy.spatial from numpy import abs as np_abs @@ -13,6 +17,7 @@ dot, eye, mean, + ndarray, ones, square, subtract, @@ -22,8 +27,22 @@ from numpy.linalg import det as ndet from numpy.linalg import matrix_rank, norm, slogdet, solve +from adaptive.types import Bool + +try: + from typing import TypeAlias +except ImportError: + # Remove this when we drop support for Python 3.9 + from typing_extensions import TypeAlias + -def fast_norm(v): +SimplexPoints: TypeAlias = Union[List[Tuple[float, ...]], ndarray] +Simplex: TypeAlias = Union[Sequence[numbers.Integral], ndarray] +Point: TypeAlias = Union[Tuple[float, ...], ndarray] +Points: TypeAlias = Union[Sequence[Tuple[float, ...]], ndarray] + + +def fast_norm(v: tuple[float, ...] | ndarray) -> float: """Take the vector norm for len 2, 3 vectors. Defaults to a square root of the dot product for larger vectors. @@ -41,7 +60,9 @@ def fast_norm(v): return sqrt(dot(v, v)) -def fast_2d_point_in_simplex(point, simplex, eps=1e-8): +def fast_2d_point_in_simplex( + point: Point, simplex: SimplexPoints, eps: float = 1e-8 +) -> Bool: (p0x, p0y), (p1x, p1y), (p2x, p2y) = simplex px, py = point @@ -55,7 +76,7 @@ def fast_2d_point_in_simplex(point, simplex, eps=1e-8): return (t >= -eps) and (s + t <= 1 + eps) -def point_in_simplex(point, simplex, eps=1e-8): +def point_in_simplex(point: Point, simplex: SimplexPoints, eps: float = 1e-8) -> Bool: if len(point) == 2: return fast_2d_point_in_simplex(point, simplex, eps) @@ -66,7 +87,7 @@ def point_in_simplex(point, simplex, eps=1e-8): return all(alpha > -eps) and sum(alpha) < 1 + eps -def fast_2d_circumcircle(points): +def fast_2d_circumcircle(points: Points) -> tuple[tuple[float, float], float]: """Compute the center and radius of the circumscribed circle of a triangle Parameters @@ -79,7 +100,7 @@ def fast_2d_circumcircle(points): tuple (center point : tuple(float), radius: float) """ - points = array(points) + points = array(points, dtype=float) # transform to relative coordinates pts = points[1:] - points[0] @@ -102,7 +123,9 @@ def fast_2d_circumcircle(points): return (x + points[0][0], y + points[0][1]), radius -def fast_3d_circumcircle(points): +def fast_3d_circumcircle( + points: Points, +) -> tuple[tuple[float, float, float], float]: """Compute the center and radius of the circumscribed sphere of a simplex. Parameters @@ -142,7 +165,7 @@ def fast_3d_circumcircle(points): return center, radius -def fast_det(matrix): +def fast_det(matrix: ndarray) -> float: matrix = asarray(matrix, dtype=float) if matrix.shape == (2, 2): return matrix[0][0] * matrix[1][1] - matrix[1][0] * matrix[0][1] @@ -153,7 +176,7 @@ def fast_det(matrix): return ndet(matrix) -def circumsphere(pts): +def circumsphere(pts: Simplex) -> tuple[tuple[float, ...], float]: """Compute the center and radius of a N dimension sphere which touches each point in pts. Parameters @@ -201,7 +224,7 @@ def circumsphere(pts): return tuple(center), radius -def orientation(face, origin): +def orientation(face: tuple | ndarray, origin: tuple | ndarray) -> int: """Compute the orientation of the face with respect to a point, origin. Parameters @@ -224,14 +247,14 @@ def orientation(face, origin): sign, logdet = slogdet(vectors - origin) if logdet < -50: # assume it to be zero when it's close to zero return 0 - return sign + return int(sign) -def is_iterable_and_sized(obj): - return isinstance(obj, Iterable) and isinstance(obj, Sized) +def is_iterable_and_sized(obj: Any) -> bool: + return isinstance(obj, collections.abc.Collection) -def simplex_volume_in_embedding(vertices) -> float: +def simplex_volume_in_embedding(vertices: Sequence[Point]) -> float: """Calculate the volume of a simplex in a higher dimensional embedding. That is: dim > len(vertices) - 1. For example if you would like to know the surface area of a triangle in a 3d space. @@ -312,7 +335,7 @@ class Triangulation: or more simplices in the """ - def __init__(self, coords): + def __init__(self, coords: Points) -> None: if not is_iterable_and_sized(coords): raise TypeError("Please provide a 2-dimensional list of points") coords = list(coords) @@ -340,10 +363,10 @@ def __init__(self, coords): "(the points are linearly dependent)" ) - self.vertices = list(coords) - self.simplices = set() + self.vertices: list[Point] = list(coords) + self.simplices: set[Simplex] = set() # initialise empty set for each vertex - self.vertex_to_simplices = [set() for _ in coords] + self.vertex_to_simplices: list[set[Simplex]] = [set() for _ in coords] # find a Delaunay triangulation to start with, then we will throw it # away and continue with our own algorithm @@ -351,27 +374,29 @@ def __init__(self, coords): for simplex in initial_tri.simplices: self.add_simplex(simplex) - def delete_simplex(self, simplex): + def delete_simplex(self, simplex: Simplex) -> None: simplex = tuple(sorted(simplex)) self.simplices.remove(simplex) for vertex in simplex: self.vertex_to_simplices[vertex].remove(simplex) - def add_simplex(self, simplex): + def add_simplex(self, simplex: Simplex) -> None: simplex = tuple(sorted(simplex)) self.simplices.add(simplex) for vertex in simplex: self.vertex_to_simplices[vertex].add(simplex) - def get_vertices(self, indices): + def get_vertices(self, indices: Iterable[numbers.Integral]) -> list[Point | None]: return [self.get_vertex(i) for i in indices] - def get_vertex(self, index): + def get_vertex(self, index: numbers.Integral | None) -> Point | None: if index is None: return None return self.vertices[index] - def get_reduced_simplex(self, point, simplex, eps=1e-8) -> list: + def get_reduced_simplex( + self, point: Point, simplex: Simplex, eps: float = 1e-8 + ) -> list[numbers.Integral]: """Check whether vertex lies within a simplex. Returns @@ -396,11 +421,13 @@ def get_reduced_simplex(self, point, simplex, eps=1e-8) -> list: return [simplex[i] for i in result] - def point_in_simplex(self, point, simplex, eps=1e-8): + def point_in_simplex( + self, point: Point, simplex: Simplex, eps: float = 1e-8 + ) -> Bool: vertices = self.get_vertices(simplex) return point_in_simplex(point, vertices, eps) - def locate_point(self, point): + def locate_point(self, point: Point) -> Simplex: """Find to which simplex the point belongs. Return indices of the simplex containing the point. @@ -412,10 +439,15 @@ def locate_point(self, point): return () @property - def dim(self): + def dim(self) -> int: return len(self.vertices[0]) - def faces(self, dim=None, simplices=None, vertices=None): + def faces( + self, + dim: int | None = None, + simplices: Iterable[Simplex] | None = None, + vertices: Iterable[int] | None = None, + ) -> Iterator[tuple[numbers.Integral, ...]]: """Iterator over faces of a simplex or vertex sequence.""" if dim is None: dim = self.dim @@ -436,11 +468,11 @@ def faces(self, dim=None, simplices=None, vertices=None): else: return faces - def containing(self, face): + def containing(self, face: tuple[int, ...]) -> set[Simplex]: """Simplices containing a face.""" return set.intersection(*(self.vertex_to_simplices[i] for i in face)) - def _extend_hull(self, new_vertex, eps=1e-8): + def _extend_hull(self, new_vertex: Point, eps: float = 1e-8) -> set[Simplex]: # count multiplicities in order to get all hull faces multiplicities = Counter(face for face in self.faces()) hull_faces = [face for face, count in multiplicities.items() if count == 1] @@ -480,7 +512,9 @@ def _extend_hull(self, new_vertex, eps=1e-8): return new_simplices - def circumscribed_circle(self, simplex, transform): + def circumscribed_circle( + self, simplex: Simplex, transform: ndarray + ) -> tuple[tuple[float, ...], float]: """Compute the center and radius of the circumscribed circle of a simplex. Parameters @@ -496,7 +530,9 @@ def circumscribed_circle(self, simplex, transform): pts = dot(self.get_vertices(simplex), transform) return circumsphere(pts) - def point_in_cicumcircle(self, pt_index, simplex, transform): + def point_in_cicumcircle( + self, pt_index: int, simplex: Simplex, transform: ndarray + ) -> Bool: # return self.fast_point_in_circumcircle(pt_index, simplex, transform) eps = 1e-8 @@ -506,10 +542,15 @@ def point_in_cicumcircle(self, pt_index, simplex, transform): return norm(center - pt) < (radius * (1 + eps)) @property - def default_transform(self): + def default_transform(self) -> ndarray: return eye(self.dim) - def bowyer_watson(self, pt_index, containing_simplex=None, transform=None): + def bowyer_watson( + self, + pt_index: int, + containing_simplex: Simplex | None = None, + transform: ndarray | None = None, + ) -> tuple[set[Simplex], set[Simplex]]: """Modified Bowyer-Watson point adding algorithm. Create a hole in the triangulation around the new point, @@ -569,10 +610,10 @@ def bowyer_watson(self, pt_index, containing_simplex=None, transform=None): new_triangles = self.vertex_to_simplices[pt_index] return bad_triangles - new_triangles, new_triangles - bad_triangles - def _simplex_is_almost_flat(self, simplex): + def _simplex_is_almost_flat(self, simplex: Simplex) -> Bool: return self._relative_volume(simplex) < 1e-8 - def _relative_volume(self, simplex): + def _relative_volume(self, simplex: Simplex) -> float: """Compute the volume of a simplex divided by the average (Manhattan) distance of its vertices. The advantage of this is that the relative volume is only dependent on the shape of the simplex and not on the @@ -583,20 +624,25 @@ def _relative_volume(self, simplex): average_edge_length = mean(np_abs(vectors)) return self.volume(simplex) / (average_edge_length**self.dim) - def add_point(self, point, simplex=None, transform=None): + def add_point( + self, + point: Point, + simplex: Simplex | None = None, + transform: ndarray | None = None, + ) -> tuple[set[Simplex], set[Simplex]]: """Add a new vertex and create simplices as appropriate. Parameters ---------- point : float vector Coordinates of the point to be added. - transform : N*N matrix of floats - Multiplication matrix to apply to the point (and neighbouring - simplices) when running the Bowyer Watson method. simplex : tuple of ints, optional Simplex containing the point. Empty tuple indicates points outside the hull. If not provided, the algorithm costs O(N), so this should be used whenever possible. + transform : N*N matrix of floats + Multiplication matrix to apply to the point (and neighbouring + simplices) when running the Bowyer Watson method. """ point = tuple(point) if simplex is None: @@ -632,16 +678,16 @@ def add_point(self, point, simplex=None, transform=None): self.vertices.append(point) return self.bowyer_watson(pt_index, actual_simplex, transform) - def volume(self, simplex): + def volume(self, simplex: Simplex) -> float: prefactor = factorial(self.dim) vertices = array(self.get_vertices(simplex)) vectors = vertices[1:] - vertices[0] return float(abs(fast_det(vectors)) / prefactor) - def volumes(self): + def volumes(self) -> list[float]: return [self.volume(sim) for sim in self.simplices] - def reference_invariant(self): + def reference_invariant(self) -> bool: """vertex_to_simplices and simplices are compatible.""" for vertex in range(len(self.vertices)): if any(vertex not in tri for tri in self.vertex_to_simplices[vertex]): @@ -655,26 +701,28 @@ def vertex_invariant(self, vertex): """Simplices originating from a vertex don't overlap.""" raise NotImplementedError - def get_neighbors_from_vertices(self, simplex): + def get_neighbors_from_vertices(self, simplex: Simplex) -> set[Simplex]: return set.union(*[self.vertex_to_simplices[p] for p in simplex]) - def get_face_sharing_neighbors(self, neighbors, simplex): + def get_face_sharing_neighbors( + self, neighbors: set[Simplex], simplex: Simplex + ) -> set[Simplex]: """Keep only the simplices sharing a whole face with simplex.""" return { simpl for simpl in neighbors if len(set(simpl) & set(simplex)) == self.dim } # they share a face - def get_simplices_attached_to_points(self, indices): + def get_simplices_attached_to_points(self, indices: Simplex) -> set[Simplex]: # Get all simplices that share at least a point with the simplex neighbors = self.get_neighbors_from_vertices(indices) return self.get_face_sharing_neighbors(neighbors, indices) - def get_opposing_vertices(self, simplex): + def get_opposing_vertices(self, simplex: Simplex) -> tuple[int, ...]: if simplex not in self.simplices: raise ValueError("Provided simplex is not part of the triangulation") neighbors = self.get_simplices_attached_to_points(simplex) - def find_opposing_vertex(vertex): + def find_opposing_vertex(vertex: int): # find the simplex: simp = next((x for x in neighbors if vertex not in x), None) if simp is None: @@ -687,7 +735,7 @@ def find_opposing_vertex(vertex): return result @property - def hull(self): + def hull(self) -> set[numbers.Integral]: """Compute hull from triangulation. Parameters diff --git a/adaptive/types.py b/adaptive/types.py index e2d57a44f..67268f822 100644 --- a/adaptive/types.py +++ b/adaptive/types.py @@ -11,3 +11,4 @@ Float: TypeAlias = Union[float, np.float_] Int: TypeAlias = Union[int, np.int_] Real: TypeAlias = Union[Float, Int] +Bool: TypeAlias = Union[bool, np.bool_]