From 2c6ef1ffe00799e6438953c4b3488cc4496f8bd3 Mon Sep 17 00:00:00 2001 From: Alex Andorra Date: Fri, 4 Apr 2025 15:08:10 -0400 Subject: [PATCH 1/3] Set up working infrastructure for batched KF --- conda-envs/environment-test.yml | 3 + notebooks/batch-examples.ipynb | 388 ++++++++++++++++++ .../statespace/filters/distributions.py | 51 ++- 3 files changed, 422 insertions(+), 20 deletions(-) create mode 100644 notebooks/batch-examples.ipynb diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 450b46e3..8234b954 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -3,6 +3,7 @@ channels: - conda-forge - nodefaults dependencies: +- ipywidgets - pymc>=5.21 - pytest-cov>=2.5 - pytest>=3.0 @@ -10,8 +11,10 @@ dependencies: - xhistogram - statsmodels - numba<=0.60.0 +- nutpie - pip - pip: - blackjax - scikit-learn - better_optimize + - -e . diff --git a/notebooks/batch-examples.ipynb b/notebooks/batch-examples.ipynb new file mode 100644 index 00000000..6139992b --- /dev/null +++ b/notebooks/batch-examples.ipynb @@ -0,0 +1,388 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "0a5841d3", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pytensor\n", + "import pytensor.tensor as pt\n", + "from pymc_extras.statespace.filters import StandardFilter\n", + "from tests.statespace.utilities.test_helpers import make_test_inputs\n", + "from pytensor.graph.replace import vectorize_graph\n", + "from importlib import reload\n", + "import pymc_extras.statespace.filters.distributions as pmss_dist\n", + "from pymc_extras.statespace.filters.distributions import SequenceMvNormal\n", + "import pymc as pm" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "14299e50", + "metadata": {}, + "outputs": [], + "source": [ + "seed = sum(map(ord, \"batched-kf\"))\n", + "rng = np.random.default_rng(seed)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "71bc513e", + "metadata": {}, + "outputs": [], + "source": [ + "def create_batch_inputs(batch_size, p=1, m=5, r=1, n=10, rng=rng):\n", + " \"\"\"\n", + " Create batched inputs for testing.\n", + "\n", + " Parameters\n", + " ----------\n", + " batch_size : int\n", + " Number of batches to create\n", + " p : int\n", + " First dimension parameter\n", + " m : int\n", + " Second dimension parameter\n", + " r : int\n", + " Third dimension parameter\n", + " n : int\n", + " Fourth dimension parameter\n", + " rng : numpy.random.Generator\n", + " Random number generator\n", + "\n", + " Returns\n", + " -------\n", + " list\n", + " List of stacked inputs for each batch\n", + " \"\"\"\n", + " # Create individual inputs for each batch\n", + " np_batch_inputs = []\n", + " for i in range(batch_size):\n", + " inputs = make_test_inputs(p, m, r, n, rng)\n", + " np_batch_inputs.append(inputs)\n", + "\n", + " return [np.stack(x, axis=0) for x in zip(*np_batch_inputs)]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "0c1824cf", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(3, 10, 1)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create batch inputs with batch size 3\n", + "np_batch_inputs = create_batch_inputs(3)\n", + "np_batch_inputs[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "773d4cb4", + "metadata": {}, + "outputs": [], + "source": [ + "p, m, r, n = 1, 5, 1, 10\n", + "inputs = [pt.as_tensor(x).type() for x in make_test_inputs(p, m, r, n, rng)]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "511de29f", + "metadata": {}, + "outputs": [], + "source": [ + "kf = StandardFilter()\n", + "kf_outputs = kf.build_graph(*inputs)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "33006d8e", + "metadata": {}, + "outputs": [], + "source": [ + "batched_inputs = [pt.tensor(shape=(None, *x.type.shape)) for x in inputs]\n", + "vec_subs = dict(zip(inputs, batched_inputs))\n", + "bacthed_kf_outputs = vectorize_graph(kf_outputs, vec_subs)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "987a4647", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[filtered_states,\n", + " predicted_states,\n", + " observed_states,\n", + " filtered_covariances,\n", + " predicted_covariances,\n", + " observed_covariances,\n", + " loglike_obs]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "kf_outputs" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "4b8be0f9", + "metadata": {}, + "outputs": [], + "source": [ + "mu = bacthed_kf_outputs[1]\n", + "cov = bacthed_kf_outputs[4]\n", + "logp = bacthed_kf_outputs[-1]" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "1dc80f94", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(None, 10, 5)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mu.type.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "1262c7d4", + "metadata": {}, + "outputs": [], + "source": [ + "pmss_dist = reload(pmss_dist)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "2dcd3958", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mus_.type.shape: (None, 10, 5), covs_.type.shape: (None, 10, 5, 5)\n", + "mus.type.shape: (10, None, 5), covs.type.shape: (10, None, 5, 5)\n", + "mvn_seq.type.shape: (None, None, 5)\n", + "mvn_seq.type.shape: (None, 10, 5)\n", + "mvn_seq.type.shape: (None, 10, 5)\n", + "mvn_seq.type.shape: (None, 10, 5)\n", + "mus_.type.shape: (None, 10, 5), covs_.type.shape: (None, 10, 5, 5)\n", + "mus.type.shape: (10, None, 5), covs.type.shape: (10, None, 5, 5)\n", + "mvn_seq.type.shape: (None, None, 5)\n", + "mvn_seq.type.shape: (None, 10, 5)\n", + "mvn_seq.type.shape: (None, 10, 5)\n", + "mvn_seq.type.shape: (None, 10, 5)\n" + ] + } + ], + "source": [ + "mv_outputs = pmss_dist.SequenceMvNormal.dist(mus=mu, covs=cov, logp=logp)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "6f41344f", + "metadata": {}, + "outputs": [], + "source": [ + "np_batch_inputs = create_batch_inputs(3)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "44905b8a", + "metadata": {}, + "outputs": [], + "source": [ + "np_batch_inputs[0] = rng.normal(size=(3, 10, 1))" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "34fe01b8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(3, 10, 5)" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "f_test = pytensor.function(batched_inputs, mv_outputs)\n", + "f_test(*np_batch_inputs).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "f37efe79", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(None, 10, 1) (None, 10, 5) (None, 10, 5, 5)\n" + ] + } + ], + "source": [ + "f_mv = pytensor.function(batched_inputs, pm.logp(mv_outputs, batched_inputs[0]))" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "7b45de74", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(3, 10)" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "f_mv(*np_batch_inputs).shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f14596aa", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "69519822", + "metadata": {}, + "outputs": [], + "source": [ + "f = pytensor.function(batched_inputs, bacthed_kf_outputs)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "3f745449", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "633 μs ± 18.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n", + "1.52 ms ± 35.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n", + "4.76 ms ± 259 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], + "source": [ + "for s in [1, 3, 10]:\n", + " np_batch_inputs = create_batch_inputs(s)\n", + " %timeit outputs = f(*np_batch_inputs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5fcadef", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c479ff22", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pymc-extras-test", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pymc_extras/statespace/filters/distributions.py b/pymc_extras/statespace/filters/distributions.py index 1e4f2b15..60b74c99 100644 --- a/pymc_extras/statespace/filters/distributions.py +++ b/pymc_extras/statespace/filters/distributions.py @@ -374,44 +374,55 @@ def dist(cls, mus, covs, logp, **kwargs): @classmethod def rv_op(cls, mus, covs, logp, size=None): # Batch dimensions (if any) will be on the far left, but scan requires time to be there instead - if mus.ndim > 2: - mus = pt.moveaxis(mus, -2, 0) - if covs.ndim > 3: - covs = pt.moveaxis(covs, -3, 0) - mus_, covs_ = mus.type(), covs.type() + print(f"mus_.type.shape: {mus_.type.shape}, covs_.type.shape: {covs_.type.shape}") logp_ = logp.type() rng = pytensor.shared(np.random.default_rng()) - def step(mu, cov, rng): - new_rng, mvn = pm.MvNormal.dist(mu=mu, cov=cov, rng=rng, method="svd").owner.outputs - return mvn, {rng: new_rng} + def recursion(mus, covs, rng): + if mus.ndim > 2: + mus = pt.moveaxis(mus, -2, 0) + if covs.ndim > 3: + covs = pt.moveaxis(covs, -3, 0) + print(f"mus.type.shape: {mus.type.shape}, covs.type.shape: {covs.type.shape}") - mvn_seq, updates = pytensor.scan( - step, sequences=[mus_, covs_], non_sequences=[rng], strict=True, n_steps=mus_.shape[0] - ) - mvn_seq = pt.specify_shape(mvn_seq, mus.type.shape) + def step(mu, cov, rng): + new_rng, mvn = pm.MvNormal.dist(mu=mu, cov=cov, rng=rng, method="svd").owner.outputs + return mvn, {rng: new_rng} + + mvn_seq, updates = pytensor.scan( + step, sequences=[mus, covs], non_sequences=[rng], strict=True, n_steps=mus.shape[0] + ) + print(f"mvn_seq.type.shape: {mvn_seq.type.shape}") + mvn_seq = pt.specify_shape(mvn_seq, mus.type.shape) + + # Move time axis back to position -2 so batches are on the left + if mvn_seq.ndim > 2: + mvn_seq = pt.moveaxis(mvn_seq, 0, -2) + print(f"mvn_seq.type.shape: {mvn_seq.type.shape}") + + (seq_mvn_rng,) = tuple(updates.values()) - # Move time axis back to position -2 so batches are on the left - if mvn_seq.ndim > 2: - mvn_seq = pt.moveaxis(mvn_seq, 0, -2) + print(f"mvn_seq.type.shape: {mvn_seq.type.shape}") - (seq_mvn_rng,) = tuple(updates.values()) + return [seq_mvn_rng, mvn_seq] mvn_seq_op = KalmanFilterRV( - inputs=[mus_, covs_, logp_, rng], outputs=[seq_mvn_rng, mvn_seq], ndim_supp=2 + inputs=[mus_, covs_, logp_, rng], outputs=recursion(mus_, covs_, rng), ndim_supp=2 ) mvn_seq = mvn_seq_op(mus, covs, logp, rng) + print(f"mvn_seq.type.shape: {mvn_seq.type.shape}") return mvn_seq @_logprob.register(KalmanFilterRV) def sequence_mvnormal_logp(op, values, mus, covs, logp, rng, **kwargs): + print(values[0].type.shape, mus.type.shape, covs.type.shape) return check_parameters( logp, - pt.eq(values[0].shape[0], mus.shape[0]), - pt.eq(covs.shape[0], mus.shape[0]), - msg="Observed data and parameters must have the same number of timesteps (dimension 0)", + pt.eq(values[0].shape[-2], mus.shape[-2]), + pt.eq(covs.shape[-3], mus.shape[-2]), + msg="Observed data and parameters must have the same number of timesteps", ) From 7f6845eace50412e0d257030cd5a96436d38c8ec Mon Sep 17 00:00:00 2001 From: Alex Andorra Date: Tue, 8 Apr 2025 16:11:35 -0400 Subject: [PATCH 2/3] Update conda env file --- conda-envs/environment-test.yml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 8234b954..16cfcc3b 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -1,20 +1,20 @@ name: pymc-extras-test channels: - conda-forge -- nodefaults dependencies: +- blackjax - ipywidgets -- pymc>=5.21 -- pytest-cov>=2.5 -- pytest>=3.0 +- ipython +- pymc +- pytest-cov +- pytest - dask - xhistogram - statsmodels -- numba<=0.60.0 +- numba - nutpie - pip +- scikit-learn - pip: - - blackjax - - scikit-learn - better_optimize - -e . From cc9f7509a46496f190ba170664f7635d4e837b71 Mon Sep 17 00:00:00 2001 From: Alex Andorra Date: Tue, 8 Apr 2025 16:12:21 -0400 Subject: [PATCH 3/3] Working with Filter, not with Smoother --- notebooks/batch-examples.ipynb | 28 ++++---- .../statespace/filters/kalman_filter.py | 29 +++++++++ .../statespace/filters/kalman_smoother.py | 39 ++++++++++- tests/statespace/test_kalman_filter.py | 65 ++++++++++--------- tests/statespace/utilities/test_helpers.py | 54 ++++++++++----- 5 files changed, 153 insertions(+), 62 deletions(-) diff --git a/notebooks/batch-examples.ipynb b/notebooks/batch-examples.ipynb index 6139992b..72b1ca92 100644 --- a/notebooks/batch-examples.ipynb +++ b/notebooks/batch-examples.ipynb @@ -189,7 +189,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 11, "id": "1262c7d4", "metadata": {}, "outputs": [], @@ -199,7 +199,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 12, "id": "2dcd3958", "metadata": {}, "outputs": [ @@ -228,7 +228,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 13, "id": "6f41344f", "metadata": {}, "outputs": [], @@ -238,7 +238,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 14, "id": "44905b8a", "metadata": {}, "outputs": [], @@ -248,7 +248,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 15, "id": "34fe01b8", "metadata": {}, "outputs": [ @@ -258,7 +258,7 @@ "(3, 10, 5)" ] }, - "execution_count": 24, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -270,7 +270,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 16, "id": "f37efe79", "metadata": {}, "outputs": [ @@ -288,7 +288,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 17, "id": "7b45de74", "metadata": {}, "outputs": [ @@ -298,7 +298,7 @@ "(3, 10)" ] }, - "execution_count": 26, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -317,7 +317,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 18, "id": "69519822", "metadata": {}, "outputs": [], @@ -327,7 +327,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 19, "id": "3f745449", "metadata": {}, "outputs": [ @@ -335,9 +335,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "633 μs ± 18.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n", - "1.52 ms ± 35.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n", - "4.76 ms ± 259 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + "675 μs ± 22.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n", + "1.64 ms ± 37.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n", + "5.28 ms ± 424 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], diff --git a/pymc_extras/statespace/filters/kalman_filter.py b/pymc_extras/statespace/filters/kalman_filter.py index 0ca47b50..887130f4 100644 --- a/pymc_extras/statespace/filters/kalman_filter.py +++ b/pymc_extras/statespace/filters/kalman_filter.py @@ -10,6 +10,7 @@ from pytensor.raise_op import Assert from pytensor.tensor import TensorVariable from pytensor.tensor.slinalg import solve_triangular +from pytensor.graph.replace import vectorize_graph from pymc_extras.statespace.filters.utilities import ( quad_form_sym, @@ -20,6 +21,7 @@ MVN_CONST = pt.log(2 * pt.constant(np.pi, dtype="float64")) PARAM_NAMES = ["c", "d", "T", "Z", "R", "H", "Q"] +CORE_NDIM = (2, 1, 2, 1, 1, 2, 2, 2, 2, 2) assert_time_varying_dim_correct = Assert( "The first dimension of a time varying matrix (the time dimension) must be " @@ -73,6 +75,23 @@ def check_params(self, data, a0, P0, c, d, T, Z, R, H, Q): """ return data, a0, P0, c, d, T, Z, R, H, Q + def has_batched_input(self, data, a0, P0, c, d, T, Z, R, H, Q): + """ + Check if any of the inputs are batched. + """ + return any(x.ndim > CORE_NDIM[i] for i, x in enumerate([data, a0, P0, c, d, T, Z, R, H, Q])) + + def get_dummy_core_inputs(self, data, a0, P0, c, d, T, Z, R, H, Q): + """ + Get dummy inputs for the core parameters. + """ + out = [] + for x, core_ndim in zip([data, a0, P0, c, d, T, Z, R, H, Q], CORE_NDIM): + out.append( + pt.tensor(f"{x.name}_core_case", dtype=x.dtype, shape=x.type.shape[-core_ndim:]) + ) + return out + @staticmethod def add_check_on_time_varying_shapes( data: TensorVariable, sequence_params: list[TensorVariable] @@ -202,6 +221,7 @@ def build_graph( self.mode = mode self.missing_fill_value = missing_fill_value self.cov_jitter = cov_jitter + is_batched = self.has_batched_input(data, a0, P0, c, d, T, Z, R, H, Q) [R_shape] = constant_fold([R.shape], raise_not_constant=False) [Z_shape] = constant_fold([Z.shape], raise_not_constant=False) @@ -209,6 +229,10 @@ def build_graph( self.n_states, self.n_shocks = R_shape[-2:] self.n_endog = Z_shape[-2] + if is_batched: + batched_inputs = [data, a0, P0, c, d, T, Z, R, H, Q] + data, a0, P0, c, d, T, Z, R, H, Q = self.get_dummy_core_inputs(*batched_inputs) + data, a0, P0, *params = self.check_params(data, a0, P0, c, d, T, Z, R, H, Q) sequences, non_sequences, seq_names, non_seq_names = split_vars_into_seq_and_nonseq( @@ -233,8 +257,13 @@ def build_graph( filter_results = self._postprocess_scan_results(results, a0, P0, n=data.type.shape[0]) + if is_batched: + vec_subs = dict(zip([data, a0, P0, c, d, T, Z, R, H, Q], batched_inputs)) + filter_results = vectorize_graph(filter_results, vec_subs) + if return_updates: return filter_results, updates + return filter_results def _postprocess_scan_results(self, results, a0, P0, n) -> list[TensorVariable]: diff --git a/pymc_extras/statespace/filters/kalman_smoother.py b/pymc_extras/statespace/filters/kalman_smoother.py index f15913b8..671d9366 100644 --- a/pymc_extras/statespace/filters/kalman_smoother.py +++ b/pymc_extras/statespace/filters/kalman_smoother.py @@ -3,7 +3,7 @@ from pytensor.compile import get_mode from pytensor.tensor.nlinalg import matrix_dot - +from pytensor.graph.replace import vectorize_graph from pymc_extras.statespace.filters.utilities import ( quad_form_sym, split_vars_into_seq_and_nonseq, @@ -11,6 +11,8 @@ ) from pymc_extras.statespace.utils.constants import JITTER_DEFAULT +SMOOTHER_CORE_NDIM = (2, 2, 2, 2, 3) + class KalmanSmoother: """ @@ -63,12 +65,41 @@ def unpack_args(self, args): return a, P, a_smooth, P_smooth, T, R, Q + def has_batched_input(self, T, R, Q, filtered_states, filtered_covariances): + """ + Check if any of the inputs are batched. + """ + return any( + x.ndim > SMOOTHER_CORE_NDIM[i] + for i, x in enumerate([T, R, Q, filtered_states, filtered_covariances]) + ) + + def get_dummy_core_inputs(self, T, R, Q, filtered_states, filtered_covariances): + """ + Get dummy inputs for the core parameters. + """ + out = [] + for x, core_ndim in zip( + [T, R, Q, filtered_states, filtered_covariances], SMOOTHER_CORE_NDIM + ): + out.append( + pt.tensor(f"{x.name}_core_case", dtype=x.dtype, shape=x.type.shape[-core_ndim:]) + ) + return out + def build_graph( self, T, R, Q, filtered_states, filtered_covariances, mode=None, cov_jitter=JITTER_DEFAULT ): self.mode = mode self.cov_jitter = cov_jitter + is_batched = self.has_batched_input(T, R, Q, filtered_states, filtered_covariances) + if is_batched: + batched_inputs = [T, R, Q, filtered_states, filtered_covariances] + T, R, Q, filtered_states, filtered_covariances = self.get_dummy_core_inputs( + *batched_inputs + ) + n, k = filtered_states.type.shape a_last = pt.specify_shape(filtered_states[-1], (k,)) @@ -98,6 +129,12 @@ def build_graph( smoothed_covariances = pt.concatenate( [smoothed_covariances[::-1], pt.expand_dims(P_last, axis=(0,))], axis=0 ) + smoothed_states.dprint() + if is_batched: + vec_subs = dict(zip([T, R, Q, filtered_states, filtered_covariances], batched_inputs)) + smoothed_states, smoothed_covariances = vectorize_graph( + [smoothed_states, smoothed_covariances], vec_subs + ) smoothed_states.name = "smoothed_states" smoothed_covariances.name = "smoothed_covariances" diff --git a/tests/statespace/test_kalman_filter.py b/tests/statespace/test_kalman_filter.py index 6c0bc18c..3cdfa569 100644 --- a/tests/statespace/test_kalman_filter.py +++ b/tests/statespace/test_kalman_filter.py @@ -31,19 +31,22 @@ RTOL = 1e-6 if floatX.endswith("64") else 1e-3 standard_inout = initialize_filter(StandardFilter()) +standard_inout_batched = initialize_filter(StandardFilter(), batched=True) cholesky_inout = initialize_filter(SquareRootFilter()) univariate_inout = initialize_filter(UnivariateFilter()) f_standard = pytensor.function(*standard_inout, on_unused_input="ignore") +f_standard_batched = pytensor.function(*standard_inout_batched, on_unused_input="ignore") f_cholesky = pytensor.function(*cholesky_inout, on_unused_input="ignore") f_univariate = pytensor.function(*univariate_inout, on_unused_input="ignore") -filter_funcs = [f_standard, f_cholesky, f_univariate] +filter_funcs = [f_standard, f_standard_batched] # , f_cholesky, f_univariate] filter_names = [ "StandardFilter", - "CholeskyFilter", - "UnivariateFilter", + "StandardFilterBatched", + # "CholeskyFilter", + # "UnivariateFilter", ] output_names = [ @@ -65,17 +68,21 @@ def test_base_class_update_raises(): filter.update(*inputs) -@pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names) -def test_output_shapes_one_state_one_observed(filter_func, rng): +@pytest.mark.parametrize( + "filter_func, filter_name", zip(filter_funcs, filter_names), ids=filter_names +) +def test_output_shapes_one_state_one_observed(filter_func, filter_name, rng): + batch_size = 3 if "batched" in filter_name.lower() else 0 p, m, r, n = 1, 1, 1, 10 - inputs = make_test_inputs(p, m, r, n, rng) - outputs = filter_func(*inputs) + inputs = make_test_inputs(p, m, r, n, rng, batch_size=batch_size) + assert 0 + # outputs = filter_func(*inputs) for output_idx, name in enumerate(output_names): - expected_output = get_expected_shape(name, p, m, r, n) - assert ( - outputs[output_idx].shape == expected_output - ), f"Shape of {name} does not match expected" + expected_shape = get_expected_shape(name, p, m, r, n, batch_size) + # assert outputs[output_idx].shape == expected_shape, ( + # f"Shape of {name} does not match expected" + # ) @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names) @@ -86,9 +93,9 @@ def test_output_shapes_when_all_states_are_stochastic(filter_func, rng): outputs = filter_func(*inputs) for output_idx, name in enumerate(output_names): expected_output = get_expected_shape(name, p, m, r, n) - assert ( - outputs[output_idx].shape == expected_output - ), f"Shape of {name} does not match expected" + assert outputs[output_idx].shape == expected_output, ( + f"Shape of {name} does not match expected" + ) @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names) @@ -99,9 +106,9 @@ def test_output_shapes_when_some_states_are_deterministic(filter_func, rng): outputs = filter_func(*inputs) for output_idx, name in enumerate(output_names): expected_output = get_expected_shape(name, p, m, r, n) - assert ( - outputs[output_idx].shape == expected_output - ), f"Shape of {name} does not match expected" + assert outputs[output_idx].shape == expected_output, ( + f"Shape of {name} does not match expected" + ) @pytest.fixture @@ -161,9 +168,9 @@ def test_output_shapes_with_time_varying_matrices(f_standard_nd, rng): for output_idx, name in enumerate(output_names): expected_output = get_expected_shape(name, p, m, r, n) - assert ( - outputs[output_idx].shape == expected_output - ), f"Shape of {name} does not match expected" + assert outputs[output_idx].shape == expected_output, ( + f"Shape of {name} does not match expected" + ) @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names) @@ -175,9 +182,9 @@ def test_output_with_deterministic_observation_equation(filter_func, rng): for output_idx, name in enumerate(output_names): expected_output = get_expected_shape(name, p, m, r, n) - assert ( - outputs[output_idx].shape == expected_output - ), f"Shape of {name} does not match expected" + assert outputs[output_idx].shape == expected_output, ( + f"Shape of {name} does not match expected" + ) @pytest.mark.parametrize( @@ -190,9 +197,9 @@ def test_output_with_multiple_observed(filter_func, filter_name, rng): outputs = filter_func(*inputs) for output_idx, name in enumerate(output_names): expected_output = get_expected_shape(name, p, m, r, n) - assert ( - outputs[output_idx].shape == expected_output - ), f"Shape of {name} does not match expected" + assert outputs[output_idx].shape == expected_output, ( + f"Shape of {name} does not match expected" + ) @pytest.mark.parametrize( @@ -206,9 +213,9 @@ def test_missing_data(filter_func, filter_name, p, rng): outputs = filter_func(*inputs) for output_idx, name in enumerate(output_names): expected_output = get_expected_shape(name, p, m, r, n) - assert ( - outputs[output_idx].shape == expected_output - ), f"Shape of {name} does not match expected" + assert outputs[output_idx].shape == expected_output, ( + f"Shape of {name} does not match expected" + ) @pytest.mark.parametrize("filter_func", filter_funcs, ids=filter_names) diff --git a/tests/statespace/utilities/test_helpers.py b/tests/statespace/utilities/test_helpers.py index c6170f88..fa970f14 100644 --- a/tests/statespace/utilities/test_helpers.py +++ b/tests/statespace/utilities/test_helpers.py @@ -34,18 +34,18 @@ def load_nile_test_data(): return nile -def initialize_filter(kfilter, mode=None, p=None, m=None, r=None, n=None): +def initialize_filter(kfilter, mode=None, p=None, m=None, r=None, n=None, batched=False): ksmoother = KalmanSmoother() - data = pt.tensor(name="data", dtype=floatX, shape=(n, p)) - a0 = pt.tensor(name="x0", dtype=floatX, shape=(m,)) - P0 = pt.tensor(name="P0", dtype=floatX, shape=(m, m)) - c = pt.tensor(name="c", dtype=floatX, shape=(m,)) - d = pt.tensor(name="d", dtype=floatX, shape=(p,)) - Q = pt.tensor(name="Q", dtype=floatX, shape=(r, r)) - H = pt.tensor(name="H", dtype=floatX, shape=(p, p)) - T = pt.tensor(name="T", dtype=floatX, shape=(m, m)) - R = pt.tensor(name="R", dtype=floatX, shape=(m, r)) - Z = pt.tensor(name="Z", dtype=floatX, shape=(p, m)) + data = pt.tensor(name="data", dtype=floatX, shape=(None, n, p) if batched else (n, p)) + a0 = pt.tensor(name="x0", dtype=floatX, shape=(None, m) if batched else (m,)) + P0 = pt.tensor(name="P0", dtype=floatX, shape=(None, m, m) if batched else (m, m)) + c = pt.tensor(name="c", dtype=floatX, shape=(None, m) if batched else (m,)) + d = pt.tensor(name="d", dtype=floatX, shape=(None, p) if batched else (p,)) + Q = pt.tensor(name="Q", dtype=floatX, shape=(None, r, r) if batched else (r, r)) + H = pt.tensor(name="H", dtype=floatX, shape=(None, p, p) if batched else (p, p)) + T = pt.tensor(name="T", dtype=floatX, shape=(None, m, m) if batched else (m, m)) + R = pt.tensor(name="R", dtype=floatX, shape=(None, m, r) if batched else (m, r)) + Z = pt.tensor(name="Z", dtype=floatX, shape=(None, p, m) if batched else (p, m)) inputs = [data, a0, P0, c, d, T, Z, R, H, Q] @@ -68,7 +68,7 @@ def initialize_filter(kfilter, mode=None, p=None, m=None, r=None, n=None): filtered_covs, predicted_covs, smoothed_covs, - ll_obs.sum(), + ll_obs.sum(axis=-1), ll_obs, ] @@ -83,7 +83,7 @@ def add_missing_data(data, n_missing, rng): return data -def make_test_inputs(p, m, r, n, rng, missing_data=None, H_is_zero=False): +def make_1d_test_inputs(p, m, r, n, rng, missing_data=None, H_is_zero=False): data = np.arange(n * p, dtype=floatX).reshape(-1, p) if missing_data is not None: data = add_missing_data(data, missing_data, rng) @@ -106,16 +106,34 @@ def make_test_inputs(p, m, r, n, rng, missing_data=None, H_is_zero=False): return data, a0, P0, c, d, T, Z, R, H, Q -def get_expected_shape(name, p, m, r, n): +def make_test_inputs(p, m, r, n, rng, missing_data=None, H_is_zero=False, batch_size=0): + if batch_size == 0: + return make_1d_test_inputs(p, m, r, n, rng, missing_data, H_is_zero) + + # Create individual inputs for each batch + np_batch_inputs = [] + for i in range(batch_size): + inputs = make_1d_test_inputs(p, m, r, n, rng, missing_data, H_is_zero) + np_batch_inputs.append(inputs) + + return [np.stack(x, axis=0) for x in zip(*np_batch_inputs)] + + +def get_expected_shape(name, p, m, r, n, batch_size=0): if name == "log_likelihood": - return () + shape = () elif name == "ll_obs": - return (n,) + shape = (n,) filter_type, variable = name.split("_") if variable == "states": - return n, m + shape = n, m if variable == "covs": - return n, m, m + shape = n, m, m + + if batch_size != 0: + shape = (batch_size, *shape) + + return shape def get_sm_state_from_output_name(res, name):