Skip to content

Commit 9ccbdf9

Browse files
committed
Reimplement stable_random_distribution
1 parent c069d5f commit 9ccbdf9

File tree

2 files changed

+80
-38
lines changed

2 files changed

+80
-38
lines changed

Diff for: micall/tests/test_stable_random_distribution.py

+69-16
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,47 @@
22
from micall.utils.stable_random_distribution import stable_random_distribution
33
import numpy as np
44
from itertools import islice
5+
from typing import Set
56

67

78
def test_indices_in_range():
8-
"""Test that each index generated is within the range [0, maximum)."""
9+
"""Test that each index generated is within the range [0, high)."""
910

10-
maximum = 10
11-
gen = stable_random_distribution(maximum, seed=123)
11+
high = 10
12+
gen = stable_random_distribution(high, seed=123)
1213
# Grab a bunch of values from the infinite generator
1314

1415
for _ in range(1000):
1516
idx = next(gen)
16-
assert 0 <= idx < maximum, f"Index {idx} out of range [0,{maximum})"
17+
assert 0 <= idx < high, f"Index {idx} out of range [0,{high})"
1718

1819

1920
def test_bounds_are_reachable():
2021
"""Test that both min and max-1 can be generated."""
2122

22-
maximum = 999
23-
gen = stable_random_distribution(maximum, seed=123)
23+
high = 999
24+
gen = stable_random_distribution(high, seed=123)
2425
lst = islice(gen, 1000)
2526

2627
assert 0 in lst
27-
assert (maximum-1) in lst
28+
assert (high-1) in lst
29+
30+
31+
def test_everything_is_reachable():
32+
"""Test that all numbers in the range [0, max-1) can be generated."""
33+
34+
high = 30
35+
fun = stable_random_distribution
36+
# def fun(high, seed):
37+
# import random
38+
# while True:
39+
# yield random.randint(0, high)
40+
41+
gen = fun(high, seed=123)
42+
lst = tuple(map(int, islice(gen, 1000)))
43+
44+
for x in range(high):
45+
assert x in lst
2846

2947

3048
def test_deterministic_output_with_seed():
@@ -33,10 +51,10 @@ def test_deterministic_output_with_seed():
3351
re-seeded with the same seed.
3452
"""
3553

36-
maximum = 15
54+
high = 15
3755
seed = 456
38-
gen1 = stable_random_distribution(maximum, seed=seed)
39-
gen2 = stable_random_distribution(maximum, seed=seed)
56+
gen1 = stable_random_distribution(high, seed=seed)
57+
gen2 = stable_random_distribution(high, seed=seed)
4058

4159
# Compare the first 50 generated values.
4260
values1 = [next(gen1) for _ in range(50)]
@@ -50,9 +68,9 @@ def test_different_seeds_differ():
5068
A sanity check that different seeds usually lead to a different sequence.
5169
"""
5270

53-
maximum = 15
54-
gen1 = stable_random_distribution(maximum, seed=789)
55-
gen2 = stable_random_distribution(maximum, seed=987)
71+
high = 15
72+
gen1 = stable_random_distribution(high, seed=789)
73+
gen2 = stable_random_distribution(high, seed=987)
5674

5775
# Compare the first 50 generated values: while not guaranteed to
5876
# be different, it is extremely unlikely that the two sequences
@@ -76,18 +94,18 @@ def test_fair_distribution_behavior():
7694
- With the adaptive update, values should tend to be farther apart.
7795
"""
7896

79-
maximum = 1_000
97+
high = 100
8098
num_samples = 3_000
8199
for seed in range(100):
82100
# Gather samples from our generator.
83-
gen = stable_random_distribution(maximum, seed=seed)
101+
gen = stable_random_distribution(high, seed=seed)
84102
samples = np.array([next(gen) for _ in range(num_samples)])
85103
diff_stable = np.abs(np.diff(np.sort(samples))) ** 2
86104
avg_diff_stable = diff_stable.mean()
87105

88106
# For comparison, generate num_samples indices uniformly at random.
89107
rng = np.random.default_rng(seed)
90-
uniform_samples = rng.choice(maximum, size=num_samples)
108+
uniform_samples = rng.choice(high, size=num_samples)
91109
diff_uniform = np.abs(np.diff(np.sort(uniform_samples))) ** 2
92110
avg_diff_uniform = diff_uniform.mean()
93111

@@ -98,3 +116,38 @@ def test_fair_distribution_behavior():
98116
f"Expected stable generator to have a higher average jump than a uniform generator: "
99117
f"stable {avg_diff_stable} vs uniform {avg_diff_uniform}"
100118
)
119+
120+
121+
def test_fill_domain_speed():
122+
"""
123+
Test that the stable_random_distribution fill out the domain
124+
quicker than a simple uniform generator.
125+
126+
Idea is similar to the previous test.
127+
"""
128+
129+
high = 100
130+
trials = 100
131+
wins = 0
132+
133+
for seed in range(trials):
134+
# Gather samples from our generator.
135+
gen = stable_random_distribution(high, seed=seed)
136+
stable_bucket: Set[int] = set()
137+
stable_steps = 0
138+
while len(stable_bucket) < high:
139+
stable_bucket.add(next(gen))
140+
stable_steps += 1
141+
142+
# For comparison, generate num_samples indices uniformly at random.
143+
rng = np.random.default_rng(seed)
144+
uniform_bucket: Set[int] = set()
145+
uniform_steps = 0
146+
while len(uniform_bucket) < high:
147+
uniform_bucket.add(rng.integers(0, high))
148+
uniform_steps += 1
149+
150+
if stable_steps < uniform_steps:
151+
wins += 1
152+
153+
assert wins / trials > 0.85

Diff for: micall/utils/stable_random_distribution.py

+11-22
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,21 @@
1-
from typing import Iterator, Sequence
1+
from typing import Iterator
22

33
import numpy as np
4-
import random
54

5+
DUPLICATION_FACTOR = 1
66

7-
def stable_random_distribution(maximum: int, seed: int = 42) -> Iterator[int]:
8-
if maximum <= 0:
7+
8+
def stable_random_distribution(high: int, seed: int = 42) -> Iterator[int]:
9+
if high <= 0:
910
return
1011

11-
n = maximum
12-
rng = random.Random(seed)
12+
rng = np.random.default_rng(seed)
13+
block = np.arange(high)
14+
population = np.concatenate([block] * DUPLICATION_FACTOR, axis=0)
1315

14-
population = np.arange(n)
15-
forward = np.arange(1, n + 1)
16-
backwards = np.copy(np.flip(forward))
17-
np_weights = np.zeros(n)
16+
assert len(population) == DUPLICATION_FACTOR * len(block)
1817

1918
while True:
20-
top = np.max(np_weights) + 1
21-
weights: Sequence[float] = top - np_weights # type: ignore
22-
index = rng.choices(population=population, weights=weights)[0]
19+
index = rng.choice(population)
2320
yield index
24-
25-
if index == 0:
26-
np_weights += backwards
27-
else:
28-
np_weights[:(index + 1)] += forward[-(index + 1):]
29-
np_weights[(index + 1):] += backwards[1:-index]
30-
31-
# Prevent overflow.
32-
np_weights -= np_weights.min()
21+
population[index] = rng.integers(low=0, high=high)

0 commit comments

Comments
 (0)