diff --git a/map2loop/contact_sampler.py b/map2loop/contact_sampler.py new file mode 100644 index 00000000..e14a17eb --- /dev/null +++ b/map2loop/contact_sampler.py @@ -0,0 +1,49 @@ +from .sampler import SamplerSpacing +from .m2l_enums import Datatype, SampleType +from .contact_extractor import ContactExtractor +from .utils import set_z_values_from_raster_df + +class ContactSampler(SamplerSpacing): + def __init__(self, spacing=50.0, + dtm_data=None, + geology_data=None, + fault_data=None, + stratigraphic_column=None, + ): + super().__init__(spacing, dtm_data, geology_data) + self.sampler_label = "ContactSampler" + self.stratigraphic_column = stratigraphic_column + self.fault_data = fault_data + self.contact_extractor = None + + def get_contact_extractor(self): + if self.contact_extractor is None: + self.contact_extractor = ContactExtractor( + self.geology_data, + self.fault_data + ) + return self.contact_extractor + + + def extract_all_contacts(self): + extractor = self.get_contact_extractor() + contacts = extractor.extract_all_contacts() + return contacts + + + def extract_basal_contacts(self): + extractor = self.get_contact_extractor() + if extractor.contacts is None: + self.extract_all_contacts() + + basal_contacts = extractor.extract_basal_contacts(self.stratigraphic_column) + + return basal_contacts + + def sample(self, spatial_data=None): + basal_contacts = self.extract_basal_contacts() + sampled_contacts = super().sample(basal_contacts) + + set_z_values_from_raster_df(self.dtm_data, sampled_contacts) + + return sampled_contacts \ No newline at end of file diff --git a/map2loop/fault_orientation_sampler.py b/map2loop/fault_orientation_sampler.py new file mode 100644 index 00000000..e4c216ae --- /dev/null +++ b/map2loop/fault_orientation_sampler.py @@ -0,0 +1,23 @@ +from map2loop.fault_orientation import FaultOrientationNearest +from .m2l_enums import Datatype, SampleType +from .sampler import Sampler +from .utils import set_z_values_from_raster_df + +class FaultOrientationSampler(Sampler): + def __init__(self, dtm_data=None, geology_data=None, fault_data=None,map_data=None): + super().__init__(dtm_data,geology_data) + self.sampler_label = "FaultOrientationSampler" + self.fault_data = fault_data + self.fault_orientation = FaultOrientationNearest() + + def sample(self, spatial_data): + + fault_orientations = self.fault_orientation.calculate( + self.fault_data, + spatial_data, + self.map_data + ) + + set_z_values_from_raster_df(self.dtm_data, fault_orientations) + + return fault_orientations \ No newline at end of file diff --git a/map2loop/m2l_enums.py b/map2loop/m2l_enums.py index f390361a..626ae9e1 100644 --- a/map2loop/m2l_enums.py +++ b/map2loop/m2l_enums.py @@ -30,3 +30,13 @@ class VerboseLevel(IntEnum): NONE = 0 TEXTONLY = 1 ALL = 2 + +class SampleType(IntEnum): + GEOLOGY = 0 + STRUCTURE = 1 + FAULT = 2 + FOLD = 3 + DTM = 4 + FAULT_ORIENTATION = 5 + CONTACT = 6 + BASAL_CONTACT = 7 \ No newline at end of file diff --git a/map2loop/sample_storage.py b/map2loop/sample_storage.py new file mode 100644 index 00000000..1d2b8de4 --- /dev/null +++ b/map2loop/sample_storage.py @@ -0,0 +1,204 @@ + +from .m2l_enums import Datatype, SampleType +from .sampler import SamplerDecimator, SamplerSpacing, Sampler +import beartype +from .mapdata import MapData +from .stratigraphic_column import StratigraphicColumn +from .fault_orientation_sampler import FaultOrientationSampler +from .contact_sampler import ContactSampler + +from .logging import getLogger + +logger = getLogger(__name__) + + +class SampleSupervisor: + """ + The SampleSupervisor class is responsible for managing the samples and samplers in the project. + It extends the AccessStorage abstract base class. + + Attributes: + storage_label (str): The label of the storage. + samples (list): A list of samples. + samplers (list): A list of samplers. + sampler_dirtyflags (list): A list of flags indicating if the sampler has changed. + dirtyflags (list): A list of flags indicating the state of the data, sample or sampler. + project (Project): The project associated with the SampleSupervisor. + map_data (MapData): The map data associated with the project. + """ + + def __init__(self, project: "Project", map_data: MapData, stratigraphic_column: StratigraphicColumn ): + """ + The constructor for the SampleSupervisor class. + + Args: + project (Project): The Project class associated with the SampleSupervisor. + """ + + self.storage_label = "SampleSupervisor" + self.map_data = map_data + self.stratigraphic_column = stratigraphic_column + self.samples = [None] * len(SampleType) + self.samplers = [None] * len(SampleType) + self.sampler_dirtyflags = [True] * len(SampleType) + self.set_default_samplers() + + def type(self): + return self.storage_label + + def set_default_samplers(self): + """ + Initialisation function to set or reset the point samplers + """ + + geology_data = self.map_data.get_map_data(Datatype.GEOLOGY) + dtm_data = self.map_data.get_map_data(Datatype.DTM) + fault_data = self.map_data.get_map_data(Datatype.FAULT) + + self._set_decimator_sampler(SampleType.STRUCTURE, decimation=1) + self._set_spacing_sampler(SampleType.GEOLOGY, spacing=50.0) + self._set_spacing_sampler(SampleType.FAULT, spacing=50.0) + self._set_spacing_sampler(SampleType.FOLD, spacing=50.0) + self._set_spacing_sampler(SampleType.DTM, spacing=50.0) + self._set_contact_sampler(SampleType.CONTACT, spacing=50.0) + self._set_fault_orientation_sampler(SampleType.FAULT_ORIENTATION) + + # dirty flags to false after initialisation + self.sampler_dirtyflags = [False] * len(SampleType) + + def _verify_sampler_type(self, sampletype: SampleType, sampler_type: str): + allowed_samplers = { + SampleType.STRUCTURE: ["SamplerDecimator"], + SampleType.GEOLOGY: ["SamplerSpacing"], + SampleType.FAULT: ["SamplerSpacing"], + SampleType.FOLD: ["SamplerSpacing"], + SampleType.DTM: ["SamplerSpacing"], + SampleType.CONTACT: ["ContactSampler"], + SampleType.FAULT_ORIENTATION: ["FaultOrientationSampler"] + } + + if sampletype in allowed_samplers and sampler_type not in allowed_samplers[sampletype]: + allowed = ", ".join(allowed_samplers[sampletype]) + raise ValueError(f"Invalid sampler type '{sampler_type}' for sample '{sampletype}', please use {allowed}") + + @beartype.beartype + def set_sampler(self, sampletype: SampleType, sampler_type: str, **kwargs): + """ + Set the point sampler for a specific datatype + + Args: + sampletype (SampleType): + The sample type (SampleType) to use this sampler on + samplertype (str): + The sampler to use + """ + self._verify_sampler_type(sampletype, sampler_type) + + if sampler_type == "SamplerDecimator": + self._set_decimator_sampler(sampletype, **kwargs) + elif sampler_type == "SamplerSpacing": + self._set_spacing_sampler(sampletype, **kwargs) + elif sampler_type == "ContactSampler": + self._set_contact_sampler(sampletype, **kwargs) + elif sampler_type == "FaultOrientationSampler": + self._set_fault_orientation_sampler(sampletype, **kwargs) + else: + raise ValueError('incorrect sampler type') + + # set the dirty flag to True to indicate that the sampler has changed + self.sampler_dirtyflags[sampletype] = True + + @beartype.beartype + def _set_decimator_sampler(self, sampletype, decimation=1): + geology_data = self.map_data.get_map_data(Datatype.GEOLOGY) + dtm_data = self.map_data.get_map_data(Datatype.DTM) + self.samplers[sampletype] = SamplerDecimator(decimation=decimation, dtm_data=dtm_data, geology_data=geology_data) + + @beartype.beartype + def _set_spacing_sampler(self, sampletype, spacing=50.0): + self.samplers[sampletype] = SamplerSpacing(spacing=spacing) + + @beartype.beartype + def _set_contact_sampler(self, sampletype, spacing=50.0): + geology_data = self.map_data.get_map_data(Datatype.GEOLOGY) + fault_data = self.map_data.get_map_data(Datatype.FAULT) + self.samplers[sampletype] = ContactSampler(spacing=spacing,geology_data=geology_data,fault_data=fault_data, stratigraphic_column=self.stratigraphic_column.column) + + @beartype.beartype + def _set_fault_orientation_sampler(self, sampletype): + geology_data = self.map_data.get_map_data(Datatype.GEOLOGY) + dtm_data = self.map_data.get_map_data(Datatype.DTM) + fault_data = self.map_data.get_map_data(Datatype.FAULT) + self.samplers[sampletype] = FaultOrientationSampler(dtm_data=dtm_data, geology_data=geology_data, fault_data=fault_data, map_data=self.map_data) + + @beartype.beartype + def get_sampler(self, sampletype: SampleType): + """ + Get the sampler name being used for a datatype + + Args: + sampletype: The sample type of the sampler + + Returns: + str: The name of the sampler being used on the specified datatype + """ + return self.samplers[sampletype].sampler_label + + @beartype.beartype + def get_samples(self, sampletype: SampleType): + """ + Get a sample given a sample type + + Args: + sampletype: The sample type of the sampler + + Returns: + The sample data associated with the specified sample type + """ + return self.samples[sampletype] + + @beartype.beartype + def store(self, sampletype: SampleType, sample_data): + self.samples[sampletype] = sample_data + self.sampler_dirtyflags[sampletype] = False + + @beartype.beartype + def sample(self, sampletype: SampleType): + """ + sample sample based on the sample type. + + Args: + sampletype (SampleType): The type of the sample. + + Returns: + The sample data for the specified sample type + """ + if self.samples[sampletype] is not None and not self.sampler_dirtyflags[sampletype]: + return self.samples[sampletype] + if sampletype == SampleType.BASAL_CONTACT: + self._sample_basal_contact() + elif sampletype == SampleType.CONTACT: + self._sample_contact() + else: + self._sample_other_types(sampletype) + + return self.samples[sampletype] + + @beartype.beartype + def _sample_basal_contact(self): + contact_sampler = self.samplers[SampleType.CONTACT] + basal_contacts = contact_sampler.extract_basal_contacts() + self.store(SampleType.BASAL_CONTACT, basal_contacts) + + @beartype.beartype + def _sample_contact(self): + contact_sampler = self.samplers[SampleType.CONTACT] + sampled_contacts = contact_sampler.sample() + self.store(SampleType.CONTACT, sampled_contacts) + + @beartype.beartype + def _sample_other_types(self, sampletype: SampleType): + datatype = Datatype(sampletype) + spatial_data = self.map_data.get_map_data(datatype) + sampled_data = self.samplers[sampletype].sample(spatial_data) + self.store(sampletype, sampled_data)