Skip to content

Commit 7aaea3f

Browse files
committed
cleaned
1 parent 5fdfebe commit 7aaea3f

File tree

5 files changed

+145
-38
lines changed

5 files changed

+145
-38
lines changed

simba/mixins/circular_statistics.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,9 @@ def sliding_angular_diff(
607607

608608
@staticmethod
609609
@njit("(float32[:], float64[:], int64)")
610-
def agg_angular_diff_timebins(data: np.ndarray, time_windows: np.ndarray, fps: int) -> np.ndarray:
610+
def agg_angular_diff_timebins(
611+
data: np.ndarray, time_windows: np.ndarray, fps: int
612+
) -> np.ndarray:
611613
"""
612614
Compute the difference between the median angle in the current time-window versus the previous time window.
613615
For example, computes the difference between the mean angle in the first 1s of the video versus
@@ -678,9 +680,14 @@ def rao_spacing(data: np.array):
678680
data = np.sort(data)
679681
Ti, TiL = np.full((data.shape[0]), np.nan), np.full((data.shape[0]), np.nan)
680682
l = np.int8(360 / len(data))
681-
Ti[-1] = np.rad2deg(np.pi - np.abs(np.pi - np.abs(np.deg2rad(data[0]) - np.deg2rad(data[-1]))))
683+
Ti[-1] = np.rad2deg(
684+
np.pi - np.abs(np.pi - np.abs(np.deg2rad(data[0]) - np.deg2rad(data[-1])))
685+
)
682686
for j in prange(data.shape[0] - 1, -1, -1):
683-
Ti[j] = np.rad2deg(np.pi - np.abs(np.pi - np.abs(np.deg2rad(data[j]) - np.deg2rad(data[j - 1]))))
687+
Ti[j] = np.rad2deg(
688+
np.pi
689+
- np.abs(np.pi - np.abs(np.deg2rad(data[j]) - np.deg2rad(data[j - 1])))
690+
)
684691
for k in prange(Ti.shape[0]):
685692
TiL[int(k)] = max((l, Ti[k])) - min((l, Ti[k]))
686693
S = np.sum(TiL)
@@ -689,7 +696,9 @@ def rao_spacing(data: np.array):
689696

690697
@staticmethod
691698
@njit("(float32[:], float64[:], int64)")
692-
def sliding_rao_spacing(data: np.ndarray, time_windows: np.ndarray, fps: int) -> np.ndarray:
699+
def sliding_rao_spacing(
700+
data: np.ndarray, time_windows: np.ndarray, fps: int
701+
) -> np.ndarray:
693702
"""
694703
Jitted compute of the uniformity of a circular dataset in sliding windows.
695704
@@ -726,11 +735,24 @@ def sliding_rao_spacing(data: np.ndarray, time_windows: np.ndarray, fps: int) ->
726735
window_size = int(time_windows[win_cnt] * fps)
727736
for i in range(window_size, data.shape[0]):
728737
w_data = np.sort(data[i - window_size : i])
729-
Ti, TiL = np.full((w_data.shape[0]), np.nan), np.full((w_data.shape[0]), np.nan)
738+
Ti, TiL = np.full((w_data.shape[0]), np.nan), np.full(
739+
(w_data.shape[0]), np.nan
740+
)
730741
l = np.int8(360 / len(w_data))
731-
Ti[-1] = np.rad2deg(np.pi - np.abs(np.pi - np.abs(np.deg2rad(w_data[0]) - np.deg2rad(w_data[-1]))))
742+
Ti[-1] = np.rad2deg(
743+
np.pi
744+
- np.abs(
745+
np.pi - np.abs(np.deg2rad(w_data[0]) - np.deg2rad(w_data[-1]))
746+
)
747+
)
732748
for j in prange(w_data.shape[0] - 1, -1, -1):
733-
Ti[j] = np.rad2deg(np.pi - np.abs(np.pi - np.abs(np.deg2rad(w_data[j]) - np.deg2rad(w_data[j - 1]))))
749+
Ti[j] = np.rad2deg(
750+
np.pi
751+
- np.abs(
752+
np.pi
753+
- np.abs(np.deg2rad(w_data[j]) - np.deg2rad(w_data[j - 1]))
754+
)
755+
)
734756
for k in prange(Ti.shape[0]):
735757
TiL[int(k)] = max((l, Ti[k])) - min((l, Ti[k]))
736758
S = np.sum(TiL)
@@ -965,7 +987,9 @@ def circular_hotspots(data: np.ndarray, bins: np.ndarray) -> np.ndarray:
965987

966988
@staticmethod
967989
@njit("(float32[:], int64[:, :], float64, float64)")
968-
def sliding_circular_hotspots(data: np.ndarray, bins: np.ndarray, time_window: float, fps: float) -> np.ndarray:
990+
def sliding_circular_hotspots(
991+
data: np.ndarray, bins: np.ndarray, time_window: float, fps: float
992+
) -> np.ndarray:
969993
"""
970994
Jitted compute of sliding circular hotspots in a dataset. Calculates circular hotspots in a time-series dataset by sliding a time window
971995
across the data and computing hotspot statistics for specified circular bins.

simba/mixins/timeseries_features_mixin.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def hjort_parameters(data: np.ndarray) -> (float, float, float):
7676
dx = np.diff(np.ascontiguousarray(data))
7777
ddx = np.diff(np.ascontiguousarray(dx))
7878
x_var, dx_var = np.var(data), np.var(dx)
79-
if (x_var <= 0) or (dx_var <= 0): return 0, 0, 0
79+
if (x_var <= 0) or (dx_var <= 0):
80+
return 0, 0, 0
8081

8182
ddx_var = np.var(ddx)
8283
mobility = np.sqrt(dx_var / x_var)
@@ -209,7 +210,9 @@ def crossings(data: np.ndarray, val: float) -> int:
209210

210211
@staticmethod
211212
@njit("(float32[:], float64, float64[:], int64,)")
212-
def sliding_crossings(data: np.ndarray, val: float, time_windows: np.ndarray, fps: int) -> np.ndarray:
213+
def sliding_crossings(
214+
data: np.ndarray, val: float, time_windows: np.ndarray, fps: int
215+
) -> np.ndarray:
213216
"""
214217
Compute the number of crossings over sliding windows in a data array.
215218
@@ -258,7 +261,9 @@ def sliding_crossings(data: np.ndarray, val: float, time_windows: np.ndarray, fp
258261

259262
@staticmethod
260263
@njit("(float32[:], int64, int64, )", cache=True, fastmath=True)
261-
def percentile_difference(data: np.ndarray, upper_pct: int, lower_pct: int) -> float:
264+
def percentile_difference(
265+
data: np.ndarray, upper_pct: int, lower_pct: int
266+
) -> float:
262267
"""
263268
Jitted compute of the difference between the ``upper`` and ``lower`` percentiles of the data as
264269
a percentage of the median value. Helps understanding the spread or variability of the data within specified percentiles.
@@ -281,13 +286,20 @@ def percentile_difference(data: np.ndarray, upper_pct: int, lower_pct: int) -> f
281286
>>> 0.7401574764125177
282287
283288
"""
284-
upper_val, lower_val = np.percentile(data, upper_pct), np.percentile(data, lower_pct)
289+
upper_val, lower_val = np.percentile(data, upper_pct), np.percentile(
290+
data, lower_pct
291+
)
285292
return np.abs(upper_val - lower_val) / np.median(data)
286293

287294
@staticmethod
288295
@njit("(float32[:], int64, int64, float64[:], int64, )", cache=True, fastmath=True)
289-
def sliding_percentile_difference(data: np.ndarray,upper_pct: int,lower_pct: int,window_sizes: np.ndarray, fps: int) -> np.ndarray:
290-
296+
def sliding_percentile_difference(
297+
data: np.ndarray,
298+
upper_pct: int,
299+
lower_pct: int,
300+
window_sizes: np.ndarray,
301+
fps: int,
302+
) -> np.ndarray:
291303
"""
292304
Jitted computes the difference between the upper and lower percentiles within a sliding window for each position
293305
in the time series using various window sizes. It returns a 2D array where each row corresponds to a position in the time series,
@@ -352,7 +364,9 @@ def percent_beyond_n_std(data: np.ndarray, n: float) -> float:
352364

353365
@staticmethod
354366
@njit("(float64[:], float64, float64[:], int64,)", cache=True, fastmath=True)
355-
def sliding_percent_beyond_n_std(data: np.ndarray, n: float, window_sizes: np.ndarray, sample_rate: int) -> np.ndarray:
367+
def sliding_percent_beyond_n_std(
368+
data: np.ndarray, n: float, window_sizes: np.ndarray, sample_rate: int
369+
) -> np.ndarray:
356370
"""
357371
Computed the percentage of data points that exceed 'n' standard deviations from the mean for each position in
358372
the time series using various window sizes. It returns a 2D array where each row corresponds to a position in the time series,
@@ -1015,7 +1029,9 @@ def sliding_longest_strike(
10151029
(float32[:], float64, int64, types.misc.Omitted(True)),
10161030
]
10171031
)
1018-
def time_since_previous_threshold(data: np.ndarray, threshold: float, fps: int, above: bool) -> np.ndarray:
1032+
def time_since_previous_threshold(
1033+
data: np.ndarray, threshold: float, fps: int, above: bool
1034+
) -> np.ndarray:
10191035
"""
10201036
Jitted compute of the time (in seconds) that has elapsed since the last occurrence of a value above (or below)
10211037
a specified threshold in a time series. The time series is assumed to have a constant sample rate.
@@ -1548,4 +1564,4 @@ def acceleration(
15481564
else:
15491565
results[wS:wE] = v - pv
15501566
pv = v
1551-
return results
1567+
return results

simba/unsupervised/grid_search_visualizers.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010

1111
from simba.mixins.unsupervised_mixin import UnsupervisedMixin
1212
from simba.unsupervised.enums import Clustering, Unsupervised
13-
from simba.utils.checks import (check_if_dir_exists, check_if_filepath_list_is_empty)
14-
from simba.utils.printing import stdout_success
13+
from simba.utils.checks import (check_if_dir_exists,
14+
check_if_filepath_list_is_empty)
1515
from simba.utils.enums import Formats
16+
from simba.utils.printing import stdout_success
1617

1718

1819
class GridSearchVisualizer(UnsupervisedMixin):
@@ -36,7 +37,10 @@ def __init__(self, model_dir: str, save_dir: str, settings: dict):
3637
check_if_dir_exists(in_dir=model_dir)
3738
self.save_dir, self.settings, self.model_dir = save_dir, settings, model_dir
3839
self.data_path = glob.glob(model_dir + f"/*.{Formats.PICKLE.value}")
39-
check_if_filepath_list_is_empty(filepaths=self.data_path, error_msg=f"SIMBA ERROR: No pickle files found in {model_dir}")
40+
check_if_filepath_list_is_empty(
41+
filepaths=self.data_path,
42+
error_msg=f"SIMBA ERROR: No pickle files found in {model_dir}",
43+
)
4044

4145
def __join_data(self, data: object):
4246
embedding_data = pd.DataFrame(
@@ -127,7 +131,7 @@ def continuous_visualizer(self, continuous_vars: list):
127131
)
128132

129133

130-
settings = {'PALETTE': 'Pastel1'}
134+
settings = {"PALETTE": "Pastel1"}
131135
# test = GridSearchVisualizer(model_dir='/Users/simon/Desktop/envs/troubleshooting/unsupervised/cluster_models',
132136
# save_dir='/Users/simon/Desktop/envs/troubleshooting/unsupervised/images',
133137
# settings=settings)

simba/unsupervised/tsne.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
from sklearn.manifold import TSNE
1515

1616
import simba
17-
#from simba.misc_tools import SimbaTimer, check_file_exist_and_readable
18-
#from simba.utils.enums import Paths
17+
18+
# from simba.misc_tools import SimbaTimer, check_file_exist_and_readable
19+
# from simba.utils.enums import Paths
1920

2021

2122
class TSNEGridSearch(object):

simba/unsupervised/ui.py

Lines changed: 77 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@ class UnsupervisedGUI(ConfigReader, PopUpMixin):
2525

2626
def __init__(self, config_path: str):
2727
ConfigReader.__init__(self, config_path=config_path)
28-
PopUpMixin.__init__( self, title="UNSUPERVISED ANALYSIS", config_path=config_path, size=(1000, 800),)
28+
PopUpMixin.__init__(
29+
self,
30+
title="UNSUPERVISED ANALYSIS",
31+
config_path=config_path,
32+
size=(1000, 800),
33+
)
2934
self.main_frm = Toplevel()
3035
self.main_frm.minsize(1000, 800)
3136
self.main_frm.wm_title("UNSUPERVISED ANALYSIS")
@@ -38,27 +43,84 @@ def __init__(self, config_path: str):
3843
self.visualization_tab = ttk.Frame(self.main_frm)
3944
self.metrics_tab = ttk.Frame(self.main_frm)
4045

41-
self.main_frm.add(self.create_dataset_tab,text=f'{"[CREATE DATASET]": ^20s}',compound="left",image=self.menu_icons["features"]["img"])
42-
self.main_frm.add(self.dimensionality_reduction_tab,text=f'{"[DIMENSIONALITY REDUCTION]": ^20s}',compound="left",image=self.menu_icons["dimensionality_reduction"]["img"])
43-
self.main_frm.add(self.clustering_tab,text=f'{"[CLUSTERING]": ^20s}',compound="left",image=self.menu_icons["cluster"]["img"])
44-
self.main_frm.add(self.visualization_tab,text=f'{"[VISUALIZATION]": ^20s}',compound="left",image=self.menu_icons["visualize"]["img"])
45-
self.main_frm.add(self.metrics_tab, text=f'{"[METRICS]": ^20s}', compound="left", image=self.menu_icons["metrics"]["img"])
46+
self.main_frm.add(
47+
self.create_dataset_tab,
48+
text=f'{"[CREATE DATASET]": ^20s}',
49+
compound="left",
50+
image=self.menu_icons["features"]["img"],
51+
)
52+
self.main_frm.add(
53+
self.dimensionality_reduction_tab,
54+
text=f'{"[DIMENSIONALITY REDUCTION]": ^20s}',
55+
compound="left",
56+
image=self.menu_icons["dimensionality_reduction"]["img"],
57+
)
58+
self.main_frm.add(
59+
self.clustering_tab,
60+
text=f'{"[CLUSTERING]": ^20s}',
61+
compound="left",
62+
image=self.menu_icons["cluster"]["img"],
63+
)
64+
self.main_frm.add(
65+
self.visualization_tab,
66+
text=f'{"[VISUALIZATION]": ^20s}',
67+
compound="left",
68+
image=self.menu_icons["visualize"]["img"],
69+
)
70+
self.main_frm.add(
71+
self.metrics_tab,
72+
text=f'{"[METRICS]": ^20s}',
73+
compound="left",
74+
image=self.menu_icons["metrics"]["img"],
75+
)
4676
self.main_frm.grid(row=0)
4777

4878
self.clf_slice_options = [f"ALL CLASSIFIERS ({len(self.clf_names)})"]
49-
for clf_name in self.clf_names: self.clf_slice_options.append(f"{clf_name}")
50-
create_dataset_frm = LabelFrame(self.create_dataset_tab,text="CREATE DATASET",pady=5,padx=5,font=Formats.LABELFRAME_HEADER_FORMAT.value,fg="black")
51-
self.feature_file_selected = FileSelect(create_dataset_frm, "FEATURE FILE (CSV)", lblwidth=25)
52-
self.data_slice_dropdown = DropDownMenu(create_dataset_frm,"FEATURE SLICE:",UMLOptions.FEATURE_SLICE_OPTIONS.value,"25",com=lambda x: self.change_status_of_file_select())
79+
for clf_name in self.clf_names:
80+
self.clf_slice_options.append(f"{clf_name}")
81+
create_dataset_frm = LabelFrame(
82+
self.create_dataset_tab,
83+
text="CREATE DATASET",
84+
pady=5,
85+
padx=5,
86+
font=Formats.LABELFRAME_HEADER_FORMAT.value,
87+
fg="black",
88+
)
89+
self.feature_file_selected = FileSelect(
90+
create_dataset_frm, "FEATURE FILE (CSV)", lblwidth=25
91+
)
92+
self.data_slice_dropdown = DropDownMenu(
93+
create_dataset_frm,
94+
"FEATURE SLICE:",
95+
UMLOptions.FEATURE_SLICE_OPTIONS.value,
96+
"25",
97+
com=lambda x: self.change_status_of_file_select(),
98+
)
5399
self.data_slice_dropdown.setChoices(UMLOptions.FEATURE_SLICE_OPTIONS.value[0])
54-
self.clf_slice_dropdown = DropDownMenu(create_dataset_frm, "CLASSIFIER SLICE:", self.clf_slice_options, "25")
100+
self.clf_slice_dropdown = DropDownMenu(
101+
create_dataset_frm, "CLASSIFIER SLICE:", self.clf_slice_options, "25"
102+
)
55103
self.clf_slice_dropdown.setChoices(self.clf_slice_options[0])
56104
self.change_status_of_file_select()
57-
self.bout_dropdown = DropDownMenu(create_dataset_frm,"BOUT AGGREGATION METHOD:",UMLOptions.BOUT_AGGREGATION_METHODS.value,"25")
58-
self.bout_dropdown.setChoices(choice=UMLOptions.BOUT_AGGREGATION_METHODS.value[0])
59-
self.min_bout_length = Entry_Box(create_dataset_frm, "MINIMUM BOUT LENGTH (MS): ", "25", validation="numeric")
105+
self.bout_dropdown = DropDownMenu(
106+
create_dataset_frm,
107+
"BOUT AGGREGATION METHOD:",
108+
UMLOptions.BOUT_AGGREGATION_METHODS.value,
109+
"25",
110+
)
111+
self.bout_dropdown.setChoices(
112+
choice=UMLOptions.BOUT_AGGREGATION_METHODS.value[0]
113+
)
114+
self.min_bout_length = Entry_Box(
115+
create_dataset_frm, "MINIMUM BOUT LENGTH (MS): ", "25", validation="numeric"
116+
)
60117
self.min_bout_length.entry_set(val=0)
61-
self.create_btn = Button(create_dataset_frm,text="CREATE DATASET",fg="blue",command=lambda: self.create_dataset())
118+
self.create_btn = Button(
119+
create_dataset_frm,
120+
text="CREATE DATASET",
121+
fg="blue",
122+
command=lambda: self.create_dataset(),
123+
)
62124

63125
create_dataset_frm.grid(row=0, column=0, sticky=NW)
64126
self.data_slice_dropdown.grid(row=0, column=0, sticky=NW)

0 commit comments

Comments
 (0)