Skip to content

Commit

Permalink
REF: refactor ids_finder and process_events functions to improve para…
Browse files Browse the repository at this point in the history
…meter handling
  • Loading branch information
Beforerr committed Nov 8, 2024
1 parent 4b13b15 commit 9a9f9a5
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 14 deletions.
9 changes: 4 additions & 5 deletions notebooks/00_ids_finder.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"import polars as pl\n",
"from discontinuitypy.detection.variance import detect_variance\n",
"from discontinuitypy.core.propeties import process_events\n",
"from discontinuitypy.utils.basic import df2ts\n",
"from space_analysis.ds.ts.io import df2ts\n",
"from loguru import logger\n",
"\n",
"from datetime import timedelta\n",
Expand Down Expand Up @@ -89,10 +89,9 @@
"# | export\n",
"def ids_finder(\n",
" detection_df: pl.LazyFrame, # data used for anomaly dectection (typically low cadence data)\n",
" tau: timedelta,\n",
" ts: timedelta,\n",
" bcols=None,\n",
" detect_func: Callable[..., pl.LazyFrame] = detect_variance,\n",
" detect_kwargs: dict = {},\n",
" extract_df: pl.LazyFrame = None, # data used for feature extraction (typically high cadence data),\n",
" **kwargs,\n",
"):\n",
Expand All @@ -108,11 +107,11 @@
" detection_df = detection_df.sort(\"time\")\n",
" extract_df = extract_df.sort(\"time\")\n",
"\n",
" events = detect_func(detection_df, tau, ts, bcols, **kwargs)\n",
" events = detect_func(detection_df, bcols=bcols, **detect_kwargs)\n",
"\n",
" data_c = compress_data_by_events(extract_df.collect(), events)\n",
" sat_fgm = df2ts(data_c, bcols)\n",
" ids = process_events(events, sat_fgm, ts, **kwargs)\n",
" ids = process_events(events, sat_fgm, **kwargs)\n",
" return ids"
]
},
Expand Down
2 changes: 0 additions & 2 deletions notebooks/02_ids_properties.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
"except ImportError:\n",
" import pandas as pd\n",
"\n",
"from datetime import timedelta\n",
"from loguru import logger\n",
"\n",
"from beforerr.polars import convert_to_pd_dataframe, decompose_vector\n",
Expand Down Expand Up @@ -388,7 +387,6 @@
"def process_events(\n",
" candidates_pl: pl.DataFrame, # potential candidates DataFrame\n",
" sat_fgm: xr.DataArray, # satellite FGM data\n",
" data_resolution: timedelta, # time resolution of the data\n",
" method: Literal[\"fit\", \"derivative\"] = \"fit\",\n",
" **kwargs,\n",
") -> pl.DataFrame:\n",
Expand Down
9 changes: 4 additions & 5 deletions src/discontinuitypy/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import polars as pl
from ..detection.variance import detect_variance
from .propeties import process_events
from ..utils.basic import df2ts
from space_analysis.ds.ts.io import df2ts
from loguru import logger

from datetime import timedelta
Expand All @@ -28,10 +28,9 @@ def compress_data_by_events(data: pl.DataFrame, events: pl.DataFrame):
# %% ../../../notebooks/00_ids_finder.ipynb 6
def ids_finder(
detection_df: pl.LazyFrame, # data used for anomaly dectection (typically low cadence data)
tau: timedelta,
ts: timedelta,
bcols=None,
detect_func: Callable[..., pl.LazyFrame] = detect_variance,
detect_kwargs: dict = {},
extract_df: pl.LazyFrame = None, # data used for feature extraction (typically high cadence data),
**kwargs,
):
Expand All @@ -47,11 +46,11 @@ def ids_finder(
detection_df = detection_df.sort("time")
extract_df = extract_df.sort("time")

events = detect_func(detection_df, tau, ts, bcols, **kwargs)
events = detect_func(detection_df, bcols=bcols, **detect_kwargs)

data_c = compress_data_by_events(extract_df.collect(), events)
sat_fgm = df2ts(data_c, bcols)
ids = process_events(events, sat_fgm, ts, **kwargs)
ids = process_events(events, sat_fgm, **kwargs)
return ids

# %% ../../../notebooks/00_ids_finder.ipynb 8
Expand Down
2 changes: 0 additions & 2 deletions src/discontinuitypy/core/propeties.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
except ImportError:
import pandas as pd

from datetime import timedelta
from loguru import logger

from beforerr.polars import convert_to_pd_dataframe, decompose_vector
Expand Down Expand Up @@ -254,7 +253,6 @@ def calc_normal_direction(data, name="normal_direction", **kwargs):
def process_events(
candidates_pl: pl.DataFrame, # potential candidates DataFrame
sat_fgm: xr.DataArray, # satellite FGM data
data_resolution: timedelta, # time resolution of the data
method: Literal["fit", "derivative"] = "fit",
**kwargs,
) -> pl.DataFrame:
Expand Down

0 comments on commit 9a9f9a5

Please sign in to comment.