|
| 1 | +"""Tests for inference.""" |
| 2 | + |
| 3 | +from typing import Type |
| 4 | + |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +from ..core import DMatrix |
| 8 | +from ..training import train |
| 9 | +from .shared import validate_leaf_output |
| 10 | +from .utils import Device |
| 11 | + |
| 12 | + |
| 13 | +# pylint: disable=invalid-name,too-many-locals |
| 14 | +def run_predict_leaf(device: Device, DMatrixT: Type[DMatrix]) -> np.ndarray: |
| 15 | + """Run tests for leaf index prediction.""" |
| 16 | + rows = 100 |
| 17 | + cols = 4 |
| 18 | + classes = 5 |
| 19 | + num_parallel_tree = 4 |
| 20 | + num_boost_round = 10 |
| 21 | + rng = np.random.RandomState(1994) |
| 22 | + X = rng.randn(rows, cols) |
| 23 | + y = rng.randint(low=0, high=classes, size=rows) |
| 24 | + |
| 25 | + m = DMatrixT(X, y) |
| 26 | + booster = train( |
| 27 | + { |
| 28 | + "num_parallel_tree": num_parallel_tree, |
| 29 | + "num_class": classes, |
| 30 | + "tree_method": "hist", |
| 31 | + }, |
| 32 | + m, |
| 33 | + num_boost_round=num_boost_round, |
| 34 | + ) |
| 35 | + |
| 36 | + booster.set_param({"device": device}) |
| 37 | + empty = DMatrixT(np.ones(shape=(0, cols))) |
| 38 | + empty_leaf = booster.predict(empty, pred_leaf=True) |
| 39 | + assert empty_leaf.shape[0] == 0 |
| 40 | + |
| 41 | + leaf = booster.predict(m, pred_leaf=True, strict_shape=True) |
| 42 | + assert leaf.shape[0] == rows |
| 43 | + assert leaf.shape[1] == num_boost_round |
| 44 | + assert leaf.shape[2] == classes |
| 45 | + assert leaf.shape[3] == num_parallel_tree |
| 46 | + |
| 47 | + validate_leaf_output(leaf, num_parallel_tree) |
| 48 | + |
| 49 | + n_iters = np.int32(2) |
| 50 | + sliced = booster.predict( |
| 51 | + m, |
| 52 | + pred_leaf=True, |
| 53 | + iteration_range=(0, n_iters), |
| 54 | + strict_shape=True, |
| 55 | + ) |
| 56 | + first = sliced[0, ...] |
| 57 | + |
| 58 | + assert np.prod(first.shape) == classes * num_parallel_tree * n_iters |
| 59 | + |
| 60 | + # When there's only 1 tree, the output is a 1 dim vector |
| 61 | + booster = train({"tree_method": "hist"}, num_boost_round=1, dtrain=m) |
| 62 | + booster.set_param({"device": device}) |
| 63 | + assert booster.predict(m, pred_leaf=True).shape == (rows,) |
| 64 | + |
| 65 | + return leaf |
0 commit comments