Skip to content

Commit a1737f7

Browse files
committed
Don't be so permissive on tree learning
1 parent f2af21c commit a1737f7

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

Diff for: scatrex/scatrex.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -211,12 +211,13 @@ def learn_roots_and_noise(self, n_iters=10, n_epochs=100, n_merges=10, n_swaps=1
211211
self.ntssb.learn_roots(n_epochs, memoized=False, mc_samples=mc_samples, step_size=step_size, return_trace=False)
212212

213213
# Update assignments
214-
# self.ntssb.update_local_params(jax.random.PRNGKey(seed), update_ass=True, update_globals=False)
214+
if update_outer_ass:
215+
self.ntssb.update_local_params(jax.random.PRNGKey(seed), update_ass=True, update_globals=False)
215216

216-
# Learn a tree with root updates on noiseless data (over-cluster) and more permissive prior on tree
217+
# Learn a tree with root updates on noiseless data (over-cluster)
217218
searcher = StructureSearch(self.ntssb)
218-
searcher.tree.set_tssb_params(dp_alpha=1., dp_gamma=1.,)
219-
searcher.tree.set_node_hyperparams(direction_shape=1.)
219+
searcher.tree.set_tssb_params(dp_alpha=.01, dp_gamma=.01,)
220+
searcher.tree.set_node_hyperparams(direction_shape=.1)
220221
searcher.tree.sample_variational_distributions(n_samples=mc_samples)
221222
searcher.tree.reset_sufficient_statistics()
222223
for batch_idx in range(len(searcher.tree.batch_indices)):

0 commit comments

Comments
 (0)