Skip to content

Commit 3362b02

Browse files
committed
Fix typo in factory.py add pilot_cost to compute_variance_reductions
1 parent 951562f commit 3362b02

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

pyapprox/multifidelity/factory.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ def __call__(self, est_covariance, est):
555555
if self._stat_type != "mean" and isinstance(
556556
est._stat, MultiOutputMeanAndVariance):
557557
return (
558-
est_covariance[est.nqoi+self._qoi_idx,
558+
est_covariance[est._nqoi+self._qoi_idx,
559559
est._nqoi+self._qoi_idx])
560560
elif (isinstance(
561561
est._stat, (MultiOutputVariance, MultiOutputMean)) or
@@ -570,7 +570,7 @@ def __repr__(self):
570570

571571
def compute_variance_reductions(optimized_estimators,
572572
criteria=ComparisonCriteria("det"),
573-
nhf_samples=None):
573+
nhf_samples=None, pilot_cost=None):
574574
"""
575575
Compute the variance reduction (relative to single model MC) for a
576576
list of optimized estimtors.
@@ -597,6 +597,11 @@ def compute_variance_reductions(optimized_estimators,
597597
evaluations that produce a estimator cost equal to the optimized
598598
target cost of the estimator is used. Usually, nhf_samples should be
599599
set to None.
600+
601+
pilot_cost : float
602+
The cost of running the pilot study. if not None this is used to
603+
determine the number of high-fidelity samples used by a single
604+
fidelity MC study that does not need a pilot
600605
"""
601606
var_red, est_criterias, sf_criterias = [], [], []
602607
optimized_estimators = optimized_estimators.copy()
@@ -606,7 +611,16 @@ def compute_variance_reductions(optimized_estimators,
606611
est_criteria = criteria(est._covariance_from_npartition_samples(
607612
est._rounded_npartition_samples), est)
608613
if nhf_samples is None:
609-
nhf_samples = int(est._rounded_target_cost/est._costs[0])
614+
if pilot_cost is None:
615+
nhf_samples = int(est._rounded_target_cost/est._costs[0])
616+
else:
617+
nhf_samples = int(
618+
(est._rounded_target_cost+pilot_cost)/est._costs[0])
619+
else:
620+
if pilot_cost is not None:
621+
msg = "pilot cost was specified even though nhf_samples was "
622+
msg += "not None"
623+
raise ValueError(msg)
610624
sf_criteria = criteria(
611625
est._stat.high_fidelity_estimator_covariance(
612626
nhf_samples), est)

0 commit comments

Comments
 (0)