Skip to content

Commit

Permalink
Vectorize beta
Browse files Browse the repository at this point in the history
  • Loading branch information
pedrofale committed Apr 22, 2024
1 parent db98972 commit 05278c5
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions scatrex/models/cna/node_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,14 @@ def sample_gene_scales(key, log_alpha, log_beta): # G
def gene_scales_logp(sample, log_alpha, log_beta): # single sample
return tfd.Gamma(jnp.exp(log_alpha), jnp.exp(log_beta)).log_prob(sample) # sum across obs and dimensions
univ_gene_scales_logp_val_and_grad = jax.jit(jax.value_and_grad(gene_scales_logp, argnums=0)) # Take grad wrt to sample (G,)
gene_scales_logp_val_and_grad = jax.jit(jax.vmap(univ_gene_scales_logp_val_and_grad, in_axes=(0,None,None))) # Take grad wrt to sample (G,)
gene_scales_logp_val_and_grad = jax.jit(jax.vmap(univ_gene_scales_logp_val_and_grad, in_axes=(0,None,0))) # Take grad wrt to sample (G,)
mc_gene_scales_logp_val_and_grad = jax.jit(jax.vmap(gene_scales_logp_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad: SxG

@jax.jit
def gene_scales_logq(log_alpha, log_beta):
return tfd.Gamma(jnp.exp(log_alpha), jnp.exp(log_beta)).entropy()
gene_scales_logq_val_and_grad = jax.jit(jax.vmap(jax.value_and_grad(gene_scales_logq, argnums=(0,1)), in_axes=(0,0))) # Take grad wrt to parameters


# Factor variances
@jax.jit
def sample_factor_precisions(key, log_alpha, log_beta): # Kx1
Expand Down

0 comments on commit 05278c5

Please sign in to comment.