Skip to content

Commit 4fb8451

Browse files
committed
Add more tests for stable_random_distribution
1 parent 98069bf commit 4fb8451

File tree

2 files changed

+80
-6
lines changed

2 files changed

+80
-6
lines changed

micall/tests/test_stable_random_distribution.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11

22
from micall.utils.stable_random_distribution import stable_random_distribution
3+
import numpy as np
34

45

56
def test_indices_in_range():
@@ -12,3 +13,76 @@ def test_indices_in_range():
1213
for _ in range(1000):
1314
idx = next(gen)
1415
assert 0 <= idx < maximum, f"Index {idx} out of range [0,{maximum})"
16+
17+
18+
def test_deterministic_output_with_seed():
19+
"""
20+
Test that the generator produces the same sequence when
21+
re-seeded with the same seed.
22+
"""
23+
24+
maximum = 15
25+
seed = 456
26+
gen1 = stable_random_distribution(maximum, seed=seed)
27+
gen2 = stable_random_distribution(maximum, seed=seed)
28+
29+
# Compare the first 50 generated values.
30+
values1 = [next(gen1) for _ in range(50)]
31+
values2 = [next(gen2) for _ in range(50)]
32+
assert values1 == values2, \
33+
"Generators with the same seed produced different outputs."
34+
35+
36+
def test_different_seeds_differ():
37+
"""
38+
A sanity check that different seeds usually lead to a different sequence.
39+
"""
40+
41+
maximum = 15
42+
gen1 = stable_random_distribution(maximum, seed=789)
43+
gen2 = stable_random_distribution(maximum, seed=987)
44+
45+
# Compare the first 50 generated values: while not guaranteed to
46+
# be different, it is extremely unlikely that the two sequences
47+
# are identical.
48+
values1 = [next(gen1) for _ in range(50)]
49+
values2 = [next(gen2) for _ in range(50)]
50+
51+
assert values1 != values2, \
52+
"Generators with different seeds produced identical sequences."
53+
54+
55+
def test_fair_distribution_behavior():
56+
"""
57+
Test that the stable_random_distribution leads to outputs that are
58+
more 'spread out' than a simple uniform generator.
59+
60+
Idea:
61+
- Generate a long sequence from our generator.
62+
- Compute the average absolute difference (jump) between indices.
63+
- Do the same for a uniformly random generator over the same range.
64+
- With the adaptive update, values should tend to be farther apart.
65+
"""
66+
67+
maximum = 1_000
68+
num_samples = 10_000
69+
for seed in range(20):
70+
# Gather samples from our generator.
71+
gen = stable_random_distribution(maximum, seed=seed)
72+
samples = np.array([next(gen) for _ in range(num_samples)])
73+
diff_stable = np.abs(np.diff(np.sort(samples))) ** 2
74+
avg_diff_stable = diff_stable.mean()
75+
76+
# For comparison, generate num_samples indices uniformly at random.
77+
rng = np.random.default_rng(seed)
78+
uniform_samples = rng.choice(maximum, size=num_samples)
79+
diff_uniform = np.abs(np.diff(np.sort(uniform_samples))) ** 2
80+
avg_diff_uniform = diff_uniform.mean()
81+
82+
# Our expectation: the stable generator should have larger jumps
83+
# on average. We include a tolerance, because both sequences are
84+
# random.
85+
assert avg_diff_stable >= avg_diff_uniform, (
86+
f"Expected stable generator to have a higher average jump than a uniform generator: "
87+
f"stable {avg_diff_stable} vs uniform {avg_diff_uniform}"
88+
)

micall/utils/stable_random_distribution.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,13 @@ def stable_random_distribution(maximum: int, seed: int = 42) -> Iterator[int]:
1212
rng = random.Random(seed)
1313

1414
population = np.arange(n)
15-
forward = np.arange(1, n + 1) ** 0.5
15+
forward = np.arange(1, n + 1)
1616
backwards = np.copy(np.flip(forward))
17-
np_weights = np.zeros(n) + 0.1
17+
np_weights = np.zeros(n) + 1
1818

1919
while True:
20-
weights: Sequence[float] = 1 - np_weights # type: ignore
21-
indexes = rng.choices(population=population, weights=weights)
22-
index = indexes[0]
20+
weights: Sequence[float] = np_weights # type: ignore
21+
index = rng.choices(population=population, weights=weights)[0]
2322
yield index
2423

2524
if index == 0:
@@ -28,4 +27,5 @@ def stable_random_distribution(maximum: int, seed: int = 42) -> Iterator[int]:
2827
np_weights[:(index + 1)] += forward[-(index + 1):]
2928
np_weights[(index + 1):] += backwards[1:-index]
3029

31-
np_weights /= np_weights.sum()
30+
np_weights -= np_weights.min()
31+
np_weights = (1 + np_weights.max()) - np_weights

0 commit comments

Comments
 (0)