From 0a8ab715f09341eed5588493ff4c86091f5bec49 Mon Sep 17 00:00:00 2001 From: kleeman Date: Thu, 27 May 2021 13:00:48 -0700 Subject: [PATCH] make_prediction approach --- cmake/common | 2 +- include/albatross/src/core/model.hpp | 6 +- include/albatross/src/core/prediction.hpp | 118 ++++++++++-------- include/albatross/src/core/traits.hpp | 22 ++-- .../src/evaluation/prediction_metrics.hpp | 9 ++ .../src/models/conditional_gaussian.hpp | 12 +- include/albatross/src/models/ransac_gp.hpp | 40 +++--- 7 files changed, 122 insertions(+), 87 deletions(-) diff --git a/cmake/common b/cmake/common index 3a6f1a72..57f1cbc4 160000 --- a/cmake/common +++ b/cmake/common @@ -1 +1 @@ -Subproject commit 3a6f1a7225c7f225ea3c651e1597caef897deddc +Subproject commit 57f1cbc446cee99ba0118955859907f6e4a1bf82 diff --git a/include/albatross/src/core/model.hpp b/include/albatross/src/core/model.hpp index 1b48c84b..a165d032 100644 --- a/include/albatross/src/core/model.hpp +++ b/include/albatross/src/core/model.hpp @@ -21,9 +21,9 @@ constexpr bool DEFAULT_USE_ASYNC = false; template class ModelBase : public ParameterHandlingMixin { - friend class JointPredictor; - friend class MarginalPredictor; - friend class MeanPredictor; + friend class detail::JointPredictor; + friend class detail::MarginalPredictor; + friend class detail::MeanPredictor; template friend class fit_model_type; diff --git a/include/albatross/src/core/prediction.hpp b/include/albatross/src/core/prediction.hpp index 724cca30..25e3eff3 100644 --- a/include/albatross/src/core/prediction.hpp +++ b/include/albatross/src/core/prediction.hpp @@ -19,6 +19,7 @@ namespace albatross { // which behave different conditional on the type of predictions desired. template struct PredictTypeIdentity { typedef T type; }; +namespace detail { /* * MeanPredictor is responsible for determining if a valid form of * predicting exists for a given set of model, feature, fit. The @@ -33,8 +34,8 @@ class MeanPredictor { typename std::enable_if< has_valid_predict_mean::value, int>::type = 0> - Eigen::VectorXd _mean(const ModelType &model, const FitType &fit, - const std::vector &features) const { + static Eigen::VectorXd _mean(const ModelType &model, const FitType &fit, + const std::vector &features) { return model.predict_(features, fit, PredictTypeIdentity()); } @@ -46,8 +47,8 @@ class MeanPredictor { has_valid_predict_marginal::value, int>::type = 0> - Eigen::VectorXd _mean(const ModelType &model, const FitType &fit, - const std::vector &features) const { + static Eigen::VectorXd _mean(const ModelType &model, const FitType &fit, + const std::vector &features) { return model .predict_(features, fit, PredictTypeIdentity()) .mean; @@ -61,8 +62,8 @@ class MeanPredictor { FitType>::value && has_valid_predict_joint::value, int>::type = 0> - Eigen::VectorXd _mean(const ModelType &model, const FitType &fit, - const std::vector &features) const { + static Eigen::VectorXd _mean(const ModelType &model, const FitType &fit, + const std::vector &features) { return model .predict_(features, fit, PredictTypeIdentity()) .mean; @@ -75,9 +76,9 @@ class MarginalPredictor { typename std::enable_if::value, int>::type = 0> - MarginalDistribution + static MarginalDistribution _marginal(const ModelType &model, const FitType &fit, - const std::vector &features) const { + const std::vector &features) { return model.predict_(features, fit, PredictTypeIdentity()); } @@ -88,9 +89,9 @@ class MarginalPredictor { !has_valid_predict_marginal::value && has_valid_predict_joint::value, int>::type = 0> - MarginalDistribution + static MarginalDistribution _marginal(const ModelType &model, const FitType &fit, - const std::vector &features) const { + const std::vector &features) { const auto joint_pred = model.predict_(features, fit, PredictTypeIdentity()); return joint_pred.marginal(); @@ -103,13 +104,56 @@ class JointPredictor { typename std::enable_if< has_valid_predict_joint::value, int>::type = 0> - JointDistribution _joint(const ModelType &model, const FitType &fit, - const std::vector &features) const { + static JointDistribution _joint(const ModelType &model, const FitType &fit, + const std::vector &features) { return model.predict_(features, fit, PredictTypeIdentity()); } }; +template < + typename ModelType, typename FeatureType, typename FitType, + typename std::enable_if::value, + int>::type = 0> +auto make_prediction(const ModelType &model, const FitType &fit, + const std::vector &features, + PredictTypeIdentity &&) { + return JointPredictor::_joint(model, fit, features); +} + +template < + typename ModelType, typename FeatureType, typename FitType, + typename std::enable_if::value, + int>::type = 0> +auto make_prediction(const ModelType &model, const FitType &fit, + const std::vector &features, + PredictTypeIdentity &&) { + return MarginalPredictor::_marginal(model, fit, features); +} + +template ::value, + int>::type = 0> +auto make_prediction(const ModelType &model, const FitType &fit, + const std::vector &features, + PredictTypeIdentity &&) { + return MeanPredictor::_mean(model, fit, features); +} +} // namespace detail + +template +auto make_prediction( + const ModelType &model, const FitType &fit, + const std::vector &features, + PredictTypeIdentity = PredictTypeIdentity()) { + return detail::make_prediction(model, fit, features, + PredictTypeIdentity()); +} + template class Prediction { @@ -126,62 +170,30 @@ class Prediction { : model_(std::move(model)), fit_(std::move(fit)), features_(features) {} // Mean - template ::value, - int>::type = 0> - Eigen::VectorXd mean() const { + template Eigen::VectorXd mean() const { static_assert(std::is_same::value, "never do prediction.mean()"); - return MeanPredictor()._mean(model_, fit_, features_); + return make_prediction(model_, fit_, features_, + PredictTypeIdentity()); } - template < - typename DummyType = FeatureType, - typename std::enable_if::value, - int>::type = 0> - Eigen::VectorXd mean() const = delete; // No valid predict method found. - // Marginal - template < - typename DummyType = FeatureType, - typename std::enable_if::value, - int>::type = 0> + template MarginalDistribution marginal() const { static_assert(std::is_same::value, "never do prediction.mean()"); - return MarginalPredictor()._marginal(model_, fit_, features_); + return make_prediction(model_, fit_, features_, + PredictTypeIdentity()); } - template ::value, - int>::type = 0> - MarginalDistribution - marginal() const = delete; // No valid predict method found. - // Joint - template < - typename DummyType = FeatureType, - typename std::enable_if::value, - int>::type = 0> - JointDistribution joint() const { + template JointDistribution joint() const { static_assert(std::is_same::value, "never do prediction.mean()"); - return JointPredictor()._joint(model_, fit_, features_); + return make_prediction(model_, fit_, features_, + PredictTypeIdentity()); } - template < - typename DummyType = FeatureType, - typename std::enable_if::value, - int>::type = 0> - JointDistribution joint() const = delete; // No valid predict method found. - template PredictType get(PredictTypeIdentity = PredictTypeIdentity()) const { diff --git a/include/albatross/src/core/traits.hpp b/include/albatross/src/core/traits.hpp index 3a6841da..520032c9 100644 --- a/include/albatross/src/core/traits.hpp +++ b/include/albatross/src/core/traits.hpp @@ -148,28 +148,30 @@ DEFINE_CLASS_METHOD_TRAITS(_mean); template struct can_predict_mean - : public has__mean::type, - typename const_ref::type, - typename const_ref>::type> {}; + : public has__mean_with_return_type< + T, Eigen::VectorXd, typename const_ref::type, + typename const_ref::type, + typename const_ref>::type> {}; DEFINE_CLASS_METHOD_TRAITS(_marginal); template struct can_predict_marginal - : public has__marginal::type, - typename const_ref::type, - typename const_ref>::type> { -}; + : public has__marginal_with_return_type< + T, MarginalDistribution, typename const_ref::type, + typename const_ref::type, + typename const_ref>::type> {}; DEFINE_CLASS_METHOD_TRAITS(_joint); template struct can_predict_joint - : public has__joint::type, - typename const_ref::type, - typename const_ref>::type> {}; + : public has__joint_with_return_type< + T, JointDistribution, typename const_ref::type, + typename const_ref::type, + typename const_ref>::type> {}; /* * Methods for inspecting `Prediction` types. diff --git a/include/albatross/src/evaluation/prediction_metrics.hpp b/include/albatross/src/evaluation/prediction_metrics.hpp index 74226caf..a9e81e3f 100644 --- a/include/albatross/src/evaluation/prediction_metrics.hpp +++ b/include/albatross/src/evaluation/prediction_metrics.hpp @@ -46,6 +46,15 @@ template struct PredictionMetric { return eval_(prediction, truth); } + template + double operator()(const ModelType &model, const FitType &fit, + const std::vector &features, + const MarginalDistribution &truth) const { + return (*this)(make_prediction(model, fit, features, + PredictTypeIdentity()), + truth); + } + template double operator()(const Prediction &prediction, diff --git a/include/albatross/src/models/conditional_gaussian.hpp b/include/albatross/src/models/conditional_gaussian.hpp index 4d3d12a5..757141d7 100644 --- a/include/albatross/src/models/conditional_gaussian.hpp +++ b/include/albatross/src/models/conditional_gaussian.hpp @@ -23,9 +23,17 @@ struct ConditionalFit { class ConditionalGaussian : public ModelBase { public: + ConditionalGaussian(JointDistribution &&prior, + const MarginalDistribution &truth) + : prior_(std::move(prior)), truth_(truth) { + std::cout << "FROM RVALUE" << std::endl; + } + ConditionalGaussian(const JointDistribution &prior, const MarginalDistribution &truth) - : prior_(prior), truth_(truth) {} + : prior_(prior), truth_(truth) { + std::cout << "COPYING" << std::endl; + } ConditionalFit fit_from_indices(const GroupIndices &indices) const { @@ -71,7 +79,7 @@ class ConditionalGaussian : public ModelBase { cross, predict_prior.covariance, fit.information, fit.cov_ldlt); conditional_pred.mean += predict_prior.mean; - return conditional_pred; + return std::move(conditional_pred); } MarginalDistribution diff --git a/include/albatross/src/models/ransac_gp.hpp b/include/albatross/src/models/ransac_gp.hpp index 8c55fc81..44848446 100644 --- a/include/albatross/src/models/ransac_gp.hpp +++ b/include/albatross/src/models/ransac_gp.hpp @@ -17,53 +17,54 @@ namespace albatross { template inline typename RansacFunctions::FitterFunc -get_gp_ransac_fitter(const ConditionalGaussian &model, +get_gp_ransac_fitter(const std::shared_ptr &model, const GroupIndexer &indexer) { return [&, model, indexer](const std::vector &groups) { auto indices = indices_from_groups(indexer, groups); - return model.fit_from_indices(indices); + return model->fit_from_indices(indices); }; } template inline typename RansacFunctions::IsValidCandidate -get_gp_ransac_is_valid_candidate(const ConditionalGaussian &model, - const GroupIndexer &indexer, - const IsValidCandidateMetric &metric) { +get_gp_ransac_is_valid_candidate( + const std::shared_ptr &model, + const GroupIndexer &indexer, + const IsValidCandidateMetric &metric) { return [&, model, indexer](const std::vector &groups) { const auto indices = indices_from_groups(indexer, groups); - const auto prior = model.get_prior(indices); - const auto truth = model.get_truth(indices); + const auto prior = model->get_prior(indices); + const auto truth = model->get_truth(indices); return metric(prior, truth); }; } template inline typename RansacFunctions::InlierMetric -get_gp_ransac_inlier_metric(const ConditionalGaussian &model, +get_gp_ransac_inlier_metric(const std::shared_ptr &model, const GroupIndexer &indexer, const InlierMetricType &metric) { return [&, indexer, model](const GroupKey &group, const ConditionalFit &fit) { const auto indices = indexer.at(group); - const auto pred = get_prediction_reference(model, fit, indices); - const auto truth = model.get_truth(indices); + const auto pred = get_prediction_reference(*model, fit, indices); + const auto truth = model->get_truth(indices); return metric(pred, truth); }; } template inline typename RansacFunctions::ConsensusMetric -get_gp_ransac_consensus_metric(const ConditionalGaussian &model, - const GroupIndexer &indexer, - const ConsensusMetric &metric) { +get_gp_ransac_consensus_metric( + const std::shared_ptr &model, + const GroupIndexer &indexer, const ConsensusMetric &metric) { return [&, model, indexer](const std::vector &groups) { const auto indices = indices_from_groups(indexer, groups); - const auto prior = model.get_prior(indices); - const auto truth = model.get_truth(indices); + const auto prior = model->get_prior(indices); + const auto truth = model->get_truth(indices); return metric(prior, truth); }; } @@ -109,9 +110,10 @@ struct AlwaysAcceptCandidateMetric { }; template + typename IsValidCandidateMetric, typename GroupKey, + typename PriorDistribution> inline RansacFunctions get_gp_ransac_functions( - const JointDistribution &prior, const MarginalDistribution &truth, + PriorDistribution &&prior, const MarginalDistribution &truth, const GroupIndexer &indexer, const InlierMetric &inlier_metric, const ConsensusMetric &consensus_metric, const IsValidCandidateMetric &is_valid_candidate_metric) { @@ -119,7 +121,9 @@ inline RansacFunctions get_gp_ransac_functions( static_assert(is_prediction_metric::value, "InlierMetric must be an PredictionMetric."); - const ConditionalGaussian model(prior, truth); + const std::shared_ptr model = + std::make_shared( + std::forward(prior), truth); const auto fitter = get_gp_ransac_fitter(model, indexer);