Skip to content

Commit

Permalink
Compatibility update (#17)
Browse files Browse the repository at this point in the history
* Update for newer scipy and pandas syntax

* Fix label strings

* More label strings

* More pandas indexing

* Fix classic STAN model syntax

* Add output dir to sampling procedure

* Avoid race condition

* Add file cleanup
  • Loading branch information
AlexHls authored Oct 22, 2024
1 parent c9b546b commit e4df8bc
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 46 deletions.
16 changes: 9 additions & 7 deletions src/sccala/asynphot/synphot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def calculate_vega_zp(filter):
/ H_ERG
/ C_AA
* integrate.simpson(
vega_flux * filter.interpolate(vega_wav) * vega_wav, vega_wav
vega_flux * filter.interpolate(vega_wav) * vega_wav, x=vega_wav
)
)
+ 0.03
Expand All @@ -38,8 +38,10 @@ def calculate_lambda_eff(spec_wav, spec_flux, filter):
"""

lambda_eff = integrate.simpson(
spec_flux * filter.interpolate(spec_wav) * spec_wav**2, spec_wav
) / integrate.simpson(spec_flux * filter.interpolate(spec_wav) * spec_wav, spec_wav)
spec_flux * filter.interpolate(spec_wav) * spec_wav**2, x=spec_wav
) / integrate.simpson(
spec_flux * filter.interpolate(spec_wav) * spec_wav, x=spec_wav
)

return lambda_eff

Expand Down Expand Up @@ -79,7 +81,7 @@ def calculate_vega_magnitude(
/ H_ERG
/ C_AA
* integrate.simpson(
spec_flux * filter.interpolate(spec_wav) * spec_wav, spec_wav
spec_flux * filter.interpolate(spec_wav) * spec_wav, x=spec_wav
)
)
+ vega_zp
Expand All @@ -101,7 +103,7 @@ def calculate_vega_magnitude(
/ C_AA
* integrate.simpson(
noisy_flux * filter.interpolate(spec_wav) * spec_wav,
spec_wav,
x=spec_wav,
)
)
+ vega_zp
Expand All @@ -113,12 +115,12 @@ def calculate_vega_magnitude(
2.5
/ np.log(10)
/ integrate.simpson(
spec_flux * filter.interpolate(spec_wav) * spec_wav, spec_wav
spec_flux * filter.interpolate(spec_wav) * spec_wav, x=spec_wav
)
* np.sqrt(
err_integrate.mod_simpson(
(spec_err * filter.interpolate(spec_wav) * spec_wav) ** 2,
spec_wav,
x=spec_wav,
)
)
)
Expand Down
8 changes: 4 additions & 4 deletions src/sccala/interplib/epoch_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,14 @@ def diagnostic_plot(self, diagnostic, target, flux_interp=False):
self.plustwosigma[plotind] / conv,
alpha=0.1,
color="red",
label="2$\sigma$ (95.44%)",
label=r"2$\sigma$ (95.44%)",
)
ax1.axvspan(
self.minusonesigma[plotind] / conv,
self.plusonesigma[plotind] / conv,
alpha=0.3,
color="red",
label="1$\sigma$ (68.26%)",
label=r"1$\sigma$ (68.26%)",
)
ax1.axvline(
self.median[plotind] / conv,
Expand Down Expand Up @@ -253,7 +253,7 @@ def diagnostic_plot(self, diagnostic, target, flux_interp=False):
- np.percentile(self.tkde.resample(10000), 84.13)
)
ax2.axvspan(
lower, upper, alpha=0.3, color="blue", label="1$\sigma$ (68.26%)"
lower, upper, alpha=0.3, color="blue", label=r"1$\sigma$ (68.26%)"
)
ax2.axvline(
self.dates[plotind],
Expand Down Expand Up @@ -295,7 +295,7 @@ def diagnostic_plot(self, diagnostic, target, flux_interp=False):
lower = self.dates[plotind] + self.toe - np.percentile(self.tkde, 15.87)
upper = self.dates[plotind] + self.toe - np.percentile(self.tkde, 84.13)
ax3.axvspan(
lower, upper, alpha=0.3, color="blue", label="1$\sigma$ (68.26%)"
lower, upper, alpha=0.3, color="blue", label=r"1$\sigma$ (68.26%)"
)
ax3.axvline(
self.dates[plotind],
Expand Down
5 changes: 5 additions & 0 deletions src/sccala/sccala_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def main(args):
replacement=args.no_replacement,
restart=args.disable_restart,
walltime=args.time,
output_dir=args.output_dir,
)

print("Finished bootstrap resampling")
Expand Down Expand Up @@ -123,6 +124,10 @@ def cli():
help="Disables writing or restart file",
action="store_false",
)
parser.add_argument(
"--output_dir",
help="Directory used for storing STAN temporary files",
)

args = parser.parse_args()

Expand Down
5 changes: 5 additions & 0 deletions src/sccala/sccala_scm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def main(args):
save_warmup=args.save_warmup,
quiet=False,
classic=args.classic,
output_dir=args.output_dir,
)

print("Finished sampling")
Expand Down Expand Up @@ -147,6 +148,10 @@ def cli():
default="HUBBLE",
help="Encryption key used for blinding H0. Default: HUBBLE",
)
parser.add_argument(
"--output_dir",
help="Directory used for storing STAN temporary files.",
)

args = parser.parse_args()

Expand Down
2 changes: 1 addition & 1 deletion src/sccala/scmlib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,7 @@ def __init__(self):
calib_c_true ~ normal(cs,rc);
for (i in 1:sn_idx) {
target += normal_lpdf(obs[i] | [mag_true[i] + mag_sys[i], v_true[i] + vel_sys[i], c_true[i] + col_sys[i]]', errors[i] + [[sigma_int^2, 0, 0], [0, 0, 0], [0, 0, 0]]);
target += multi_normal_lpdf(obs[i] | [mag_true[i] + mag_sys[i], v_true[i] + vel_sys[i], c_true[i] + col_sys[i]]', errors[i] + [[sigma_int^2, 0, 0], [0, 0, 0], [0, 0, 0]]);
}
for (i in 1:calib_sn_idx) {
target += normal_lpdf(calib_obs[i] | [calib_mag_true[i] + calib_mag_sys[i], calib_v_true[i] + calib_vel_sys[i], calib_c_true[i] + calib_col_sys[i]]', sqrt(calib_errors[i] + [calib_sigma_int[calib_dset_idx[i]]^2, 0, 0]'));
Expand Down
34 changes: 28 additions & 6 deletions src/sccala/scmlib/sccala.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ def get_error_matrix(self, classic=False, rho=1.0, rho_calib=0.0):
)
return np.array(errors)


def sample(
self,
model,
Expand All @@ -251,6 +250,7 @@ def sample(
quiet=False,
init=None,
classic=False,
output_dir=None,
):
"""
Samples the posterior for the given data and model using
Expand Down Expand Up @@ -281,6 +281,8 @@ def sample(
classic : bool
Switches classic mode on if True. In classic mode, a/e input is
ignored.
output_dir : str
Directory where temporary STAN files will be stored. Default: None
Returns
-------
Expand Down Expand Up @@ -382,6 +384,7 @@ def sample(
iter_sampling=iters,
save_warmup=save_warmup,
inits=[model.init] * chains,
output_dir=output_dir,
)

summary = fit.summary()
Expand All @@ -391,7 +394,6 @@ def sample(
print(summary)
print(diagnose)


self.posterior = fit.draws_pd()

# Encrypt H0 for blinding
Expand Down Expand Up @@ -424,7 +426,6 @@ def sample(

return self.posterior


def bootstrap(
self,
model,
Expand All @@ -442,6 +443,7 @@ def bootstrap(
replacement=True,
restart=True,
walltime=24.0,
output_dir=None,
):
"""
Samples the posterior for the given data and model
Expand Down Expand Up @@ -480,6 +482,8 @@ def bootstrap(
Wallclock time (in h) available. Once 95% of the available wallclock
time is used, no new iteration will be started and job will exit
cleanly. Should be used with restart set to True. Default 24.0
output_dir : str
Directory where temporary STAN files will be stored. Default: None
Returns
-------
Expand Down Expand Up @@ -653,6 +657,13 @@ def bootstrap(
# and the path is 'hardcoded' anyway
stan_file = os.path.join(log_dir, "model.stan")

if output_dir is not None:
output_dir_rank = os.path.join(output_dir, "rank_%03d" % rank)
if not os.path.exists(output_dir_rank):
os.makedirs(output_dir_rank)
else:
output_dir_rank = None

comm.Barrier()

for k in tr:
Expand All @@ -676,10 +687,14 @@ def bootstrap(
model.data["calib_sn_idx"] = len(self.calib_sn)
model.data["calib_obs"] = np.array([calib_obs[i] for i in inds])
model.data["calib_errors"] = np.array([calib_errors[i] for i in inds])
model.data["calib_mag_sys"] = np.array([self.calib_mag_sys[i] for i in inds])
model.data["calib_mag_sys"] = np.array(
[self.calib_mag_sys[i] for i in inds]
)
model.data["calib_vel_sys"] = np.array([self.calib_v_sys[i] for i in inds])
model.data["calib_col_sys"] = np.array([self.calib_c_sys[i] for i in inds])
model.data["calib_dist_mod"] = np.array([self.calib_dist_mod[i] for i in inds])
model.data["calib_dist_mod"] = np.array(
[self.calib_dist_mod[i] for i in inds]
)

# Convert differnet datasets to dataset indices
active_datasets = [self.calib_datasets[i] for i in inds]
Expand All @@ -693,7 +708,6 @@ def bootstrap(

model.set_initial_conditions(init)


# Setup/ build STAN model
with nullify_output(suppress_stdout=True, suppress_stderr=True):
data_file = model.write_json(f"data_{rank}.json", path=log_dir)
Expand All @@ -707,6 +721,7 @@ def bootstrap(
iter_sampling=iters,
save_warmup=save_warmup,
inits=[model.init] * chains,
output_dir=output_dir_rank,
)

self.posterior = fit.draws_pd()
Expand Down Expand Up @@ -741,6 +756,13 @@ def bootstrap(
print("[TIMELIMIT] Rank %d reached wallclock limit, exiting..." % rank)
break

# If not using the default output_dir, clean up the temporary files to avoid
# excessive disk usage
if output_dir is not None:
files = glob.glob(os.path.join(output_dir_rank, "*"))
for f in files:
os.remove(f)

if parallel:
comm.Barrier()
h0_vals = comm.gather(h0_vals, root=0)
Expand Down
27 changes: 13 additions & 14 deletions src/sccala/speclib/linefit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ def __init__(self, wav, flux, error, numcode=100):
"""

# Check if error has correct format
assert (
len(error) == len(flux) or len(error) == 1
), "Length of error does not match length of flux: %d <-> %d" % (
len(error),
len(flux),
assert len(error) == len(flux) or len(error) == 1, (
"Length of error does not match length of flux: %d <-> %d"
% (
len(error),
len(flux),
)
)

self.wav = wav
Expand Down Expand Up @@ -468,14 +469,14 @@ def diagnostic_plot(self, line, save):
np.percentile(ae_avg, 97.72),
alpha=0.1,
color="red",
label="2$\sigma$ (95.44%)",
label=r"2$\sigma$ (95.44%)",
)
ax1.axvspan(
np.percentile(ae_avg, 15.87),
np.percentile(ae_avg, 84.13),
alpha=0.3,
color="red",
label="1$\sigma$ (68.26%)",
label=r"1$\sigma$ (68.26%)",
)
ax1.axvline(np.percentile(ae_avg, 50), color="red", label="Median")

Expand Down Expand Up @@ -527,15 +528,15 @@ def diagnostic_plot(self, line, save):
)
ax2.axvline(4861, color="k", ls="--", alpha=0.3)
ax2.set_title(r"H$_\alpha$ line fit")
ax2.set_xlabel("Wavelength ($\AA$)")
ax2.set_xlabel(r"Wavelength ($\AA$)")
ax2.set_ylabel("Flux (arb. unit)")
ax2.legend()
ax2.set_xlim([min(x), max(x)])
ax2.grid(which="major")

velocity, vel_err_lower, vel_err_upper = self.get_results(line)
ax2.set_title(
"MinWavelength: {:.2f} +{:.2f}/ -{:.2f} $\AA$\n a/e: {:.2e} +{:.2e}/ -{:.2e}".format(
r"MinWavelength: {:.2f} +{:.2f}/ -{:.2f} $\AA$\n a/e: {:.2e} +{:.2e}/ -{:.2e}".format(
median,
min_error_upper,
min_error_lower,
Expand Down Expand Up @@ -580,15 +581,13 @@ def diagnostic_plot(self, line, save):
axes2_ticks = []
for X in ax1_ticks:
# Velocity in km/s
vel_value = (
299792458 * (4861**2 - X**2) / (4861**2 + X**2) / 1000
)
vel_value = 299792458 * (4861**2 - X**2) / (4861**2 + X**2) / 1000
axes2_ticks.append("%.0f" % vel_value)

axes2.set_xticks(ax1_ticks)
axes2.set_xbound(ax1.get_xbound())
axes2.set_xticklabels(axes2_ticks)
ax1.set_xlabel("Wavelength ($\AA$)")
ax1.set_xlabel(r"Wavelength ($\AA$)")
axes2.set_xlabel("Velocity (km/s)")

# Plot fit with error band for peak position
Expand Down Expand Up @@ -623,7 +622,7 @@ def diagnostic_plot(self, line, save):
velocity, vel_err_lower, vel_err_upper = self.get_results(line)

ax2.set_title(
"MinWavelength: {:.2f} +{:.2f}/ -{:.2f} $\AA$\n Velocity: {:.2f} +{:.2f}/ -{:.2f} km/s".format(
r"MinWavelength: {:.2f} +{:.2f}/ -{:.2f} $\AA$\n Velocity: {:.2f} +{:.2f}/ -{:.2f} km/s".format(
median,
min_error_upper,
min_error_lower,
Expand Down
14 changes: 4 additions & 10 deletions src/sccala/spectral_line_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def main(args):
"'noisefit' should be 'True' or 'False', but it is %s" % noisefit[i]
)

# try:
print("ID: %s" % str(sid))
print("Noisefit: ", nf)
print("HODLR solver: ", hodlrsolver)
Expand All @@ -114,11 +113,6 @@ def main(args):
hodlrsolver=hodlrsolver,
num_live_points=args.num_live_points,
)
# except ValueError as e:
# warnings.warn(
# "Encountered error '%s' for ID %s, skipping..." % (str(e), str(sid))
# )
# continue

peak_loc, peak_error_lower, peak_error_upper = fit.get_results(line[i])

Expand All @@ -139,10 +133,10 @@ def main(args):
if not data.index.isin([(line[i], sid)]).any():
data = pd.concat([data, expdf])
else:
data.loc[(line[i], sid)]["MJD"] = mjd
data.loc[(line[i], sid)]["PeakLoc"] = peak_loc
data.loc[(line[i], sid)]["PeakErrorLower"] = peak_error_lower
data.loc[(line[i], sid)]["PeakErrorUpper"] = peak_error_upper
data.loc["MJD", (line[i], sid)] = mjd
data.loc["PeakLoc", (line[i], sid)] = peak_loc
data.loc["PeakErrorLower", (line[i], sid)] = peak_error_lower
data.loc["PeakErrorUpper", (line[i], sid)] = peak_error_upper
data.to_csv(exp_name)
else:
expdf.to_csv(exp_name)
Expand Down
Loading

0 comments on commit e4df8bc

Please sign in to comment.