Skip to content

Commit 87b47d4

Browse files
committed
Use pythons random instead of numpy random in stable_number_distribution
1 parent b0ea7a9 commit 87b47d4

File tree

2 files changed

+38
-26
lines changed

2 files changed

+38
-26
lines changed

micall/tests/test_stable_random_distribution.py

+24-20
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
import numpy as np
44
from itertools import islice
55
from typing import Set
6+
import random
67

78

89
def test_indices_in_range():
910
"""Test that each index generated is within the range [0, high)."""
1011

1112
high = 10
12-
gen = stable_random_distribution(high, seed=123)
13+
rng = random.Random(123)
14+
gen = stable_random_distribution(high, rng=rng)
1315
# Grab a bunch of values from the infinite generator
1416

1517
for _ in range(1000):
@@ -21,7 +23,8 @@ def test_bounds_are_reachable():
2123
"""Test that both min and max-1 can be generated."""
2224

2325
high = 999
24-
gen = stable_random_distribution(high, seed=123)
26+
rng = random.Random(123456)
27+
gen = stable_random_distribution(high, rng=rng)
2528
lst = islice(gen, 1000)
2629

2730
assert 0 in lst
@@ -32,13 +35,8 @@ def test_everything_is_reachable():
3235
"""Test that all numbers in the range [0, max-1) can be generated."""
3336

3437
high = 30
35-
fun = stable_random_distribution
36-
# def fun(high, seed):
37-
# import random
38-
# while True:
39-
# yield random.randint(0, high)
40-
41-
gen = fun(high, seed=123)
38+
rng = random.Random(123)
39+
gen = stable_random_distribution(high, rng=rng)
4240
lst = tuple(map(int, islice(gen, 1000)))
4341

4442
for x in range(high):
@@ -53,8 +51,10 @@ def test_deterministic_output_with_seed():
5351

5452
high = 15
5553
seed = 456
56-
gen1 = stable_random_distribution(high, seed=seed)
57-
gen2 = stable_random_distribution(high, seed=seed)
54+
rng1 = random.Random(seed)
55+
rng2 = random.Random(seed)
56+
gen1 = stable_random_distribution(high, rng=rng1)
57+
gen2 = stable_random_distribution(high, rng=rng2)
5858

5959
# Compare the first 50 generated values.
6060
values1 = [next(gen1) for _ in range(50)]
@@ -69,8 +69,10 @@ def test_different_seeds_differ():
6969
"""
7070

7171
high = 15
72-
gen1 = stable_random_distribution(high, seed=789)
73-
gen2 = stable_random_distribution(high, seed=987)
72+
rng1 = random.Random(789)
73+
rng2 = random.Random(987)
74+
gen1 = stable_random_distribution(high, rng=rng1)
75+
gen2 = stable_random_distribution(high, rng=rng2)
7476

7577
# Compare the first 50 generated values: while not guaranteed to
7678
# be different, it is extremely unlikely that the two sequences
@@ -98,14 +100,15 @@ def test_fair_distribution_behavior():
98100
num_samples = 3_000
99101
for seed in range(100):
100102
# Gather samples from our generator.
101-
gen = stable_random_distribution(high, seed=seed)
103+
rng = random.Random(seed)
104+
gen = stable_random_distribution(high, rng=rng)
102105
samples = np.array([next(gen) for _ in range(num_samples)])
103106
diff_stable = np.abs(np.diff(np.sort(samples))) ** 2
104107
avg_diff_stable = diff_stable.mean()
105108

106109
# For comparison, generate num_samples indices uniformly at random.
107-
rng = np.random.default_rng(seed)
108-
uniform_samples = rng.choice(high, size=num_samples)
110+
nprng = np.random.default_rng(seed)
111+
uniform_samples = nprng.choice(high, size=num_samples)
109112
diff_uniform = np.abs(np.diff(np.sort(uniform_samples))) ** 2
110113
avg_diff_uniform = diff_uniform.mean()
111114

@@ -132,22 +135,23 @@ def test_fill_domain_speed():
132135

133136
for seed in range(trials):
134137
# Gather samples from our generator.
135-
gen = stable_random_distribution(high, seed=seed)
138+
rng = random.Random(seed)
139+
gen = stable_random_distribution(high, rng=rng)
136140
stable_bucket: Set[int] = set()
137141
stable_steps = 0
138142
while len(stable_bucket) < high:
139143
stable_bucket.add(next(gen))
140144
stable_steps += 1
141145

142146
# For comparison, generate num_samples indices uniformly at random.
143-
rng = np.random.default_rng(seed)
147+
nprng = np.random.default_rng(seed)
144148
uniform_bucket: Set[int] = set()
145149
uniform_steps = 0
146150
while len(uniform_bucket) < high:
147-
uniform_bucket.add(rng.integers(0, high))
151+
uniform_bucket.add(nprng.integers(0, high))
148152
uniform_steps += 1
149153

150154
if stable_steps < uniform_steps:
151155
wins += 1
152156

153-
assert wins / trials > 0.85
157+
assert wins / trials > 0.80
+14-6
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,29 @@
1-
from typing import Iterator
1+
from typing import Iterator, Optional
22

3+
import random
34
import numpy as np
45

5-
DUPLICATION_FACTOR = 1
6+
DUPLICATION_FACTOR = 100
67

78

8-
def stable_random_distribution(high: int, seed: int = 42) -> Iterator[int]:
9+
def stable_random_distribution(high: int,
10+
rng: Optional[random.Random] = None,
11+
) -> Iterator[int]:
12+
913
if high <= 0:
1014
return
1115

12-
rng = np.random.default_rng(seed)
16+
if rng is None:
17+
rng = random.Random()
18+
19+
maximum = high - 1
1320
block = np.arange(high)
1421
population = np.concatenate([block] * DUPLICATION_FACTOR, axis=0)
1522

1623
assert len(population) == DUPLICATION_FACTOR * len(block)
1724

1825
while True:
19-
index = rng.choice(population)
26+
choice = rng.randint(0, maximum)
27+
index = population[choice]
2028
yield index
21-
population[index] = rng.integers(low=0, high=high)
29+
population[index] = rng.randint(0, maximum)

0 commit comments

Comments
 (0)