Skip to content

Commit

Permalink
Add new MVI and MCMC with new 2D model
Browse files Browse the repository at this point in the history
  • Loading branch information
pedrofale committed Aug 9, 2024
1 parent d791dc4 commit c5fbebb
Show file tree
Hide file tree
Showing 10 changed files with 2,008 additions and 231 deletions.
1,159 changes: 1,159 additions & 0 deletions notebooks/vonmises.ipynb

Large diffs are not rendered by default.

243 changes: 179 additions & 64 deletions scatrex/models/trajectory/node.py

Large diffs are not rendered by default.

30 changes: 26 additions & 4 deletions scatrex/models/trajectory/node_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ def sample_angle(key, mu, log_kappa): # univariate: one sample
sample_angle_val_and_grad = jax.vmap(jax.value_and_grad(sample_angle, argnums=(1,2)), in_axes=(None, 0, 0)) # per-dimension val and grad
mc_sample_angle_val_and_grad = jax.jit(jax.vmap(sample_angle_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad

@jax.jit
def sample_event(key, log_alpha, log_beta): # univariate: one sample
return tfd.Gamma(jnp.exp(log_alpha), jnp.exp(log_beta)).sample(seed=key)
sample_event_val_and_grad = jax.vmap(jax.value_and_grad(sample_event, argnums=(1,2)), in_axes=(None, 0, 0)) # per-dimension val and grad
mc_sample_event_val_and_grad = jax.jit(jax.vmap(sample_event_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad


@jax.jit
def sample_loc(key, mu, log_std): # univariate: one sample
return tfd.Normal(mu, jnp.exp(log_std)).sample(seed=key)
Expand All @@ -28,21 +35,36 @@ def angle_logq(mu, log_kappa):
return jnp.sum(tfd.VonMises(mu, jnp.exp(log_kappa)).entropy())
angle_logq_val_and_grad = jax.jit(jax.value_and_grad(angle_logq, argnums=(0,1))) # Take grad wrt to parameters

@jax.jit
def event_logp(this_event, mean, concentration): # single sample
return jnp.sum(tfd.Gamma(concentration, concentration / mean).log_prob(this_event))
event_logp_val_and_grad = jax.jit(jax.value_and_grad(event_logp, argnums=0)) # Take grad wrt to this
mc_event_logp_val_and_grad = jax.jit(jax.vmap(event_logp_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad

@jax.jit
def event_logq(log_alpha, log_beta):
return jnp.sum(tfd.Gamma(jnp.exp(log_alpha), jnp.exp(log_beta)).entropy())
event_logq_val_and_grad = jax.jit(jax.value_and_grad(event_logq, argnums=(0,1))) # Take grad wrt to parameters

@jax.jit
def loc_logp(this_loc, parent_loc, this_angle, log_std, radius): # single sample
mean = parent_loc + jnp.hstack([jnp.cos(this_angle)*radius, jnp.sin(this_angle)*radius]) # Use samples from parent
return jnp.sum(tfd.Normal(mean, jnp.exp(log_std)).log_prob(this_loc)) # sum across dimensions
loc_logp_val = jax.jit(loc_logp)
mc_loc_logp_val = jax.jit(jax.vmap(loc_logp_val, in_axes=(0,0,0, None, None))) # Multiple sample
mc_loc_logp_val = jax.jit(jax.vmap(loc_logp_val, in_axes=(0,0,0, None, 0))) # Multiple sample

loc_logp_val_and_grad = jax.jit(jax.value_and_grad(loc_logp, argnums=0)) # Take grad wrt to this
mc_loc_logp_val_and_grad = jax.jit(jax.vmap(loc_logp_val_and_grad, in_axes=(0,0,0, None, None))) # Multiple sample value_and_grad
mc_loc_logp_val_and_grad = jax.jit(jax.vmap(loc_logp_val_and_grad, in_axes=(0,0,0, None, 0))) # Multiple sample value_and_grad

loc_logp_val_and_grad_wrt_parent = jax.jit(jax.value_and_grad(loc_logp, argnums=1)) # Take grad wrt to parent
mc_loc_logp_val_and_grad_wrt_parent = jax.jit(jax.vmap(loc_logp_val_and_grad_wrt_parent, in_axes=(0,0,0, None, None))) # Multiple sample value_and_grad
mc_loc_logp_val_and_grad_wrt_parent = jax.jit(jax.vmap(loc_logp_val_and_grad_wrt_parent, in_axes=(0,0,0, None, 0))) # Multiple sample value_and_grad

loc_logp_val_and_grad_wrt_angle = jax.jit(jax.value_and_grad(loc_logp, argnums=2)) # Take grad wrt to angle
mc_loc_logp_val_and_grad_wrt_angle = jax.jit(jax.vmap(loc_logp_val_and_grad_wrt_angle, in_axes=(0,0,0, None, None))) # Multiple sample value_and_grad
mc_loc_logp_val_and_grad_wrt_angle = jax.jit(jax.vmap(loc_logp_val_and_grad_wrt_angle, in_axes=(0,0,0, None, 0))) # Multiple sample value_and_grad

loc_logp_val_and_grad_wrt_event = jax.jit(jax.value_and_grad(loc_logp, argnums=4)) # Take grad wrt to event
mc_loc_logp_val_and_grad_wrt_event = jax.jit(jax.vmap(loc_logp_val_and_grad_wrt_event, in_axes=(0,0,0, None, 0))) # Multiple sample value_and_grad


@jax.jit
def loc_logq(mu, log_std):
Expand Down
11 changes: 6 additions & 5 deletions scatrex/models/trajectory/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,19 @@ def __init__(self, **kwargs):
super(TrajectoryTree, self).__init__(**kwargs)
self.node_constructor = TrajectoryNode

def sample_kernel(self, parent_params, mean_dist=1., angle_concentration=1., loc_variance=.1, seed=42, depth=1., **kwargs):
def sample_kernel(self, parent_params, event_mean=1., event_concentration=1., angle_concentration=1., loc_variance=.1, seed=42, depth=1., **kwargs):
rng = np.random.default_rng(seed=seed)
parent_loc = parent_params[0]
parent_angle = parent_params[1]
angle_concentration = angle_concentration * depth
sampled_angle = rng.vonmises(parent_angle, angle_concentration)
sampled_loc = rng.normal(mean_dist, loc_variance)
sampled_loc = parent_loc + np.array([np.cos(sampled_angle)*np.abs(sampled_loc), np.sin(sampled_angle)*np.abs(sampled_loc)])
return [sampled_loc, sampled_angle]
sampled_event = rng.gamma(event_concentration, event_mean/event_concentration)
loc_mean = parent_loc + np.array([np.cos(sampled_angle)*sampled_event, np.sin(sampled_angle)*sampled_event])
sampled_loc = rng.normal(loc_mean, loc_variance)
return [sampled_loc, sampled_angle, sampled_event]

def sample_root(self, **kwargs):
return [np.array([0., 0.]), 0.]
return [np.array([0., 0.]), 0., 0.]

def get_param_size(self):
return self.tree["param"][0].size
Expand Down
4 changes: 1 addition & 3 deletions scatrex/ntssb/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,11 @@ def get_top_obs(self, q=70, idx=None):
top_obs = idx[np.where(lls > np.percentile(lls, q=q))[0]]
return top_obs

def reset_variational_state(self, **kwargs):
return

def reset_opt(self):
# For adaptive optimization
self.direction_states = self.initialize_direction_states()
self.state_states = self.initialize_state_states()
self.event_states = self.initialize_event_states()

def init_new_node_kernel(self, **kwargs):
return
15 changes: 8 additions & 7 deletions scatrex/ntssb/ntssb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,7 +1201,7 @@ def compute_elbo_batch(self, batch_idx=None):
idx = self.batch_indices[batch_idx]
def descend(root, depth=0, local_contrib=0, global_contrib=0, psi_priors=None):
# Traverse inner TSSB
subtree_ll_contrib, subtree_ass_contrib, subtree_node_contrib = root['node'].compute_elbo(idx)
subtree_ll_contrib, subtree_ass_contrib, subtree_node_contrib = root['node'].compute_elbo_batch(idx)
ll_contrib = subtree_ll_contrib * root['node'].variational_parameters['q_c'][idx]

# Assignments
Expand Down Expand Up @@ -1244,7 +1244,7 @@ def descend(root, depth=0, local_contrib=0, global_contrib=0, psi_priors=None):
# Auxiliary quantities
## Branches
E_log_psi = E_log_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2'])
child['node'].variational_parameters['E_log_phi'] = E_log_psi + sum_E_log_1_psi
child['node'].variational_parameters['E_log_phi'] = root['node'].variational_parameters['E_log_phi'] + E_log_psi + sum_E_log_1_psi
E_log_1_psi = E_log_1_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2'])
sum_E_log_1_psi += E_log_1_psi

Expand Down Expand Up @@ -1306,7 +1306,7 @@ def descend(root, depth=0, local_contrib=0, global_contrib=0, psi_priors=None):
# Auxiliary quantities
## Branches
E_log_psi = E_log_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2'])
child['node'].variational_parameters['E_log_phi'] = E_log_psi + sum_E_log_1_psi
child['node'].variational_parameters['E_log_phi'] = root['node'].variational_parameters['E_log_phi'] + E_log_psi + sum_E_log_1_psi
E_log_1_psi = E_log_1_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2'])
sum_E_log_1_psi += E_log_1_psi

Expand Down Expand Up @@ -1474,7 +1474,7 @@ def descend(root, local_grads=None):
sum_E_log_1_psi = 0.
for child in root['children']:
E_log_psi = E_log_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2'])
child['node'].variational_parameters['E_log_phi'] = E_log_psi + sum_E_log_1_psi
child['node'].variational_parameters['E_log_phi'] = root['node'].variational_parameters['E_log_phi'] + E_log_psi + sum_E_log_1_psi
E_log_1_psi = E_log_1_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2'])
sum_E_log_1_psi += E_log_1_psi

Expand Down Expand Up @@ -2888,13 +2888,14 @@ def descend(root):
descend(child)
descend(self.root)

def show_tree(self, **kwargs):
def show_tree(self, ax=None, **kwargs):
self.set_learned_parameters()
self.set_node_names()
self.set_expected_weights()
self.assign_samples()
self.set_ntssb_colors()
tree = self.get_param_dict()
plt.figure(figsize=(4,4))
ax = plt.gca()
if ax is None:
plt.figure(figsize=(4,4))
ax = plt.gca()
plot_full_tree(tree, ax=ax, node_size=101, **kwargs)
Loading

0 comments on commit c5fbebb

Please sign in to comment.