diff --git a/gtsam/hybrid/HybridNonlinearFactor.cpp b/gtsam/hybrid/HybridNonlinearFactor.cpp new file mode 100644 index 0000000000..9613233b15 --- /dev/null +++ b/gtsam/hybrid/HybridNonlinearFactor.cpp @@ -0,0 +1,134 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridNonlinearFactor.h + * @brief A set of nonlinear factors indexed by a set of discrete keys. + * @author Varun Agrawal + * @date Sep 12, 2024 + */ + +#include + +namespace gtsam { + +/* *******************************************************************************/ +HybridNonlinearFactor::HybridNonlinearFactor(const KeyVector& keys, + const DiscreteKeys& discreteKeys, + const Factors& factors) + : Base(keys, discreteKeys), factors_(factors) {} + +/* *******************************************************************************/ +AlgebraicDecisionTree HybridNonlinearFactor::errorTree( + const Values& continuousValues) const { + // functor to convert from sharedFactor to double error value. + auto errorFunc = + [continuousValues](const std::pair& f) { + auto [factor, val] = f; + return factor->error(continuousValues) + val; + }; + DecisionTree result(factors_, errorFunc); + return result; +} + +/* *******************************************************************************/ +double HybridNonlinearFactor::error( + const Values& continuousValues, + const DiscreteValues& discreteValues) const { + // Retrieve the factor corresponding to the assignment in discreteValues. + auto [factor, val] = factors_(discreteValues); + // Compute the error for the selected factor + const double factorError = factor->error(continuousValues); + return factorError + val; +} + +/* *******************************************************************************/ +double HybridNonlinearFactor::error(const HybridValues& values) const { + return error(values.nonlinear(), values.discrete()); +} + +/* *******************************************************************************/ +size_t HybridNonlinearFactor::dim() const { + const auto assignments = DiscreteValues::CartesianProduct(discreteKeys_); + auto [factor, val] = factors_(assignments.at(0)); + return factor->dim(); +} + +/* *******************************************************************************/ +void HybridNonlinearFactor::print(const std::string& s, + const KeyFormatter& keyFormatter) const { + std::cout << (s.empty() ? "" : s + " "); + Base::print("", keyFormatter); + std::cout << "\nHybridNonlinearFactor\n"; + auto valueFormatter = [](const std::pair& v) { + auto [factor, val] = v; + if (factor) { + return "Nonlinear factor on " + std::to_string(factor->size()) + " keys"; + } else { + return std::string("nullptr"); + } + }; + factors_.print("", keyFormatter, valueFormatter); +} + +/* *******************************************************************************/ +bool HybridNonlinearFactor::equals(const HybridFactor& other, + double tol) const { + // We attempt a dynamic cast from HybridFactor to HybridNonlinearFactor. If + // it fails, return false. + if (!dynamic_cast(&other)) return false; + + // If the cast is successful, we'll properly construct a + // HybridNonlinearFactor object from `other` + const HybridNonlinearFactor& f( + static_cast(other)); + + // Ensure that this HybridNonlinearFactor and `f` have the same `factors_`. + auto compare = [tol](const std::pair& a, + const std::pair& b) { + return traits::Equals(*a.first, *b.first, tol) && + (a.second == b.second); + }; + if (!factors_.equals(f.factors_, compare)) return false; + + // If everything above passes, and the keys_ and discreteKeys_ + // member variables are identical, return true. + return (std::equal(keys_.begin(), keys_.end(), f.keys().begin()) && + (discreteKeys_ == f.discreteKeys_)); +} + +/* *******************************************************************************/ +GaussianFactor::shared_ptr HybridNonlinearFactor::linearize( + const Values& continuousValues, + const DiscreteValues& discreteValues) const { + auto factor = factors_(discreteValues).first; + return factor->linearize(continuousValues); +} + +/* *******************************************************************************/ +std::shared_ptr HybridNonlinearFactor::linearize( + const Values& continuousValues) const { + // functional to linearize each factor in the decision tree + auto linearizeDT = + [continuousValues]( + const std::pair& f) -> GaussianFactorValuePair { + auto [factor, val] = f; + return {factor->linearize(continuousValues), val}; + }; + + DecisionTree> + linearized_factors(factors_, linearizeDT); + + return std::make_shared(continuousKeys_, discreteKeys_, + linearized_factors); +} + +} // namespace gtsam \ No newline at end of file diff --git a/gtsam/hybrid/HybridNonlinearFactor.h b/gtsam/hybrid/HybridNonlinearFactor.h index a0c7af92be..6da846abe5 100644 --- a/gtsam/hybrid/HybridNonlinearFactor.h +++ b/gtsam/hybrid/HybridNonlinearFactor.h @@ -11,7 +11,7 @@ /** * @file HybridNonlinearFactor.h - * @brief Nonlinear Mixture factor of continuous and discrete. + * @brief A set of nonlinear factors indexed by a set of discrete keys. * @author Kevin Doherty, kdoherty@mit.edu * @author Varun Agrawal * @date December 2021 @@ -57,7 +57,7 @@ using NonlinearFactorValuePair = std::pair; * * @ingroup hybrid */ -class HybridNonlinearFactor : public HybridFactor { +class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor { public: using Base = HybridFactor; using This = HybridNonlinearFactor; @@ -85,8 +85,7 @@ class HybridNonlinearFactor : public HybridFactor { * @param factors Decision tree with of shared factors. */ HybridNonlinearFactor(const KeyVector& keys, const DiscreteKeys& discreteKeys, - const Factors& factors) - : Base(keys, discreteKeys), factors_(factors) {} + const Factors& factors); /** * @brief Convenience constructor that generates the underlying factor @@ -140,16 +139,7 @@ class HybridNonlinearFactor : public HybridFactor { * @return AlgebraicDecisionTree A decision tree with the same keys * as the factor, and leaf values as the error. */ - AlgebraicDecisionTree errorTree(const Values& continuousValues) const { - // functor to convert from sharedFactor to double error value. - auto errorFunc = - [continuousValues](const std::pair& f) { - auto [factor, val] = f; - return factor->error(continuousValues) + val; - }; - DecisionTree result(factors_, errorFunc); - return result; - } + AlgebraicDecisionTree errorTree(const Values& continuousValues) const; /** * @brief Compute error of factor given both continuous and discrete values. @@ -159,13 +149,7 @@ class HybridNonlinearFactor : public HybridFactor { * @return double The error of this factor. */ double error(const Values& continuousValues, - const DiscreteValues& discreteValues) const { - // Retrieve the factor corresponding to the assignment in discreteValues. - auto [factor, val] = factors_(discreteValues); - // Compute the error for the selected factor - const double factorError = factor->error(continuousValues); - return factorError + val; - } + const DiscreteValues& discreteValues) const; /** * @brief Compute error of factor given hybrid values. @@ -173,67 +157,24 @@ class HybridNonlinearFactor : public HybridFactor { * @param values The continuous Values and the discrete assignment. * @return double The error of this factor. */ - double error(const HybridValues& values) const override { - return error(values.nonlinear(), values.discrete()); - } + double error(const HybridValues& values) const override; /** * @brief Get the dimension of the factor (number of rows on linearization). * Returns the dimension of the first component factor. * @return size_t */ - size_t dim() const { - const auto assignments = DiscreteValues::CartesianProduct(discreteKeys_); - auto [factor, val] = factors_(assignments.at(0)); - return factor->dim(); - } + size_t dim() const; /// Testable /// @{ /// print to stdout - void print( - const std::string& s = "", - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override { - std::cout << (s.empty() ? "" : s + " "); - Base::print("", keyFormatter); - std::cout << "\nHybridNonlinearFactor\n"; - auto valueFormatter = [](const std::pair& v) { - auto [factor, val] = v; - if (factor) { - return "Nonlinear factor on " + std::to_string(factor->size()) + - " keys"; - } else { - return std::string("nullptr"); - } - }; - factors_.print("", keyFormatter, valueFormatter); - } + void print(const std::string& s = "", const KeyFormatter& keyFormatter = + DefaultKeyFormatter) const override; /// Check equality - bool equals(const HybridFactor& other, double tol = 1e-9) const override { - // We attempt a dynamic cast from HybridFactor to HybridNonlinearFactor. If - // it fails, return false. - if (!dynamic_cast(&other)) return false; - - // If the cast is successful, we'll properly construct a - // HybridNonlinearFactor object from `other` - const HybridNonlinearFactor& f( - static_cast(other)); - - // Ensure that this HybridNonlinearFactor and `f` have the same `factors_`. - auto compare = [tol](const std::pair& a, - const std::pair& b) { - return traits::Equals(*a.first, *b.first, tol) && - (a.second == b.second); - }; - if (!factors_.equals(f.factors_, compare)) return false; - - // If everything above passes, and the keys_ and discreteKeys_ - // member variables are identical, return true. - return (std::equal(keys_.begin(), keys_.end(), f.keys().begin()) && - (discreteKeys_ == f.discreteKeys_)); - } + bool equals(const HybridFactor& other, double tol = 1e-9) const override; /// @} @@ -241,28 +182,11 @@ class HybridNonlinearFactor : public HybridFactor { /// discreteValues. GaussianFactor::shared_ptr linearize( const Values& continuousValues, - const DiscreteValues& discreteValues) const { - auto factor = factors_(discreteValues).first; - return factor->linearize(continuousValues); - } + const DiscreteValues& discreteValues) const; /// Linearize all the continuous factors to get a HybridGaussianFactor. std::shared_ptr linearize( - const Values& continuousValues) const { - // functional to linearize each factor in the decision tree - auto linearizeDT = - [continuousValues](const std::pair& f) - -> GaussianFactorValuePair { - auto [factor, val] = f; - return {factor->linearize(continuousValues), val}; - }; - - DecisionTree> - linearized_factors(factors_, linearizeDT); - - return std::make_shared( - continuousKeys_, discreteKeys_, linearized_factors); - } + const Values& continuousValues) const; }; } // namespace gtsam