Skip to content

Commit

Permalink
make generic over nalgebra storage WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
astraw committed May 9, 2024
1 parent f7c20fc commit 760bd31
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 135 deletions.
193 changes: 75 additions & 118 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
#![allow(non_snake_case)]
#[cfg(feature = "std")]
use log::trace;
use na::{OMatrix, OVector};
use na::{Matrix, Vector};
use na::{RawStorageMut, Storage, StorageMut};

use nalgebra as na;
use nalgebra::base::dimension::DimMin;

Expand Down Expand Up @@ -81,52 +83,37 @@ mod state_and_covariance;
pub use state_and_covariance::StateAndCovariance;

/// A linear model of process dynamics with no control inputs
pub trait TransitionModelLinearNoControl<R, SS>
pub trait TransitionModelLinearNoControl<R, SS, S1, S2>
where
R: RealField,
SS: DimName,
DefaultAllocator: Allocator<R, SS, SS>,
DefaultAllocator: Allocator<R, SS>,
// DefaultAllocator: Allocator<R, SS, SS>,
// DefaultAllocator: Allocator<R, SS>,
S1: Clone + Storage<R, SS> + StorageMut<R, SS>,
S2: Clone + Storage<R, SS, SS>,
{
/// Get the state transition model, `F`.
fn F(&self) -> &OMatrix<R, SS, SS>;
fn F(&self) -> &Matrix<R, SS, SS, S2>;

/// Get the transpose of the state transition model, `FT`.
fn FT(&self) -> &OMatrix<R, SS, SS>;
fn FT(&self) -> &Matrix<R, SS, SS, S2>;

/// Get the process covariance, `Q`.
fn Q(&self) -> &OMatrix<R, SS, SS>;
fn Q(&self) -> &Matrix<R, SS, SS, S2>;

/// Predict new state from previous estimate.
fn predict(&self, previous_estimate: &StateAndCovariance<R, SS>) -> StateAndCovariance<R, SS> {
fn predict(
&self,
previous_estimate: &StateAndCovariance<R, SS, S1, S2>,
) -> StateAndCovariance<R, SS, S1, S2> {
// The prior.
let P = previous_estimate.state();
let F = self.F();
let state = F * P;
let mut state = P.clone(); // allocate output
F.mul_to(P, &mut state);
let covariance = ((F * previous_estimate.covariance()) * self.FT()) + self.Q();
StateAndCovariance::new(state, covariance)
}

/// Get the state transition model, `F`.
#[deprecated(since = "0.8.0", note = "Please use the F function instead")]
#[inline]
fn transition_model(&self) -> &OMatrix<R, SS, SS> {
self.F()
}

/// Get the transpose of the state transition model, `FT`.
#[deprecated(since = "0.8.0", note = "Please use the FT function instead")]
#[inline]
fn transition_model_transpose(&self) -> &OMatrix<R, SS, SS> {
self.FT()
}

/// Get the transition noise covariance.
#[deprecated(since = "0.8.0", note = "Please use the Q function instead")]
#[inline]
fn transition_noise_covariance(&self) -> &OMatrix<R, SS, SS> {
self.Q()
}
}

/// An observation model, potentially non-linear.
Expand All @@ -136,7 +123,7 @@ where
/// as the basis for a `ObservationModel` implementation. This would be done
/// every timestep. For an example, see
/// [`nonlinear_observation.rs`](https://github.com/strawlab/adskalman-rs/blob/main/examples/src/bin/nonlinear_observation.rs).
pub trait ObservationModel<R, SS, OS>
pub trait ObservationModel<R, SS, OS, S1, S2>
where
R: RealField,
SS: DimName,
Expand All @@ -148,6 +135,10 @@ where
DefaultAllocator: Allocator<R, OS, OS>,
DefaultAllocator: Allocator<R, OS>,
DefaultAllocator: Allocator<(usize, usize), OS>,
// <DefaultAllocator as Allocator<R, SS, SS>>::Buffer = S2,
S1: Clone + Storage<R, SS> + Storage<R, OS>,
S2: Clone + Storage<R, OS, SS> + Storage<R, SS, SS> + Storage<R, SS, OS> + Storage<R, OS, OS>,
Matrix<R, SS, SS, S2>: One,
{
/// For a given state, predict the observation.
///
Expand All @@ -162,29 +153,29 @@ where
/// this trait and must be evaluated for a state for which no observation is
/// possible.) Observations with NaN values are treated as missing
/// observations.
fn predict_observation(&self, state: &OVector<R, SS>) -> OVector<R, OS> {
fn predict_observation(&self, state: &Vector<R, SS, S1>) -> Vector<R, OS, S1> {
self.H() * state
}

/// Get the observation matrix, `H`.
fn H(&self) -> &OMatrix<R, OS, SS>;
fn H(&self) -> &Matrix<R, OS, SS, S2>;

/// Get the transpose of the observation matrix, `HT`.
fn HT(&self) -> &OMatrix<R, SS, OS>;
fn HT(&self) -> &Matrix<R, SS, OS, S2>;

/// Get the observation noise covariance, `R`.
// TODO: ensure this is positive definite?
fn R(&self) -> &OMatrix<R, OS, OS>;
fn R(&self) -> &Matrix<R, OS, OS, S2>;

/// Given prior state and observation, estimate the posterior state.
///
/// This is the *update* step in the Kalman filter literature.
fn update(
&self,
prior: &StateAndCovariance<R, SS>,
observation: &OVector<R, OS>,
prior: &StateAndCovariance<R, SS, S1, S2>,
observation: &Vector<R, OS, S1>,
covariance_method: CovarianceUpdateMethod,
) -> Result<StateAndCovariance<R, SS>, Error> {
) -> Result<StateAndCovariance<R, SS, S1, S2>, Error> {
let h = self.H();
trace!("h {}", pretty_print!(h));

Expand Down Expand Up @@ -217,28 +208,28 @@ where
return Err(ErrorKind::CovarianceNotPositiveSemiDefinite.into());
}
};
let s_inv: OMatrix<R, OS, OS> = s_chol.inverse();
let s_inv: Matrix<R, OS, OS, S2> = s_chol.inverse();
trace!("s_inv {}", pretty_print!(s_inv));

let k_gain: OMatrix<R, SS, OS> = p * ht * s_inv;
let k_gain: Matrix<R, SS, OS, S2> = p * ht * s_inv;
// let k_gain: OMatrix<R,SS,OS> = solve!( (p*ht), s );
trace!("k_gain {}", pretty_print!(k_gain));

let predicted: OVector<R, OS> = self.predict_observation(prior.state());
let predicted: Vector<R, OS, S1> = self.predict_observation(prior.state());
trace!("predicted {}", pretty_print!(predicted));
trace!("observation {}", pretty_print!(observation));
let innovation: OVector<R, OS> = observation - predicted;
let innovation: Vector<R, OS, S1> = observation - predicted;
trace!("innovation {}", pretty_print!(innovation));
let state: OVector<R, SS> = prior.state() + &k_gain * innovation;
let state: Vector<R, SS, S1> = prior.state() + &k_gain * innovation;
trace!("state {}", pretty_print!(state));

trace!("self.observation_matrix() {}", pretty_print!(self.H()));
let kh: OMatrix<R, SS, SS> = &k_gain * self.H();
let kh: Matrix<R, SS, SS, S2> = &k_gain * self.H();
trace!("kh {}", pretty_print!(kh));
let one_minus_kh = OMatrix::<R, SS, SS>::one() - kh;
let one_minus_kh = Matrix::<R, SS, SS, S2>::one() - kh;
trace!("one_minus_kh {}", pretty_print!(one_minus_kh));

let covariance: OMatrix<R, SS, SS> = match covariance_method {
let covariance: Matrix<R, SS, SS, S2> = match covariance_method {
CovarianceUpdateMethod::JosephForm => {
// Joseph form of covariance update keeps covariance matrix symmetric.

Expand All @@ -261,43 +252,6 @@ where

Ok(StateAndCovariance::new(state, covariance))
}

/// Get the observation matrix, `H`.
#[deprecated(since = "0.8.0", note = "Please use the H function instead")]
#[inline]
fn observation_matrix(&self) -> &OMatrix<R, OS, SS> {
self.H()
}

/// Get the transpose of the observation matrix, `HT`.
#[deprecated(since = "0.8.0", note = "Please use the HT function instead")]
#[inline]
fn observation_matrix_transpose(&self) -> &OMatrix<R, SS, OS> {
self.HT()
}

/// Get the observation noise covariance, `R`.
#[deprecated(since = "0.8.0", note = "Please use the R function instead")]
#[inline]
fn observation_noise_covariance(&self) -> &OMatrix<R, OS, OS> {
self.R()
}

/// For a given state, predict the observation.
///
/// If an observation is not possible, this returns NaN values. (This
/// happens, for example, when a non-linear observation model implements
/// this trait and must be evaluated for a state for which no observation is
/// possible.) Observations with NaN values are treated as missing
/// observations.
#[deprecated(
since = "0.8.0",
note = "Please use the predict_observation function instead"
)]
#[inline]
fn evaluate(&self, state: &OVector<R, SS>) -> OVector<R, OS> {
self.predict_observation(state)
}
}

/// Specifies the approach used for updating the covariance matrix
Expand Down Expand Up @@ -325,28 +279,30 @@ pub enum CovarianceUpdateMethod {
/// bound of this struct, a useful strategy to avoid requiring lifetime
/// annotations is to construct it just before [Self::step] and then dropping it
/// immediately afterward.
pub struct KalmanFilterNoControl<'a, R, SS, OS>
pub struct KalmanFilterNoControl<'a, R, SS, OS, S1, S2>
where
R: RealField,
SS: DimName,
OS: DimName,
{
transition_model: &'a dyn TransitionModelLinearNoControl<R, SS>,
observation_matrix: &'a dyn ObservationModel<R, SS, OS>,
transition_model: &'a dyn TransitionModelLinearNoControl<R, SS, S1, S2>,
observation_matrix: &'a dyn ObservationModel<R, SS, OS, S1, S2>,
}

impl<'a, R, SS, OS> KalmanFilterNoControl<'a, R, SS, OS>
impl<'a, R, SS, OS, S1, S2> KalmanFilterNoControl<'a, R, SS, OS, S1, S2>
where
R: RealField,
SS: DimName,
OS: DimName + DimMin<OS, Output = OS>,
DefaultAllocator: Allocator<R, SS, SS>,
DefaultAllocator: Allocator<R, SS>,
DefaultAllocator: Allocator<R, OS, SS>,
DefaultAllocator: Allocator<R, SS, OS>,
DefaultAllocator: Allocator<R, OS, OS>,
DefaultAllocator: Allocator<R, OS>,
DefaultAllocator: Allocator<(usize, usize), OS>,
// DefaultAllocator: Allocator<R, SS, SS>,
// DefaultAllocator: Allocator<R, SS>,
// DefaultAllocator: Allocator<R, OS, SS>,
// DefaultAllocator: Allocator<R, SS, OS>,
// DefaultAllocator: Allocator<R, OS, OS>,
// DefaultAllocator: Allocator<R, OS>,
// DefaultAllocator: Allocator<(usize, usize), OS>,
S1: Clone + Storage<R, SS> + Storage<R, OS> + RawStorageMut<R, SS>,
S2: Clone + Storage<R, SS, SS> + Storage<R, OS, SS> + Storage<R, SS, OS> + Storage<R, OS, OS>,
{
/// Initialize a new `KalmanFilterNoControl` struct.
///
Expand All @@ -356,8 +312,8 @@ where
/// including the measurement function `H` and the measurement covariance
/// `R`.
pub fn new(
transition_model: &'a dyn TransitionModelLinearNoControl<R, SS>,
observation_matrix: &'a dyn ObservationModel<R, SS, OS>,
transition_model: &'a dyn TransitionModelLinearNoControl<R, SS, S1, S2>,
observation_matrix: &'a dyn ObservationModel<R, SS, OS, S1, S2>,
) -> Self {
Self {
transition_model,
Expand All @@ -380,9 +336,9 @@ where
/// [step_with_options](struct.KalmanFilterNoControl.html#method.step_with_options).
pub fn step(
&self,
previous_estimate: &StateAndCovariance<R, SS>,
observation: &OVector<R, OS>,
) -> Result<StateAndCovariance<R, SS>, Error> {
previous_estimate: &StateAndCovariance<R, SS, S1, S2>,
observation: &Vector<R, OS, S1>,
) -> Result<StateAndCovariance<R, SS, S1, S2>, Error> {
self.step_with_options(
previous_estimate,
observation,
Expand All @@ -401,10 +357,10 @@ where
/// observation model using the specified covariance update method.
pub fn step_with_options(
&self,
previous_estimate: &StateAndCovariance<R, SS>,
observation: &OVector<R, OS>,
previous_estimate: &StateAndCovariance<R, SS, S1, S2>,
observation: &Vector<R, OS, S1>,
covariance_update_method: CovarianceUpdateMethod,
) -> Result<StateAndCovariance<R, SS>, Error> {
) -> Result<StateAndCovariance<R, SS, S1, S2>, Error> {
let prior = self.transition_model.predict(previous_estimate);
if observation.iter().any(|x| is_nan(x.clone())) {
Ok(prior)
Expand All @@ -425,9 +381,9 @@ where
/// If any observation has a NaN component, it is treated as missing.
pub fn filter_inplace(
&self,
initial_estimate: &StateAndCovariance<R, SS>,
observations: &[OVector<R, OS>],
state_estimates: &mut [StateAndCovariance<R, SS>],
initial_estimate: &StateAndCovariance<R, SS, S1, S2>,
observations: &[Vector<R, OS, S1>],
state_estimates: &mut [StateAndCovariance<R, SS, S1, S2>],
) -> Result<(), Error> {
let mut previous_estimate = initial_estimate.clone();
assert!(state_estimates.len() >= observations.len());
Expand All @@ -448,11 +404,12 @@ where
#[cfg(feature = "std")]
pub fn filter(
&self,
initial_estimate: &StateAndCovariance<R, SS>,
observations: &[OVector<R, OS>],
) -> Result<Vec<StateAndCovariance<R, SS>>, Error> {
initial_estimate: &StateAndCovariance<R, SS, S1, S2>,
observations: &[Vector<R, OS, S1>],
) -> Result<Vec<StateAndCovariance<R, SS, S1, S2>>, Error> {
let mut state_estimates = Vec::with_capacity(observations.len());
let empty = StateAndCovariance::new(na::zero(), na::OMatrix::<R, SS, SS>::identity());
let empty: StateAndCovariance<R, SS, S1, S2> =
StateAndCovariance::new(S1::zero(), S2::identity());
for _ in 0..observations.len() {
state_estimates.push(empty.clone());
}
Expand All @@ -477,9 +434,9 @@ where
#[cfg(feature = "std")]
pub fn smooth(
&self,
initial_estimate: &StateAndCovariance<R, SS>,
observations: &[OVector<R, OS>],
) -> Result<Vec<StateAndCovariance<R, SS>>, Error> {
initial_estimate: &StateAndCovariance<R, SS, S1, S2>,
observations: &[Vector<R, OS, S1>],
) -> Result<Vec<StateAndCovariance<R, SS, S1, S2>>, Error> {
let forward_results = self.filter(initial_estimate, observations)?;
self.smooth_from_filtered(forward_results)
}
Expand All @@ -492,8 +449,8 @@ where
#[cfg(feature = "std")]
pub fn smooth_from_filtered(
&self,
mut forward_results: Vec<StateAndCovariance<R, SS>>,
) -> Result<Vec<StateAndCovariance<R, SS>>, Error> {
mut forward_results: Vec<StateAndCovariance<R, SS, S1, S2>>,
) -> Result<Vec<StateAndCovariance<R, SS, S1, S2>>, Error> {
forward_results.reverse();

let mut smoothed_backwards = Vec::with_capacity(forward_results.len());
Expand All @@ -512,9 +469,9 @@ where
#[cfg(feature = "std")]
fn smooth_step(
&self,
smooth_future: &StateAndCovariance<R, SS>,
filt: &StateAndCovariance<R, SS>,
) -> Result<StateAndCovariance<R, SS>, Error> {
smooth_future: &StateAndCovariance<R, SS, S1, S2>,
filt: &StateAndCovariance<R, SS, S1, S2>,
) -> Result<StateAndCovariance<R, SS, S1, S2>, Error> {
let prior = self.transition_model.predict(filt);

let v_chol = match na::linalg::Cholesky::new(prior.covariance().clone()) {
Expand All @@ -523,7 +480,7 @@ where
return Err(ErrorKind::CovarianceNotPositiveSemiDefinite.into());
}
};
let inv_prior_covariance: OMatrix<R, SS, SS> = v_chol.inverse();
let inv_prior_covariance: Matrix<R, SS, SS, S2> = v_chol.inverse();
trace!(
"inv_prior_covariance {}",
pretty_print!(inv_prior_covariance)
Expand Down
Loading

0 comments on commit 760bd31

Please sign in to comment.