Skip to content

Commit 41dfe7d

Browse files
committed
Add seed parameter to stable_random_distribution
1 parent 7a10a2f commit 41dfe7d

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

micall/utils/stable_random_distribution.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,17 @@
33
import numpy as np
44

55

6-
def stable_random_distribution(maximum: int) -> Iterator[int]:
6+
def stable_random_distribution(maximum: int, seed: int = 42) -> Iterator[int]:
77
n = maximum
8+
rng = np.random.default_rng(seed)
89

910
weights = np.zeros(n) + 1
1011
forward = np.arange(1, n + 1)
1112
backwards = np.arange(n, 0, -1)
1213

1314
while True:
1415
probabilities = weights / weights.sum()
15-
index = np.random.choice(n, p=probabilities)
16+
index = rng.choice(n, p=probabilities)
1617
yield index
1718
weights[:index] += forward[:index]
1819
weights[index:] += backwards[index:]

0 commit comments

Comments
 (0)