diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md
index 56622eb903..802f40f816 100644
--- a/RELEASE-NOTES.md
+++ b/RELEASE-NOTES.md
@@ -8,6 +8,7 @@
 - `DEMetropolisZ`, an improved variant of `DEMetropolis` brings better parallelization and higher efficiency with fewer chains with a slower initial convergence. This implementation is experimental. See [#3784](https://github.com/pymc-devs/pymc3/pull/3784) for more info.
 - Notebooks that give insight into `DEMetropolis`, `DEMetropolisZ` and the `DifferentialEquation` interface are now located in the [Tutorials/Deep Dive](https://docs.pymc.io/nb_tutorials/index.html) section.
 - Add `fast_sample_posterior_predictive`, a vectorized alternative to `sample_posterior_predictive`.  This alternative is substantially faster for large models.
+- `sample_posterior_predictive` can now feed on `xarray.Dataset` - e.g. from `InferenceData.posterior`. (see [#3846](https://github.com/pymc-devs/pymc3/pull/3846))
 - `SamplerReport` (`MultiTrace.report`) now has properties `n_tune`, `n_draws`, `t_sampling` for increased convenience (see [#3827](https://github.com/pymc-devs/pymc3/pull/3827))
 
 ### Maintenance
diff --git a/pymc3/distributions/posterior_predictive.py b/pymc3/distributions/posterior_predictive.py
index c012329d3e..46eff017b6 100644
--- a/pymc3/distributions/posterior_predictive.py
+++ b/pymc3/distributions/posterior_predictive.py
@@ -12,12 +12,14 @@
 import numpy as np
 import theano
 import theano.tensor as tt
+from xarray import Dataset
 
 from ..backends.base import MultiTrace #, TraceLike, TraceDict
 from .distribution import _DrawValuesContext, _DrawValuesContextBlocker, is_fast_drawable, _compile_theano_function, vectorized_ppc
 from ..model import Model, get_named_nodes_and_relations, ObservedRV, MultiObservedRV, modelcontext
 from ..exceptions import IncorrectArgumentsError
 from ..vartypes import theano_constant
+from ..util import dataset_to_point_dict
 # Failing tests:
 #    test_mixture_random_shape::test_mixture_random_shape
 #
@@ -119,7 +121,7 @@ def __getitem__(self, item):
 
 
 
-def fast_sample_posterior_predictive(trace: Union[MultiTrace, List[Dict[str, np.ndarray]]],
+def fast_sample_posterior_predictive(trace: Union[MultiTrace, Dataset, List[Dict[str, np.ndarray]]],
                                 samples: Optional[int]=None,
                                 model: Optional[Model]=None,
                                 var_names: Optional[List[str]]=None,
@@ -135,7 +137,7 @@ def fast_sample_posterior_predictive(trace: Union[MultiTrace, List[Dict[str, np.
 
     Parameters
     ----------
-    trace : MultiTrace or List of points
+    trace : MultiTrace, xarray.Dataset, or List of points (dictionary)
         Trace generated from MCMC sampling.
     samples : int, optional
         Number of posterior predictive samples to generate. Defaults to one posterior predictive
@@ -168,6 +170,9 @@ def fast_sample_posterior_predictive(trace: Union[MultiTrace, List[Dict[str, np.
     ### greater than the number of samples in the trace parameter, we sample repeatedly.  This
     ### makes the shape issues just a little easier to deal with.
 
+    if isinstance(trace, Dataset):
+        trace = dataset_to_point_dict(trace)
+
     model = modelcontext(model)
     assert model is not None
     with model:
diff --git a/pymc3/sampling.py b/pymc3/sampling.py
index b0f920c672..ca1773fb27 100644
--- a/pymc3/sampling.py
+++ b/pymc3/sampling.py
@@ -30,6 +30,7 @@
 import numpy as np
 import theano.gradient as tg
 from theano.tensor import Tensor
+import xarray
 
 from .backends.base import BaseTrace, MultiTrace
 from .backends.ndarray import NDArray
@@ -53,6 +54,7 @@
     get_untransformed_name,
     is_transformed_name,
     get_default_varnames,
+    dataset_to_point_dict,
 )
 from .vartypes import discrete_types
 from .exceptions import IncorrectArgumentsError
@@ -1520,9 +1522,9 @@ def sample_posterior_predictive(
 
     Parameters
     ----------
-    trace: backend, list, or MultiTrace
-        Trace generated from MCMC sampling. Or a list containing dicts from
-        find_MAP() or points
+    trace: backend, list, xarray.Dataset, or MultiTrace
+        Trace generated from MCMC sampling, or a list of dicts (eg. points or from find_MAP()),
+        or xarray.Dataset (eg. InferenceData.posterior or InferenceData.prior)
     samples: int
         Number of posterior predictive samples to generate. Defaults to one posterior predictive
         sample per posterior sample, that is, the number of draws times the number of chains. It
@@ -1556,6 +1558,9 @@ def sample_posterior_predictive(
         Dictionary with the variable names as keys, and values numpy arrays containing
         posterior predictive samples.
     """
+    if isinstance(trace, xarray.Dataset):
+        trace = dataset_to_point_dict(trace)
+
     len_trace = len(trace)
     try:
         nchain = trace.nchains
diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py
index 8d1bf00bd4..6c0eece037 100644
--- a/pymc3/tests/test_sampling.py
+++ b/pymc3/tests/test_sampling.py
@@ -22,6 +22,7 @@
     import mock
 
 import numpy.testing as npt
+import arviz as az
 import pymc3 as pm
 import theano.tensor as tt
 from theano import shared
@@ -880,3 +881,32 @@ def test_point_list_arg_bug_spp(self, point_list_arg_bug_fixture):
                 var_names=['d']
             )
 
+    def test_sample_from_xarray_prior(self, point_list_arg_bug_fixture):
+        pmodel, trace = point_list_arg_bug_fixture
+
+        with pmodel:
+            prior = pm.sample_prior_predictive(samples=20)
+        idat = az.from_pymc3(trace, prior=prior)
+        with pmodel:
+            pp = pm.sample_posterior_predictive(
+                idat.prior,
+                var_names=['d']
+            )
+
+    def test_sample_from_xarray_posterior(self, point_list_arg_bug_fixture):
+        pmodel, trace = point_list_arg_bug_fixture
+        idat = az.from_pymc3(trace)
+        with pmodel:
+            pp = pm.sample_posterior_predictive(
+                idat.posterior,
+                var_names=['d']
+            )
+
+    def test_sample_from_xarray_posterior_fast(self, point_list_arg_bug_fixture):
+        pmodel, trace = point_list_arg_bug_fixture
+        idat = az.from_pymc3(trace)
+        with pmodel:
+            pp = pm.fast_sample_posterior_predictive(
+                idat.posterior,
+                var_names=['d']
+            )
diff --git a/pymc3/util.py b/pymc3/util.py
index a9bd3c5af6..18a78aed97 100644
--- a/pymc3/util.py
+++ b/pymc3/util.py
@@ -14,7 +14,11 @@
 
 import re
 import functools
-from numpy import asscalar
+from typing import List, Dict
+
+import xarray
+from numpy import asscalar, ndarray
+
 
 LATEX_ESCAPE_RE = re.compile(r'(%|_|\$|#|&)', re.MULTILINE)
 
@@ -179,3 +183,21 @@ def enhanced(*args, **kwargs):
             newwrapper = functools.partial(wrapper, *args, **kwargs)
             return newwrapper
     return enhanced
+
+def dataset_to_point_dict(ds: xarray.Dataset) -> List[Dict[str, ndarray]]:
+    # grab posterior samples for each variable
+    _samples = {
+        vn : ds[vn].values
+        for vn in ds.keys()
+    }
+    # make dicts
+    points = []
+    for c in ds.chain:
+        for d in ds.draw:
+            points.append({
+                vn : s[c, d]
+                for vn, s in _samples.items()
+            })
+    # use the list of points
+    ds = points
+    return ds