Skip to content

Commit 21eb865

Browse files
michaelosthegeColCarroll
authored andcommitted
Repair parallelized population sampling (#3559)
* test with and without parallelization across cores demonstrates issue #3555 * replace parallelize kwarg by reliance on cores setting closes #3555 * add the changes from pull 3559 * use more general suggestion in the log message Co-Authored-By: Colin <[email protected]>
1 parent c0edddd commit 21eb865

File tree

3 files changed

+18
-4
lines changed

3 files changed

+18
-4
lines changed

RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
### Maintenance
1313
- Moved math operations out of `Rice`, `TruncatedNormal`, `Triangular` and `ZeroInflatedNegativeBinomial` `random` methods. Math operations on values returned by `draw_values` might not broadcast well, and all the `size` aware broadcasting is left to `generate_samples`. Fixes [#3481](https://github.com/pymc-devs/pymc3/issues/3481) and [#3508](https://github.com/pymc-devs/pymc3/issues/3508)
14+
- Parallelization of population steppers (`DEMetropolis`) is now set via the `cores` argument. ([#3559](https://github.com/pymc-devs/pymc3/pull/3559))
1415

1516
## PyMC3 3.7 (May 29 2019)
1617

pymc3/sampling.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None, trace=N
453453
if has_population_samplers:
454454
_log.info('Population sampling ({} chains)'.format(chains))
455455
_print_step_hierarchy(step)
456-
trace = _sample_population(**sample_args)
456+
trace = _sample_population(**sample_args, parallelize=cores > 1)
457457
else:
458458
_log.info('Sequential sampling ({} chains in 1 job)'.format(chains))
459459
_print_step_hierarchy(step)
@@ -690,7 +690,7 @@ def __init__(self, steppers, parallelize):
690690
if parallelize:
691691
try:
692692
# configure a child process for each stepper
693-
_log.info('Attempting to parallelize chains.')
693+
_log.info('Attempting to parallelize chains to all cores. You can turn this off with `pm.sample(cores=1)`.')
694694
import multiprocessing
695695
for c, stepper in enumerate(tqdm(steppers)):
696696
slave_end, master_end = multiprocessing.Pipe()
@@ -715,7 +715,7 @@ def __init__(self, steppers, parallelize):
715715
_log.debug('Error was: ', exec_info=True)
716716
else:
717717
_log.info('Chains are not parallelized. You can enable this by passing '
718-
'pm.sample(parallelize=True).')
718+
'`pm.sample(cores=n)`, where n > 1.')
719719
return super().__init__()
720720

721721
def __enter__(self):

pymc3/tests/test_step.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -915,12 +915,25 @@ def test_checks_population_size(self):
915915
trace = sample(draws=100, chains=4, step=step)
916916
pass
917917

918+
def test_nonparallelized_chains_are_random(self):
919+
with Model() as model:
920+
x = Normal("x", 0, 1)
921+
for stepper in TestPopulationSamplers.steppers:
922+
step = stepper()
923+
trace = sample(chains=4, cores=1, draws=20, tune=0, step=DEMetropolis())
924+
samples = np.array(trace.get_values("x", combine=False))[:, 5]
925+
926+
assert len(set(samples)) == 4, "Parallelized {} " "chains are identical.".format(
927+
stepper
928+
)
929+
pass
930+
918931
def test_parallelized_chains_are_random(self):
919932
with Model() as model:
920933
x = Normal("x", 0, 1)
921934
for stepper in TestPopulationSamplers.steppers:
922935
step = stepper()
923-
trace = sample(chains=4, draws=20, tune=0, step=DEMetropolis())
936+
trace = sample(chains=4, cores=4, draws=20, tune=0, step=DEMetropolis())
924937
samples = np.array(trace.get_values("x", combine=False))[:, 5]
925938

926939
assert len(set(samples)) == 4, "Parallelized {} " "chains are identical.".format(

0 commit comments

Comments
 (0)