Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Prediction #290

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/common
6 changes: 3 additions & 3 deletions include/albatross/src/core/model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ constexpr bool DEFAULT_USE_ASYNC = false;

template <typename ModelType> class ModelBase : public ParameterHandlingMixin {

friend class JointPredictor;
friend class MarginalPredictor;
friend class MeanPredictor;
friend class detail::JointPredictor;
friend class detail::MarginalPredictor;
friend class detail::MeanPredictor;

template <typename T, typename FeatureType> friend class fit_model_type;

Expand Down
118 changes: 65 additions & 53 deletions include/albatross/src/core/prediction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace albatross {
// which behave different conditional on the type of predictions desired.
template <typename T> struct PredictTypeIdentity { typedef T type; };

namespace detail {
/*
* MeanPredictor is responsible for determining if a valid form of
* predicting exists for a given set of model, feature, fit. The
Expand All @@ -33,8 +34,8 @@ class MeanPredictor {
typename std::enable_if<
has_valid_predict_mean<ModelType, FeatureType, FitType>::value,
int>::type = 0>
Eigen::VectorXd _mean(const ModelType &model, const FitType &fit,
const std::vector<FeatureType> &features) const {
static Eigen::VectorXd _mean(const ModelType &model, const FitType &fit,
const std::vector<FeatureType> &features) {
return model.predict_(features, fit,
PredictTypeIdentity<Eigen::VectorXd>());
}
Expand All @@ -46,8 +47,8 @@ class MeanPredictor {
has_valid_predict_marginal<ModelType, FeatureType,
FitType>::value,
int>::type = 0>
Eigen::VectorXd _mean(const ModelType &model, const FitType &fit,
const std::vector<FeatureType> &features) const {
static Eigen::VectorXd _mean(const ModelType &model, const FitType &fit,
const std::vector<FeatureType> &features) {
return model
.predict_(features, fit, PredictTypeIdentity<MarginalDistribution>())
.mean;
Expand All @@ -61,8 +62,8 @@ class MeanPredictor {
FitType>::value &&
has_valid_predict_joint<ModelType, FeatureType, FitType>::value,
int>::type = 0>
Eigen::VectorXd _mean(const ModelType &model, const FitType &fit,
const std::vector<FeatureType> &features) const {
static Eigen::VectorXd _mean(const ModelType &model, const FitType &fit,
const std::vector<FeatureType> &features) {
return model
.predict_(features, fit, PredictTypeIdentity<JointDistribution>())
.mean;
Expand All @@ -75,9 +76,9 @@ class MarginalPredictor {
typename std::enable_if<has_valid_predict_marginal<
ModelType, FeatureType, FitType>::value,
int>::type = 0>
MarginalDistribution
static MarginalDistribution
_marginal(const ModelType &model, const FitType &fit,
const std::vector<FeatureType> &features) const {
const std::vector<FeatureType> &features) {
return model.predict_(features, fit,
PredictTypeIdentity<MarginalDistribution>());
}
Expand All @@ -88,9 +89,9 @@ class MarginalPredictor {
!has_valid_predict_marginal<ModelType, FeatureType, FitType>::value &&
has_valid_predict_joint<ModelType, FeatureType, FitType>::value,
int>::type = 0>
MarginalDistribution
static MarginalDistribution
_marginal(const ModelType &model, const FitType &fit,
const std::vector<FeatureType> &features) const {
const std::vector<FeatureType> &features) {
const auto joint_pred =
model.predict_(features, fit, PredictTypeIdentity<JointDistribution>());
return joint_pred.marginal();
Expand All @@ -103,13 +104,56 @@ class JointPredictor {
typename std::enable_if<
has_valid_predict_joint<ModelType, FeatureType, FitType>::value,
int>::type = 0>
JointDistribution _joint(const ModelType &model, const FitType &fit,
const std::vector<FeatureType> &features) const {
static JointDistribution _joint(const ModelType &model, const FitType &fit,
const std::vector<FeatureType> &features) {
return model.predict_(features, fit,
PredictTypeIdentity<JointDistribution>());
}
};

template <
typename ModelType, typename FeatureType, typename FitType,
typename std::enable_if<can_predict_joint<JointPredictor, ModelType,
FeatureType, FitType>::value,
int>::type = 0>
auto make_prediction(const ModelType &model, const FitType &fit,
const std::vector<FeatureType> &features,
PredictTypeIdentity<JointDistribution> &&) {
return JointPredictor::_joint(model, fit, features);
}

template <
typename ModelType, typename FeatureType, typename FitType,
typename std::enable_if<can_predict_marginal<MarginalPredictor, ModelType,
FeatureType, FitType>::value,
int>::type = 0>
auto make_prediction(const ModelType &model, const FitType &fit,
const std::vector<FeatureType> &features,
PredictTypeIdentity<MarginalDistribution> &&) {
return MarginalPredictor::_marginal(model, fit, features);
}

template <typename ModelType, typename FeatureType, typename FitType,
typename std::enable_if<can_predict_mean<MeanPredictor, ModelType,
FeatureType, FitType>::value,
int>::type = 0>
auto make_prediction(const ModelType &model, const FitType &fit,
const std::vector<FeatureType> &features,
PredictTypeIdentity<Eigen::VectorXd> &&) {
return MeanPredictor::_mean(model, fit, features);
}
} // namespace detail

template <typename PredictType, typename ModelType, typename FeatureType,
typename FitType>
auto make_prediction(
const ModelType &model, const FitType &fit,
const std::vector<FeatureType> &features,
PredictTypeIdentity<PredictType> = PredictTypeIdentity<PredictType>()) {
return detail::make_prediction(model, fit, features,
PredictTypeIdentity<PredictType>());
}

template <typename ModelType, typename FeatureType, typename FitType>
class Prediction {

Expand All @@ -126,62 +170,30 @@ class Prediction {
: model_(std::move(model)), fit_(std::move(fit)), features_(features) {}

// Mean
template <typename DummyType = FeatureType,
typename std::enable_if<can_predict_mean<MeanPredictor, ModelType,
DummyType, FitType>::value,
int>::type = 0>
Eigen::VectorXd mean() const {
template <typename DummyType = FeatureType> Eigen::VectorXd mean() const {
static_assert(std::is_same<DummyType, FeatureType>::value,
"never do prediction.mean<T>()");
return MeanPredictor()._mean(model_, fit_, features_);
return make_prediction(model_, fit_, features_,
PredictTypeIdentity<Eigen::VectorXd>());
}

template <
typename DummyType = FeatureType,
typename std::enable_if<!can_predict_mean<MeanPredictor, ModelType,
DummyType, FitType>::value,
int>::type = 0>
Eigen::VectorXd mean() const = delete; // No valid predict method found.

// Marginal
template <
typename DummyType = FeatureType,
typename std::enable_if<can_predict_marginal<MarginalPredictor, ModelType,
DummyType, FitType>::value,
int>::type = 0>
template <typename DummyType = FeatureType>
MarginalDistribution marginal() const {
static_assert(std::is_same<DummyType, FeatureType>::value,
"never do prediction.mean<T>()");
return MarginalPredictor()._marginal(model_, fit_, features_);
return make_prediction(model_, fit_, features_,
PredictTypeIdentity<MarginalDistribution>());
}

template <typename DummyType = FeatureType,
typename std::enable_if<
!can_predict_marginal<MarginalPredictor, ModelType, DummyType,
FitType>::value,
int>::type = 0>
MarginalDistribution
marginal() const = delete; // No valid predict method found.

// Joint
template <
typename DummyType = FeatureType,
typename std::enable_if<can_predict_joint<JointPredictor, ModelType,
DummyType, FitType>::value,
int>::type = 0>
JointDistribution joint() const {
template <typename DummyType = FeatureType> JointDistribution joint() const {
static_assert(std::is_same<DummyType, FeatureType>::value,
"never do prediction.mean<T>()");
return JointPredictor()._joint(model_, fit_, features_);
return make_prediction(model_, fit_, features_,
PredictTypeIdentity<JointDistribution>());
}

template <
typename DummyType = FeatureType,
typename std::enable_if<!can_predict_joint<JointPredictor, ModelType,
DummyType, FitType>::value,
int>::type = 0>
JointDistribution joint() const = delete; // No valid predict method found.

template <typename PredictType>
PredictType get(PredictTypeIdentity<PredictType> =
PredictTypeIdentity<PredictType>()) const {
Expand Down
22 changes: 12 additions & 10 deletions include/albatross/src/core/traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,28 +148,30 @@ DEFINE_CLASS_METHOD_TRAITS(_mean);
template <typename T, typename ModelType, typename FeatureType,
typename FitType>
struct can_predict_mean
: public has__mean<T, typename const_ref<ModelType>::type,
typename const_ref<FitType>::type,
typename const_ref<std::vector<FeatureType>>::type> {};
: public has__mean_with_return_type<
T, Eigen::VectorXd, typename const_ref<ModelType>::type,
typename const_ref<FitType>::type,
typename const_ref<std::vector<FeatureType>>::type> {};

DEFINE_CLASS_METHOD_TRAITS(_marginal);

template <typename T, typename ModelType, typename FeatureType,
typename FitType>
struct can_predict_marginal
: public has__marginal<T, typename const_ref<ModelType>::type,
typename const_ref<FitType>::type,
typename const_ref<std::vector<FeatureType>>::type> {
};
: public has__marginal_with_return_type<
T, MarginalDistribution, typename const_ref<ModelType>::type,
typename const_ref<FitType>::type,
typename const_ref<std::vector<FeatureType>>::type> {};

DEFINE_CLASS_METHOD_TRAITS(_joint);

template <typename T, typename ModelType, typename FeatureType,
typename FitType>
struct can_predict_joint
: public has__joint<T, typename const_ref<ModelType>::type,
typename const_ref<FitType>::type,
typename const_ref<std::vector<FeatureType>>::type> {};
: public has__joint_with_return_type<
T, JointDistribution, typename const_ref<ModelType>::type,
typename const_ref<FitType>::type,
typename const_ref<std::vector<FeatureType>>::type> {};

/*
* Methods for inspecting `Prediction` types.
Expand Down
9 changes: 9 additions & 0 deletions include/albatross/src/evaluation/prediction_metrics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ template <typename RequiredPredictType> struct PredictionMetric {
return eval_(prediction, truth);
}

template <typename ModelType, typename FeatureType, typename FitType>
double operator()(const ModelType &model, const FitType &fit,
const std::vector<FeatureType> &features,
const MarginalDistribution &truth) const {
return (*this)(make_prediction(model, fit, features,
PredictTypeIdentity<RequiredPredictType>()),
truth);
}

template <typename ModelType, typename FeatureType, typename FitType>
double
operator()(const Prediction<ModelType, FeatureType, FitType> &prediction,
Expand Down
12 changes: 10 additions & 2 deletions include/albatross/src/models/conditional_gaussian.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,17 @@ struct ConditionalFit {
class ConditionalGaussian : public ModelBase<ConditionalGaussian> {

public:
ConditionalGaussian(JointDistribution &&prior,
const MarginalDistribution &truth)
: prior_(std::move(prior)), truth_(truth) {
std::cout << "FROM RVALUE" << std::endl;
}

ConditionalGaussian(const JointDistribution &prior,
const MarginalDistribution &truth)
: prior_(prior), truth_(truth) {}
: prior_(prior), truth_(truth) {
std::cout << "COPYING" << std::endl;
}

ConditionalFit fit_from_indices(const GroupIndices &indices) const {

Expand Down Expand Up @@ -71,7 +79,7 @@ class ConditionalGaussian : public ModelBase<ConditionalGaussian> {
cross, predict_prior.covariance, fit.information, fit.cov_ldlt);

conditional_pred.mean += predict_prior.mean;
return conditional_pred;
return std::move(conditional_pred);
}

MarginalDistribution
Expand Down
40 changes: 22 additions & 18 deletions include/albatross/src/models/ransac_gp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,53 +17,54 @@ namespace albatross {

template <typename GroupKey>
inline typename RansacFunctions<ConditionalFit, GroupKey>::FitterFunc
get_gp_ransac_fitter(const ConditionalGaussian &model,
get_gp_ransac_fitter(const std::shared_ptr<ConditionalGaussian> &model,
const GroupIndexer<GroupKey> &indexer) {

return [&, model, indexer](const std::vector<GroupKey> &groups) {
auto indices = indices_from_groups(indexer, groups);
return model.fit_from_indices(indices);
return model->fit_from_indices(indices);
};
}

template <typename IsValidCandidateMetric, typename GroupKey>
inline typename RansacFunctions<ConditionalFit, GroupKey>::IsValidCandidate
get_gp_ransac_is_valid_candidate(const ConditionalGaussian &model,
const GroupIndexer<GroupKey> &indexer,
const IsValidCandidateMetric &metric) {
get_gp_ransac_is_valid_candidate(
const std::shared_ptr<ConditionalGaussian> &model,
const GroupIndexer<GroupKey> &indexer,
const IsValidCandidateMetric &metric) {

return [&, model, indexer](const std::vector<GroupKey> &groups) {
const auto indices = indices_from_groups(indexer, groups);
const auto prior = model.get_prior(indices);
const auto truth = model.get_truth(indices);
const auto prior = model->get_prior(indices);
const auto truth = model->get_truth(indices);
return metric(prior, truth);
};
}

template <typename InlierMetricType, typename GroupKey>
inline typename RansacFunctions<ConditionalFit, GroupKey>::InlierMetric
get_gp_ransac_inlier_metric(const ConditionalGaussian &model,
get_gp_ransac_inlier_metric(const std::shared_ptr<ConditionalGaussian> &model,
const GroupIndexer<GroupKey> &indexer,
const InlierMetricType &metric) {

return [&, indexer, model](const GroupKey &group, const ConditionalFit &fit) {
const auto indices = indexer.at(group);
const auto pred = get_prediction_reference(model, fit, indices);
const auto truth = model.get_truth(indices);
const auto pred = get_prediction_reference(*model, fit, indices);
const auto truth = model->get_truth(indices);
return metric(pred, truth);
};
}

template <typename ConsensusMetric, typename GroupKey>
inline typename RansacFunctions<ConditionalFit, GroupKey>::ConsensusMetric
get_gp_ransac_consensus_metric(const ConditionalGaussian &model,
const GroupIndexer<GroupKey> &indexer,
const ConsensusMetric &metric) {
get_gp_ransac_consensus_metric(
const std::shared_ptr<ConditionalGaussian> &model,
const GroupIndexer<GroupKey> &indexer, const ConsensusMetric &metric) {

return [&, model, indexer](const std::vector<GroupKey> &groups) {
const auto indices = indices_from_groups(indexer, groups);
const auto prior = model.get_prior(indices);
const auto truth = model.get_truth(indices);
const auto prior = model->get_prior(indices);
const auto truth = model->get_truth(indices);
return metric(prior, truth);
};
}
Expand Down Expand Up @@ -109,17 +110,20 @@ struct AlwaysAcceptCandidateMetric {
};

template <typename InlierMetric, typename ConsensusMetric,
typename IsValidCandidateMetric, typename GroupKey>
typename IsValidCandidateMetric, typename GroupKey,
typename PriorDistribution>
inline RansacFunctions<ConditionalFit, GroupKey> get_gp_ransac_functions(
const JointDistribution &prior, const MarginalDistribution &truth,
PriorDistribution &&prior, const MarginalDistribution &truth,
const GroupIndexer<GroupKey> &indexer, const InlierMetric &inlier_metric,
const ConsensusMetric &consensus_metric,
const IsValidCandidateMetric &is_valid_candidate_metric) {

static_assert(is_prediction_metric<InlierMetric>::value,
"InlierMetric must be an PredictionMetric.");

const ConditionalGaussian model(prior, truth);
const std::shared_ptr<ConditionalGaussian> model =
std::make_shared<ConditionalGaussian>(
std::forward<PriorDistribution>(prior), truth);

const auto fitter = get_gp_ransac_fitter<GroupKey>(model, indexer);

Expand Down