6
6
7
7
import numpy as np
8
8
import pandas as pd
9
- import stan
10
9
import matplotlib .pyplot as plt
11
10
from tqdm import trange
11
+ from cmdstanpy import CmdStanModel
12
12
13
13
from sccala .scmlib .models import SCM_Model
14
14
from sccala .utillib .aux import distmod_kin , quantile , split_list , nullify_output
@@ -237,6 +237,7 @@ def get_error_matrix(self, classic=False, rho=1.0, rho_calib=0.0):
237
237
)
238
238
return np .array (errors )
239
239
240
+
240
241
def sample (
241
242
self ,
242
243
model ,
@@ -252,7 +253,8 @@ def sample(
252
253
classic = False ,
253
254
):
254
255
"""
255
- Samples the posterior for the given data and model
256
+ Samples the posterior for the given data and model using
257
+ the cmdstanpy interface.
256
258
257
259
Parameters
258
260
----------
@@ -283,7 +285,7 @@ def sample(
283
285
Returns
284
286
-------
285
287
posterior : pandas DataFrame
286
- Result of the STAN sampling
288
+
287
289
"""
288
290
289
291
assert issubclass (
@@ -368,17 +370,29 @@ def sample(
368
370
369
371
model .set_initial_conditions (init )
370
372
371
- # Setup/ build STAN model
372
- fit = stan .build (model .model , data = model .data )
373
- samples = fit .sample (
374
- num_chains = chains ,
375
- num_samples = iters ,
376
- init = [model .init ] * chains ,
377
- num_warmup = warmup ,
373
+ data_file = model .write_json ("data.json" , path = log_dir )
374
+ stan_file = model .write_stan ("model.stan" , path = log_dir )
375
+
376
+ mdl = CmdStanModel (stan_file = stan_file )
377
+
378
+ fit = mdl .sample (
379
+ data = data_file ,
380
+ chains = chains ,
381
+ iter_warmup = warmup ,
382
+ iter_sampling = iters ,
378
383
save_warmup = save_warmup ,
384
+ inits = [model .init ] * chains ,
379
385
)
380
386
381
- self .posterior = samples .to_frame ()
387
+ summary = fit .summary ()
388
+ diagnose = fit .diagnose ()
389
+
390
+ if not quiet :
391
+ print (summary )
392
+ print (diagnose )
393
+
394
+
395
+ self .posterior = fit .draws_pd ()
382
396
383
397
# Encrypt H0 for blinding
384
398
if self .blind and model .hubble :
@@ -393,13 +407,24 @@ def sample(
393
407
norm = None
394
408
395
409
if log_dir is not None :
396
- self .__save_samples__ (self .posterior , log_dir = log_dir , norm = norm )
410
+ savename = self .__save_samples__ (self .posterior , log_dir = log_dir , norm = norm )
411
+ chains_dir = savename .replace (".csv" , "" )
412
+ os .makedirs (chains_dir )
413
+ with open (os .path .join (chains_dir , "summary.txt" ), "w" ) as f :
414
+ f .write (summary .to_string ())
415
+ with open (os .path .join (chains_dir , "diagnose.txt" ), "w" ) as f :
416
+ f .write (diagnose )
417
+ if not self .blind :
418
+ # Only move the csv files if we're not blinding the result
419
+ # TODO: find a way of blinding the individual chains
420
+ fit .save_csvfiles (chains_dir )
397
421
398
422
if not quiet :
399
423
model .print_results (self .posterior , blind = self .blind )
400
424
401
425
return self .posterior
402
426
427
+
403
428
def bootstrap (
404
429
self ,
405
430
model ,
@@ -614,6 +639,22 @@ def bootstrap(
614
639
else :
615
640
done = []
616
641
642
+ if rank == 0 :
643
+ stan_file = model .write_stan ("model.stan" , path = log_dir )
644
+
645
+ # Create a model instance to trigger compilation and avoid
646
+ # having to compile the model on each rank separately
647
+ print ("Compiling model..." )
648
+ mdl_0 = CmdStanModel (stan_file = stan_file )
649
+ del mdl_0
650
+ print ("Model compiled, starting sampling..." )
651
+ else :
652
+ # Should be done via broadcast, but this is easier
653
+ # and the path is 'hardcoded' anyway
654
+ stan_file = os .path .join (log_dir , "model.stan" )
655
+
656
+ comm .Barrier ()
657
+
617
658
for k in tr :
618
659
if parallel :
619
660
inds = bt_inds_lists [rank ][k ]
@@ -633,12 +674,12 @@ def bootstrap(
633
674
continue
634
675
635
676
model .data ["calib_sn_idx" ] = len (self .calib_sn )
636
- model .data ["calib_obs" ] = [calib_obs [i ] for i in inds ]
637
- model .data ["calib_errors" ] = [calib_errors [i ] for i in inds ]
638
- model .data ["calib_mag_sys" ] = [self .calib_mag_sys [i ] for i in inds ]
639
- model .data ["calib_vel_sys" ] = [self .calib_v_sys [i ] for i in inds ]
640
- model .data ["calib_col_sys" ] = [self .calib_c_sys [i ] for i in inds ]
641
- model .data ["calib_dist_mod" ] = [self .calib_dist_mod [i ] for i in inds ]
677
+ model .data ["calib_obs" ] = np . array ( [calib_obs [i ] for i in inds ])
678
+ model .data ["calib_errors" ] = np . array ( [calib_errors [i ] for i in inds ])
679
+ model .data ["calib_mag_sys" ] = np . array ( [self .calib_mag_sys [i ] for i in inds ])
680
+ model .data ["calib_vel_sys" ] = np . array ( [self .calib_v_sys [i ] for i in inds ])
681
+ model .data ["calib_col_sys" ] = np . array ( [self .calib_c_sys [i ] for i in inds ])
682
+ model .data ["calib_dist_mod" ] = np . array ( [self .calib_dist_mod [i ] for i in inds ])
642
683
643
684
# Convert differnet datasets to dataset indices
644
685
active_datasets = [self .calib_datasets [i ] for i in inds ]
@@ -652,18 +693,23 @@ def bootstrap(
652
693
653
694
model .set_initial_conditions (init )
654
695
696
+
655
697
# Setup/ build STAN model
656
698
with nullify_output (suppress_stdout = True , suppress_stderr = True ):
657
- fit = stan .build (model .model , data = model .data )
658
- samples = fit .sample (
659
- num_chains = chains ,
660
- num_samples = iters ,
661
- init = [model .init ] * chains ,
662
- num_warmup = warmup ,
699
+ data_file = model .write_json (f"data_{ rank } .json" , path = log_dir )
700
+
701
+ mdl = CmdStanModel (stan_file = stan_file )
702
+
703
+ fit = mdl .sample (
704
+ data = data_file ,
705
+ chains = chains ,
706
+ iter_warmup = warmup ,
707
+ iter_sampling = iters ,
663
708
save_warmup = save_warmup ,
709
+ inits = [model .init ] * chains ,
664
710
)
665
711
666
- self .posterior = samples . to_frame ()
712
+ self .posterior = fit . draws_pd ()
667
713
668
714
# Append found H0 values to list
669
715
h0 = quantile (self .posterior ["H0" ], 0.5 )
0 commit comments