@@ -292,11 +292,11 @@ def sample(
292292 chains : int | None = None ,
293293 cores : int | None = None ,
294294 random_seed : RandomState = None ,
295+ step = None ,
296+ external_sampler : ExternalSampler | None = None ,
295297 progressbar : bool | ProgressBarType = True ,
296298 progressbar_theme : Theme | None = default_progress_theme ,
297- step = None ,
298299 var_names : Sequence [str ] | None = None ,
299- nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = "pymc" ,
300300 initvals : StartDict | Sequence [StartDict | None ] | None = None ,
301301 init : str = "auto" ,
302302 jitter_max_retries : int = 10 ,
@@ -324,11 +324,11 @@ def sample(
324324 chains : int | None = None ,
325325 cores : int | None = None ,
326326 random_seed : RandomState = None ,
327+ step = None ,
328+ external_sampler : ExternalSampler | None = None ,
327329 progressbar : bool | ProgressBarType = True ,
328330 progressbar_theme : Theme | None = default_progress_theme ,
329- step = None ,
330331 var_names : Sequence [str ] | None = None ,
331- nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = "pymc" ,
332332 initvals : StartDict | Sequence [StartDict | None ] | None = None ,
333333 init : str = "auto" ,
334334 jitter_max_retries : int = 10 ,
@@ -356,11 +356,11 @@ def sample(
356356 chains : int | None = None ,
357357 cores : int | None = None ,
358358 random_seed : RandomState = None ,
359+ step = None ,
360+ external_sampler : ExternalSampler | None = None ,
359361 progressbar : bool | ProgressBarType = True ,
360362 progressbar_theme : Theme | None = None ,
361- step = None ,
362363 var_names : Sequence [str ] | None = None ,
363- nuts_sampler : None | Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = None ,
364364 initvals : StartDict | Sequence [StartDict | None ] | None = None ,
365365 init : str = "auto" ,
366366 jitter_max_retries : int = 10 ,
@@ -407,6 +407,12 @@ def sample(
407407 A ``TypeError`` will be raised if a legacy :py:class:`~numpy.random.RandomState` object is passed.
408408 We no longer support ``RandomState`` objects because their seeding mechanism does not allow
409409 easy spawning of new independent random streams that are needed by the step methods.
410+ step : function or iterable of functions, optional
411+ A step function or collection of functions. If there are variables without step methods,
412+ step methods for those variables will be assigned automatically. By default the NUTS step
413+ method will be used, if appropriate to the model. Not compatible with external_sampler
414+ external_sampler: ExternalSampler, optional
415+ An external sampler to sample the whole model. Not compatible with step.
410416 progressbar: bool or ProgressType, optional
411417 How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask
412418 for one of the following:
@@ -419,16 +425,8 @@ def sample(
419425 are also displayed.
420426
421427 If True, the default is "split+stats" is used.
422- step : function or iterable of functions
423- A step function or collection of functions. If there are variables without step methods,
424- step methods for those variables will be assigned automatically. By default the NUTS step
425- method will be used, if appropriate to the model.
426428 var_names : list of str, optional
427429 Names of variables to be stored in the trace. Defaults to all free variables and deterministics.
428- nuts_sampler : str
429- Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"].
430- This requires the chosen sampler to be installed.
431- All samplers, except "pymc", require the full model to be continuous.
432430 blas_cores: int or "auto" or None, default = "auto"
433431 The total number of threads blas and openmp functions should use during sampling.
434432 Setting it to "auto" will ensure that the total number of active blas threads is the
@@ -608,35 +606,40 @@ def joined_blas_limiter():
608606 rngs = get_random_generator (random_seed ).spawn (chains )
609607 random_seed_list = [rng .integers (2 ** 30 ) for rng in rngs ]
610608
611- if step is None and nuts_sampler not in (None , "pymc" ):
612- # Temporarily instantiate external samplers for user, for backwards-compat
613- warnings .warn (
614- f"Setting `pm.sample(nuts_sampler='{ nuts_sampler } , nuts_sampler_kwargs=...)'` is deprecated.\n "
615- f"Use `pm.sample(step=pm.external.{ nuts_sampler .capitalize ()} (**nuts_sampler_kwargs))` instead" ,
616- FutureWarning ,
617- )
618- from pymc .sampling import external
609+ if "nuts_sampler" in kwargs :
610+ # Transition backwards-compatibility
611+ nuts_sampler = kwargs .pop ("nuts_sampler" )
612+ if nuts_sampler != "pymc" :
613+ warnings .warn (
614+ f"Setting `pm.sample(nuts_sampler='{ nuts_sampler } , nuts_sampler_kwargs=...)'` is deprecated.\n "
615+ f"Use `pm.sample(external_sampler=pm.external.{ nuts_sampler .capitalize ()} (**nuts_sampler_kwargs))` instead" ,
616+ FutureWarning ,
617+ )
618+ from pymc .sampling import external
619619
620- step = getattr (external , nuts_sampler .capitalize ())(
621- model = model ,
622- ** (nuts_sampler_kwargs or {}),
623- )
624- nuts_sampler_kwargs = None
620+ external_sampler = getattr (external , nuts_sampler .capitalize ())(
621+ model = model ,
622+ ** (nuts_sampler_kwargs or {}),
623+ ** (kwargs .pop ("nuts" ) or {}),
624+ )
625+ nuts_sampler_kwargs = None
625626
626- if isinstance (step , list | tuple ) and len (step ) == 1 :
627- [step ] = step
627+ if external_sampler is not None :
628+ if step is not None :
629+ raise ValueError ("`step` and `external_sampler` cannot be used together" )
628630
629- if isinstance (step , ExternalSampler ):
630- if step .model is not model :
631- raise ValueError ("External step model does not match model detected by sample" )
631+ if external_sampler .model is not model :
632+ raise ValueError (
633+ "External sampler model does not match model detected by sample function"
634+ )
632635 if nuts_sampler_kwargs :
633636 raise ValueError (
634637 f"{ nuts_sampler_kwargs = } should be passed when constructing external sampler"
635638 )
636639 if "nuts" in kwargs :
637- kwargs .update (kwargs [ "nuts" ] .pop ())
640+ kwargs .update (kwargs .pop ("nuts" ))
638641 with joined_blas_limiter ():
639- return step .sample (
642+ return external_sampler .sample (
640643 tune = tune ,
641644 draws = draws ,
642645 chains = chains ,
0 commit comments