Skip to content

Commit cd36ffe

Browse files
authored
[R-package] Fix inefficiency in retrieving pointers (#6208)
1 parent 516bde9 commit cd36ffe

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

Diff for: R-package/src/lightgbm_R.cpp

+16-8
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,10 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle,
226226
int32_t len = static_cast<int32_t>(Rf_asInteger(len_used_row_indices));
227227
std::vector<int32_t> idxvec(len);
228228
// convert from one-based to zero-based index
229+
const int *used_row_indices_ = INTEGER(used_row_indices);
229230
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (len >= 1024)
230231
for (int32_t i = 0; i < len; ++i) {
231-
idxvec[i] = static_cast<int32_t>(INTEGER(used_row_indices)[i] - 1);
232+
idxvec[i] = static_cast<int32_t>(used_row_indices_[i] - 1);
232233
}
233234
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
234235
DatasetHandle res = nullptr;
@@ -339,18 +340,20 @@ SEXP LGBM_DatasetSetField_R(SEXP handle,
339340
const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
340341
if (!strcmp("group", name) || !strcmp("query", name)) {
341342
std::vector<int32_t> vec(len);
343+
const int *field_data_ = INTEGER(field_data);
342344
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (len >= 1024)
343345
for (int i = 0; i < len; ++i) {
344-
vec[i] = static_cast<int32_t>(INTEGER(field_data)[i]);
346+
vec[i] = static_cast<int32_t>(field_data_[i]);
345347
}
346348
CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, vec.data(), len, C_API_DTYPE_INT32));
347349
} else if (!strcmp("init_score", name)) {
348350
CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, REAL(field_data), len, C_API_DTYPE_FLOAT64));
349351
} else {
350352
std::vector<float> vec(len);
353+
const double *field_data_ = REAL(field_data);
351354
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (len >= 1024)
352355
for (int i = 0; i < len; ++i) {
353-
vec[i] = static_cast<float>(REAL(field_data)[i]);
356+
vec[i] = static_cast<float>(field_data_[i]);
354357
}
355358
CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32));
356359
}
@@ -372,21 +375,24 @@ SEXP LGBM_DatasetGetField_R(SEXP handle,
372375
if (!strcmp("group", name) || !strcmp("query", name)) {
373376
auto p_data = reinterpret_cast<const int32_t*>(res);
374377
// convert from boundaries to size
378+
int *field_data_ = INTEGER(field_data);
375379
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (out_len >= 1024)
376380
for (int i = 0; i < out_len - 1; ++i) {
377-
INTEGER(field_data)[i] = p_data[i + 1] - p_data[i];
381+
field_data_[i] = p_data[i + 1] - p_data[i];
378382
}
379383
} else if (!strcmp("init_score", name)) {
380384
auto p_data = reinterpret_cast<const double*>(res);
385+
double *field_data_ = REAL(field_data);
381386
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (out_len >= 1024)
382387
for (int i = 0; i < out_len; ++i) {
383-
REAL(field_data)[i] = p_data[i];
388+
field_data_[i] = p_data[i];
384389
}
385390
} else {
386391
auto p_data = reinterpret_cast<const float*>(res);
392+
double *field_data_ = REAL(field_data);
387393
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (out_len >= 1024)
388394
for (int i = 0; i < out_len; ++i) {
389-
REAL(field_data)[i] = p_data[i];
395+
field_data_[i] = p_data[i];
390396
}
391397
}
392398
UNPROTECT(1);
@@ -611,10 +617,12 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
611617
int is_finished = 0;
612618
int int_len = Rf_asInteger(len);
613619
std::vector<float> tgrad(int_len), thess(int_len);
620+
const double *grad_ = REAL(grad);
621+
const double *hess_ = REAL(hess);
614622
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (int_len >= 1024)
615623
for (int j = 0; j < int_len; ++j) {
616-
tgrad[j] = static_cast<float>(REAL(grad)[j]);
617-
thess[j] = static_cast<float>(REAL(hess)[j]);
624+
tgrad[j] = static_cast<float>(grad_[j]);
625+
thess[j] = static_cast<float>(hess_[j]);
618626
}
619627
CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_ExternalPtrAddr(handle), tgrad.data(), thess.data(), &is_finished));
620628
return R_NilValue;

0 commit comments

Comments
 (0)