|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. |
| 2 | +# SPDX-FileCopyrightText: All rights reserved. |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | + |
| 17 | +r"""Integration of scalar, vector, and tensor fields over simplicial meshes. |
| 18 | +
|
| 19 | +Provides quadrature rules for integrating fields discretized on simplicial |
| 20 | +meshes of any manifold dimension. The manifold dimension determines the |
| 21 | +measure automatically: arc length for 1-manifolds, surface area for |
| 22 | +2-manifolds, volume for 3-manifolds, etc. |
| 23 | +
|
| 24 | +Two data sources are supported: |
| 25 | +
|
| 26 | +**Cell data (P0)** - piecewise-constant fields: |
| 27 | +
|
| 28 | +.. math:: |
| 29 | + \int_\Omega f\,d\Omega = \sum_c f_c \,|\sigma_c| |
| 30 | +
|
| 31 | +**Point data (P1)** - vertex-centered fields treated as nodal values of a |
| 32 | +piecewise-linear field interpolated via barycentric coordinates. The |
| 33 | +integral of a linear function over an n-simplex equals the volume times the |
| 34 | +arithmetic mean of vertex values: |
| 35 | +
|
| 36 | +.. math:: |
| 37 | + \int_\Omega f\,d\Omega |
| 38 | + = \sum_c |\sigma_c| \cdot \frac{1}{n_v} \sum_{v \in c} f(v) |
| 39 | +
|
| 40 | +This is exact for P1 fields and second-order accurate for smooth fields. |
| 41 | +""" |
| 42 | + |
| 43 | +from typing import TYPE_CHECKING, Literal |
| 44 | + |
| 45 | +import torch |
| 46 | + |
| 47 | +if TYPE_CHECKING: |
| 48 | + from physicsnemo.mesh.mesh import Mesh |
| 49 | + |
| 50 | + |
| 51 | +def _resolve_field( |
| 52 | + mesh: "Mesh", |
| 53 | + field: str | tuple[str, ...] | torch.Tensor, |
| 54 | + data_source: Literal["cells", "points"], |
| 55 | +) -> torch.Tensor: |
| 56 | + r"""Resolve a field specification to a concrete tensor. |
| 57 | +
|
| 58 | + Parameters |
| 59 | + ---------- |
| 60 | + mesh : Mesh |
| 61 | + Source mesh. |
| 62 | + field : str, tuple, or torch.Tensor |
| 63 | + A string or tuple is looked up in ``cell_data`` or ``point_data`` |
| 64 | + depending on *data_source*. A tensor is returned as-is. |
| 65 | + data_source : {"cells", "points"} |
| 66 | + Which data dictionary to use for string key lookups. |
| 67 | +
|
| 68 | + Returns |
| 69 | + ------- |
| 70 | + torch.Tensor |
| 71 | + The resolved field tensor. |
| 72 | + """ |
| 73 | + if isinstance(field, torch.Tensor): |
| 74 | + return field |
| 75 | + match data_source: |
| 76 | + case "cells": |
| 77 | + data, attr_name = mesh.cell_data, "cell_data" |
| 78 | + case "points": |
| 79 | + data, attr_name = mesh.point_data, "point_data" |
| 80 | + case _: |
| 81 | + raise ValueError(f"Invalid {data_source=!r}. Must be 'cells' or 'points'.") |
| 82 | + try: |
| 83 | + return data[field] |
| 84 | + except KeyError: |
| 85 | + available = sorted(data.keys()) |
| 86 | + raise KeyError( |
| 87 | + f"Field {field!r} not found in {attr_name}. Available keys: {available}" |
| 88 | + ) from None |
| 89 | + |
| 90 | + |
| 91 | +def integrate_cell_data( |
| 92 | + mesh: "Mesh", |
| 93 | + field: torch.Tensor, |
| 94 | +) -> torch.Tensor: |
| 95 | + r"""Integrate a cell-centered (P0) field over the mesh. |
| 96 | +
|
| 97 | + Computes the exact integral of a piecewise-constant field: |
| 98 | +
|
| 99 | + .. math:: |
| 100 | + \int_\Omega f\,d\Omega = \sum_c f_c \,|\sigma_c| |
| 101 | +
|
| 102 | + NaN values in *field* are excluded from the sum (treated as zero |
| 103 | + contribution), which is appropriate for fields with patched-out |
| 104 | + regions (e.g. non-physical points in CFD solutions). |
| 105 | +
|
| 106 | + Parameters |
| 107 | + ---------- |
| 108 | + mesh : Mesh |
| 109 | + Simplicial mesh with at least one cell. |
| 110 | + field : torch.Tensor |
| 111 | + Cell-centered values, shape ``(n_cells, *trailing)``. |
| 112 | + Trailing dimensions are preserved in the output. |
| 113 | +
|
| 114 | + Returns |
| 115 | + ------- |
| 116 | + torch.Tensor |
| 117 | + Integral value. Shape matches ``field.shape[1:]`` (the trailing |
| 118 | + dimensions). A scalar field ``(n_cells,)`` produces a 0-d tensor. |
| 119 | +
|
| 120 | + Raises |
| 121 | + ------ |
| 122 | + ValueError |
| 123 | + If ``field.shape[0]`` does not equal ``mesh.n_cells``. |
| 124 | + """ |
| 125 | + if not torch.compiler.is_compiling(): |
| 126 | + if field.shape[0] != mesh.n_cells: |
| 127 | + raise ValueError( |
| 128 | + f"Field leading dimension ({field.shape[0]}) must equal " |
| 129 | + f"n_cells ({mesh.n_cells})." |
| 130 | + ) |
| 131 | + |
| 132 | + cell_areas = mesh.cell_areas # (n_cells,) |
| 133 | + |
| 134 | + ### Reshape cell_areas for broadcasting with arbitrary trailing dims |
| 135 | + weights = cell_areas.reshape(-1, *([1] * (field.ndim - 1))) |
| 136 | + |
| 137 | + return torch.nansum(field * weights, dim=0) |
| 138 | + |
| 139 | + |
| 140 | +def integrate_point_data( |
| 141 | + mesh: "Mesh", |
| 142 | + field: torch.Tensor, |
| 143 | +) -> torch.Tensor: |
| 144 | + r"""Integrate a vertex-centered (P1) field over the mesh. |
| 145 | +
|
| 146 | + Treats vertex values as nodal values of a piecewise-linear field |
| 147 | + and integrates analytically per simplex using the vertex-averaging |
| 148 | + rule (second-order accurate for smooth fields). |
| 149 | +
|
| 150 | + If any vertex of a cell has NaN, that cell's contribution is NaN and |
| 151 | + is excluded by ``nansum`` (the P1 interpolant is undefined on that cell). |
| 152 | +
|
| 153 | + Parameters |
| 154 | + ---------- |
| 155 | + mesh : Mesh |
| 156 | + Simplicial mesh with at least one cell. |
| 157 | + field : torch.Tensor |
| 158 | + Vertex-centered values, shape ``(n_points, *trailing)``. |
| 159 | + Trailing dimensions are preserved in the output. |
| 160 | +
|
| 161 | + Returns |
| 162 | + ------- |
| 163 | + torch.Tensor |
| 164 | + Integral value with shape ``field.shape[1:]``. |
| 165 | +
|
| 166 | + Raises |
| 167 | + ------ |
| 168 | + ValueError |
| 169 | + If ``field.shape[0]`` does not equal ``mesh.n_points``. |
| 170 | + """ |
| 171 | + if not torch.compiler.is_compiling(): |
| 172 | + if field.shape[0] != mesh.n_points: |
| 173 | + raise ValueError( |
| 174 | + f"Field leading dimension ({field.shape[0]}) must equal " |
| 175 | + f"n_points ({mesh.n_points})." |
| 176 | + ) |
| 177 | + |
| 178 | + cell_areas = mesh.cell_areas # (n_cells,) |
| 179 | + |
| 180 | + ### Gather vertex values for each cell: (n_cells, n_verts_per_cell, *trailing) |
| 181 | + cell_vertex_values = field[mesh.cells] |
| 182 | + |
| 183 | + ### Mean over vertices within each cell: (n_cells, *trailing) |
| 184 | + cell_means = cell_vertex_values.mean(dim=1) |
| 185 | + |
| 186 | + ### Weight by cell area and sum |
| 187 | + weights = cell_areas.reshape(-1, *([1] * (cell_means.ndim - 1))) |
| 188 | + return torch.nansum(cell_means * weights, dim=0) |
| 189 | + |
| 190 | + |
| 191 | +def integrate( |
| 192 | + mesh: "Mesh", |
| 193 | + field: str | tuple[str, ...] | torch.Tensor, |
| 194 | + data_source: Literal["cells", "points"] = "cells", |
| 195 | +) -> torch.Tensor: |
| 196 | + r"""Integrate a field over the mesh domain. |
| 197 | +
|
| 198 | + This is the unified entry point for mesh integration. It dispatches to |
| 199 | + :func:`integrate_cell_data` or :func:`integrate_point_data` based on |
| 200 | + *data_source*, and resolves *field* from a string key or tensor. |
| 201 | +
|
| 202 | + Parameters |
| 203 | + ---------- |
| 204 | + mesh : Mesh |
| 205 | + Simplicial mesh. |
| 206 | + field : str, tuple[str, ...], or torch.Tensor |
| 207 | + Field to integrate. |
| 208 | +
|
| 209 | + - ``str`` or ``tuple``: looked up in ``cell_data`` or ``point_data`` |
| 210 | + according to *data_source*. |
| 211 | + - ``torch.Tensor``: used directly. |
| 212 | + data_source : {"cells", "points"} |
| 213 | + Whether *field* is cell-centered (P0) or vertex-centered (P1). |
| 214 | +
|
| 215 | + Returns |
| 216 | + ------- |
| 217 | + torch.Tensor |
| 218 | + Integral value. Shape matches the trailing dimensions of the field |
| 219 | + (scalar field -> 0-d tensor, vector field -> 1-d tensor, etc.). |
| 220 | +
|
| 221 | + Raises |
| 222 | + ------ |
| 223 | + KeyError |
| 224 | + If *field* is a string key not present in the specified data source. |
| 225 | + ValueError |
| 226 | + If the mesh has no cells, or if a raw tensor has the wrong leading |
| 227 | + dimension for the specified *data_source*. |
| 228 | +
|
| 229 | + Examples |
| 230 | + -------- |
| 231 | + >>> import torch |
| 232 | + >>> from physicsnemo.mesh import Mesh |
| 233 | + >>> pts = torch.tensor([[0., 0.], [1., 0.], [0.5, 1.]]) |
| 234 | + >>> cells = torch.tensor([[0, 1, 2]]) |
| 235 | + >>> mesh = Mesh(points=pts, cells=cells) |
| 236 | + >>> mesh.cell_data["p"] = torch.tensor([3.0]) |
| 237 | + >>> mesh.integrate("p") # integrate cell-centered pressure |
| 238 | + tensor(1.5000) |
| 239 | + >>> mesh.point_data["T"] = torch.tensor([1.0, 2.0, 3.0]) |
| 240 | + >>> mesh.integrate("T", data_source="points") # P1 integral |
| 241 | + tensor(1.) |
| 242 | + """ |
| 243 | + if not torch.compiler.is_compiling(): |
| 244 | + if mesh.n_cells == 0: |
| 245 | + raise ValueError( |
| 246 | + "Cannot integrate over a mesh with no cells. " |
| 247 | + "Integration requires simplicial connectivity." |
| 248 | + ) |
| 249 | + |
| 250 | + resolved = _resolve_field(mesh, field, data_source) |
| 251 | + |
| 252 | + match data_source: |
| 253 | + case "cells": |
| 254 | + return integrate_cell_data(mesh, resolved) |
| 255 | + case "points": |
| 256 | + return integrate_point_data(mesh, resolved) |
| 257 | + case _: |
| 258 | + raise ValueError(f"Invalid {data_source=!r}. Must be 'cells' or 'points'.") |
| 259 | + |
| 260 | + |
| 261 | +def integrate_flux( |
| 262 | + mesh: "Mesh", |
| 263 | + field: str | tuple[str, ...] | torch.Tensor, |
| 264 | + data_source: Literal["cells", "points"] = "cells", |
| 265 | +) -> torch.Tensor: |
| 266 | + r"""Compute the surface flux integral for codimension-1 meshes. |
| 267 | +
|
| 268 | + Computes the oriented flux of a vector field through the mesh surface: |
| 269 | +
|
| 270 | + .. math:: |
| 271 | + \int_\Gamma \mathbf{F} \cdot \mathbf{n}\,d\Gamma |
| 272 | +
|
| 273 | + This is only defined for codimension-1 meshes (surfaces in 3D, curves |
| 274 | + in 2D) where unique cell normals exist. |
| 275 | +
|
| 276 | + For cell data, the flux is: |
| 277 | +
|
| 278 | + .. math:: |
| 279 | + \int_\Gamma \mathbf{F} \cdot \mathbf{n}\,d\Gamma |
| 280 | + = \sum_c (\mathbf{F}_c \cdot \mathbf{n}_c)\,|\sigma_c| |
| 281 | +
|
| 282 | + For point data, the P1 vertex-averaged field is dotted with the cell |
| 283 | + normal (which is constant per cell): |
| 284 | +
|
| 285 | + .. math:: |
| 286 | + \int_\Gamma \mathbf{F} \cdot \mathbf{n}\,d\Gamma |
| 287 | + = \sum_c \Bigl(\frac{1}{n_v}\sum_{v \in c} \mathbf{F}(v)\Bigr) |
| 288 | + \cdot \mathbf{n}_c\,|\sigma_c| |
| 289 | +
|
| 290 | + Parameters |
| 291 | + ---------- |
| 292 | + mesh : Mesh |
| 293 | + Codimension-1 simplicial mesh (i.e. ``n_manifold_dims == |
| 294 | + n_spatial_dims - 1``). |
| 295 | + field : str, tuple[str, ...], or torch.Tensor |
| 296 | + Vector field to integrate. Must have last dimension equal to |
| 297 | + ``n_spatial_dims``. |
| 298 | + data_source : {"cells", "points"} |
| 299 | + Whether *field* is cell-centered or vertex-centered. |
| 300 | +
|
| 301 | + Returns |
| 302 | + ------- |
| 303 | + torch.Tensor |
| 304 | + Scalar flux value (0-d tensor). |
| 305 | +
|
| 306 | + Raises |
| 307 | + ------ |
| 308 | + KeyError |
| 309 | + If *field* is a string key not present in the specified data source. |
| 310 | + ValueError |
| 311 | + If the mesh is not codimension-1, if the field leading dimension |
| 312 | + does not match the expected entity count, or if the field does |
| 313 | + not have the correct trailing dimension. |
| 314 | +
|
| 315 | + Examples |
| 316 | + -------- |
| 317 | + >>> import torch |
| 318 | + >>> from physicsnemo.mesh import Mesh |
| 319 | + >>> # Unit square boundary in 2D (4 edges forming a closed loop) |
| 320 | + >>> pts = torch.tensor([[0., 0.], [1., 0.], [1., 1.], [0., 1.]]) |
| 321 | + >>> cells = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 0]]) |
| 322 | + >>> mesh = Mesh(points=pts, cells=cells) |
| 323 | + >>> # Constant outward velocity field - flux through closed boundary |
| 324 | + >>> mesh.cell_data["v"] = torch.zeros(4, 2) |
| 325 | + >>> mesh.integrate_flux("v") |
| 326 | + tensor(0.) |
| 327 | + """ |
| 328 | + if not torch.compiler.is_compiling(): |
| 329 | + if mesh.codimension != 1: |
| 330 | + raise ValueError( |
| 331 | + f"integrate_flux requires a codimension-1 mesh " |
| 332 | + f"(n_manifold_dims == n_spatial_dims - 1), but got " |
| 333 | + f"{mesh.n_manifold_dims=}, {mesh.n_spatial_dims=} " |
| 334 | + f"(codimension={mesh.codimension})." |
| 335 | + ) |
| 336 | + |
| 337 | + resolved = _resolve_field(mesh, field, data_source) |
| 338 | + |
| 339 | + if not torch.compiler.is_compiling(): |
| 340 | + expected_leading = mesh.n_cells if data_source == "cells" else mesh.n_points |
| 341 | + if resolved.shape[0] != expected_leading: |
| 342 | + entity = "n_cells" if data_source == "cells" else "n_points" |
| 343 | + raise ValueError( |
| 344 | + f"Field leading dimension ({resolved.shape[0]}) must equal " |
| 345 | + f"{entity} ({expected_leading})." |
| 346 | + ) |
| 347 | + if resolved.shape[-1] != mesh.n_spatial_dims: |
| 348 | + raise ValueError( |
| 349 | + f"Field last dimension ({resolved.shape[-1]}) must match " |
| 350 | + f"n_spatial_dims ({mesh.n_spatial_dims}) for flux integration." |
| 351 | + ) |
| 352 | + |
| 353 | + cell_normals = mesh.cell_normals # (n_cells, n_spatial_dims) |
| 354 | + cell_areas = mesh.cell_areas # (n_cells,) |
| 355 | + |
| 356 | + ### Resolve per-cell vector field |
| 357 | + match data_source: |
| 358 | + case "cells": |
| 359 | + cell_field = resolved |
| 360 | + case "points": |
| 361 | + cell_field = resolved[mesh.cells].mean(dim=1) # P1 average |
| 362 | + case _: |
| 363 | + raise ValueError(f"Invalid {data_source=!r}. Must be 'cells' or 'points'.") |
| 364 | + |
| 365 | + f_dot_n = (cell_field * cell_normals).sum(dim=-1) # (n_cells,) |
| 366 | + return torch.nansum(f_dot_n * cell_areas, dim=0) |
0 commit comments