Skip to content

Commit 3869db6

Browse files
committed
notebooks to docs, small fixes for plots
1 parent b2e015f commit 3869db6

10 files changed

+5269
-94
lines changed

bayes_window/utils.py

+17-11
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@ def add_data_to_posterior(df,
1616
conditions=None, # eg ('stim_on', 'stim_stop')
1717
b_name='b_stim_per_condition', # for posterior
1818
group_name='Condition code', # for posterior
19-
do_make_change=True,
19+
do_make_change='subtract',
2020
do_mean_over_trials=True,
2121
):
2222
index_cols = list(index_cols)
23-
if conditions is None:
24-
conditions = df[condition_name].drop_duplicates().sort_values().values
23+
24+
conditions= conditions or df[condition_name].drop_duplicates().sort_values().values
2525
assert len(conditions) == 2, f'{condition_name}={conditions}. Should be only two instead!'
26+
assert do_make_change in [False, 'subtract', 'divide']
2627
if not (condition_name in index_cols):
2728
index_cols.append(condition_name)
2829
if do_mean_over_trials:
@@ -36,6 +37,7 @@ def add_data_to_posterior(df,
3637
index_cols=index_cols,
3738
condition_name=condition_name,
3839
conditions=conditions,
40+
fold_change_method=do_make_change,
3941
do_take_mean=False)
4042
# Condition is removed from both index columns and dfbayes
4143
index_cols.remove(condition_name)
@@ -98,29 +100,33 @@ def fill_row(rows):
98100

99101

100102
def make_fold_change(df, y='log_firing_rate', index_cols=('Brain region', 'Stim phase'),
101-
condition_name='stim', conditions=(0, 1), do_take_mean=False):
102-
# for index_col in index_cols:
103-
# assert type(df[index_col].iloc[0]) != str, f'Make sure {index_col} contains not strings!'
103+
condition_name='stim', conditions=(0, 1), do_take_mean=False, fold_change_method='divide'):
104104
for condition in conditions:
105105
assert condition in df[condition_name].unique(), f'{condition} not in {df[condition_name].unique()}'
106106
if y not in df.columns:
107107
raise ValueError(f'{y} is not a column in this dataset: {df.columns}')
108+
109+
# Take mean of trials:
108110
if do_take_mean:
109-
# Take mean of trials:
110111
df = df.groupby(list(index_cols)).mean().reset_index()
112+
111113
# Make multiindex
112114
mdf = df.set_index(list(set(index_cols) - {'i_spike'})).copy()
113-
# mdf.xs(0, level='stim') - mdf.xs(1, level='stim')
114115
if (mdf.xs(conditions[1], level=condition_name).size !=
115116
mdf.xs(conditions[0], level=condition_name).size):
116117
raise IndexError(f'Uneven number of entries in conditions! Try setting do_take_mean=True'
117118
f'{mdf.xs(conditions[0], level=condition_name).size, mdf.xs(conditions[1], level=condition_name).size}')
118119

119120
# Subtract/divide
120121
try:
121-
data = (mdf.xs(conditions[1], level=condition_name) -
122-
mdf.xs(conditions[0], level=condition_name)
123-
).reset_index()
122+
if fold_change_method == 'subtract':
123+
data = (mdf.xs(conditions[1], level=condition_name) -
124+
mdf.xs(conditions[0], level=condition_name)
125+
).reset_index()
126+
else:
127+
data = (mdf.xs(conditions[1], level=condition_name) /
128+
mdf.xs(conditions[0], level=condition_name)
129+
).reset_index()
124130
except Exception as e:
125131
print(f'Try recasting {condition_name} as integer and try again. Alternatively, use bayes_window.workflow.'
126132
f' We do that automatically there ')

bayes_window/visualization.py

+4-19
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ def facet(base_chart,
1313
width=80,
1414
height=150,
1515
):
16-
print('custom facet')
1716
alt.themes.enable('vox')
1817
if column is None and row is None:
1918
raise RuntimeError('Need either column, or row, or both!')
@@ -67,32 +66,18 @@ def plot_data(df=None, x=None, y=None, color=None, add_box=True, base_chart=None
6766
chart = base.mark_line(fill=None, opacity=.5, size=3).encode(
6867
x=x,
6968
color=f'{color}',
70-
y=f'{y}:Q'
69+
y=alt.Y(f'{y}:Q')
7170
)
7271
if add_box:
7372
# Shift x axis for box so that it doesnt overlap:
7473
# df['x_box'] = df[x[:-2]] + .01
7574
chart += base.mark_boxplot(opacity=.3, size=12, color='black').encode(
7675
x=x,
77-
y=f'{y}:Q'
76+
y=alt.Y(f'{y}:Q')
7877
)
7978
return chart
8079

8180

82-
# def plot_data_and_posterior(df, y='Coherence diff', title='coherence', x='Stim phase', color='Subject',
83-
# add_box=True, **kwargs):
84-
# # Keep kwargs!
85-
# assert (x in df) | (x[:-2] in df), f'Column {x} is not present in data: {df.columns}'
86-
# assert color in df
87-
# assert y in df.columns, f'{y} is not in {df.columns}'
88-
#
89-
# chart_d = plot_data(df=df, x=x, y=y, color=color, add_box=add_box, base_chart=alt.Chart(df))
90-
# chart_p = plot_posterior(df, title=title, x=x, base_chart=alt.Chart(df))
91-
# chart = chart_d + chart_p
92-
#
93-
# return chart
94-
95-
9681
def plot_posterior(df=None, title='', x='Stim phase', do_make_change=True, base_chart=None, **kwargs):
9782
assert (df is not None) or (base_chart is not None)
9883
data = base_chart.data if df is None else df
@@ -244,15 +229,15 @@ def fake_spikes_explore(df, df_monster, index_cols):
244229
y=alt.Y(y, scale=alt.Scale(zero=True)),
245230
).properties(width=width, height=240).facet(
246231
# row='mouse:N',
247-
column=alt.Column('mouse')) # .resolve_scale(y='independent')
232+
column=alt.Column('mouse'))
248233

249234
bar = (alt.Chart(data=data_fold_change).mark_bar().encode(y=alt.Y(y, aggregate='mean')) +
250235
alt.Chart(data=data_fold_change).mark_errorbar().encode(y=alt.Y(y, aggregate='stderr'))).encode(
251236
x=alt.X('neuron:N', ),
252237
y=alt.Y(y),
253238
).properties(width=width * 2, height=240).facet(
254239
# row='Inversion:N',
255-
column=alt.Column('mouse')) # .resolve_scale(y='independent')
240+
column=alt.Column('mouse'))
256241

257242
bar_combined = (alt.Chart(data=data_fold_change).mark_bar().encode(y=alt.Y(y, aggregate='mean')) +
258243
alt.Chart(data=data_fold_change).mark_errorbar().encode(y=alt.Y(y, aggregate='stderr'))).encode(

bayes_window/workflow.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
from importlib import reload
33

44
import altair as alt
5+
from sklearn.preprocessing import LabelEncoder
6+
57
from bayes_window import models
68
from bayes_window import utils
79
from bayes_window import visualization
810
from bayes_window.fitting import fit_numpyro
911
from bayes_window.visualization import plot_posterior
10-
from sklearn.preprocessing import LabelEncoder
1112

1213
reload(visualization)
1314
reload(utils)
@@ -54,9 +55,11 @@ def fit_conditions(self, model=models.model_single_lognormal, add_data=True):
5455
do_make_change=False
5556
)
5657

57-
def fit_slopes(self, add_data=True, model=models.model_hier_normal_stim,
58+
def fit_slopes(self, add_data=True, model=models.model_hier_normal_stim, do_make_change='subtract',
5859
plot_index_cols=None):
60+
assert do_make_change in ['subtract', 'divide']
5961
self.bname = 'b_stim_per_condition'
62+
self.do_make_change = do_make_change
6063
if plot_index_cols is None:
6164
plot_index_cols = self.levels[:3]
6265
# By convention, top condition is first in list of levels:
@@ -89,7 +92,7 @@ def fit_slopes(self, add_data=True, model=models.model_hier_normal_stim,
8992
condition_name=top_condition,
9093
b_name=self.bname,
9194
group_name=self.levels[-1],
92-
do_make_change=True,
95+
do_make_change=do_make_change,
9396
do_mean_over_trials=True,
9497
)
9598
else: # Just convert posterior to dataframe
@@ -107,18 +110,21 @@ def plot_posteriors_slopes(self, x=None, color=None, add_box=True, independent_a
107110
# Plot posterior
108111
if hasattr(self, 'data_and_posterior'):
109112
base_chart = alt.Chart(self.data_and_posterior)
110-
chart_p = plot_posterior(title=f'd_{self.y}', x=x, base_chart=base_chart)
113+
chart_p = plot_posterior(title=f'{self.y}', x=x, base_chart=base_chart)
111114
else:
112115
base_chart = alt.Chart(self.data)
113116
add_data = True # Otherwise nothing to do
114117

115118
if add_data:
116119
chart_d = visualization.plot_data(x=x, y=f'{self.y} diff', color=color, add_box=add_box,
117120
base_chart=base_chart)
118-
self.chart = chart_p + chart_d
121+
self.chart = chart_d + chart_p # Not chart_d + chart_p, or bayes means and HPD get scaled independently
119122
else:
120123
self.chart = chart_p
121-
assert self.chart.data is not None
124+
if independent_axes:
125+
self.chart = self.chart.resolve_scale(y='independent')
126+
elif self.do_make_change == 'divide':
127+
warnings.warn('division change and independent axes will lead to separate axis! Upstream bug I think')
122128
return self.chart
123129

124130
# TODO plot_posteriors_slopes and plot_posteriors_no_slope can be one

0 commit comments

Comments
 (0)