Skip to content

Commit f3c1f3a

Browse files
committed
Format
1 parent 09f51f8 commit f3c1f3a

File tree

6 files changed

+27
-61
lines changed

6 files changed

+27
-61
lines changed

many/stats/continuous_categorical.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,6 @@ def mat_mwu_gpu(a_mat, b_mat, melt: bool, effect: str, use_continuity=True):
476476

477477

478478
def biserial_continuous_nan(a_mat, b_mat, melt: bool, effect: str):
479-
480479
"""
481480
Compute biserial (point or rank) correlations for every column-column pair of
482481
a_mat (continuous) and b_mat (binary). Allows for missing values in a_mat.

many/stats/continuous_continuous.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,7 @@
33
import numpy as np
44
import pandas as pd
55
import scipy.special as special
6-
from scipy.stats import (
7-
pearsonr,
8-
spearmanr,
9-
)
6+
from scipy.stats import pearsonr, spearmanr
107
from statsmodels.stats.multitest import multipletests
118
from tqdm import tqdm_notebook as tqdm
129

tests/benchmark_stats.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,7 @@
3030

3131
for submodule in submodules:
3232

33-
params = importlib.import_module(
34-
f"stats_benchmark_params.{submodule}"
35-
).params
33+
params = importlib.import_module(f"stats_benchmark_params.{submodule}").params
3634

3735
base_times = []
3836
method_times = []
@@ -51,21 +49,15 @@
5149
method_times.append(method_time)
5250
ratios.append(ratio)
5351

54-
benchmarks_df = pd.DataFrame(
55-
params, columns=inspect.getfullargspec(compare)[0]
56-
)
52+
benchmarks_df = pd.DataFrame(params, columns=inspect.getfullargspec(compare)[0])
5753

5854
benchmarks_df["base_method"] = benchmarks_df["base_method"].apply(
5955
lambda x: x.__name__
6056
)
61-
benchmarks_df["method"] = benchmarks_df["method"].apply(
62-
lambda x: x.__name__
63-
)
57+
benchmarks_df["method"] = benchmarks_df["method"].apply(lambda x: x.__name__)
6458

6559
benchmarks_df["base_times"] = base_times
6660
benchmarks_df["method_times"] = method_times
6761
benchmarks_df["ratios"] = ratios
6862

69-
benchmarks_df.to_csv(
70-
config.BENCHMARK_DATA_DIR / f"{submodule}.txt", sep="\t"
71-
)
63+
benchmarks_df.to_csv(config.BENCHMARK_DATA_DIR / f"{submodule}.txt", sep="\t")

tests/plot_benchmarks.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@
3636

3737
for method_id in method_ids:
3838

39-
method_benchmarks = benchmarks_df[
40-
benchmarks_df["method_id"] == method_id
41-
]
39+
method_benchmarks = benchmarks_df[benchmarks_df["method_id"] == method_id]
4240

4341
ax = fig.add_subplot(SUBPLOT_ROWS, SUBPLOT_COLS, position)
4442

tests/test_visuals.py

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
# import importlib
55

66
import config
7-
import many
87
import matplotlib.pyplot as plt
98
import numpy as np
109
import utils
1110

11+
import many
12+
1213
DPI = 256
1314

1415
# p = Path("./test_stats").glob("*.py")
@@ -32,7 +33,8 @@
3233
print(many.visuals.colorize([0, 0, 0, 0, 1, 1, 1, 2, 2], cmap="Blues"))
3334
print(
3435
many.visuals.colorize(
35-
["a", "b", "c", "a", "a", "b", "c", "d", "b", "d", "a"], cmap="Blues",
36+
["a", "b", "c", "a", "a", "b", "c", "d", "b", "d", "a"],
37+
cmap="Blues",
3638
)
3739
)
3840

@@ -54,9 +56,7 @@
5456
a[3] = a[2] + np.random.normal(size=100)
5557

5658
many.visuals.scatter_grid(a)
57-
plt.savefig(
58-
config.PLOTS_DIR / "scatter_grid.png", bbox_inches="tight", dpi=DPI
59-
)
59+
plt.savefig(config.PLOTS_DIR / "scatter_grid.png", bbox_inches="tight", dpi=DPI)
6060
plt.clf()
6161

6262
# ----------
@@ -70,9 +70,7 @@
7070
y = x + np.random.normal(size=1000)
7171

7272
many.visuals.regression(x, y, method="pearson")
73-
plt.savefig(
74-
config.PLOTS_DIR / "regression_pearson.png", bbox_inches="tight", dpi=DPI
75-
)
73+
plt.savefig(config.PLOTS_DIR / "regression_pearson.png", bbox_inches="tight", dpi=DPI)
7674
plt.clf()
7775

7876
# -----------
@@ -89,7 +87,9 @@
8987
x, y, text_adjust=False, ax=ax, colormap="Blues", cmap_offset=0.1
9088
)
9189
plt.savefig(
92-
config.PLOTS_DIR / "dense_plot_default.png", bbox_inches="tight", dpi=DPI,
90+
config.PLOTS_DIR / "dense_plot_default.png",
91+
bbox_inches="tight",
92+
dpi=DPI,
9393
)
9494
plt.clf()
9595

@@ -103,9 +103,7 @@
103103
x = np.random.normal(size=1000)
104104
y = x + np.random.normal(size=1000)
105105

106-
many.visuals.dense_regression(
107-
x, y, method="pearson", colormap="Blues", cmap_offset=0.1
108-
)
106+
many.visuals.dense_regression(x, y, method="pearson", colormap="Blues", cmap_offset=0.1)
109107
plt.savefig(
110108
config.PLOTS_DIR / "dense_regression_pearson.png",
111109
bbox_inches="tight",
@@ -129,12 +127,8 @@
129127

130128
b[0] = b[0] + a[0]
131129

132-
many.visuals.two_dists(
133-
a[0], b[0], method="t_test", summary_type="box", stripplot=True
134-
)
135-
plt.savefig(
136-
config.PLOTS_DIR / "two_dists_t_test_box.png", bbox_inches="tight", dpi=DPI
137-
)
130+
many.visuals.two_dists(a[0], b[0], method="t_test", summary_type="box", stripplot=True)
131+
plt.savefig(config.PLOTS_DIR / "two_dists_t_test_box.png", bbox_inches="tight", dpi=DPI)
138132
plt.clf()
139133

140134
# -----------
@@ -158,9 +152,7 @@
158152
b = (b * 25).astype(int)
159153

160154
many.visuals.multi_dists(a[0], b[0], count_cutoff=0, summary_type="box", ax=ax)
161-
plt.savefig(
162-
config.PLOTS_DIR / "multi_dists_box.png", bbox_inches="tight", dpi=DPI
163-
)
155+
plt.savefig(config.PLOTS_DIR / "multi_dists_box.png", bbox_inches="tight", dpi=DPI)
164156
plt.clf()
165157

166158
# -------------
@@ -180,9 +172,7 @@
180172
b[0] = b[0] + a[0] * np.random.random(size=100)
181173

182174
many.visuals.roc_auc_curve(a[0], b[0])
183-
plt.savefig(
184-
config.PLOTS_DIR / "roc_auc_curve.png", bbox_inches="tight", dpi=DPI
185-
)
175+
plt.savefig(config.PLOTS_DIR / "roc_auc_curve.png", bbox_inches="tight", dpi=DPI)
186176
plt.clf()
187177

188178
# --------
@@ -222,9 +212,7 @@
222212
b[0] = b[0] + a[0] * np.random.random(size=100)
223213

224214
many.visuals.binary_metrics(a[0], b[0])
225-
plt.savefig(
226-
config.PLOTS_DIR / "binary_metrics.png", bbox_inches="tight", dpi=DPI
227-
)
215+
plt.savefig(config.PLOTS_DIR / "binary_metrics.png", bbox_inches="tight", dpi=DPI)
228216
plt.clf()
229217

230218
# ------------------

tests/utils.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import time
22
from typing import List
33

4-
import many
54
import numpy as np
65
import pandas as pd
76
from config import TOLERANCE
87

8+
import many
9+
910

1011
class bcolors:
1112
HEADER = "\033[95m"
@@ -118,7 +119,6 @@ def compare(
118119
output_names: List[str],
119120
report_benchmark: bool,
120121
):
121-
122122
"""
123123
General test handler for comparing two methods.
124124
@@ -158,9 +158,7 @@ def compare(
158158
)
159159

160160
# announce method parameters
161-
print(
162-
f" with {bcolors.BOLD}{bcolors.HEADER}{base_method.__name__}{bcolors.ENDC}"
163-
)
161+
print(f" with {bcolors.BOLD}{bcolors.HEADER}{base_method.__name__}{bcolors.ENDC}")
164162

165163
args_string = ", ".join(
166164
f"{bcolors.BOLD}{key}{bcolors.ENDC} = {value}"
@@ -201,12 +199,8 @@ def compare(
201199

202200
if report_benchmark:
203201
print(f"\tNaive speed: {bcolors.BOLD}{base_time:.2f}s{bcolors.ENDC}")
204-
print(
205-
f"\tVectorized speed: {bcolors.BOLD}{method_time:.2f}s{bcolors.ENDC}"
206-
)
207-
print(
208-
f"\tSpeedup: {bcolors.BOLD}{base_time/method_time:.2f}x{bcolors.ENDC}"
209-
)
202+
print(f"\tVectorized speed: {bcolors.BOLD}{method_time:.2f}s{bcolors.ENDC}")
203+
print(f"\tSpeedup: {bcolors.BOLD}{base_time/method_time:.2f}x{bcolors.ENDC}")
210204

211205
benchmark_results = {"base_time": base_time, "method_time": method_time}
212206

@@ -253,8 +247,6 @@ def compare(
253247

254248
else:
255249

256-
print(
257-
f"max deviation is {bcolors.FAIL}{max_deviation_str}{bcolors.ENDC}"
258-
)
250+
print(f"max deviation is {bcolors.FAIL}{max_deviation_str}{bcolors.ENDC}")
259251

260252
return base_result, result, benchmark_results

0 commit comments

Comments
 (0)