Skip to content

Commit

Permalink
At least 2 features are chosen in subcolumn (#2409)
Browse files Browse the repository at this point in the history
* at least 2 features are chosen in subcolumn

* Update serial_tree_learner.cpp

* rounding
  • Loading branch information
guolinke authored Sep 19, 2019
1 parent a119639 commit a3a353d
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions src/treelearner/serial_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,10 @@ std::vector<int8_t> SerialTreeLearner::GetUsedFeatures(bool is_tree_level) {
return ret;
}
std::memset(ret.data(), 0, sizeof(int8_t) * num_features_);
const int min_used_features = std::min(2, static_cast<int>(valid_feature_indices_.size()));
if (is_tree_level) {
int used_feature_cnt = static_cast<int>(valid_feature_indices_.size() * config_->feature_fraction);
used_feature_cnt = std::max(used_feature_cnt, 1);
int used_feature_cnt = static_cast<int>(std::round(valid_feature_indices_.size() * config_->feature_fraction));
used_feature_cnt = std::max(used_feature_cnt, min_used_features);
used_feature_indices_ = random_.Sample(static_cast<int>(valid_feature_indices_.size()), used_feature_cnt);
int omp_loop_size = static_cast<int>(used_feature_indices_.size());
#pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024)
Expand All @@ -290,8 +291,8 @@ std::vector<int8_t> SerialTreeLearner::GetUsedFeatures(bool is_tree_level) {
ret[inner_feature_index] = 1;
}
} else if(used_feature_indices_.size() <= 0) {
int used_feature_cnt = static_cast<int>(valid_feature_indices_.size() * config_->feature_fraction_bynode);
used_feature_cnt = std::max(used_feature_cnt, 1);
int used_feature_cnt = static_cast<int>(std::round(valid_feature_indices_.size() * config_->feature_fraction_bynode));
used_feature_cnt = std::max(used_feature_cnt, min_used_features);
auto sampled_indices = random_.Sample(static_cast<int>(valid_feature_indices_.size()), used_feature_cnt);
int omp_loop_size = static_cast<int>(sampled_indices.size());
#pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024)
Expand All @@ -302,8 +303,8 @@ std::vector<int8_t> SerialTreeLearner::GetUsedFeatures(bool is_tree_level) {
ret[inner_feature_index] = 1;
}
} else {
int used_feature_cnt = static_cast<int>(used_feature_indices_.size() * config_->feature_fraction_bynode);
used_feature_cnt = std::max(used_feature_cnt, 1);
int used_feature_cnt = static_cast<int>(std::round(used_feature_indices_.size() * config_->feature_fraction_bynode));
used_feature_cnt = std::max(used_feature_cnt, min_used_features);
auto sampled_indices = random_.Sample(static_cast<int>(used_feature_indices_.size()), used_feature_cnt);
int omp_loop_size = static_cast<int>(sampled_indices.size());
#pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024)
Expand Down

0 comments on commit a3a353d

Please sign in to comment.