Skip to content

Commit 3af70ea

Browse files
authored
Merge pull request #121 from marscher/traj_gen_random_state
added random_state handling to traj generation
2 parents 1611b57 + 8267211 commit 3af70ea

File tree

2 files changed

+64
-55
lines changed

2 files changed

+64
-55
lines changed

msmtools/generation/api.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class MarkovChainSampler(object):
4242
4343
"""
4444

45-
def __init__(self, P, dt=1):
45+
def __init__(self, P, dt=1, random_state=None):
4646
"""
4747
Constructs a sampling object with transition matrix P. The results will be produced every dt'th time step
4848
@@ -71,12 +71,26 @@ def __init__(self, P, dt=1):
7171
# initialize mu
7272
self.mudist = None
7373

74+
self.random_state = random_state
75+
7476
# generate discrete random value generators for each line
75-
self.rgs = np.ndarray((self.n), dtype=object)
77+
self.rgs = np.ndarray(self.n, dtype=object)
7678
from scipy.stats import rv_discrete
77-
for i in range(self.n):
78-
nz = np.nonzero(self.P[i])
79-
self.rgs[i] = rv_discrete(values=(nz, self.P[i, nz]))
79+
for i, row in enumerate(self.P):
80+
nz = row.nonzero()[0]
81+
self.rgs[i] = rv_discrete(values=(nz, row[nz]))
82+
83+
def _get_start_state(self):
84+
if self.mudist is None:
85+
# compute mu, the stationary distribution of P
86+
from ..analysis import stationary_distribution
87+
from scipy.stats import rv_discrete
88+
89+
mu = stationary_distribution(self.P)
90+
self.mudist = rv_discrete(values=(np.arange(self.n), mu))
91+
# sample starting point from mu
92+
start = self.mudist.rvs(random_state=self.random_state)
93+
return start
8094

8195
def trajectory(self, N, start=None, stop=None):
8296
"""
@@ -97,22 +111,12 @@ def trajectory(self, N, start=None, stop=None):
97111
stop = types.ensure_int_vector_or_None(stop, require_order=False)
98112

99113
if start is None:
100-
if self.mudist is None:
101-
# compute mu, the stationary distribution of P
102-
from ..analysis import stationary_distribution
103-
from scipy.stats import rv_discrete
104-
105-
mu = stationary_distribution(self.P)
106-
self.mudist = rv_discrete(values=(np.arange(self.n), mu))
107-
# sample starting point from mu
108-
start = self.mudist.rvs()
114+
start = self._get_start_state()
109115

110116
# evaluate stopping set
111-
stopat = np.ndarray((self.n), dtype=bool)
112-
stopat[:] = False
113-
if (stop is not None):
114-
for s in np.array(stop):
115-
stopat[s] = True
117+
stopat = np.zeros(self.n, dtype=bool)
118+
if stop is not None:
119+
stopat[np.array(stop)] = True
116120

117121
# result
118122
traj = np.zeros(N, dtype=int)
@@ -122,9 +126,10 @@ def trajectory(self, N, start=None, stop=None):
122126
return traj[:1]
123127
# else run until end or stopping state
124128
for t in range(1, N):
125-
traj[t] = self.rgs[traj[t - 1]].rvs()
129+
traj[t] = self.rgs[traj[t - 1]].rvs(random_state=self.random_state)
126130
if stopat[traj[t]]:
127-
return traj[:t+1]
131+
traj = np.resize(traj, t + 1)
132+
break
128133
# return
129134
return traj
130135

@@ -149,7 +154,7 @@ def trajectories(self, M, N, start=None, stop=None):
149154
return trajs
150155

151156

152-
def generate_traj(P, N, start=None, stop=None, dt=1):
157+
def generate_traj(P, N, start=None, stop=None, dt=1, random_state=None):
153158
"""
154159
Generates a realization of the Markov chain with transition matrix P.
155160
@@ -167,18 +172,22 @@ def generate_traj(P, N, start=None, stop=None, dt=1):
167172
dt : int
168173
trajectory will be saved every dt time steps.
169174
Internally, the dt'th power of P is taken to ensure a more efficient simulation.
175+
random_state : None or int or numpy.random.RandomState instance, optional
176+
This parameter defines the RandomState object to use for drawing random variates.
177+
If None, the global np.random state is used. If integer, it is used to seed the local RandomState instance.
178+
Default is None.
170179
171180
Returns
172181
-------
173182
traj_sliced : (N/dt, ) ndarray
174183
A discrete trajectory with length N/dt
175184
176185
"""
177-
sampler = MarkovChainSampler(P, dt=dt)
186+
sampler = MarkovChainSampler(P, dt=dt, random_state=random_state)
178187
return sampler.trajectory(N, start=start, stop=stop)
179188

180189

181-
def generate_trajs(P, M, N, start=None, stop=None, dt=1):
190+
def generate_trajs(P, M, N, start=None, stop=None, dt=1, random_state=None):
182191
"""
183192
Generates multiple realizations of the Markov chain with transition matrix P.
184193
@@ -198,14 +207,18 @@ def generate_trajs(P, M, N, start=None, stop=None, dt=1):
198207
dt : int
199208
trajectory will be saved every dt time steps.
200209
Internally, the dt'th power of P is taken to ensure a more efficient simulation.
210+
random_state : None or int or numpy.random.RandomState instance, optional
211+
This parameter defines the RandomState object to use for drawing random variates.
212+
If None, the global np.random state is used. If integer, it is used to seed the local RandomState instance.
213+
Default is None.
201214
202215
Returns
203216
-------
204217
traj_sliced : (N/dt, ) ndarray
205218
A discrete trajectory with length N/dt
206219
207220
"""
208-
sampler = MarkovChainSampler(P, dt=dt)
221+
sampler = MarkovChainSampler(P, dt=dt, random_state=random_state)
209222
return sampler.trajectories(M, N, start=start, stop=stop)
210223

211224

@@ -235,12 +248,12 @@ def transition_matrix_metropolis_1d(E, d=1.0):
235248
236249
"""
237250
# check input
238-
if (d <= 0 or d > 1):
251+
if d <= 0 or d > 1:
239252
raise ValueError('Diffusivity must be in (0,1]. Trying to set the invalid value', str(d))
240253
# init
241254
n = len(E)
242255
P = np.zeros((n, n))
243-
# set offdiagonals
256+
# set off diagonals
244257
P[0, 1] = 0.5 * d * min(1.0, math.exp(-(E[1] - E[0])))
245258
for i in range(1, n - 1):
246259
P[i, i - 1] = 0.5 * d * min(1.0, math.exp(-(E[i - 1] - E[i])))

tests/generation/test_generation.py

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
# This file is part of MSMTools.
32
#
43
# Copyright (c) 2015, 2014 Computational Molecular Biology Group
@@ -25,66 +24,63 @@
2524
import msmtools.estimation as msmest
2625
import msmtools.analysis as msmana
2726

28-
class Test(unittest.TestCase):
2927

30-
def setUp(self):
31-
"""Safe random state"""
32-
self.state = np.random.get_state()
33-
"""Set seed to enforce deterministic behavior"""
34-
np.random.seed(42)
28+
class TestTrajGeneration(unittest.TestCase):
3529

36-
def tearDown(self):
37-
"""Reset state"""
38-
np.random.set_state(self.state)
30+
@classmethod
31+
def setUpClass(cls):
32+
cls.P = np.array([[0.9, 0.1],
33+
[0.1, 0.9]])
34+
35+
def setUp(self):
36+
self.random_state = np.random.RandomState(42)
3937

4038
def test_trajectory(self):
41-
P = np.array([[0.9,0.1],
42-
[0.1,0.9]])
4339
N = 1000
44-
traj = msmgen.generate_traj(P, N, start=0)
40+
traj = msmgen.generate_traj(self.P, N, start=0, random_state=self.random_state)
4541

4642
# test shapes and sizes
4743
assert traj.size == N
4844
assert traj.min() >= 0
4945
assert traj.max() <= 1
5046

5147
# test statistics of transition matrix
52-
C = msmest.count_matrix(traj,1)
48+
C = msmest.count_matrix(traj, 1)
5349
Pest = msmest.transition_matrix(C)
54-
assert np.max(np.abs(Pest - P)) < 0.025
55-
50+
assert np.max(np.abs(Pest - self.P)) < 0.025
5651

5752
def test_trajectories(self):
58-
P = np.array([[0.9,0.1],
59-
[0.1,0.9]])
60-
6153
# test number of trajectories
6254
M = 10
6355
N = 10
64-
trajs = msmgen.generate_trajs(P, M, N, start=0)
56+
trajs = msmgen.generate_trajs(self.P, M, N, start=0, random_state=self.random_state)
6557
assert len(trajs) == M
6658

59+
def test_stats(self):
6760
# test statistics of starting state
68-
trajs = msmgen.generate_trajs(P, 1000, 1)
61+
N = 5000
62+
trajs = msmgen.generate_trajs(self.P, N, 1, random_state=self.random_state)
6963
ss = np.concatenate(trajs).astype(int)
70-
pi = msmana.stationary_distribution(P)
71-
piest = msmest.count_states(ss) / 1000.0
72-
assert np.max(np.abs(pi - piest)) < 0.025
64+
pi = msmana.stationary_distribution(self.P)
65+
piest = msmest.count_states(ss) / float(N)
66+
np.testing.assert_allclose(piest, pi, atol=0.025)
7367

74-
# test stopping state = starting state
68+
def test_stop_eq_start(self):
7569
M = 10
76-
trajs = msmgen.generate_trajs(P, M, N, start=0, stop=0)
70+
N = 10
71+
trajs = msmgen.generate_trajs(self.P, M, N, start=0, stop=0, random_state=self.random_state)
7772
for traj in trajs:
7873
assert traj.size == 1
7974

75+
def test_stop(self):
8076
# test if we always stop at stopping state
8177
M = 100
78+
N = 10
8279
stop = 1
83-
trajs = msmgen.generate_trajs(P, M, N, start=0, stop=stop)
80+
trajs = msmgen.generate_trajs(self.P, M, N, start=0, stop=stop, random_state=self.random_state)
8481
for traj in trajs:
8582
assert traj.size == N or traj[-1] == stop
8683
assert stop not in traj[:-1]
8784

8885
if __name__ == "__main__":
89-
# import sys;sys.argv = ['', 'Test.testName']
9086
unittest.main()

0 commit comments

Comments
 (0)