33import numpy as np
44from itertools import islice
55from typing import Set
6+ import random
67
78
89def test_indices_in_range ():
910 """Test that each index generated is within the range [0, high)."""
1011
1112 high = 10
12- gen = stable_random_distribution (high , seed = 123 )
13+ rng = random .Random (123 )
14+ gen = stable_random_distribution (high , rng = rng )
1315 # Grab a bunch of values from the infinite generator
1416
1517 for _ in range (1000 ):
@@ -21,7 +23,8 @@ def test_bounds_are_reachable():
2123 """Test that both min and max-1 can be generated."""
2224
2325 high = 999
24- gen = stable_random_distribution (high , seed = 123 )
26+ rng = random .Random (123456 )
27+ gen = stable_random_distribution (high , rng = rng )
2528 lst = islice (gen , 1000 )
2629
2730 assert 0 in lst
@@ -32,13 +35,8 @@ def test_everything_is_reachable():
3235 """Test that all numbers in the range [0, max-1) can be generated."""
3336
3437 high = 30
35- fun = stable_random_distribution
36- # def fun(high, seed):
37- # import random
38- # while True:
39- # yield random.randint(0, high)
40-
41- gen = fun (high , seed = 123 )
38+ rng = random .Random (123 )
39+ gen = stable_random_distribution (high , rng = rng )
4240 lst = tuple (map (int , islice (gen , 1000 )))
4341
4442 for x in range (high ):
@@ -53,8 +51,10 @@ def test_deterministic_output_with_seed():
5351
5452 high = 15
5553 seed = 456
56- gen1 = stable_random_distribution (high , seed = seed )
57- gen2 = stable_random_distribution (high , seed = seed )
54+ rng1 = random .Random (seed )
55+ rng2 = random .Random (seed )
56+ gen1 = stable_random_distribution (high , rng = rng1 )
57+ gen2 = stable_random_distribution (high , rng = rng2 )
5858
5959 # Compare the first 50 generated values.
6060 values1 = [next (gen1 ) for _ in range (50 )]
@@ -69,8 +69,10 @@ def test_different_seeds_differ():
6969 """
7070
7171 high = 15
72- gen1 = stable_random_distribution (high , seed = 789 )
73- gen2 = stable_random_distribution (high , seed = 987 )
72+ rng1 = random .Random (789 )
73+ rng2 = random .Random (987 )
74+ gen1 = stable_random_distribution (high , rng = rng1 )
75+ gen2 = stable_random_distribution (high , rng = rng2 )
7476
7577 # Compare the first 50 generated values: while not guaranteed to
7678 # be different, it is extremely unlikely that the two sequences
@@ -98,14 +100,15 @@ def test_fair_distribution_behavior():
98100 num_samples = 3_000
99101 for seed in range (100 ):
100102 # Gather samples from our generator.
101- gen = stable_random_distribution (high , seed = seed )
103+ rng = random .Random (seed )
104+ gen = stable_random_distribution (high , rng = rng )
102105 samples = np .array ([next (gen ) for _ in range (num_samples )])
103106 diff_stable = np .abs (np .diff (np .sort (samples ))) ** 2
104107 avg_diff_stable = diff_stable .mean ()
105108
106109 # For comparison, generate num_samples indices uniformly at random.
107- rng = np .random .default_rng (seed )
108- uniform_samples = rng .choice (high , size = num_samples )
110+ nprng = np .random .default_rng (seed )
111+ uniform_samples = nprng .choice (high , size = num_samples )
109112 diff_uniform = np .abs (np .diff (np .sort (uniform_samples ))) ** 2
110113 avg_diff_uniform = diff_uniform .mean ()
111114
@@ -132,22 +135,23 @@ def test_fill_domain_speed():
132135
133136 for seed in range (trials ):
134137 # Gather samples from our generator.
135- gen = stable_random_distribution (high , seed = seed )
138+ rng = random .Random (seed )
139+ gen = stable_random_distribution (high , rng = rng )
136140 stable_bucket : Set [int ] = set ()
137141 stable_steps = 0
138142 while len (stable_bucket ) < high :
139143 stable_bucket .add (next (gen ))
140144 stable_steps += 1
141145
142146 # For comparison, generate num_samples indices uniformly at random.
143- rng = np .random .default_rng (seed )
147+ nprng = np .random .default_rng (seed )
144148 uniform_bucket : Set [int ] = set ()
145149 uniform_steps = 0
146150 while len (uniform_bucket ) < high :
147- uniform_bucket .add (rng .integers (0 , high ))
151+ uniform_bucket .add (nprng .integers (0 , high ))
148152 uniform_steps += 1
149153
150154 if stable_steps < uniform_steps :
151155 wins += 1
152156
153- assert wins / trials > 0.85
157+ assert wins / trials > 0.80
0 commit comments