Skip to content

Commit 6f7496d

Browse files
committed
[test] add more tests for trait functions
1 parent bafec58 commit 6f7496d

File tree

3 files changed

+270
-108
lines changed

3 files changed

+270
-108
lines changed

README.md

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@ similar to those in the MATLAB/Octave versions of Iso2Mesh.
3232
* PIP+Git (latest version): `python3 -m pip install git+https://github.com/NeuroJSON/pyiso2mesh.git`
3333

3434
MacOS users: you need to run the following commands to install this module
35-
```bash
36-
python3 -m venv /tmp/pyiso2mesh-venv
37-
source /tmp/pyiso2mesh-venv/bin/activate
38-
python3 -m pip install iso2mesh
35+
36+
```
37+
python3 -m venv /tmp/pyiso2mesh-venv
38+
source /tmp/pyiso2mesh-venv/bin/activate
39+
python3 -m pip install iso2mesh
3940
```
4041

4142
## Runtime Dependencies
@@ -64,9 +65,9 @@ including Windows, Linux, and macOS.
6465

6566
2. Clone the repository:
6667

67-
```bash
68-
git clone --recursive https://github.com/NeuroJSON/pyiso2mesh.git
69-
cd pyiso2mesh
68+
```
69+
git clone --recursive https://github.com/NeuroJSON/pyiso2mesh.git
70+
cd pyiso2mesh
7071
```
7172

7273
3. Type `python3 -m build` to build the package
@@ -83,8 +84,8 @@ If you want to modify the source, and verify that it still produces correct resu
8384
the built-in unit-test script inside the downloaded git repository by using this command
8485
inside the `pyiso2mesh` root folder
8586

86-
```bash
87-
python -m unittest test.run_test
87+
```
88+
python -m unittest test.run_test
8889
```
8990

9091

iso2mesh/trait.py

Lines changed: 14 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -772,11 +772,7 @@ def mesheuler(face):
772772
V = len(np.unique(face))
773773

774774
# Construct edges from faces
775-
E = np.vstack((face[:, [0, 2]], face[:, [0, 1]], face[:, [1, 2]]))
776-
E = np.unique(
777-
np.sort(E, axis=1), axis=0
778-
) # Sort edge vertex pairs and remove duplicates
779-
E = len(E)
775+
E = uniqedges(face)[0].shape[0]
780776

781777
# Number of faces
782778
F = face.shape[0]
@@ -1382,14 +1378,13 @@ def uniqedges(elem):
13821378
edges = edges[idx, :]
13831379

13841380
# Compute edgemap if requested
1385-
edgemap = None
13861381
edgemap = np.reshape(
1387-
jdx,
1388-
(elem.shape[0], np.array(list(combinations(range(elem.shape[1]), 2))).shape[0]),
1382+
jdx + 1,
1383+
(-1, elem.shape[0]),
13891384
)
13901385
edgemap = edgemap.T
13911386

1392-
return edges, idx, edgemap
1387+
return edges, idx + 1, edgemap
13931388

13941389

13951390
# _________________________________________________________________________________________________________
@@ -1436,7 +1431,7 @@ def uniqfaces(elem):
14361431
order="F",
14371432
)
14381433

1439-
return faces, idx, facemap
1434+
return faces, idx + 1, facemap
14401435

14411436

14421437
def innersurf(node, face, outface=None):
@@ -1653,13 +1648,8 @@ def elemfacecenter(node, elem):
16531648
faces, idx, newelem = uniqfaces(elem[:, :4])
16541649

16551650
# Extract the coordinates of the nodes forming these faces
1656-
newnode = node[faces.flatten() - 1, :3]
1657-
1658-
# Reshape newnode to group coordinates of nodes in each face
1659-
newnode = newnode.reshape(3, 3, faces.shape[0])
1660-
1661-
# Compute the mean of the coordinates to find the face centers
1662-
newnode = np.mean(np.transpose(newnode, (2, 1, 0)), axis=1).squeeze()
1651+
newnode = node[faces.T - 1, :3]
1652+
newnode = np.mean(newnode, axis=0)
16631653

16641654
return newnode, newelem
16651655

@@ -1724,10 +1714,7 @@ def barydualmesh(node, elem, flag=None):
17241714
) # Adjust to 0-based indexing for Python
17251715

17261716
newelem = newidx[:, newelem.flatten()]
1727-
1728-
newelem = newelem.reshape((elem.shape[0], 4, 6))
1729-
newelem = np.transpose(newelem, (0, 2, 1))
1730-
newelem = newelem.reshape((elem.shape[0] * 6, 4))
1717+
newelem = newelem.T.reshape(4, -1).T
17311718

17321719
# If the 'cell' flag is set, return `newelem` as a list of lists (cells)
17331720
if flag == "cell":
@@ -1754,70 +1741,10 @@ def highordertet(node, elem, order=2, opt=None):
17541741
newelem: Element connectivity of the higher-order tetrahedral mesh.
17551742
"""
17561743

1757-
if order < 2:
1758-
raise ValueError("Order must be greater than or equal to 2")
1759-
1760-
if opt is None:
1761-
opt = {}
1762-
1763-
# Example: linear to quadratic conversion (order=2)
1764-
if order == 2:
1765-
newnode, newelem = lin_to_quad_tet(node, elem)
1766-
else:
1767-
raise NotImplementedError(
1768-
f"Higher order {order} mesh refinement is not yet implemented"
1769-
)
1770-
1771-
return newnode, newelem
1772-
1773-
1774-
# _________________________________________________________________________________________________________
1775-
1776-
1777-
def lin_to_quad_tet(node, elem):
1778-
"""
1779-
Convert linear tetrahedral elements (4-node) to quadratic tetrahedral elements (10-node).
1780-
1781-
Args:
1782-
node: Nodal coordinates (n_nodes, 3).
1783-
elem: Element connectivity (n_elements, 4).
1784-
1785-
Returns:
1786-
newnode: Nodal coordinates of the quadratic mesh.
1787-
newelem: Element connectivity of the quadratic mesh.
1788-
"""
1789-
1790-
n_elem = elem.shape[0]
1791-
n_node = node.shape[0]
1792-
1793-
# Initialize new node and element lists
1794-
edge_midpoints = {}
1795-
new_nodes = []
1796-
new_elements = []
1744+
if order >= 3 or order <= 1:
1745+
raise ValueError("currently this function only supports order=2")
17971746

1798-
for i in range(n_elem):
1799-
element = elem[i] - 1
1800-
quad_element = list(element) # Start with linear nodes
1801-
1802-
# Loop over each edge of the tetrahedron
1803-
edges = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
1804-
1805-
for e in edges:
1806-
n1, n2 = sorted([element[e[0]], element[e[1]]])
1807-
edge_key = (n1, n2)
1808-
1809-
if edge_key not in edge_midpoints:
1810-
# Compute midpoint and add it as a new node
1811-
midpoint = (node[n1] + node[n2]) / 2
1812-
new_nodes.append(midpoint)
1813-
edge_midpoints[edge_key] = n_node + len(new_nodes) - 1
1814-
1815-
quad_element.append(edge_midpoints[edge_key])
1816-
1817-
new_elements.append(quad_element)
1818-
1819-
# Combine old and new nodes
1820-
newnode = np.vstack([node, np.array(new_nodes)])
1821-
newelem = np.array(new_elements) + 1
1822-
1823-
return newnode, newelem
1747+
edges, idx, newelem = uniqedges(elem[:, : min(elem.shape[1], 4)])
1748+
newnode = node[edges.T - 1, :3] # adjust for 1-based index
1749+
newnode = np.mean(newnode, axis=0)
1750+
return newnode, newelem + 1

0 commit comments

Comments
 (0)