Skip to content

Commit f44ff48

Browse files
fix nofeatures model (remove coef)
1 parent 52ed01e commit f44ff48

File tree

1 file changed

+35
-32
lines changed

1 file changed

+35
-32
lines changed

pyrocov/mutrans.py

+35-32
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,8 @@ def model(dataset, model_type, *, forecast_steps=None):
491491
# Configure reparametrization (which does not affect model density).
492492
reparam = {}
493493
if "reparam" in model_type:
494-
reparam["coef"] = LocScaleReparam()
494+
if "nofeatures" not in model_type:
495+
reparam["coef"] = LocScaleReparam()
495496
if "localrate" in model_type or "nofeatures" in model_type:
496497
reparam["rate_loc"] = LocScaleReparam()
497498
if "localinit" in model_type:
@@ -501,22 +502,22 @@ def model(dataset, model_type, *, forecast_steps=None):
501502
with poutine.reparam(config=reparam):
502503

503504
# Sample global random variables.
504-
coef_scale = pyro.sample("coef_scale", dist.LogNormal(-4, 2))
505+
if "nofeatures" not in model_type:
506+
coef_scale = pyro.sample("coef_scale", dist.LogNormal(-4, 2))
505507
rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2))
506508
init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2))
507-
if "localrate" in model_type:
508-
rate_loc_scale = pyro.sample("rate_loc_scale", dist.LogNormal(-4, 2))
509-
if "nofeatures" in model_type:
509+
if "localrate" or "nofeatures" in model_type:
510510
rate_loc_scale = pyro.sample("rate_loc_scale", dist.LogNormal(-4, 2))
511511
if "localinit" in model_type:
512512
init_loc_scale = pyro.sample("init_loc_scale", dist.LogNormal(0, 2))
513513

514514
# Assume relative growth rate depends strongly on mutations and weakly
515515
# on clade and place. Assume initial infections depend strongly on
516516
# clade and place.
517-
coef = pyro.sample(
518-
"coef", dist.Laplace(torch.zeros(F), coef_scale).to_event(1)
519-
) # [F]
517+
if "nofeatures" not in model_type:
518+
coef = pyro.sample(
519+
"coef", dist.Laplace(torch.zeros(F), coef_scale).to_event(1)
520+
) # [F]
520521
with clade_plate:
521522
if "localrate" in model_type:
522523
rate_loc = pyro.sample(
@@ -959,33 +960,35 @@ def log_stats(dataset: dict, result: dict) -> dict:
959960
stats = {k: float(v) for k, v in result["median"].items() if v.numel() == 1}
960961
stats["loss"] = float(np.median(result["losses"][-100:]))
961962
mutations = dataset["mutations"]
962-
mean = result["mean"]["coef"].cpu()
963-
if not mean.shape:
964-
return stats # Work around error in map estimation.
965-
logger.info(
966-
"Dense data has shape {} totaling {} sequences".format(
967-
" x ".join(map(str, dataset["weekly_clades"].shape)),
968-
int(dataset["weekly_clades"].sum()),
963+
964+
if "coef" in result["mean"]:
965+
mean = result["mean"]["coef"].cpu()
966+
if not mean.shape:
967+
return stats # Work around error in map estimation.
968+
logger.info(
969+
"Dense data has shape {} totaling {} sequences".format(
970+
" x ".join(map(str, dataset["weekly_clades"].shape)),
971+
int(dataset["weekly_clades"].sum()),
972+
)
969973
)
970-
)
971974

972-
# Statistical significance.
973-
std = result["std"]["coef"].cpu()
974-
sig = mean.abs() / std
975-
logger.info(f"|μ|/σ [median,max] = [{sig.median():0.3g},{sig.max():0.3g}]")
976-
stats["|μ|/σ median"] = sig.median()
977-
stats["|μ|/σ max"] = sig.max()
975+
# Statistical significance.
976+
std = result["std"]["coef"].cpu()
977+
sig = mean.abs() / std
978+
logger.info(f"|μ|/σ [median,max] = [{sig.median():0.3g},{sig.max():0.3g}]")
979+
stats["|μ|/σ median"] = sig.median()
980+
stats["|μ|/σ max"] = sig.max()
978981

979-
# Effects of individual mutations.
980-
for name in ["S:D614G", "S:N501Y", "S:E484K", "S:L452R"]:
981-
if name not in mutations:
982-
continue
983-
i = mutations.index(name)
984-
m = mean[i] * 0.01
985-
s = std[i] * 0.01
986-
logger.info(f"ΔlogR({name}) = {m:0.3g} ± {s:0.2f}")
987-
stats[f"ΔlogR({name}) mean"] = m
988-
stats[f"ΔlogR({name}) std"] = s
982+
# Effects of individual mutations.
983+
for name in ["S:D614G", "S:N501Y", "S:E484K", "S:L452R"]:
984+
if name not in mutations:
985+
continue
986+
i = mutations.index(name)
987+
m = mean[i] * 0.01
988+
s = std[i] * 0.01
989+
logger.info(f"ΔlogR({name}) = {m:0.3g} ± {s:0.2f}")
990+
stats[f"ΔlogR({name}) mean"] = m
991+
stats[f"ΔlogR({name}) std"] = s
989992

990993
# Growth rates of individual clades.
991994
rate = quotient_central_moments(

0 commit comments

Comments
 (0)