Skip to content

Commit c069d5f

Browse files
committed
Revert "Remove stable_random_distribution"
This reverts commit 2e1656f.
1 parent c287614 commit c069d5f

File tree

2 files changed

+132
-0
lines changed

2 files changed

+132
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
2+
from micall.utils.stable_random_distribution import stable_random_distribution
3+
import numpy as np
4+
from itertools import islice
5+
6+
7+
def test_indices_in_range():
8+
"""Test that each index generated is within the range [0, maximum)."""
9+
10+
maximum = 10
11+
gen = stable_random_distribution(maximum, seed=123)
12+
# Grab a bunch of values from the infinite generator
13+
14+
for _ in range(1000):
15+
idx = next(gen)
16+
assert 0 <= idx < maximum, f"Index {idx} out of range [0,{maximum})"
17+
18+
19+
def test_bounds_are_reachable():
20+
"""Test that both min and max-1 can be generated."""
21+
22+
maximum = 999
23+
gen = stable_random_distribution(maximum, seed=123)
24+
lst = islice(gen, 1000)
25+
26+
assert 0 in lst
27+
assert (maximum-1) in lst
28+
29+
30+
def test_deterministic_output_with_seed():
31+
"""
32+
Test that the generator produces the same sequence when
33+
re-seeded with the same seed.
34+
"""
35+
36+
maximum = 15
37+
seed = 456
38+
gen1 = stable_random_distribution(maximum, seed=seed)
39+
gen2 = stable_random_distribution(maximum, seed=seed)
40+
41+
# Compare the first 50 generated values.
42+
values1 = [next(gen1) for _ in range(50)]
43+
values2 = [next(gen2) for _ in range(50)]
44+
assert values1 == values2, \
45+
"Generators with the same seed produced different outputs."
46+
47+
48+
def test_different_seeds_differ():
49+
"""
50+
A sanity check that different seeds usually lead to a different sequence.
51+
"""
52+
53+
maximum = 15
54+
gen1 = stable_random_distribution(maximum, seed=789)
55+
gen2 = stable_random_distribution(maximum, seed=987)
56+
57+
# Compare the first 50 generated values: while not guaranteed to
58+
# be different, it is extremely unlikely that the two sequences
59+
# are identical.
60+
values1 = [next(gen1) for _ in range(50)]
61+
values2 = [next(gen2) for _ in range(50)]
62+
63+
assert values1 != values2, \
64+
"Generators with different seeds produced identical sequences."
65+
66+
67+
def test_fair_distribution_behavior():
68+
"""
69+
Test that the stable_random_distribution leads to outputs that are
70+
more 'spread out' than a simple uniform generator.
71+
72+
Idea:
73+
- Generate a long sequence from our generator.
74+
- Compute the average absolute difference (jump) between indices.
75+
- Do the same for a uniformly random generator over the same range.
76+
- With the adaptive update, values should tend to be farther apart.
77+
"""
78+
79+
maximum = 1_000
80+
num_samples = 3_000
81+
for seed in range(100):
82+
# Gather samples from our generator.
83+
gen = stable_random_distribution(maximum, seed=seed)
84+
samples = np.array([next(gen) for _ in range(num_samples)])
85+
diff_stable = np.abs(np.diff(np.sort(samples))) ** 2
86+
avg_diff_stable = diff_stable.mean()
87+
88+
# For comparison, generate num_samples indices uniformly at random.
89+
rng = np.random.default_rng(seed)
90+
uniform_samples = rng.choice(maximum, size=num_samples)
91+
diff_uniform = np.abs(np.diff(np.sort(uniform_samples))) ** 2
92+
avg_diff_uniform = diff_uniform.mean()
93+
94+
# Our expectation: the stable generator should have larger jumps
95+
# on average. We include a tolerance, because both sequences are
96+
# random.
97+
assert avg_diff_stable >= avg_diff_uniform, (
98+
f"Expected stable generator to have a higher average jump than a uniform generator: "
99+
f"stable {avg_diff_stable} vs uniform {avg_diff_uniform}"
100+
)
+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from typing import Iterator, Sequence
2+
3+
import numpy as np
4+
import random
5+
6+
7+
def stable_random_distribution(maximum: int, seed: int = 42) -> Iterator[int]:
8+
if maximum <= 0:
9+
return
10+
11+
n = maximum
12+
rng = random.Random(seed)
13+
14+
population = np.arange(n)
15+
forward = np.arange(1, n + 1)
16+
backwards = np.copy(np.flip(forward))
17+
np_weights = np.zeros(n)
18+
19+
while True:
20+
top = np.max(np_weights) + 1
21+
weights: Sequence[float] = top - np_weights # type: ignore
22+
index = rng.choices(population=population, weights=weights)[0]
23+
yield index
24+
25+
if index == 0:
26+
np_weights += backwards
27+
else:
28+
np_weights[:(index + 1)] += forward[-(index + 1):]
29+
np_weights[(index + 1):] += backwards[1:-index]
30+
31+
# Prevent overflow.
32+
np_weights -= np_weights.min()

0 commit comments

Comments
 (0)