1+ import numpy as np
2+
13from typing import (
24 List ,
35 Tuple ,
@@ -24,9 +26,9 @@ def __init__(
2426 self ,
2527 ):
2628 self .stage_schedulers = []
27- self .stage_reports = [[]]
2829 self .cur_stage = 0
2930 self .iteration = - 1
31+ self .complete_ = False
3032
3133 def add_stage_scheduler (
3234 self ,
@@ -44,6 +46,7 @@ def add_stage_scheduler(
4446
4547 """
4648 self .stage_schedulers .append (stage_scheduler )
49+ self .complete_ = False
4750 return self
4851
4952 def get_stage (self ):
@@ -64,6 +67,13 @@ def get_iteration(self):
6467 """
6568 return self .iteration
6669
70+ def complete (self ):
71+ """
72+ Tell if all stages are converged.
73+
74+ """
75+ return self .complete_
76+
6777 def plan_next_iteration (
6878 self ,
6979 report : ExplorationReport = None ,
@@ -81,8 +91,8 @@ def plan_next_iteration(
8191
8292 Returns
8393 -------
84- converged : bool
85- If DPGEN converges .
94+ complete : bool
95+ If all the DPGEN stages complete .
8696 task: ExplorationTaskGroup
8797 A `ExplorationTaskGroup` defining the exploration of the next iteration. Should be `None` if converged.
8898 conf_selector: ConfSelector
@@ -91,26 +101,118 @@ def plan_next_iteration(
91101 """
92102
93103 try :
94- converged , lmp_task_grp , conf_selector = \
104+ stg_complete , lmp_task_grp , conf_selector = \
95105 self .stage_schedulers [self .cur_stage ].plan_next_iteration (
96- self .stage_reports [self .cur_stage ],
97106 report ,
98107 trajs ,
99108 )
100- self .stage_reports [self .cur_stage ].append (report )
101109 except FatalError as e :
102110 raise FatalError (f'stage { self .cur_stage } : ' + str (e ))
103111
104- if converged :
112+ if stg_complete :
105113 self .cur_stage += 1
106- self .stage_reports .append ([])
107114 if self .cur_stage < len (self .stage_schedulers ):
108115 # goes to next stage
109116 return self .plan_next_iteration ()
110117 else :
111- # all stages converged
118+ # all stages complete
119+ self .complete_ = True
112120 return True , None , None ,
113121 else :
114122 self .iteration += 1
115- return converged , lmp_task_grp , conf_selector
123+ return stg_complete , lmp_task_grp , conf_selector
124+
125+
126+ def get_stage_of_iterations (self ):
127+ """
128+ Get the stage index and the index in the stage of iterations.
116129
130+ """
131+ stages = self .stage_schedulers
132+ n_stage_iters = []
133+ for ii in range (self .get_stage () + 1 ):
134+ if ii < len (stages ) and len (stages [ii ].reports ) > 0 :
135+ n_stage_iters .append (len (stages [ii ].reports ))
136+ cumsum_stage_iters = np .cumsum (n_stage_iters )
137+
138+ max_iter = self .get_iteration ()
139+ if self .complete () or max_iter == - 1 :
140+ max_iter += 1
141+ stage_idx = []
142+ idx_in_stage = []
143+ iter_idx = []
144+ for ii in range (max_iter ):
145+ idx = np .searchsorted (cumsum_stage_iters , ii + 1 )
146+ stage_idx .append (idx )
147+ if idx > 0 :
148+ idx_in_stage .append (ii - cumsum_stage_iters [idx - 1 ])
149+ else :
150+ idx_in_stage .append (ii )
151+ iter_idx .append (ii )
152+ assert ( len (stage_idx ) == max_iter )
153+ assert ( len (idx_in_stage ) == max_iter )
154+ assert ( len (iter_idx ) == max_iter )
155+ return stage_idx , idx_in_stage , iter_idx
156+
157+
158+ def get_convergence_ratio (self ):
159+ """
160+ Get the accurate, candidate and failed ratios of the iterations
161+
162+ Returns
163+ -------
164+ accu np.ndarray
165+ The accurate ratio. length of array the same as # iterations.
166+ cand np.ndarray
167+ The candidate ratio. length of array the same as # iterations.
168+ fail np.ndarray
169+ The failed ration. length of array the same as # iterations.
170+ """
171+ stages = self .stage_schedulers
172+ stag_idx , idx_in_stag , iter_idx = self .get_stage_of_iterations ()
173+ accu = []
174+ cand = []
175+ fail = []
176+ for ii in range (np .size (iter_idx )):
177+ accu .append (stages [stag_idx [ii ]].reports [idx_in_stag [ii ]].accurate_ratio ())
178+ cand .append (stages [stag_idx [ii ]].reports [idx_in_stag [ii ]].candidate_ratio ())
179+ fail .append (stages [stag_idx [ii ]].reports [idx_in_stag [ii ]].failed_ratio ())
180+ return np .array (accu ), np .array (cand ), np .array (fail )
181+
182+ def _print_prev_summary (self , prev_stg_idx ):
183+ if prev_stg_idx >= 0 :
184+ yes = 'YES' if self .stage_schedulers [prev_stg_idx ].converged () else 'NO '
185+ rmx = 'YES' if self .stage_schedulers [prev_stg_idx ].reached_max_iteration () else 'NO '
186+ return f'# Stage { prev_stg_idx :4d} converged { yes } reached max numb iterations { rmx } '
187+ else :
188+ return None
189+
190+ def print_convergence (self ):
191+ spaces = [8 , 8 , 8 , 10 , 10 , 10 ]
192+ fmt_str = ' ' .join ([f'%{ ii } s' for ii in spaces ])
193+ fmt_flt = '%.4f'
194+ header_str = '#' + fmt_str % ('stage' , 'id_stg.' , 'iter.' , 'accu.' , 'cand.' , 'fail.' )
195+ ret = [header_str ]
196+
197+ stage_idx , idx_in_stage , iter_idx = self .get_stage_of_iterations ()
198+ accu , cand , fail = self .get_convergence_ratio ()
199+
200+ iidx = 0
201+ prev_stg_idx = - 1
202+ for iidx in range (len (accu )):
203+ if stage_idx [iidx ] != prev_stg_idx :
204+ if prev_stg_idx >= 0 :
205+ ret .append (self ._print_prev_summary (prev_stg_idx ))
206+ ret .append (f'# Stage { stage_idx [iidx ]:4d} ' + '-' * 20 )
207+ prev_stg_idx = stage_idx [iidx ]
208+ ret .append (' ' + fmt_str % (
209+ str (stage_idx [iidx ]), str (idx_in_stage [iidx ]), str (iidx ),
210+ fmt_flt % (accu [iidx ]* 1 ),
211+ fmt_flt % (cand [iidx ]* 1 ),
212+ fmt_flt % (fail [iidx ]* 1 ),
213+ ))
214+ if self .complete ():
215+ if prev_stg_idx >= 0 :
216+ ret .append (self ._print_prev_summary (prev_stg_idx ))
217+ ret .append (f'# All stages converged' )
218+ return '\n ' .join (ret + ['' ])
0 commit comments