Skip to content

Commit c9b546b

Browse files
authored
Merge pull request #16 from AlexHls/cmdstanpy
replace pystan with cmdstanpy
2 parents 082e8df + d7ea8d5 commit c9b546b

File tree

3 files changed

+109
-25
lines changed

3 files changed

+109
-25
lines changed

src/sccala/scmlib/models.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
import os
2+
13
import numpy as np
24

35

6+
from sccala.utillib.aux import NumpyEncoder
7+
8+
49
class SCM_Model:
510
def __init__(self):
611
self.data = {}
@@ -28,6 +33,31 @@ def print_results(self, df, blind=True):
2833
print("%s = %.2e +/- %.2e" % (np.mean(df[key][0]), np.std(df[key][0])))
2934
return
3035

36+
def write_json(self, filename, path=""):
37+
try:
38+
import json
39+
except ImportError:
40+
print("json module not available")
41+
return
42+
43+
with open(os.path.join(path, filename), "w") as f:
44+
json.dump(self.data, f, cls=NumpyEncoder)
45+
46+
return os.path.join(path, filename)
47+
48+
def write_stan(self, filename, path=""):
49+
# Check if file exists and if the contents are identical
50+
# to the current to avoid re-compilation
51+
if os.path.exists(os.path.join(path, filename)):
52+
with open(os.path.join(path, filename), "r") as f:
53+
if f.read() == self.model:
54+
print("Model already exists, skipping compilation...")
55+
return os.path.join(path, filename)
56+
57+
with open(os.path.join(path, filename), "w") as f:
58+
f.write(self.model)
59+
return os.path.join(path, filename)
60+
3161

3262
class NHHubbleFreeSCM(SCM_Model):
3363
def __init__(self):

src/sccala/scmlib/sccala.py

Lines changed: 71 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
import numpy as np
88
import pandas as pd
9-
import stan
109
import matplotlib.pyplot as plt
1110
from tqdm import trange
11+
from cmdstanpy import CmdStanModel
1212

1313
from sccala.scmlib.models import SCM_Model
1414
from sccala.utillib.aux import distmod_kin, quantile, split_list, nullify_output
@@ -237,6 +237,7 @@ def get_error_matrix(self, classic=False, rho=1.0, rho_calib=0.0):
237237
)
238238
return np.array(errors)
239239

240+
240241
def sample(
241242
self,
242243
model,
@@ -252,7 +253,8 @@ def sample(
252253
classic=False,
253254
):
254255
"""
255-
Samples the posterior for the given data and model
256+
Samples the posterior for the given data and model using
257+
the cmdstanpy interface.
256258
257259
Parameters
258260
----------
@@ -283,7 +285,7 @@ def sample(
283285
Returns
284286
-------
285287
posterior : pandas DataFrame
286-
Result of the STAN sampling
288+
287289
"""
288290

289291
assert issubclass(
@@ -368,17 +370,29 @@ def sample(
368370

369371
model.set_initial_conditions(init)
370372

371-
# Setup/ build STAN model
372-
fit = stan.build(model.model, data=model.data)
373-
samples = fit.sample(
374-
num_chains=chains,
375-
num_samples=iters,
376-
init=[model.init] * chains,
377-
num_warmup=warmup,
373+
data_file = model.write_json("data.json", path=log_dir)
374+
stan_file = model.write_stan("model.stan", path=log_dir)
375+
376+
mdl = CmdStanModel(stan_file=stan_file)
377+
378+
fit = mdl.sample(
379+
data=data_file,
380+
chains=chains,
381+
iter_warmup=warmup,
382+
iter_sampling=iters,
378383
save_warmup=save_warmup,
384+
inits=[model.init] * chains,
379385
)
380386

381-
self.posterior = samples.to_frame()
387+
summary = fit.summary()
388+
diagnose = fit.diagnose()
389+
390+
if not quiet:
391+
print(summary)
392+
print(diagnose)
393+
394+
395+
self.posterior = fit.draws_pd()
382396

383397
# Encrypt H0 for blinding
384398
if self.blind and model.hubble:
@@ -393,13 +407,24 @@ def sample(
393407
norm = None
394408

395409
if log_dir is not None:
396-
self.__save_samples__(self.posterior, log_dir=log_dir, norm=norm)
410+
savename = self.__save_samples__(self.posterior, log_dir=log_dir, norm=norm)
411+
chains_dir = savename.replace(".csv", "")
412+
os.makedirs(chains_dir)
413+
with open(os.path.join(chains_dir, "summary.txt"), "w") as f:
414+
f.write(summary.to_string())
415+
with open(os.path.join(chains_dir, "diagnose.txt"), "w") as f:
416+
f.write(diagnose)
417+
if not self.blind:
418+
# Only move the csv files if we're not blinding the result
419+
# TODO: find a way of blinding the individual chains
420+
fit.save_csvfiles(chains_dir)
397421

398422
if not quiet:
399423
model.print_results(self.posterior, blind=self.blind)
400424

401425
return self.posterior
402426

427+
403428
def bootstrap(
404429
self,
405430
model,
@@ -614,6 +639,22 @@ def bootstrap(
614639
else:
615640
done = []
616641

642+
if rank == 0:
643+
stan_file = model.write_stan("model.stan", path=log_dir)
644+
645+
# Create a model instance to trigger compilation and avoid
646+
# having to compile the model on each rank separately
647+
print("Compiling model...")
648+
mdl_0 = CmdStanModel(stan_file=stan_file)
649+
del mdl_0
650+
print("Model compiled, starting sampling...")
651+
else:
652+
# Should be done via broadcast, but this is easier
653+
# and the path is 'hardcoded' anyway
654+
stan_file = os.path.join(log_dir, "model.stan")
655+
656+
comm.Barrier()
657+
617658
for k in tr:
618659
if parallel:
619660
inds = bt_inds_lists[rank][k]
@@ -633,12 +674,12 @@ def bootstrap(
633674
continue
634675

635676
model.data["calib_sn_idx"] = len(self.calib_sn)
636-
model.data["calib_obs"] = [calib_obs[i] for i in inds]
637-
model.data["calib_errors"] = [calib_errors[i] for i in inds]
638-
model.data["calib_mag_sys"] = [self.calib_mag_sys[i] for i in inds]
639-
model.data["calib_vel_sys"] = [self.calib_v_sys[i] for i in inds]
640-
model.data["calib_col_sys"] = [self.calib_c_sys[i] for i in inds]
641-
model.data["calib_dist_mod"] = [self.calib_dist_mod[i] for i in inds]
677+
model.data["calib_obs"] = np.array([calib_obs[i] for i in inds])
678+
model.data["calib_errors"] = np.array([calib_errors[i] for i in inds])
679+
model.data["calib_mag_sys"] = np.array([self.calib_mag_sys[i] for i in inds])
680+
model.data["calib_vel_sys"] = np.array([self.calib_v_sys[i] for i in inds])
681+
model.data["calib_col_sys"] = np.array([self.calib_c_sys[i] for i in inds])
682+
model.data["calib_dist_mod"] = np.array([self.calib_dist_mod[i] for i in inds])
642683

643684
# Convert differnet datasets to dataset indices
644685
active_datasets = [self.calib_datasets[i] for i in inds]
@@ -652,18 +693,23 @@ def bootstrap(
652693

653694
model.set_initial_conditions(init)
654695

696+
655697
# Setup/ build STAN model
656698
with nullify_output(suppress_stdout=True, suppress_stderr=True):
657-
fit = stan.build(model.model, data=model.data)
658-
samples = fit.sample(
659-
num_chains=chains,
660-
num_samples=iters,
661-
init=[model.init] * chains,
662-
num_warmup=warmup,
699+
data_file = model.write_json(f"data_{rank}.json", path=log_dir)
700+
701+
mdl = CmdStanModel(stan_file=stan_file)
702+
703+
fit = mdl.sample(
704+
data=data_file,
705+
chains=chains,
706+
iter_warmup=warmup,
707+
iter_sampling=iters,
663708
save_warmup=save_warmup,
709+
inits=[model.init] * chains,
664710
)
665711

666-
self.posterior = samples.to_frame()
712+
self.posterior = fit.draws_pd()
667713

668714
# Append found H0 values to list
669715
h0 = quantile(self.posterior["H0"], 0.5)

src/sccala/utillib/aux.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import sys
3+
import json
34
from contextlib import contextmanager
45

56
import numpy as np
@@ -9,6 +10,13 @@
910
from sccala.utillib.const import H_ERG, C_AA, C_LIGHT
1011

1112

13+
class NumpyEncoder(json.JSONEncoder):
14+
def default(self, obj):
15+
if isinstance(obj, np.ndarray):
16+
return obj.tolist()
17+
return super().default(obj)
18+
19+
1220
def calc_single_error(err_low, err_high, mode="mean"):
1321
"""
1422
Calculates single error from asymmetric errors

0 commit comments

Comments
 (0)