Skip to content

Commit 27672b6

Browse files
committed
Fix logical error in stable_random_distribution
1 parent 41dfe7d commit 27672b6

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed
+21-9
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,31 @@
1-
from typing import Iterator
1+
from typing import Iterator, Sequence
22

33
import numpy as np
4+
import random
45

56

67
def stable_random_distribution(maximum: int, seed: int = 42) -> Iterator[int]:
8+
if maximum <= 0:
9+
return
10+
711
n = maximum
8-
rng = np.random.default_rng(seed)
12+
rng = random.Random(seed)
913

10-
weights = np.zeros(n) + 1
11-
forward = np.arange(1, n + 1)
12-
backwards = np.arange(n, 0, -1)
14+
population = np.arange(n)
15+
forward = np.arange(1, n + 1) ** 0.5
16+
backwards = np.copy(np.flip(forward))
17+
np_weights = np.zeros(n) + 0.1
1318

1419
while True:
15-
probabilities = weights / weights.sum()
16-
index = rng.choice(n, p=probabilities)
20+
weights: Sequence[float] = 1 - np_weights # type: ignore
21+
indexes = rng.choices(population=population, weights=weights)
22+
index = indexes[0]
1723
yield index
18-
weights[:index] += forward[:index]
19-
weights[index:] += backwards[index:]
24+
25+
if index == 0:
26+
np_weights += backwards
27+
else:
28+
np_weights[:(index + 1)] += forward[-(index + 1):]
29+
np_weights[(index + 1):] += backwards[1:-index]
30+
31+
np_weights /= np_weights.sum()

0 commit comments

Comments
 (0)