diff --git a/pymc3/distributions/bart.py b/pymc3/distributions/bart.py
index 20fd32a801..4914844555 100644
--- a/pymc3/distributions/bart.py
+++ b/pymc3/distributions/bart.py
@@ -14,6 +14,8 @@
 
 import numpy as np
 
+from pandas import DataFrame, Series
+
 from pymc3.distributions.distribution import NoDistribution
 from pymc3.distributions.tree import LeafNode, SplitNode, Tree
 
@@ -21,9 +23,10 @@
 
 
 class BaseBART(NoDistribution):
-    def __init__(self, X, Y, m=200, alpha=0.25, *args, **kwargs):
-        self.X = X
-        self.Y = Y
+    def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None, *args, **kwargs):
+
+        self.X, self.Y, self.missing_data = self.preprocess_XY(X, Y)
+
         super().__init__(shape=X.shape[0], dtype="float64", testval=0, *args, **kwargs)
 
         if self.X.ndim != 2:
@@ -48,12 +51,24 @@ def __init__(self, X, Y, m=200, alpha=0.25, *args, **kwargs):
 
         self.num_observations = X.shape[0]
         self.num_variates = X.shape[1]
+        self.available_predictors = list(range(self.num_variates))
+        self.ssv = SampleSplittingVariable(split_prior, self.num_variates)
         self.m = m
         self.alpha = alpha
         self.trees = self.init_list_of_trees()
+        self.all_trees = []
         self.mean = fast_mean()
         self.prior_prob_leaf_node = compute_prior_probability(alpha)
 
+    def preprocess_XY(self, X, Y):
+        if isinstance(Y, (Series, DataFrame)):
+            Y = Y.to_numpy()
+        if isinstance(X, (Series, DataFrame)):
+            X = X.to_numpy()
+        missing_data = np.any(np.isnan(X))
+        X = np.random.normal(X, np.std(X, 0) / 100)
+        return X, Y, missing_data
+
     def init_list_of_trees(self):
         initial_value_leaf_nodes = self.Y.mean() / self.m
         initial_idx_data_points_leaf_nodes = np.array(range(self.num_observations), dtype="int32")
@@ -79,39 +94,26 @@ def __iter__(self):
     def __repr_latex(self):
         raise NotImplementedError
 
-    def get_available_predictors(self, idx_data_points_split_node):
-        possible_splitting_variables = []
-        for j in range(self.num_variates):
-            x_j = self.X[idx_data_points_split_node, j]
-            x_j = x_j[~np.isnan(x_j)]
-            for i in range(1, len(x_j)):
-                if x_j[i - 1] != x_j[i]:
-                    possible_splitting_variables.append(j)
-                    break
-        return possible_splitting_variables
-
     def get_available_splitting_rules(self, idx_data_points_split_node, idx_split_variable):
         x_j = self.X[idx_data_points_split_node, idx_split_variable]
-        x_j = x_j[~np.isnan(x_j)]
-        values, indices = np.unique(x_j, return_index=True)
-        # The last value is not consider since if we choose it as the value of
-        # the splitting rule assignment, it would leave the right subtree empty.
-        return values[:-1], indices[:-1]
+        if self.missing_data:
+            x_j = x_j[~np.isnan(x_j)]
+        values = np.unique(x_j)
+        # The last value is never available as it would leave the right subtree empty.
+        return values[:-1]
 
     def grow_tree(self, tree, index_leaf_node):
-        # This can be unsuccessful when there are not available predictors
         current_node = tree.get_node(index_leaf_node)
 
-        available_predictors = self.get_available_predictors(current_node.idx_data_points)
-
-        if not available_predictors:
-            return False, None
-
-        index_selected_predictor = discrete_uniform_sampler(len(available_predictors))
-        selected_predictor = available_predictors[index_selected_predictor]
-        available_splitting_rules, _ = self.get_available_splitting_rules(
+        index_selected_predictor = self.ssv.rvs()
+        selected_predictor = self.available_predictors[index_selected_predictor]
+        available_splitting_rules = self.get_available_splitting_rules(
             current_node.idx_data_points, selected_predictor
         )
+        # This can be unsuccessful when there are not available splitting rules
+        if available_splitting_rules.size == 0:
+            return False, None
+
         index_selected_splitting_rule = discrete_uniform_sampler(len(available_splitting_rules))
         selected_splitting_rule = available_splitting_rules[index_selected_splitting_rule]
         new_split_node = SplitNode(
@@ -167,6 +169,19 @@ def draw_leaf_value(self, idx_data_points):
         draw = self.mean(R_j)
         return draw
 
+    def predict(self, X_new):
+        """Compute out of sample predictions evaluated at X_new"""
+        trees = self.all_trees
+        num_observations = X_new.shape[0]
+        pred = np.zeros((len(trees), num_observations))
+        np.random.randint(len(trees))
+        for draw, trees_to_sum in enumerate(trees):
+            new_Y = np.zeros(num_observations)
+            for tree in trees_to_sum:
+                new_Y += [tree.predict_out_of_sample(x) for x in X_new]
+            pred[draw] = new_Y
+        return pred
+
 
 def compute_prior_probability(alpha):
     """
@@ -217,6 +232,31 @@ def discrete_uniform_sampler(upper_value):
     return int(np.random.random() * upper_value)
 
 
+class SampleSplittingVariable:
+    def __init__(self, prior, num_variates):
+        self.prior = prior
+        self.num_variates = num_variates
+
+        if self.prior is not None:
+            self.prior = np.asarray(self.prior)
+            self.prior = self.prior / self.prior.sum()
+            if self.prior.size != self.num_variates:
+                raise ValueError(
+                    f"The size of split_prior ({self.prior.size}) should be the "
+                    f"same as the number of covariates ({self.num_variates})"
+                )
+            self.enu = list(enumerate(np.cumsum(self.prior)))
+
+    def rvs(self):
+        if self.prior is None:
+            return int(np.random.random() * self.num_variates)
+        else:
+            r = np.random.random()
+            for i, v in self.enu:
+                if r <= v:
+                    return i
+
+
 class BART(BaseBART):
     """
     BART distribution.
@@ -225,19 +265,23 @@ class BART(BaseBART):
 
     Parameters
     ----------
-    X :
+    X : array-like
         The design matrix.
-    Y :
+    Y : array-like
         The response vector.
     m : int
         Number of trees
     alpha : float
         Control the prior probability over the depth of the trees. Must be in the interval (0, 1),
         altought it is recomenned to be in the interval (0, 0.5].
+    split_prior : array-like
+        Each element of split_prior should be in the [0, 1] interval and the elements should sum
+        to 1. Otherwise they will be normalized.
+        Defaults to None, all variable have the same a prior probability
     """
 
-    def __init__(self, X, Y, m=200, alpha=0.25):
-        super().__init__(X, Y, m, alpha)
+    def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None):
+        super().__init__(X, Y, m, alpha, split_prior)
 
     def _str_repr(self, name=None, dist=None, formatting="plain"):
         if dist is None:
diff --git a/pymc3/distributions/tree.py b/pymc3/distributions/tree.py
index 81c727232a..8e84bd9a7c 100644
--- a/pymc3/distributions/tree.py
+++ b/pymc3/distributions/tree.py
@@ -84,6 +84,22 @@ def predict_output(self, num_observations):
             output[current_node.idx_data_points] = current_node.value
         return output
 
+    def predict_out_of_sample(self, x):
+        """
+        Predict output of tree for an unobserved point x.
+
+        Parameters
+        ----------
+        x : numpy array
+
+        Returns
+        -------
+        float
+            Value of the leaf value where the unobserved point lies.
+        """
+        leaf_node = self._traverse_tree(x=x, node_index=0)
+        return leaf_node.value
+
     def _traverse_tree(self, x, node_index=0):
         """
         Traverse the tree starting from a particular node given an unobserved point.
@@ -99,15 +115,13 @@ def _traverse_tree(self, x, node_index=0):
         """
         current_node = self.get_node(node_index)
         if isinstance(current_node, SplitNode):
-            if x is not np.NaN:
+            if x[current_node.idx_split_variable] <= current_node.split_value:
                 left_child = current_node.get_idx_left_child()
-                final_node = self._traverse_tree(x, left_child)
+                current_node = self._traverse_tree(x, left_child)
             else:
                 right_child = current_node.get_idx_right_child()
-                final_node = self._traverse_tree(x, right_child)
-        else:
-            final_node = current_node
-        return final_node
+                current_node = self._traverse_tree(x, right_child)
+        return current_node
 
     def grow_tree(self, index_leaf_node, new_split_node, new_left_node, new_right_node):
         """
diff --git a/pymc3/step_methods/pgbart.py b/pymc3/step_methods/pgbart.py
index ca0c61aa25..c3bac3ade9 100644
--- a/pymc3/step_methods/pgbart.py
+++ b/pymc3/step_methods/pgbart.py
@@ -64,8 +64,13 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", m
 
         self.tune = True
         self.idx = 0
+        self.iter = 0
+        self.sum_trees = []
+        self.chunk = chunk
+
         if chunk == "auto":
             self.chunk = max(1, int(self.bart.m * 0.1))
+        self.bart.chunk = self.chunk
         self.num_particles = num_particles
         self.log_num_particles = np.log(num_particles)
         self.indices = list(range(1, num_particles))
@@ -96,14 +101,14 @@ def astep(self, _):
             self.idx = 0
 
         for idx in range(self.idx, self.idx + self.chunk):
-            if idx > bart.m:
+            if idx >= bart.m:
                 break
             self.idx += 1
             tree = bart.trees[idx]
             R_j = bart.get_residuals_loo(tree)
             # Generate an initial set of SMC particles
             # at the end of the algorithm we return one of these particles as the new tree
-            particles = self.init_particles(tree.tree_id, R_j, bart.num_observations)
+            particles = self.init_particles(tree.tree_id, R_j, num_observations)
 
             for t in range(1, max_stages):
                 # Get old particle at stage t
@@ -147,6 +152,11 @@ def astep(self, _):
             bart.sum_trees_output = bart.Y - R_j + new_prediction
 
             if not self.tune:
+                self.iter += 1
+                self.sum_trees.append(new_tree.tree)
+                if not self.iter % bart.m:
+                    bart.all_trees.append(self.sum_trees)
+                    self.sum_trees = []
                 for index in new_tree.used_variates:
                     variable_inclusion[index] += 1