From 3437ff0189ef85927a88e25bdc65fd5467557468 Mon Sep 17 00:00:00 2001 From: zethson Date: Thu, 8 Feb 2024 12:01:05 +0100 Subject: [PATCH] Address some Jax random key issues Signed-off-by: zethson --- pertpy/tools/_coda/_base_coda.py | 9 ++++----- pertpy/tools/_coda/_sccoda.py | 2 +- pertpy/tools/_coda/_tasccoda.py | 2 +- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/pertpy/tools/_coda/_base_coda.py b/pertpy/tools/_coda/_base_coda.py index b7a68e11..9a296b87 100644 --- a/pertpy/tools/_coda/_base_coda.py +++ b/pertpy/tools/_coda/_base_coda.py @@ -30,7 +30,6 @@ import numpyro as npy import toytree as tt from ete3 import Tree - from jax._src.prng import PRNGKeyArray from jax._src.typing import Array from matplotlib.axes import Axes @@ -188,7 +187,7 @@ def __run_mcmc( self, sample_adata: AnnData, kernel: npy.infer.mcmc.MCMCKernel, - rng_key: Array | PRNGKeyArray, + rng_key: Array, copy: bool = False, *args, **kwargs, @@ -304,7 +303,7 @@ def run_nuts( if copy: sample_adata = sample_adata.copy() - rng_key_array = random.PRNGKey(rng_key) + rng_key_array = random.key(rng_key) sample_adata.uns["scCODA_params"]["mcmc"]["rng_key"] = np.array(rng_key_array) # Set up NUTS kernel @@ -367,10 +366,10 @@ def run_hmc( # Set rng key if needed if rng_key is None: rng = np.random.default_rng() - rng_key = random.PRNGKey(rng.integers(0, 10000)) + rng_key = random.key(rng.integers(0, 10000)) sample_adata.uns["scCODA_params"]["mcmc"]["rng_key"] = rng_key else: - rng_key = random.PRNGKey(rng_key) + rng_key = random.key(rng_key) # Set up HMC kernel sample_adata = self.set_init_mcmc_states( diff --git a/pertpy/tools/_coda/_sccoda.py b/pertpy/tools/_coda/_sccoda.py index 3ee7b75d..7f734ef2 100644 --- a/pertpy/tools/_coda/_sccoda.py +++ b/pertpy/tools/_coda/_sccoda.py @@ -365,7 +365,7 @@ def make_arviz( # type: ignore if rng_key is None: rng = np.random.default_rng() - rng_key = random.PRNGKey(rng.integers(0, 10000)) + rng_key = random.key(rng.integers(0, 10000)) if use_posterior_predictive: posterior_predictive = Predictive(self.model, self.mcmc.get_samples())( diff --git a/pertpy/tools/_coda/_tasccoda.py b/pertpy/tools/_coda/_tasccoda.py index 6e7aa393..31166bad 100644 --- a/pertpy/tools/_coda/_tasccoda.py +++ b/pertpy/tools/_coda/_tasccoda.py @@ -552,7 +552,7 @@ def make_arviz( # type: ignore if rng_key is None: rng = np.random.default_rng() - rng_key = random.PRNGKey(rng.integers(0, 10000)) + rng_key = random.key(rng.integers(0, 10000)) if use_posterior_predictive: posterior_predictive = Predictive(self.model, self.mcmc.get_samples())(