diff --git a/pertpy/tools/_coda/_base_coda.py b/pertpy/tools/_coda/_base_coda.py index 03b00e55..544f6bda 100644 --- a/pertpy/tools/_coda/_base_coda.py +++ b/pertpy/tools/_coda/_base_coda.py @@ -30,7 +30,6 @@ import numpyro as npy import toytree as tt from ete3 import Tree - from jax._src.prng import PRNGKeyArray from jax._src.typing import Array from matplotlib.axes import Axes @@ -188,7 +187,7 @@ def __run_mcmc( self, sample_adata: AnnData, kernel: npy.infer.mcmc.MCMCKernel, - rng_key: Array | PRNGKeyArray, + rng_key: Array, copy: bool = False, *args, **kwargs, @@ -304,7 +303,7 @@ def run_nuts( if copy: sample_adata = sample_adata.copy() - rng_key_array = random.PRNGKey(rng_key) + rng_key_array = random.key(rng_key) sample_adata.uns["scCODA_params"]["mcmc"]["rng_key"] = np.array(rng_key_array) # Set up NUTS kernel @@ -367,10 +366,10 @@ def run_hmc( # Set rng key if needed if rng_key is None: rng = np.random.default_rng() - rng_key = random.PRNGKey(rng.integers(0, 10000)) + rng_key = random.key(rng.integers(0, 10000)) sample_adata.uns["scCODA_params"]["mcmc"]["rng_key"] = rng_key else: - rng_key = random.PRNGKey(rng_key) + rng_key = random.key(rng_key) # Set up HMC kernel sample_adata = self.set_init_mcmc_states( diff --git a/pertpy/tools/_coda/_sccoda.py b/pertpy/tools/_coda/_sccoda.py index 3ee7b75d..7f734ef2 100644 --- a/pertpy/tools/_coda/_sccoda.py +++ b/pertpy/tools/_coda/_sccoda.py @@ -365,7 +365,7 @@ def make_arviz( # type: ignore if rng_key is None: rng = np.random.default_rng() - rng_key = random.PRNGKey(rng.integers(0, 10000)) + rng_key = random.key(rng.integers(0, 10000)) if use_posterior_predictive: posterior_predictive = Predictive(self.model, self.mcmc.get_samples())( diff --git a/pertpy/tools/_coda/_tasccoda.py b/pertpy/tools/_coda/_tasccoda.py index 6e7aa393..31166bad 100644 --- a/pertpy/tools/_coda/_tasccoda.py +++ b/pertpy/tools/_coda/_tasccoda.py @@ -552,7 +552,7 @@ def make_arviz( # type: ignore if rng_key is None: rng = np.random.default_rng() - rng_key = random.PRNGKey(rng.integers(0, 10000)) + rng_key = random.key(rng.integers(0, 10000)) if use_posterior_predictive: posterior_predictive = Predictive(self.model, self.mcmc.get_samples())(