Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
Beforerr committed Jul 14, 2024
1 parent a3252eb commit 1e70b6c
Show file tree
Hide file tree
Showing 15 changed files with 550 additions and 404 deletions.
55 changes: 29 additions & 26 deletions discontinuitypy/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,27 @@
'git_url': 'https://github.com/Beforerr/discontinuitypy',
'lib_path': 'discontinuitypy'},
'syms': { 'discontinuitypy.config': { 'discontinuitypy.config.IDsConfig': ('ids_config.html#idsconfig', 'discontinuitypy/config.py'),
'discontinuitypy.config.IDsConfig.file_prefix': ( 'ids_config.html#idsconfig.file_prefix',
'discontinuitypy/config.py'),
'discontinuitypy.config.IDsConfig.timeranges': ( 'ids_config.html#idsconfig.timeranges',
'discontinuitypy/config.py'),
'discontinuitypy.config.SpeasyIDsConfig': ( 'ids_config.html#speasyidsconfig',
'discontinuitypy/config.py'),
'discontinuitypy.config.SpeasyIDsConfig._get_and_process_data': ( 'ids_config.html#speasyidsconfig._get_and_process_data',
'discontinuitypy/config.py'),
'discontinuitypy.config.SpeasyIDsConfig.e_temp_var': ( 'ids_config.html#speasyidsconfig.e_temp_var',
'discontinuitypy/config.py'),
'discontinuitypy.config.SpeasyIDsConfig.get_and_process_data': ( 'ids_config.html#speasyidsconfig.get_and_process_data',
'discontinuitypy/config.py'),
'discontinuitypy.config.SpeasyIDsConfig.get_data': ( 'ids_config.html#speasyidsconfig.get_data',
'discontinuitypy/config.py'),
'discontinuitypy.config.SpeasyIDsConfig.get_vars': ( 'ids_config.html#speasyidsconfig.get_vars',
'discontinuitypy/config.py'),
'discontinuitypy.config.SpeasyIDsConfig.get_vars_df': ( 'ids_config.html#speasyidsconfig.get_vars_df',
'discontinuitypy/config.py'),
'discontinuitypy.config.SpeasyIDsConfig.ion_temp_var': ( 'ids_config.html#speasyidsconfig.ion_temp_var',
'discontinuitypy/config.py'),
'discontinuitypy.config.SpeasyIDsConfig.mag_vars': ( 'ids_config.html#speasyidsconfig.mag_vars',
'discontinuitypy/config.py'),
'discontinuitypy.config.SpeasyIDsConfig.model_post_init': ( 'ids_config.html#speasyidsconfig.model_post_init',
'discontinuitypy/config.py'),
'discontinuitypy.config.SpeasyIDsConfig.plasma_vars': ( 'ids_config.html#speasyidsconfig.plasma_vars',
'discontinuitypy/config.py'),
'discontinuitypy.config.SpeasyIDsConfig.set_data_from_vars': ( 'ids_config.html#speasyidsconfig.set_data_from_vars',
'discontinuitypy/config.py'),
'discontinuitypy.config.SpeasyIDsConfig.produce_or_load': ( 'ids_config.html#speasyidsconfig.produce_or_load',
'discontinuitypy/config.py'),
'discontinuitypy.config.get_vars': ('ids_config.html#get_vars', 'discontinuitypy/config.py'),
'discontinuitypy.config.split_timerange': ( 'ids_config.html#split_timerange',
'discontinuitypy/config.py'),
'discontinuitypy.config.standardize_plasma_data': ( 'ids_config.html#standardize_plasma_data',
'discontinuitypy/config.py')},
'discontinuitypy.core.detection': { 'discontinuitypy.core.detection.detect_events': ( 'ids_detection.html#detect_events',
Expand Down Expand Up @@ -82,33 +79,39 @@
'discontinuitypy/core/propeties.py')},
'discontinuitypy.datasets': { 'discontinuitypy.datasets.IDsDataset': ( 'datasets.html#idsdataset',
'discontinuitypy/datasets.py'),
'discontinuitypy.datasets.IDsDataset.config_updates': ( 'datasets.html#idsdataset.config_updates',
'discontinuitypy/datasets.py'),
'discontinuitypy.datasets.IDsDataset.file': ( 'datasets.html#idsdataset.file',
'discontinuitypy/datasets.py'),
'discontinuitypy.datasets.IDsDataset.plot': ( 'datasets.html#idsdataset.plot',
'discontinuitypy/datasets.py'),
'discontinuitypy.datasets.IDsDataset.update_events': ( 'datasets.html#idsdataset.update_events',
'discontinuitypy/datasets.py'),
'discontinuitypy.datasets.IDsDataset.update_events_with_plasma_data': ( 'datasets.html#idsdataset.update_events_with_plasma_data',
'discontinuitypy/datasets.py'),
'discontinuitypy.datasets.IDsDataset.update_events_with_temp_data': ( 'datasets.html#idsdataset.update_events_with_temp_data',
'discontinuitypy/datasets.py'),
'discontinuitypy.datasets.IDsDataset.produce_or_load': ( 'datasets.html#idsdataset.produce_or_load',
'discontinuitypy/datasets.py'),
'discontinuitypy.datasets.IdsEvents': ('datasets.html#idsevents', 'discontinuitypy/datasets.py'),
'discontinuitypy.datasets.IdsEvents.Config': ( 'datasets.html#idsevents.config',
'discontinuitypy/datasets.py'),
'discontinuitypy.datasets.IdsEvents.config_detection': ( 'datasets.html#idsevents.config_detection',
'discontinuitypy/datasets.py'),
'discontinuitypy.datasets.IdsEvents.file': ( 'datasets.html#idsevents.file',
'discontinuitypy/datasets.py'),
'discontinuitypy.datasets.IdsEvents.file_prefix': ( 'datasets.html#idsevents.file_prefix',
'discontinuitypy/datasets.py'),
'discontinuitypy.datasets.IdsEvents.find_events': ( 'datasets.html#idsevents.find_events',
'discontinuitypy/datasets.py'),
'discontinuitypy.datasets.IdsEvents.fname': ( 'datasets.html#idsevents.fname',
'discontinuitypy/datasets.py'),
'discontinuitypy.datasets.IdsEvents.get_event': ( 'datasets.html#idsevents.get_event',
'discontinuitypy/datasets.py'),
'discontinuitypy.datasets.IdsEvents.get_event_data': ( 'datasets.html#idsevents.get_event_data',
'discontinuitypy/datasets.py'),
'discontinuitypy.datasets.IdsEvents.produce_or_load_file': ( 'datasets.html#idsevents.produce_or_load_file',
'discontinuitypy/datasets.py'),
'discontinuitypy.datasets.IdsEvents.produce_or_load': ( 'datasets.html#idsevents.produce_or_load',
'discontinuitypy/datasets.py'),
'discontinuitypy.datasets.log_event_change': ( 'datasets.html#log_event_change',
'discontinuitypy/datasets.py'),
'discontinuitypy.datasets.select_row': ( 'datasets.html#select_row',
'discontinuitypy/datasets.py')},
'discontinuitypy/datasets.py'),
'discontinuitypy.datasets.update_events': ( 'datasets.html#update_events',
'discontinuitypy/datasets.py'),
'discontinuitypy.datasets.update_events_with_temp_data': ( 'datasets.html#update_events_with_temp_data',
'discontinuitypy/datasets.py')},
'discontinuitypy.detection.variance': { 'discontinuitypy.detection.variance.add_neighbor_std': ( 'detection/variance.html#add_neighbor_std',
'discontinuitypy/detection/variance.py'),
'discontinuitypy.detection.variance.compute_combinded_std': ( 'detection/variance.html#compute_combinded_std',
Expand All @@ -135,12 +138,12 @@
'discontinuitypy/integration.py'),
'discontinuitypy.integration.combine_features': ( 'mag_plasma.html#combine_features',
'discontinuitypy/integration.py'),
'discontinuitypy.integration.format_time': ( 'mag_plasma.html#format_time',
'discontinuitypy/integration.py'),
'discontinuitypy.integration.interpolate': ( 'mag_plasma.html#interpolate',
'discontinuitypy/integration.py'),
'discontinuitypy.integration.interpolate2': ( 'mag_plasma.html#interpolate2',
'discontinuitypy/integration.py')},
'discontinuitypy/integration.py'),
'discontinuitypy.integration.update_events_with_plasma_data': ( 'mag_plasma.html#update_events_with_plasma_data',
'discontinuitypy/integration.py')},
'discontinuitypy.mission.stereo': { 'discontinuitypy.mission.stereo.StereoConfigBase': ( 'missions/stereo.html#stereoconfigbase',
'discontinuitypy/mission/stereo.py')},
'discontinuitypy.mission.themis': { 'discontinuitypy.mission.themis.ThemisConfigBase': ( 'missions/themis.html#themisconfigbase',
Expand Down
118 changes: 62 additions & 56 deletions discontinuitypy/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/11_ids_config.ipynb.

# %% auto 0
__all__ = ['standardize_plasma_data', 'IDsConfig', 'SpeasyIDsConfig']
__all__ = ['standardize_plasma_data', 'split_timerange', 'IDsConfig', 'get_vars', 'SpeasyIDsConfig']

# %% ../notebooks/11_ids_config.ipynb 0
from datetime import datetime
from beforerr.project import produce_or_load_file
from .datasets import IDsDataset
from space_analysis.meta import PlasmaDataset, Dataset
from space_analysis.utils.speasy import Variables
Expand All @@ -27,50 +28,59 @@ def standardize_plasma_data(data: pl.LazyFrame, meta: PlasmaDataset):
return data

# %% ../notebooks/11_ids_config.ipynb 2
def split_timerange(timerange: list[datetime], split: int = 1):
"""
Split a timerange into multiple timeranges.
"""
from sunpy.time import TimeRange

trs: list[TimeRange] = TimeRange(timerange).split(split)
return [[tr.start.value, tr.end.value] for tr in trs]


class IDsConfig(IDsDataset):
"""
Extend the IDsDataset class to provide additional functionalities:
- Export and load data
- Standardize data
- Split data to handle large datasets
"""

timerange: list[datetime] = None
split: int = 1

@property
def timeranges(self):
from sunpy.time import TimeRange

trs: list[TimeRange] = TimeRange(self.timerange).split(self.split)
return [[tr.start.value, tr.end.value] for tr in trs]
def timeranges(self):
return split_timerange(self.timerange, self.split)

@property
def file_prefix(self):
if self.timerange is None:
return super().file_prefix
else:
tr_str = "-".join(t.strftime("%Y%m%d") for t in self.timerange)
return super().file_prefix + f"_tr={tr_str}"

# %% ../notebooks/11_ids_config.ipynb 3
def get_vars(self, vars: str, timerange: list[datetime] = None):
meta: Dataset = getattr(self, f"{vars}_meta")
timerange = timerange or self.timerange or meta.timerange
return Variables(
timerange=timerange,
provider=self.provider,
**meta.model_dump(exclude_unset=True),
)


class SpeasyIDsConfig(IDsConfig):
"""Based on `speasy` Variables to get the data"""

provider: str = "cda"

def model_post_init(self, __context):
# TODO: directly get columns from the data without loading them
# self.plasma_meta.density_col = self.plasma_vars.data[0].columns[0]
# self.plasma_meta.velocity_cols = self.plasma_vars.data[1].columns
pass

def get_vars(self, vars: str):
meta: Dataset = getattr(self, f"{vars}_meta")
return Variables(
timerange=self.timerange,
provider=self.provider,
**meta.model_dump(exclude_unset=True),
)
def get_vars(self, *args, **kwargs):
return get_vars(self, *args, **kwargs)

def get_vars_df(self, vars: str, cached: bool = False):
if not cached:
return self.get_vars(vars).to_polars()
else:
return NotImplementedError
def get_vars_df(self, vars: str, **kwargs):
return get_vars(self, vars, **kwargs).to_polars()

# Variables
@cached_property
Expand All @@ -81,38 +91,34 @@ def mag_vars(self):
def plasma_vars(self):
return self.get_vars("plasma")

@cached_property
def ion_temp_var(self):
return self.get_vars("ion_temp")

@cached_property
def e_temp_var(self):
return self.get_vars("e_temp")

# DataFrames
def set_data_from_vars(self, update: False):
pass

def _get_and_process_data(self, **kwargs):
self.plasma_meta.density_col = self.plasma_vars.data[0].columns[0]
self.plasma_meta.velocity_cols = self.plasma_vars.data[1].columns

# TODO: optimize for no-split timeranges
def get_data(self):
# TODO: directly get columns from the data without loading them
self.plasma_meta.density_col = (
self.plasma_meta.density_col or self.plasma_vars.data[0].columns[0]
)
self.plasma_meta.velocity_cols = (
self.plasma_meta.velocity_cols or self.plasma_vars.data[1].columns
)

for tr in tqdm(self.timeranges):
ids_ds = self.model_copy(update={"timerange": tr, "split": 1}, deep=True)
self.data = self.data or self.get_vars_df("mag")
self.plasma_data = self.plasma_data or self.get_vars_df("plasma").pipe(
standardize_plasma_data, self.plasma_meta
)
return self

ids_ds.data = ids_ds.get_vars_df("mag", cached=False)
ids_ds.plasma_data = ids_ds.get_vars_df("plasma", cached=False).pipe(
standardize_plasma_data, ids_ds.plasma_meta
def produce_or_load(self, **kwargs):
if self.split == 1:
self.file.exists() or self.get_data()
return super().produce_or_load(**kwargs)
else:
updates = [{"timerange": tr, "split": 1} for tr in self.timeranges]
configs = [self.model_copy(update=update, deep=True) for update in updates]
datas, files = zip(
config.produce_or_load(**kwargs) for config in tqdm(configs)
)

yield (
ids_ds.find_events(return_best_fit=False)
.update_events_with_plasma_data()
.events
return produce_or_load_file(
f=pl.concat,
c=dict(items=datas),
file=self.file,
)

def get_and_process_data(self, **kwargs):
self.events = pl.concat(self._get_and_process_data(**kwargs))
return self
7 changes: 3 additions & 4 deletions discontinuitypy/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,13 @@ def ids_finder(
extract_df: pl.LazyFrame = None, # data used for feature extraction (typically high cadence data),
**kwargs
):
if extract_df is None:
extract_df = detection_df
extract_df = extract_df or detection_df
if bcols is None:
bcols = detection_df.columns
bcols.remove("time")

detection_df = detection_df.sort("time").with_columns(pl.col("time").dt.cast_time_unit("us")) # https://github.com/pola-rs/polars/issues/12023
extract_df = extract_df.sort("time").with_columns(pl.col("time").dt.cast_time_unit("us"))
detection_df = detection_df.sort("time") # https://github.com/pola-rs/polars/issues/12023
extract_df = extract_df.sort("time")

events = detect_events(detection_df, tau, ts, bcols, **kwargs)

Expand Down
Loading

0 comments on commit 1e70b6c

Please sign in to comment.