@@ -491,7 +491,8 @@ def model(dataset, model_type, *, forecast_steps=None):
491
491
# Configure reparametrization (which does not affect model density).
492
492
reparam = {}
493
493
if "reparam" in model_type :
494
- reparam ["coef" ] = LocScaleReparam ()
494
+ if "nofeatures" not in model_type :
495
+ reparam ["coef" ] = LocScaleReparam ()
495
496
if "localrate" in model_type or "nofeatures" in model_type :
496
497
reparam ["rate_loc" ] = LocScaleReparam ()
497
498
if "localinit" in model_type :
@@ -501,22 +502,22 @@ def model(dataset, model_type, *, forecast_steps=None):
501
502
with poutine .reparam (config = reparam ):
502
503
503
504
# Sample global random variables.
504
- coef_scale = pyro .sample ("coef_scale" , dist .LogNormal (- 4 , 2 ))
505
+ if "nofeatures" not in model_type :
506
+ coef_scale = pyro .sample ("coef_scale" , dist .LogNormal (- 4 , 2 ))
505
507
rate_scale = pyro .sample ("rate_scale" , dist .LogNormal (- 4 , 2 ))
506
508
init_scale = pyro .sample ("init_scale" , dist .LogNormal (0 , 2 ))
507
- if "localrate" in model_type :
508
- rate_loc_scale = pyro .sample ("rate_loc_scale" , dist .LogNormal (- 4 , 2 ))
509
- if "nofeatures" in model_type :
509
+ if "localrate" or "nofeatures" in model_type :
510
510
rate_loc_scale = pyro .sample ("rate_loc_scale" , dist .LogNormal (- 4 , 2 ))
511
511
if "localinit" in model_type :
512
512
init_loc_scale = pyro .sample ("init_loc_scale" , dist .LogNormal (0 , 2 ))
513
513
514
514
# Assume relative growth rate depends strongly on mutations and weakly
515
515
# on clade and place. Assume initial infections depend strongly on
516
516
# clade and place.
517
- coef = pyro .sample (
518
- "coef" , dist .Laplace (torch .zeros (F ), coef_scale ).to_event (1 )
519
- ) # [F]
517
+ if "nofeatures" not in model_type :
518
+ coef = pyro .sample (
519
+ "coef" , dist .Laplace (torch .zeros (F ), coef_scale ).to_event (1 )
520
+ ) # [F]
520
521
with clade_plate :
521
522
if "localrate" in model_type :
522
523
rate_loc = pyro .sample (
@@ -959,33 +960,35 @@ def log_stats(dataset: dict, result: dict) -> dict:
959
960
stats = {k : float (v ) for k , v in result ["median" ].items () if v .numel () == 1 }
960
961
stats ["loss" ] = float (np .median (result ["losses" ][- 100 :]))
961
962
mutations = dataset ["mutations" ]
962
- mean = result ["mean" ]["coef" ].cpu ()
963
- if not mean .shape :
964
- return stats # Work around error in map estimation.
965
- logger .info (
966
- "Dense data has shape {} totaling {} sequences" .format (
967
- " x " .join (map (str , dataset ["weekly_clades" ].shape )),
968
- int (dataset ["weekly_clades" ].sum ()),
963
+
964
+ if "coef" in result ["mean" ]:
965
+ mean = result ["mean" ]["coef" ].cpu ()
966
+ if not mean .shape :
967
+ return stats # Work around error in map estimation.
968
+ logger .info (
969
+ "Dense data has shape {} totaling {} sequences" .format (
970
+ " x " .join (map (str , dataset ["weekly_clades" ].shape )),
971
+ int (dataset ["weekly_clades" ].sum ()),
972
+ )
969
973
)
970
- )
971
974
972
- # Statistical significance.
973
- std = result ["std" ]["coef" ].cpu ()
974
- sig = mean .abs () / std
975
- logger .info (f"|μ|/σ [median,max] = [{ sig .median ():0.3g} ,{ sig .max ():0.3g} ]" )
976
- stats ["|μ|/σ median" ] = sig .median ()
977
- stats ["|μ|/σ max" ] = sig .max ()
975
+ # Statistical significance.
976
+ std = result ["std" ]["coef" ].cpu ()
977
+ sig = mean .abs () / std
978
+ logger .info (f"|μ|/σ [median,max] = [{ sig .median ():0.3g} ,{ sig .max ():0.3g} ]" )
979
+ stats ["|μ|/σ median" ] = sig .median ()
980
+ stats ["|μ|/σ max" ] = sig .max ()
978
981
979
- # Effects of individual mutations.
980
- for name in ["S:D614G" , "S:N501Y" , "S:E484K" , "S:L452R" ]:
981
- if name not in mutations :
982
- continue
983
- i = mutations .index (name )
984
- m = mean [i ] * 0.01
985
- s = std [i ] * 0.01
986
- logger .info (f"ΔlogR({ name } ) = { m :0.3g} ± { s :0.2f} " )
987
- stats [f"ΔlogR({ name } ) mean" ] = m
988
- stats [f"ΔlogR({ name } ) std" ] = s
982
+ # Effects of individual mutations.
983
+ for name in ["S:D614G" , "S:N501Y" , "S:E484K" , "S:L452R" ]:
984
+ if name not in mutations :
985
+ continue
986
+ i = mutations .index (name )
987
+ m = mean [i ] * 0.01
988
+ s = std [i ] * 0.01
989
+ logger .info (f"ΔlogR({ name } ) = { m :0.3g} ± { s :0.2f} " )
990
+ stats [f"ΔlogR({ name } ) mean" ] = m
991
+ stats [f"ΔlogR({ name } ) std" ] = s
989
992
990
993
# Growth rates of individual clades.
991
994
rate = quotient_central_moments (
0 commit comments