@@ -68,28 +68,17 @@ def __init__(self, P, dt=1, random_state=None):
68
68
self .P = np .array (P )
69
69
self .n = self .P .shape [0 ]
70
70
71
- # initialize mu
72
- self .mudist = None
73
-
71
+ if random_state is None :
72
+ random_state = np .random .RandomState ()
74
73
self .random_state = random_state
75
74
76
- # generate discrete random value generators for each line
77
- self .rgs = np .ndarray (self .n , dtype = object )
78
- from scipy .stats import rv_discrete
79
- for i , row in enumerate (self .P ):
80
- nz = row .nonzero ()[0 ]
81
- self .rgs [i ] = rv_discrete (values = (nz , row [nz ]))
82
-
83
75
def _get_start_state (self ):
84
- if self .mudist is None :
85
- # compute mu, the stationary distribution of P
86
- from ..analysis import stationary_distribution
87
- from scipy .stats import rv_discrete
88
-
89
- mu = stationary_distribution (self .P )
90
- self .mudist = rv_discrete (values = (np .arange (self .n ), mu ))
91
- # sample starting point from mu
92
- start = self .mudist .rvs (random_state = self .random_state )
76
+ # compute mu, the stationary distribution of P
77
+ from ..analysis import stationary_distribution
78
+
79
+ mu = stationary_distribution (self .P )
80
+ start = self .random_state .choice (self .n , p = mu )
81
+
93
82
return start
94
83
95
84
def trajectory (self , N , start = None , stop = None ):
@@ -113,24 +102,19 @@ def trajectory(self, N, start=None, stop=None):
113
102
if start is None :
114
103
start = self ._get_start_state ()
115
104
116
- # evaluate stopping set
117
- stopat = np .zeros (self .n , dtype = bool )
118
- if stop is not None :
119
- stopat [np .array (stop )] = True
120
-
121
105
# result
122
106
traj = np .zeros (N , dtype = int )
123
107
traj [0 ] = start
124
108
# already at stopping state?
125
- if stopat [ traj [0 ]] :
109
+ if traj [0 ] == stop :
126
110
return traj [:1 ]
127
111
# else run until end or stopping state
128
112
for t in range (1 , N ):
129
- traj [t ] = self .rgs [traj [t - 1 ]]. rvs ( random_state = self . random_state )
130
- if stopat [ traj [t ]] :
113
+ traj [t ] = self .random_state . choice ( self . n , p = self . P [traj [t - 1 ]])
114
+ if traj [t ] == stop :
131
115
traj = np .resize (traj , t + 1 )
132
116
break
133
- # return
117
+
134
118
return traj
135
119
136
120
def trajectories (self , M , N , start = None , stop = None ):
0 commit comments