2
2
from micall .utils .stable_random_distribution import stable_random_distribution
3
3
import numpy as np
4
4
from itertools import islice
5
+ from typing import Set
5
6
6
7
7
8
def test_indices_in_range ():
8
- """Test that each index generated is within the range [0, maximum )."""
9
+ """Test that each index generated is within the range [0, high )."""
9
10
10
- maximum = 10
11
- gen = stable_random_distribution (maximum , seed = 123 )
11
+ high = 10
12
+ gen = stable_random_distribution (high , seed = 123 )
12
13
# Grab a bunch of values from the infinite generator
13
14
14
15
for _ in range (1000 ):
15
16
idx = next (gen )
16
- assert 0 <= idx < maximum , f"Index { idx } out of range [0,{ maximum } )"
17
+ assert 0 <= idx < high , f"Index { idx } out of range [0,{ high } )"
17
18
18
19
19
20
def test_bounds_are_reachable ():
20
21
"""Test that both min and max-1 can be generated."""
21
22
22
- maximum = 999
23
- gen = stable_random_distribution (maximum , seed = 123 )
23
+ high = 999
24
+ gen = stable_random_distribution (high , seed = 123 )
24
25
lst = islice (gen , 1000 )
25
26
26
27
assert 0 in lst
27
- assert (maximum - 1 ) in lst
28
+ assert (high - 1 ) in lst
29
+
30
+
31
+ def test_everything_is_reachable ():
32
+ """Test that all numbers in the range [0, max-1) can be generated."""
33
+
34
+ high = 30
35
+ fun = stable_random_distribution
36
+ # def fun(high, seed):
37
+ # import random
38
+ # while True:
39
+ # yield random.randint(0, high)
40
+
41
+ gen = fun (high , seed = 123 )
42
+ lst = tuple (map (int , islice (gen , 1000 )))
43
+
44
+ for x in range (high ):
45
+ assert x in lst
28
46
29
47
30
48
def test_deterministic_output_with_seed ():
@@ -33,10 +51,10 @@ def test_deterministic_output_with_seed():
33
51
re-seeded with the same seed.
34
52
"""
35
53
36
- maximum = 15
54
+ high = 15
37
55
seed = 456
38
- gen1 = stable_random_distribution (maximum , seed = seed )
39
- gen2 = stable_random_distribution (maximum , seed = seed )
56
+ gen1 = stable_random_distribution (high , seed = seed )
57
+ gen2 = stable_random_distribution (high , seed = seed )
40
58
41
59
# Compare the first 50 generated values.
42
60
values1 = [next (gen1 ) for _ in range (50 )]
@@ -50,9 +68,9 @@ def test_different_seeds_differ():
50
68
A sanity check that different seeds usually lead to a different sequence.
51
69
"""
52
70
53
- maximum = 15
54
- gen1 = stable_random_distribution (maximum , seed = 789 )
55
- gen2 = stable_random_distribution (maximum , seed = 987 )
71
+ high = 15
72
+ gen1 = stable_random_distribution (high , seed = 789 )
73
+ gen2 = stable_random_distribution (high , seed = 987 )
56
74
57
75
# Compare the first 50 generated values: while not guaranteed to
58
76
# be different, it is extremely unlikely that the two sequences
@@ -76,18 +94,18 @@ def test_fair_distribution_behavior():
76
94
- With the adaptive update, values should tend to be farther apart.
77
95
"""
78
96
79
- maximum = 1_000
97
+ high = 100
80
98
num_samples = 3_000
81
99
for seed in range (100 ):
82
100
# Gather samples from our generator.
83
- gen = stable_random_distribution (maximum , seed = seed )
101
+ gen = stable_random_distribution (high , seed = seed )
84
102
samples = np .array ([next (gen ) for _ in range (num_samples )])
85
103
diff_stable = np .abs (np .diff (np .sort (samples ))) ** 2
86
104
avg_diff_stable = diff_stable .mean ()
87
105
88
106
# For comparison, generate num_samples indices uniformly at random.
89
107
rng = np .random .default_rng (seed )
90
- uniform_samples = rng .choice (maximum , size = num_samples )
108
+ uniform_samples = rng .choice (high , size = num_samples )
91
109
diff_uniform = np .abs (np .diff (np .sort (uniform_samples ))) ** 2
92
110
avg_diff_uniform = diff_uniform .mean ()
93
111
@@ -98,3 +116,38 @@ def test_fair_distribution_behavior():
98
116
f"Expected stable generator to have a higher average jump than a uniform generator: "
99
117
f"stable { avg_diff_stable } vs uniform { avg_diff_uniform } "
100
118
)
119
+
120
+
121
+ def test_fill_domain_speed ():
122
+ """
123
+ Test that the stable_random_distribution fill out the domain
124
+ quicker than a simple uniform generator.
125
+
126
+ Idea is similar to the previous test.
127
+ """
128
+
129
+ high = 100
130
+ trials = 100
131
+ wins = 0
132
+
133
+ for seed in range (trials ):
134
+ # Gather samples from our generator.
135
+ gen = stable_random_distribution (high , seed = seed )
136
+ stable_bucket : Set [int ] = set ()
137
+ stable_steps = 0
138
+ while len (stable_bucket ) < high :
139
+ stable_bucket .add (next (gen ))
140
+ stable_steps += 1
141
+
142
+ # For comparison, generate num_samples indices uniformly at random.
143
+ rng = np .random .default_rng (seed )
144
+ uniform_bucket : Set [int ] = set ()
145
+ uniform_steps = 0
146
+ while len (uniform_bucket ) < high :
147
+ uniform_bucket .add (rng .integers (0 , high ))
148
+ uniform_steps += 1
149
+
150
+ if stable_steps < uniform_steps :
151
+ wins += 1
152
+
153
+ assert wins / trials > 0.85
0 commit comments