Skip to content

Commit 2dd4c8c

Browse files
aloctavodiaricardoV94
authored andcommitted
allow external step method
1 parent 50ceb2b commit 2dd4c8c

File tree

4 files changed

+39
-45
lines changed

4 files changed

+39
-45
lines changed

pymc/sampling.py

+6-44
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,7 @@
6464
)
6565
from pymc.model import Model, modelcontext
6666
from pymc.parallel_sampling import Draw, _cpu_count
67-
from pymc.step_methods import (
68-
NUTS,
69-
BinaryGibbsMetropolis,
70-
BinaryMetropolis,
71-
CategoricalGibbsMetropolis,
72-
CompoundStep,
73-
DEMetropolis,
74-
HamiltonianMC,
75-
Metropolis,
76-
Slice,
77-
)
67+
from pymc.step_methods import NUTS, CompoundStep, DEMetropolis
7868
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
7969
from pymc.step_methods.hmc import quadpotential
8070
from pymc.util import (
@@ -98,15 +88,6 @@
9888
"draw",
9989
]
10090

101-
STEP_METHODS = (
102-
NUTS,
103-
HamiltonianMC,
104-
Metropolis,
105-
BinaryMetropolis,
106-
BinaryGibbsMetropolis,
107-
Slice,
108-
CategoricalGibbsMetropolis,
109-
)
11091
Step: TypeAlias = Union[BlockedStep, CompoundStep]
11192

11293
ArrayLike: TypeAlias = Union[np.ndarray, List[float]]
@@ -164,7 +145,7 @@ def instantiate_steppers(
164145
return steps
165146

166147

167-
def assign_step_methods(model, step=None, methods=STEP_METHODS, step_kwargs=None):
148+
def assign_step_methods(model, step=None, methods=None, step_kwargs=None):
168149
"""Assign model variables to appropriate step methods.
169150
170151
Passing a specified model will auto-assign its constituent stochastic
@@ -197,6 +178,9 @@ def assign_step_methods(model, step=None, methods=STEP_METHODS, step_kwargs=None
197178
steps = []
198179
assigned_vars = set()
199180

181+
if methods is None:
182+
methods = pm.STEP_METHODS
183+
200184
if step is not None:
201185
try:
202186
steps += list(step)
@@ -481,29 +465,7 @@ def sample(
481465
draws += tune
482466

483467
initial_points = None
484-
if step is None and init is not None and all_continuous(model.value_vars):
485-
try:
486-
# By default, try to use NUTS
487-
_log.info("Auto-assigning NUTS sampler...")
488-
initial_points, step = init_nuts(
489-
init=init,
490-
chains=chains,
491-
n_init=n_init,
492-
model=model,
493-
seeds=random_seed,
494-
progressbar=progressbar,
495-
jitter_max_retries=jitter_max_retries,
496-
tune=tune,
497-
initvals=initvals,
498-
**kwargs,
499-
)
500-
except (AttributeError, NotImplementedError, tg.NullTypeGradError):
501-
# gradient computation failed
502-
_log.info("Initializing NUTS failed. Falling back to elementwise auto-assignment.")
503-
_log.debug("Exception in init nuts", exc_info=True)
504-
step = assign_step_methods(model, step, step_kwargs=kwargs)
505-
else:
506-
step = assign_step_methods(model, step, step_kwargs=kwargs)
468+
step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
507469

508470
if isinstance(step, list):
509471
step = CompoundStep(step)

pymc/step_methods/__init__.py

+10
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,13 @@
3636
RecursiveDAProposal,
3737
)
3838
from pymc.step_methods.slicer import Slice
39+
40+
STEP_METHODS = (
41+
NUTS,
42+
HamiltonianMC,
43+
Metropolis,
44+
BinaryMetropolis,
45+
BinaryGibbsMetropolis,
46+
Slice,
47+
CategoricalGibbsMetropolis,
48+
)

pymc/step_methods/hmc/nuts.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def competence(var, has_grad):
200200

201201
dist = getattr(var.owner, "op", None)
202202
if var.dtype in continuous_types and has_grad:
203-
return Competence.IDEAL
203+
return Competence.PREFERRED
204204
return Competence.INCOMPATIBLE
205205

206206
def warnings(self):

pymc/tests/test_step.py

+22
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from aesara.graph.op import Op
2828
from numpy.testing import assert_array_almost_equal
2929

30+
import pymc as pm
31+
3032
from pymc.aesaraf import floatX
3133
from pymc.data import Data
3234
from pymc.distributions import (
@@ -741,6 +743,26 @@ def kill_grad(x):
741743
steps = assign_step_methods(model, [])
742744
assert isinstance(steps, Slice)
743745

746+
def test_modify_step_methods(self):
747+
"""Test step methods can be changed"""
748+
# remove nuts from step_methods
749+
step_methods = list(pm.STEP_METHODS)
750+
step_methods.remove(NUTS)
751+
pm.STEP_METHODS = step_methods
752+
753+
with Model() as model:
754+
Normal("x", 0, 1)
755+
steps = assign_step_methods(model, [])
756+
assert not isinstance(steps, NUTS)
757+
758+
# add back nuts
759+
pm.STEP_METHODS = step_methods + [NUTS]
760+
761+
with Model() as model:
762+
Normal("x", 0, 1)
763+
steps = assign_step_methods(model, [])
764+
assert isinstance(steps, NUTS)
765+
744766

745767
class TestPopulationSamplers:
746768

0 commit comments

Comments
 (0)