diff --git a/src/sccala/scmlib/models.py b/src/sccala/scmlib/models.py index b50dfa4..9310255 100644 --- a/src/sccala/scmlib/models.py +++ b/src/sccala/scmlib/models.py @@ -1,6 +1,11 @@ +import os + import numpy as np +from sccala.utillib.aux import NumpyEncoder + + class SCM_Model: def __init__(self): self.data = {} @@ -28,6 +33,31 @@ def print_results(self, df, blind=True): print("%s = %.2e +/- %.2e" % (np.mean(df[key][0]), np.std(df[key][0]))) return + def write_json(self, filename, path=""): + try: + import json + except ImportError: + print("json module not available") + return + + with open(os.path.join(path, filename), "w") as f: + json.dump(self.data, f, cls=NumpyEncoder) + + return os.path.join(path, filename) + + def write_stan(self, filename, path=""): + # Check if file exists and if the contents are identical + # to the current to avoid re-compilation + if os.path.exists(os.path.join(path, filename)): + with open(os.path.join(path, filename), "r") as f: + if f.read() == self.model: + print("Model already exists, skipping compilation...") + return os.path.join(path, filename) + + with open(os.path.join(path, filename), "w") as f: + f.write(self.model) + return os.path.join(path, filename) + class NHHubbleFreeSCM(SCM_Model): def __init__(self): diff --git a/src/sccala/scmlib/sccala.py b/src/sccala/scmlib/sccala.py index 8919499..0f30278 100644 --- a/src/sccala/scmlib/sccala.py +++ b/src/sccala/scmlib/sccala.py @@ -6,9 +6,9 @@ import numpy as np import pandas as pd -import stan import matplotlib.pyplot as plt from tqdm import trange +from cmdstanpy import CmdStanModel from sccala.scmlib.models import SCM_Model from sccala.utillib.aux import distmod_kin, quantile, split_list, nullify_output @@ -237,6 +237,7 @@ def get_error_matrix(self, classic=False, rho=1.0, rho_calib=0.0): ) return np.array(errors) + def sample( self, model, @@ -252,7 +253,8 @@ def sample( classic=False, ): """ - Samples the posterior for the given data and model + Samples the posterior for the given data and model using + the cmdstanpy interface. Parameters ---------- @@ -283,7 +285,7 @@ def sample( Returns ------- posterior : pandas DataFrame - Result of the STAN sampling + """ assert issubclass( @@ -368,17 +370,29 @@ def sample( model.set_initial_conditions(init) - # Setup/ build STAN model - fit = stan.build(model.model, data=model.data) - samples = fit.sample( - num_chains=chains, - num_samples=iters, - init=[model.init] * chains, - num_warmup=warmup, + data_file = model.write_json("data.json", path=log_dir) + stan_file = model.write_stan("model.stan", path=log_dir) + + mdl = CmdStanModel(stan_file=stan_file) + + fit = mdl.sample( + data=data_file, + chains=chains, + iter_warmup=warmup, + iter_sampling=iters, save_warmup=save_warmup, + inits=[model.init] * chains, ) - self.posterior = samples.to_frame() + summary = fit.summary() + diagnose = fit.diagnose() + + if not quiet: + print(summary) + print(diagnose) + + + self.posterior = fit.draws_pd() # Encrypt H0 for blinding if self.blind and model.hubble: @@ -393,13 +407,24 @@ def sample( norm = None if log_dir is not None: - self.__save_samples__(self.posterior, log_dir=log_dir, norm=norm) + savename = self.__save_samples__(self.posterior, log_dir=log_dir, norm=norm) + chains_dir = savename.replace(".csv", "") + os.makedirs(chains_dir) + with open(os.path.join(chains_dir, "summary.txt"), "w") as f: + f.write(summary.to_string()) + with open(os.path.join(chains_dir, "diagnose.txt"), "w") as f: + f.write(diagnose) + if not self.blind: + # Only move the csv files if we're not blinding the result + # TODO: find a way of blinding the individual chains + fit.save_csvfiles(chains_dir) if not quiet: model.print_results(self.posterior, blind=self.blind) return self.posterior + def bootstrap( self, model, @@ -614,6 +639,22 @@ def bootstrap( else: done = [] + if rank == 0: + stan_file = model.write_stan("model.stan", path=log_dir) + + # Create a model instance to trigger compilation and avoid + # having to compile the model on each rank separately + print("Compiling model...") + mdl_0 = CmdStanModel(stan_file=stan_file) + del mdl_0 + print("Model compiled, starting sampling...") + else: + # Should be done via broadcast, but this is easier + # and the path is 'hardcoded' anyway + stan_file = os.path.join(log_dir, "model.stan") + + comm.Barrier() + for k in tr: if parallel: inds = bt_inds_lists[rank][k] @@ -633,12 +674,12 @@ def bootstrap( continue model.data["calib_sn_idx"] = len(self.calib_sn) - model.data["calib_obs"] = [calib_obs[i] for i in inds] - model.data["calib_errors"] = [calib_errors[i] for i in inds] - model.data["calib_mag_sys"] = [self.calib_mag_sys[i] for i in inds] - model.data["calib_vel_sys"] = [self.calib_v_sys[i] for i in inds] - model.data["calib_col_sys"] = [self.calib_c_sys[i] for i in inds] - model.data["calib_dist_mod"] = [self.calib_dist_mod[i] for i in inds] + model.data["calib_obs"] = np.array([calib_obs[i] for i in inds]) + model.data["calib_errors"] = np.array([calib_errors[i] for i in inds]) + model.data["calib_mag_sys"] = np.array([self.calib_mag_sys[i] for i in inds]) + model.data["calib_vel_sys"] = np.array([self.calib_v_sys[i] for i in inds]) + model.data["calib_col_sys"] = np.array([self.calib_c_sys[i] for i in inds]) + model.data["calib_dist_mod"] = np.array([self.calib_dist_mod[i] for i in inds]) # Convert differnet datasets to dataset indices active_datasets = [self.calib_datasets[i] for i in inds] @@ -652,18 +693,23 @@ def bootstrap( model.set_initial_conditions(init) + # Setup/ build STAN model with nullify_output(suppress_stdout=True, suppress_stderr=True): - fit = stan.build(model.model, data=model.data) - samples = fit.sample( - num_chains=chains, - num_samples=iters, - init=[model.init] * chains, - num_warmup=warmup, + data_file = model.write_json(f"data_{rank}.json", path=log_dir) + + mdl = CmdStanModel(stan_file=stan_file) + + fit = mdl.sample( + data=data_file, + chains=chains, + iter_warmup=warmup, + iter_sampling=iters, save_warmup=save_warmup, + inits=[model.init] * chains, ) - self.posterior = samples.to_frame() + self.posterior = fit.draws_pd() # Append found H0 values to list h0 = quantile(self.posterior["H0"], 0.5) diff --git a/src/sccala/utillib/aux.py b/src/sccala/utillib/aux.py index 8a228c2..0d5bf15 100644 --- a/src/sccala/utillib/aux.py +++ b/src/sccala/utillib/aux.py @@ -1,5 +1,6 @@ import os import sys +import json from contextlib import contextmanager import numpy as np @@ -9,6 +10,13 @@ from sccala.utillib.const import H_ERG, C_AA, C_LIGHT +class NumpyEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.ndarray): + return obj.tolist() + return super().default(obj) + + def calc_single_error(err_low, err_high, mode="mean"): """ Calculates single error from asymmetric errors