Skip to content
49 changes: 49 additions & 0 deletions map2loop/contact_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from .sampler import SamplerSpacing
from .m2l_enums import Datatype, SampleType
from .contact_extractor import ContactExtractor
from .utils import set_z_values_from_raster_df

class ContactSampler(SamplerSpacing):
def __init__(self, spacing=50.0,
dtm_data=None,
geology_data=None,
fault_data=None,
stratigraphic_column=None,
):
super().__init__(spacing, dtm_data, geology_data)
self.sampler_label = "ContactSampler"
self.stratigraphic_column = stratigraphic_column
self.fault_data = fault_data
self.contact_extractor = None

def get_contact_extractor(self):
if self.contact_extractor is None:
self.contact_extractor = ContactExtractor(
self.geology_data,
self.fault_data
)
return self.contact_extractor


def extract_all_contacts(self):
extractor = self.get_contact_extractor()
contacts = extractor.extract_all_contacts()
return contacts


def extract_basal_contacts(self):
extractor = self.get_contact_extractor()
if extractor.contacts is None:
self.extract_all_contacts()

basal_contacts = extractor.extract_basal_contacts(self.stratigraphic_column)

return basal_contacts

def sample(self, spatial_data=None):
basal_contacts = self.extract_basal_contacts()
sampled_contacts = super().sample(basal_contacts)

set_z_values_from_raster_df(self.dtm_data, sampled_contacts)

return sampled_contacts
23 changes: 23 additions & 0 deletions map2loop/fault_orientation_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from map2loop.fault_orientation import FaultOrientationNearest
from .m2l_enums import Datatype, SampleType
from .sampler import Sampler
from .utils import set_z_values_from_raster_df

class FaultOrientationSampler(Sampler):
def __init__(self, dtm_data=None, geology_data=None, fault_data=None,map_data=None):
super().__init__(dtm_data,geology_data)
self.sampler_label = "FaultOrientationSampler"
self.fault_data = fault_data
self.fault_orientation = FaultOrientationNearest()

def sample(self, spatial_data):

fault_orientations = self.fault_orientation.calculate(
self.fault_data,
spatial_data,
self.map_data
)

set_z_values_from_raster_df(self.dtm_data, fault_orientations)

return fault_orientations
10 changes: 10 additions & 0 deletions map2loop/m2l_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,13 @@ class VerboseLevel(IntEnum):
NONE = 0
TEXTONLY = 1
ALL = 2

class SampleType(IntEnum):
GEOLOGY = 0
STRUCTURE = 1
FAULT = 2
FOLD = 3
DTM = 4
FAULT_ORIENTATION = 5
CONTACT = 6
BASAL_CONTACT = 7
204 changes: 204 additions & 0 deletions map2loop/sample_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@

from .m2l_enums import Datatype, SampleType
from .sampler import SamplerDecimator, SamplerSpacing, Sampler
import beartype
from .mapdata import MapData
from .stratigraphic_column import StratigraphicColumn
from .fault_orientation_sampler import FaultOrientationSampler
from .contact_sampler import ContactSampler

from .logging import getLogger

logger = getLogger(__name__)


class SampleSupervisor:
"""
The SampleSupervisor class is responsible for managing the samples and samplers in the project.
It extends the AccessStorage abstract base class.

Attributes:
storage_label (str): The label of the storage.
samples (list): A list of samples.
samplers (list): A list of samplers.
sampler_dirtyflags (list): A list of flags indicating if the sampler has changed.
dirtyflags (list): A list of flags indicating the state of the data, sample or sampler.
project (Project): The project associated with the SampleSupervisor.
map_data (MapData): The map data associated with the project.
"""

def __init__(self, project: "Project", map_data: MapData, stratigraphic_column: StratigraphicColumn ):
"""
The constructor for the SampleSupervisor class.

Args:
project (Project): The Project class associated with the SampleSupervisor.
"""

self.storage_label = "SampleSupervisor"
self.map_data = map_data
self.stratigraphic_column = stratigraphic_column
self.samples = [None] * len(SampleType)
self.samplers = [None] * len(SampleType)
self.sampler_dirtyflags = [True] * len(SampleType)
self.set_default_samplers()

def type(self):
return self.storage_label

def set_default_samplers(self):
"""
Initialisation function to set or reset the point samplers
"""

geology_data = self.map_data.get_map_data(Datatype.GEOLOGY)
dtm_data = self.map_data.get_map_data(Datatype.DTM)
fault_data = self.map_data.get_map_data(Datatype.FAULT)

self._set_decimator_sampler(SampleType.STRUCTURE, decimation=1)
self._set_spacing_sampler(SampleType.GEOLOGY, spacing=50.0)
self._set_spacing_sampler(SampleType.FAULT, spacing=50.0)
self._set_spacing_sampler(SampleType.FOLD, spacing=50.0)
self._set_spacing_sampler(SampleType.DTM, spacing=50.0)
self._set_contact_sampler(SampleType.CONTACT, spacing=50.0)
self._set_fault_orientation_sampler(SampleType.FAULT_ORIENTATION)

# dirty flags to false after initialisation
self.sampler_dirtyflags = [False] * len(SampleType)

def _verify_sampler_type(self, sampletype: SampleType, sampler_type: str):
allowed_samplers = {
SampleType.STRUCTURE: ["SamplerDecimator"],
SampleType.GEOLOGY: ["SamplerSpacing"],
SampleType.FAULT: ["SamplerSpacing"],
SampleType.FOLD: ["SamplerSpacing"],
SampleType.DTM: ["SamplerSpacing"],
SampleType.CONTACT: ["ContactSampler"],
SampleType.FAULT_ORIENTATION: ["FaultOrientationSampler"]
}

if sampletype in allowed_samplers and sampler_type not in allowed_samplers[sampletype]:
allowed = ", ".join(allowed_samplers[sampletype])
raise ValueError(f"Invalid sampler type '{sampler_type}' for sample '{sampletype}', please use {allowed}")

@beartype.beartype
def set_sampler(self, sampletype: SampleType, sampler_type: str, **kwargs):
"""
Set the point sampler for a specific datatype

Args:
sampletype (SampleType):
The sample type (SampleType) to use this sampler on
samplertype (str):
The sampler to use
"""
self._verify_sampler_type(sampletype, sampler_type)

if sampler_type == "SamplerDecimator":
self._set_decimator_sampler(sampletype, **kwargs)
elif sampler_type == "SamplerSpacing":
self._set_spacing_sampler(sampletype, **kwargs)
elif sampler_type == "ContactSampler":
self._set_contact_sampler(sampletype, **kwargs)
elif sampler_type == "FaultOrientationSampler":
self._set_fault_orientation_sampler(sampletype, **kwargs)
else:
raise ValueError('incorrect sampler type')

# set the dirty flag to True to indicate that the sampler has changed
self.sampler_dirtyflags[sampletype] = True

@beartype.beartype
def _set_decimator_sampler(self, sampletype, decimation=1):
geology_data = self.map_data.get_map_data(Datatype.GEOLOGY)
dtm_data = self.map_data.get_map_data(Datatype.DTM)
self.samplers[sampletype] = SamplerDecimator(decimation=decimation, dtm_data=dtm_data, geology_data=geology_data)

@beartype.beartype
def _set_spacing_sampler(self, sampletype, spacing=50.0):
self.samplers[sampletype] = SamplerSpacing(spacing=spacing)

@beartype.beartype
def _set_contact_sampler(self, sampletype, spacing=50.0):
geology_data = self.map_data.get_map_data(Datatype.GEOLOGY)
fault_data = self.map_data.get_map_data(Datatype.FAULT)
self.samplers[sampletype] = ContactSampler(spacing=spacing,geology_data=geology_data,fault_data=fault_data, stratigraphic_column=self.stratigraphic_column.column)

@beartype.beartype
def _set_fault_orientation_sampler(self, sampletype):
geology_data = self.map_data.get_map_data(Datatype.GEOLOGY)
dtm_data = self.map_data.get_map_data(Datatype.DTM)
fault_data = self.map_data.get_map_data(Datatype.FAULT)
self.samplers[sampletype] = FaultOrientationSampler(dtm_data=dtm_data, geology_data=geology_data, fault_data=fault_data, map_data=self.map_data)

@beartype.beartype
def get_sampler(self, sampletype: SampleType):
"""
Get the sampler name being used for a datatype

Args:
sampletype: The sample type of the sampler

Returns:
str: The name of the sampler being used on the specified datatype
"""
return self.samplers[sampletype].sampler_label

@beartype.beartype
def get_samples(self, sampletype: SampleType):
"""
Get a sample given a sample type

Args:
sampletype: The sample type of the sampler

Returns:
The sample data associated with the specified sample type
"""
return self.samples[sampletype]

@beartype.beartype
def store(self, sampletype: SampleType, sample_data):
self.samples[sampletype] = sample_data
self.sampler_dirtyflags[sampletype] = False

@beartype.beartype
def sample(self, sampletype: SampleType):
"""
sample sample based on the sample type.

Args:
sampletype (SampleType): The type of the sample.

Returns:
The sample data for the specified sample type
"""
if self.samples[sampletype] is not None and not self.sampler_dirtyflags[sampletype]:
return self.samples[sampletype]
if sampletype == SampleType.BASAL_CONTACT:
self._sample_basal_contact()
elif sampletype == SampleType.CONTACT:
self._sample_contact()
else:
self._sample_other_types(sampletype)

return self.samples[sampletype]

@beartype.beartype
def _sample_basal_contact(self):
contact_sampler = self.samplers[SampleType.CONTACT]
basal_contacts = contact_sampler.extract_basal_contacts()
self.store(SampleType.BASAL_CONTACT, basal_contacts)

@beartype.beartype
def _sample_contact(self):
contact_sampler = self.samplers[SampleType.CONTACT]
sampled_contacts = contact_sampler.sample()
self.store(SampleType.CONTACT, sampled_contacts)

@beartype.beartype
def _sample_other_types(self, sampletype: SampleType):
datatype = Datatype(sampletype)
spatial_data = self.map_data.get_map_data(datatype)
sampled_data = self.samplers[sampletype].sample(spatial_data)
self.store(sampletype, sampled_data)