Skip to content

Commit 1a63b41

Browse files
committed
Improve stable random distribution
1 parent df3786e commit 1a63b41

File tree

2 files changed

+11
-14
lines changed

2 files changed

+11
-14
lines changed

micall/tests/test_stable_random_distribution.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ def test_indices_in_range():
2222
def test_bounds_are_reachable():
2323
"""Test that both min and max-1 can be generated."""
2424

25-
high = 999
26-
rng = random.Random(123456)
25+
high = 200
26+
rng = random.Random(123)
2727
gen = stable_random_distribution(high, rng=rng)
2828
lst = islice(gen, 1000)
2929

@@ -154,4 +154,4 @@ def test_fill_domain_speed():
154154
if stable_steps < uniform_steps:
155155
wins += 1
156156

157-
assert wins / trials > 0.90
157+
assert wins / trials > 0.999
+8-11
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
from typing import Iterator, Optional
1+
from typing import Iterator, Optional, Sequence
22

33
import random
44
import numpy as np
55

6-
DUPLICATION_FACTOR = 1
7-
86

97
def stable_random_distribution(high: int,
108
rng: Optional[random.Random] = None,
@@ -16,14 +14,13 @@ def stable_random_distribution(high: int,
1614
if rng is None:
1715
rng = random.Random()
1816

19-
maximum = high - 1
20-
block = np.arange(high)
21-
population = np.concatenate([block] * DUPLICATION_FACTOR, axis=0)
22-
23-
assert len(population) == DUPLICATION_FACTOR * len(block)
17+
population = np.arange(high)
18+
weights = np.zeros(high) + 1
2419

2520
while True:
26-
choice = rng.randint(0, maximum)
27-
index = population[choice]
21+
pweights: Sequence[float] = weights # type: ignore
22+
index = rng.choices(population, weights=pweights)[0]
2823
yield index
29-
population[choice] = rng.randint(0, maximum)
24+
weights[index] /= 2
25+
if weights[index] < 0.01:
26+
weights += 1

0 commit comments

Comments
 (0)