Skip to content

Commit

Permalink
Remove memoization when not needed
Browse files Browse the repository at this point in the history
  • Loading branch information
pedrofale committed Apr 24, 2024
1 parent c9fd280 commit c08312f
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions scatrex/scatrex.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def learn_roots_and_noise(self, n_iters=10, n_epochs=100, n_merges=10, n_swaps=1
self.ntssb.reset_variational_kernels(log_std=0.)
self.ntssb.sample_variational_distributions(n_samples=mc_samples)
self.ntssb.update_sufficient_statistics()
self.ntssb.learn_roots(n_epochs, memoized=memoized, mc_samples=mc_samples, step_size=step_size, return_trace=False)
self.ntssb.learn_roots(n_epochs, memoized=False, mc_samples=mc_samples, step_size=step_size, return_trace=False)

# Update assignments
self.ntssb.update_local_params(jax.random.PRNGKey(seed), update_ass=True, update_globals=False)
Expand All @@ -219,8 +219,8 @@ def learn_roots_and_noise(self, n_iters=10, n_epochs=100, n_merges=10, n_swaps=1
searcher.tree.set_node_hyperparams(direction_shape=1.)
searcher.tree.sample_variational_distributions(n_samples=mc_samples)
searcher.tree.reset_sufficient_statistics()
searcher.tree.update_sufficient_statistics()
searcher.tree.compute_elbo(memoized=memoized)
for batch_idx in range(len(searcher.tree.batch_indices)):
searcher.tree.update_sufficient_statistics(batch_idx=batch_idx)
searcher.proposed_tree = deepcopy(searcher.tree)
searcher.run_search(n_iters=n_iters, n_epochs=n_epochs, mc_samples=mc_samples, step_size=step_size,
memoized=memoized, seed=seed, update_roots=True)
Expand Down Expand Up @@ -295,7 +295,9 @@ def learn_tree(self, n_iters=10, n_epochs=100, memoized=True, mc_samples=10, ste
searcher = StructureSearch(self.ntssb)
searcher.tree.set_tssb_params(dp_alpha=dp_alpha, dp_gamma=dp_gamma)
searcher.tree.sample_variational_distributions(n_samples=mc_samples)
searcher.tree.update_sufficient_statistics()
searcher.tree.reset_sufficient_statistics()
for batch_idx in range(len(searcher.tree.batch_indices)):
searcher.tree.update_sufficient_statistics(batch_idx=batch_idx)
searcher.tree.compute_elbo(memoized=memoized)
searcher.proposed_tree = deepcopy(searcher.tree)
searcher.run_search(n_iters=n_iters, n_epochs=n_epochs, mc_samples=mc_samples, step_size=step_size,
Expand Down

0 comments on commit c08312f

Please sign in to comment.