Skip to content

Commit

Permalink
Address some Jax random key issues
Browse files Browse the repository at this point in the history
Signed-off-by: zethson <[email protected]>
  • Loading branch information
Zethson committed Feb 8, 2024
1 parent b8f7359 commit 3437ff0
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 7 deletions.
9 changes: 4 additions & 5 deletions pertpy/tools/_coda/_base_coda.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import numpyro as npy
import toytree as tt
from ete3 import Tree
from jax._src.prng import PRNGKeyArray
from jax._src.typing import Array
from matplotlib.axes import Axes

Expand Down Expand Up @@ -188,7 +187,7 @@ def __run_mcmc(
self,
sample_adata: AnnData,
kernel: npy.infer.mcmc.MCMCKernel,
rng_key: Array | PRNGKeyArray,
rng_key: Array,
copy: bool = False,
*args,
**kwargs,
Expand Down Expand Up @@ -304,7 +303,7 @@ def run_nuts(
if copy:
sample_adata = sample_adata.copy()

rng_key_array = random.PRNGKey(rng_key)
rng_key_array = random.key(rng_key)
sample_adata.uns["scCODA_params"]["mcmc"]["rng_key"] = np.array(rng_key_array)

# Set up NUTS kernel
Expand Down Expand Up @@ -367,10 +366,10 @@ def run_hmc(
# Set rng key if needed
if rng_key is None:
rng = np.random.default_rng()
rng_key = random.PRNGKey(rng.integers(0, 10000))
rng_key = random.key(rng.integers(0, 10000))
sample_adata.uns["scCODA_params"]["mcmc"]["rng_key"] = rng_key
else:
rng_key = random.PRNGKey(rng_key)
rng_key = random.key(rng_key)

# Set up HMC kernel
sample_adata = self.set_init_mcmc_states(
Expand Down
2 changes: 1 addition & 1 deletion pertpy/tools/_coda/_sccoda.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def make_arviz( # type: ignore

if rng_key is None:
rng = np.random.default_rng()
rng_key = random.PRNGKey(rng.integers(0, 10000))
rng_key = random.key(rng.integers(0, 10000))

if use_posterior_predictive:
posterior_predictive = Predictive(self.model, self.mcmc.get_samples())(
Expand Down
2 changes: 1 addition & 1 deletion pertpy/tools/_coda/_tasccoda.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ def make_arviz( # type: ignore

if rng_key is None:
rng = np.random.default_rng()
rng_key = random.PRNGKey(rng.integers(0, 10000))
rng_key = random.key(rng.integers(0, 10000))

if use_posterior_predictive:
posterior_predictive = Predictive(self.model, self.mcmc.get_samples())(
Expand Down

0 comments on commit 3437ff0

Please sign in to comment.