Skip to content

Commit 765d22c

Browse files
authored
[ENH] Refactor mesh extraction and add scalar field at interface to structural elements | | GEN-12031 (#1019)
# Description Refactored the mesh extraction process using marching cubes algorithm to improve accuracy and efficiency. The main changes include: 1. Renamed `scalar_field` to `scalar_field_at_interface` in the `StructuralElement` class for better clarity 2. Added proper assignment of scalar field values to structural elements 3. Simplified the mesh extraction logic by using the output group's scalar field matrix and mask directly 4. Improved the marching cubes implementation with proper parameters (allow_degenerate=False, method="lewiner") 5. Added comprehensive tests to verify the mesh extraction results Relates to #mesh-extraction-improvements # Checklist - [x] My code uses type hinting for function and method arguments and return values. - [x] I have created tests which cover my code. - [x] The test code either 1. demonstrates at least one valuable use case (e.g. integration tests) or 2. verifies that outputs are as expected for given inputs (e.g. unit tests). - [x] New tests pass locally with my changes.
2 parents 176c6f6 + 95a82f3 commit 765d22c

File tree

4 files changed

+25
-77
lines changed

4 files changed

+25
-77
lines changed

gempy/core/data/geo_model.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def solutions(self) -> Solutions:
146146
return self._solutions
147147

148148
@solutions.setter
149-
def solutions(self, value):
149+
def solutions(self, value: Solutions):
150150
# * This is set from the gempy engine
151151

152152
self._solutions = value
@@ -161,6 +161,8 @@ def solutions(self, value):
161161

162162
# * Set solutions per element
163163
for e, element in enumerate(self.structural_frame.structural_elements[:-1]): # * Ignore basement
164+
element.scalar_field_at_interface = value.scalar_field_at_surface_points[e]
165+
164166
if self._solutions.dc_meshes is None:
165167
continue
166168
dc_mesh = self._solutions.dc_meshes[e]

gempy/core/data/structural_element.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class StructuralElement:
3131
# ? Should we extract this to a separate class?
3232
vertices: Optional[np.ndarray] = None #: The vertices of the element in 3D space.
3333
edges: Optional[np.ndarray] = None #: The edges of the element in 3D space.
34-
scalar_field: Optional[float] = None #: The scalar field value for the element.
34+
scalar_field_at_interface: Optional[float] = None #: The scalar field value for the element.
3535

3636
_id: int = -1
3737

gempy/modules/mesh_extranction/marching_cubes.py

+12-74
Original file line numberDiff line numberDiff line change
@@ -34,91 +34,27 @@ def set_meshes_with_marching_cubes(model: GeoModel) -> None:
3434

3535
output_lvl0: list[InterpOutput] = model.solutions.octrees_output[0].outputs_centers
3636

37-
# TODO: How to get this properly in gempy
38-
# get a list of indices of the lithological groups
39-
lith_group_indices = []
40-
fault_group_indices = []
41-
index = 0
42-
for i in model.structural_frame.structural_groups:
43-
if i.is_fault:
44-
fault_group_indices.append(index)
45-
else:
46-
lith_group_indices.append(index)
47-
index += 1
48-
49-
# extract scalar field values at surface points
50-
scalar_values = model.solutions.raw_arrays.scalar_field_at_surface_points
51-
52-
# TODO: Here I just get my own masks, cause the gempy masks dont work as expected
53-
masks = _get_masking_arrays(lith_group_indices, model, scalar_values)
54-
55-
# TODO: Attribute of element.scalar_field was None, changed it to scalar field value of that element
56-
# This should probably be done somewhere else and maybe renamed to scalar_field_value?
57-
# This is just the most basic solution to be clear what I did
58-
_set_scalar_field_to_element(model, output_lvl0, structural_groups)
59-
60-
# Trying to use the exiting gempy masks
61-
# masks = []
62-
# masks.append(
63-
# np.ones_like(model.solutions.raw_arrays.scalar_field_matrix[0].reshape(model.grid.regular_grid.resolution),
64-
# dtype=bool))
65-
# for idx in lith_group_indices:
66-
# output_group: InterpOutput = output_lvl0[idx]
67-
# masks.append(output_group.mask_components[8:].reshape(model.grid.regular_grid.resolution))
68-
69-
non_fault_counter = 0
7037
for e, structural_group in enumerate(structural_groups):
7138
if e >= len(output_lvl0):
7239
continue
7340

74-
# Outdated?
75-
# output_group: InterpOutput = output_lvl0[e]
76-
# scalar_field_matrix = output_group.exported_fields_dense_grid.scalar_field
77-
78-
# Specify the correct scalar field, can be removed in the future
79-
scalar_field = model.solutions.raw_arrays.scalar_field_matrix[e].reshape(model.grid.regular_grid.resolution)
80-
81-
# pick mask depending on whether the structural group is a fault or not
82-
if structural_group.is_fault:
83-
mask = np.ones_like(scalar_field, dtype=bool)
41+
output_group: InterpOutput = output_lvl0[e]
42+
scalar_field_matrix = output_group.exported_fields_dense_grid.scalar_field
43+
if structural_group.is_fault is False:
44+
slice_: slice = output_group.grid.dense_grid_slice
45+
mask = output_group.combined_scalar_field.squeezed_mask_array[slice_]
8446
else:
85-
mask = masks[non_fault_counter] # TODO: I need the entry without faults here
86-
non_fault_counter += 1
47+
mask = np.ones_like(scalar_field_matrix, dtype=bool)
8748

8849
for element in structural_group.elements:
8950
extract_mesh_for_element(
9051
structural_element=element,
9152
regular_grid=regular_grid,
92-
scalar_field=scalar_field,
53+
scalar_field=scalar_field_matrix,
9354
mask=mask
9455
)
9556

9657

97-
# TODO: This should be set somewhere else
98-
def _set_scalar_field_to_element(model, output_lvl0, structural_groups):
99-
counter = 0
100-
for e, structural_group in enumerate(structural_groups):
101-
if e >= len(output_lvl0):
102-
continue
103-
104-
for element in structural_group.elements:
105-
element.scalar_field = model.solutions.scalar_field_at_surface_points[counter]
106-
counter += 1
107-
108-
109-
# TODO: This should be set somewhere else
110-
def _get_masking_arrays(lith_group_indices, model, scalar_values):
111-
masks = []
112-
masks.append(np.ones_like(model.solutions.raw_arrays.scalar_field_matrix[0].reshape(model.grid.regular_grid.resolution),
113-
dtype=bool))
114-
for idx in lith_group_indices:
115-
mask = model.solutions.raw_arrays.scalar_field_matrix[idx].reshape(model.grid.regular_grid.resolution) <= \
116-
scalar_values[idx][-1]
117-
118-
masks.append(mask)
119-
return masks
120-
121-
12258
def extract_mesh_for_element(structural_element: StructuralElement,
12359
regular_grid: RegularGrid,
12460
scalar_field: np.ndarray,
@@ -138,10 +74,12 @@ def extract_mesh_for_element(structural_element: StructuralElement,
13874
"""
13975
# Extract mesh using marching cubes
14076
verts, faces, _, _ = measure.marching_cubes(
141-
volume=scalar_field,
142-
level=structural_element.scalar_field,
77+
volume=scalar_field.reshape(regular_grid.resolution),
78+
level=structural_element.scalar_field_at_interface,
14379
spacing=(regular_grid.dx, regular_grid.dy, regular_grid.dz),
144-
mask=mask
80+
mask=mask.reshape(regular_grid.resolution) if mask is not None else None,
81+
allow_degenerate=False,
82+
method="lewiner"
14583
)
14684

14785
# Adjust vertices to correct coordinates in the model's extent

test/test_modules/test_marching_cubes.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,20 @@ def test_marching_cubes_implementation():
3838
# assert arrays.scalar_field_matrix.shape == (3, 8_000) # * 3 surfaces, 8000 points
3939

4040
marching_cubes.set_meshes_with_marching_cubes(model)
41+
42+
# Assert
43+
assert model.solutions.block_solution_type == RawArraysSolution.BlockSolutionType.DENSE_GRID
44+
assert model.solutions.dc_meshes is None
45+
assert model.structural_frame.structural_groups[0].elements[0].vertices.shape == (600, 3)
46+
assert model.structural_frame.structural_groups[1].elements[0].vertices.shape == (860, 3)
47+
assert model.structural_frame.structural_groups[2].elements[0].vertices.shape == (1_256, 3)
48+
assert model.structural_frame.structural_groups[2].elements[1].vertices.shape == (1_680, 3)
4149

4250
if PLOT:
4351
gpv = require_gempy_viewer()
4452
gtv: gpv.GemPyToVista = gpv.plot_3d(
4553
model=model,
4654
show_data=True,
47-
image=False,
55+
image=True,
4856
show=True
4957
)

0 commit comments

Comments
 (0)