Skip to content

Commit

Permalink
Merge pull request #16 from AlexHls/cmdstanpy
Browse files Browse the repository at this point in the history
replace pystan with cmdstanpy
  • Loading branch information
AlexHls authored Aug 11, 2024
2 parents 082e8df + d7ea8d5 commit c9b546b
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 25 deletions.
30 changes: 30 additions & 0 deletions src/sccala/scmlib/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import os

import numpy as np


from sccala.utillib.aux import NumpyEncoder


class SCM_Model:
def __init__(self):
self.data = {}
Expand Down Expand Up @@ -28,6 +33,31 @@ def print_results(self, df, blind=True):
print("%s = %.2e +/- %.2e" % (np.mean(df[key][0]), np.std(df[key][0])))
return

def write_json(self, filename, path=""):
try:
import json
except ImportError:
print("json module not available")
return

with open(os.path.join(path, filename), "w") as f:
json.dump(self.data, f, cls=NumpyEncoder)

return os.path.join(path, filename)

def write_stan(self, filename, path=""):
# Check if file exists and if the contents are identical
# to the current to avoid re-compilation
if os.path.exists(os.path.join(path, filename)):
with open(os.path.join(path, filename), "r") as f:
if f.read() == self.model:
print("Model already exists, skipping compilation...")
return os.path.join(path, filename)

with open(os.path.join(path, filename), "w") as f:
f.write(self.model)
return os.path.join(path, filename)


class NHHubbleFreeSCM(SCM_Model):
def __init__(self):
Expand Down
96 changes: 71 additions & 25 deletions src/sccala/scmlib/sccala.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

import numpy as np
import pandas as pd
import stan
import matplotlib.pyplot as plt
from tqdm import trange
from cmdstanpy import CmdStanModel

from sccala.scmlib.models import SCM_Model
from sccala.utillib.aux import distmod_kin, quantile, split_list, nullify_output
Expand Down Expand Up @@ -237,6 +237,7 @@ def get_error_matrix(self, classic=False, rho=1.0, rho_calib=0.0):
)
return np.array(errors)


def sample(
self,
model,
Expand All @@ -252,7 +253,8 @@ def sample(
classic=False,
):
"""
Samples the posterior for the given data and model
Samples the posterior for the given data and model using
the cmdstanpy interface.
Parameters
----------
Expand Down Expand Up @@ -283,7 +285,7 @@ def sample(
Returns
-------
posterior : pandas DataFrame
Result of the STAN sampling
"""

assert issubclass(
Expand Down Expand Up @@ -368,17 +370,29 @@ def sample(

model.set_initial_conditions(init)

# Setup/ build STAN model
fit = stan.build(model.model, data=model.data)
samples = fit.sample(
num_chains=chains,
num_samples=iters,
init=[model.init] * chains,
num_warmup=warmup,
data_file = model.write_json("data.json", path=log_dir)
stan_file = model.write_stan("model.stan", path=log_dir)

mdl = CmdStanModel(stan_file=stan_file)

fit = mdl.sample(
data=data_file,
chains=chains,
iter_warmup=warmup,
iter_sampling=iters,
save_warmup=save_warmup,
inits=[model.init] * chains,
)

self.posterior = samples.to_frame()
summary = fit.summary()
diagnose = fit.diagnose()

if not quiet:
print(summary)
print(diagnose)


self.posterior = fit.draws_pd()

# Encrypt H0 for blinding
if self.blind and model.hubble:
Expand All @@ -393,13 +407,24 @@ def sample(
norm = None

if log_dir is not None:
self.__save_samples__(self.posterior, log_dir=log_dir, norm=norm)
savename = self.__save_samples__(self.posterior, log_dir=log_dir, norm=norm)
chains_dir = savename.replace(".csv", "")
os.makedirs(chains_dir)
with open(os.path.join(chains_dir, "summary.txt"), "w") as f:
f.write(summary.to_string())
with open(os.path.join(chains_dir, "diagnose.txt"), "w") as f:
f.write(diagnose)
if not self.blind:
# Only move the csv files if we're not blinding the result
# TODO: find a way of blinding the individual chains
fit.save_csvfiles(chains_dir)

if not quiet:
model.print_results(self.posterior, blind=self.blind)

return self.posterior


def bootstrap(
self,
model,
Expand Down Expand Up @@ -614,6 +639,22 @@ def bootstrap(
else:
done = []

if rank == 0:
stan_file = model.write_stan("model.stan", path=log_dir)

# Create a model instance to trigger compilation and avoid
# having to compile the model on each rank separately
print("Compiling model...")
mdl_0 = CmdStanModel(stan_file=stan_file)
del mdl_0
print("Model compiled, starting sampling...")
else:
# Should be done via broadcast, but this is easier
# and the path is 'hardcoded' anyway
stan_file = os.path.join(log_dir, "model.stan")

comm.Barrier()

for k in tr:
if parallel:
inds = bt_inds_lists[rank][k]
Expand All @@ -633,12 +674,12 @@ def bootstrap(
continue

model.data["calib_sn_idx"] = len(self.calib_sn)
model.data["calib_obs"] = [calib_obs[i] for i in inds]
model.data["calib_errors"] = [calib_errors[i] for i in inds]
model.data["calib_mag_sys"] = [self.calib_mag_sys[i] for i in inds]
model.data["calib_vel_sys"] = [self.calib_v_sys[i] for i in inds]
model.data["calib_col_sys"] = [self.calib_c_sys[i] for i in inds]
model.data["calib_dist_mod"] = [self.calib_dist_mod[i] for i in inds]
model.data["calib_obs"] = np.array([calib_obs[i] for i in inds])
model.data["calib_errors"] = np.array([calib_errors[i] for i in inds])
model.data["calib_mag_sys"] = np.array([self.calib_mag_sys[i] for i in inds])
model.data["calib_vel_sys"] = np.array([self.calib_v_sys[i] for i in inds])
model.data["calib_col_sys"] = np.array([self.calib_c_sys[i] for i in inds])
model.data["calib_dist_mod"] = np.array([self.calib_dist_mod[i] for i in inds])

# Convert differnet datasets to dataset indices
active_datasets = [self.calib_datasets[i] for i in inds]
Expand All @@ -652,18 +693,23 @@ def bootstrap(

model.set_initial_conditions(init)


# Setup/ build STAN model
with nullify_output(suppress_stdout=True, suppress_stderr=True):
fit = stan.build(model.model, data=model.data)
samples = fit.sample(
num_chains=chains,
num_samples=iters,
init=[model.init] * chains,
num_warmup=warmup,
data_file = model.write_json(f"data_{rank}.json", path=log_dir)

mdl = CmdStanModel(stan_file=stan_file)

fit = mdl.sample(
data=data_file,
chains=chains,
iter_warmup=warmup,
iter_sampling=iters,
save_warmup=save_warmup,
inits=[model.init] * chains,
)

self.posterior = samples.to_frame()
self.posterior = fit.draws_pd()

# Append found H0 values to list
h0 = quantile(self.posterior["H0"], 0.5)
Expand Down
8 changes: 8 additions & 0 deletions src/sccala/utillib/aux.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import sys
import json
from contextlib import contextmanager

import numpy as np
Expand All @@ -9,6 +10,13 @@
from sccala.utillib.const import H_ERG, C_AA, C_LIGHT


class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
return super().default(obj)


def calc_single_error(err_low, err_high, mode="mean"):
"""
Calculates single error from asymmetric errors
Expand Down

0 comments on commit c9b546b

Please sign in to comment.