Skip to content

Commit b97481d

Browse files
authored
Merge pull request #123 from markovmodel/trajgen
[generation] replaced scipy.stats.rv_discrete by np.random.choice
2 parents 3bd4e38 + 013f630 commit b97481d

File tree

2 files changed

+20
-28
lines changed

2 files changed

+20
-28
lines changed

Diff for: msmtools/generation/api.py

+12-28
Original file line numberDiff line numberDiff line change
@@ -68,28 +68,17 @@ def __init__(self, P, dt=1, random_state=None):
6868
self.P = np.array(P)
6969
self.n = self.P.shape[0]
7070

71-
# initialize mu
72-
self.mudist = None
73-
71+
if random_state is None:
72+
random_state = np.random.RandomState()
7473
self.random_state = random_state
7574

76-
# generate discrete random value generators for each line
77-
self.rgs = np.ndarray(self.n, dtype=object)
78-
from scipy.stats import rv_discrete
79-
for i, row in enumerate(self.P):
80-
nz = row.nonzero()[0]
81-
self.rgs[i] = rv_discrete(values=(nz, row[nz]))
82-
8375
def _get_start_state(self):
84-
if self.mudist is None:
85-
# compute mu, the stationary distribution of P
86-
from ..analysis import stationary_distribution
87-
from scipy.stats import rv_discrete
88-
89-
mu = stationary_distribution(self.P)
90-
self.mudist = rv_discrete(values=(np.arange(self.n), mu))
91-
# sample starting point from mu
92-
start = self.mudist.rvs(random_state=self.random_state)
76+
# compute mu, the stationary distribution of P
77+
from ..analysis import stationary_distribution
78+
79+
mu = stationary_distribution(self.P)
80+
start = self.random_state.choice(self.n, p=mu)
81+
9382
return start
9483

9584
def trajectory(self, N, start=None, stop=None):
@@ -113,24 +102,19 @@ def trajectory(self, N, start=None, stop=None):
113102
if start is None:
114103
start = self._get_start_state()
115104

116-
# evaluate stopping set
117-
stopat = np.zeros(self.n, dtype=bool)
118-
if stop is not None:
119-
stopat[np.array(stop)] = True
120-
121105
# result
122106
traj = np.zeros(N, dtype=int)
123107
traj[0] = start
124108
# already at stopping state?
125-
if stopat[traj[0]]:
109+
if traj[0] == stop:
126110
return traj[:1]
127111
# else run until end or stopping state
128112
for t in range(1, N):
129-
traj[t] = self.rgs[traj[t - 1]].rvs(random_state=self.random_state)
130-
if stopat[traj[t]]:
113+
traj[t] = self.random_state.choice(self.n, p=self.P[traj[t - 1]])
114+
if traj[t] == stop:
131115
traj = np.resize(traj, t + 1)
132116
break
133-
# return
117+
134118
return traj
135119

136120
def trajectories(self, M, N, start=None, stop=None):

Diff for: tests/generation/test_generation.py

+8
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ def test_stats(self):
6565
piest = msmest.count_states(ss) / float(N)
6666
np.testing.assert_allclose(piest, pi, atol=0.025)
6767

68+
def test_transitionmatrix(self):
69+
# test if transition matrix can be reconstructed
70+
N = 5000
71+
trajs = msmgen.generate_traj(self.P, N, random_state=self.random_state)
72+
C = msmest.count_matrix(trajs, 1, sparse_return=False)
73+
T = msmest.transition_matrix(C)
74+
np.testing.assert_allclose(T, self.P, atol=.01)
75+
6876
def test_stop_eq_start(self):
6977
M = 10
7078
N = 10

0 commit comments

Comments
 (0)