diff --git a/scatrex/scatrex.py b/scatrex/scatrex.py index b3f222a..8bffd60 100644 --- a/scatrex/scatrex.py +++ b/scatrex/scatrex.py @@ -208,7 +208,7 @@ def learn_roots_and_noise(self, n_iters=10, n_epochs=100, n_merges=10, n_swaps=1 self.ntssb.reset_variational_kernels(log_std=0.) self.ntssb.sample_variational_distributions(n_samples=mc_samples) self.ntssb.update_sufficient_statistics() - self.ntssb.learn_roots(n_epochs, memoized=memoized, mc_samples=mc_samples, step_size=step_size, return_trace=False) + self.ntssb.learn_roots(n_epochs, memoized=False, mc_samples=mc_samples, step_size=step_size, return_trace=False) # Update assignments self.ntssb.update_local_params(jax.random.PRNGKey(seed), update_ass=True, update_globals=False) @@ -219,8 +219,8 @@ def learn_roots_and_noise(self, n_iters=10, n_epochs=100, n_merges=10, n_swaps=1 searcher.tree.set_node_hyperparams(direction_shape=1.) searcher.tree.sample_variational_distributions(n_samples=mc_samples) searcher.tree.reset_sufficient_statistics() - searcher.tree.update_sufficient_statistics() - searcher.tree.compute_elbo(memoized=memoized) + for batch_idx in range(len(searcher.tree.batch_indices)): + searcher.tree.update_sufficient_statistics(batch_idx=batch_idx) searcher.proposed_tree = deepcopy(searcher.tree) searcher.run_search(n_iters=n_iters, n_epochs=n_epochs, mc_samples=mc_samples, step_size=step_size, memoized=memoized, seed=seed, update_roots=True) @@ -295,7 +295,9 @@ def learn_tree(self, n_iters=10, n_epochs=100, memoized=True, mc_samples=10, ste searcher = StructureSearch(self.ntssb) searcher.tree.set_tssb_params(dp_alpha=dp_alpha, dp_gamma=dp_gamma) searcher.tree.sample_variational_distributions(n_samples=mc_samples) - searcher.tree.update_sufficient_statistics() + searcher.tree.reset_sufficient_statistics() + for batch_idx in range(len(searcher.tree.batch_indices)): + searcher.tree.update_sufficient_statistics(batch_idx=batch_idx) searcher.tree.compute_elbo(memoized=memoized) searcher.proposed_tree = deepcopy(searcher.tree) searcher.run_search(n_iters=n_iters, n_epochs=n_epochs, mc_samples=mc_samples, step_size=step_size,