Skip to content

Commit 348403b

Browse files
Fixed incorrect clipping of label data
1 parent 5ffc2f6 commit 348403b

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

graphium/graphium_cpp/labels.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,7 +1421,7 @@ static void save_label_data(
14211421

14221422
// temp_data is used for normalization
14231423
std::vector<char> temp_data;
1424-
temp_data.reserve(total_num_cols*sizeof(double));
1424+
temp_data.resize(total_num_cols*sizeof(double));
14251425

14261426
std::vector<char> data;
14271427
data.reserve(num_mols_per_file*(total_num_cols*sizeof(double) + (1+2*num_tasks)*sizeof(uint64_t)));
@@ -1517,6 +1517,7 @@ static void save_label_data(
15171517
const size_t in_bytes_per_float,
15181518
const size_t out_bytes_per_float,
15191519
const NormalizationMethod normalization_method,
1520+
const bool do_clipping,
15201521
const double* task_stats) {
15211522

15221523
if (size_t(col_stride) == in_bytes_per_float) {
@@ -1540,8 +1541,10 @@ static void save_label_data(
15401541
assert(in_bytes_per_float == sizeof(uint16_t));
15411542
value = c10::detail::fp16_ieee_to_fp32_value(((const uint16_t*)(temp_data.data()))[col]);
15421543
}
1543-
value = std::max(value, task_stats[stat_min_offset]);
1544-
value = std::min(value, task_stats[stat_max_offset]);
1544+
if (do_clipping) {
1545+
value = std::max(value, task_stats[stat_min_offset]);
1546+
value = std::min(value, task_stats[stat_max_offset]);
1547+
}
15451548
if (normalization_method == NormalizationMethod::NORMAL) {
15461549
if (task_stats[stat_std_offset] != 0) {
15471550
value = (value - task_stats[stat_mean_offset])/task_stats[stat_std_offset];
@@ -1599,6 +1602,9 @@ static void save_label_data(
15991602
const size_t task_first_col = task_col_starts[task_index];
16001603
const size_t task_num_cols = task_col_starts[task_index+1] - task_first_col;
16011604
const NormalizationOptions& normalization = task_normalization_options[task_index];
1605+
const bool do_clipping =
1606+
(normalization.min_clipping > -std::numeric_limits<double>::infinity()) &&
1607+
(normalization.max_clipping < std::numeric_limits<double>::infinity());
16021608
const double* task_stats = all_task_stats + num_stats*task_first_col;
16031609

16041610
const size_t bytes_per_float = task_bytes_per_float[task_index];
@@ -1703,36 +1709,36 @@ static void save_label_data(
17031709
const intptr_t offsets_stride = label_offsets_numpy_array ? PyArray_STRIDES(label_offsets_numpy_array)[0] : 0;
17041710
if (offsets_raw_data == nullptr) {
17051711
const char* row_data = raw_data + strides[0]*task_mol_index;
1706-
store_single_row(row_data, task_num_cols, strides[1], bytes_per_float, bytes_per_float, normalization.method, task_stats);
1712+
store_single_row(row_data, task_num_cols, strides[1], bytes_per_float, bytes_per_float, normalization.method, do_clipping, task_stats);
17071713
}
17081714
else {
17091715
size_t begin_offset = *reinterpret_cast<const int64_t*>(offsets_raw_data + offsets_stride*task_mol_index);
17101716
size_t end_offset = *reinterpret_cast<const int64_t*>(offsets_raw_data + offsets_stride*(task_mol_index+1));
17111717
const char* row_data = raw_data + strides[0]*begin_offset;
17121718
if (same_order_as_first) {
17131719
for (size_t row = begin_offset; row < end_offset; ++row, row_data += strides[0]) {
1714-
store_single_row(row_data, task_num_cols, strides[1], bytes_per_float, bytes_per_float, normalization.method, task_stats);
1720+
store_single_row(row_data, task_num_cols, strides[1], bytes_per_float, bytes_per_float, normalization.method, do_clipping, task_stats);
17151721
}
17161722
}
17171723
else if (task_levels[task_index] == FeatureLevel::NODE) {
17181724
assert(end_offset - begin_offset == current_atom_order.size());
17191725
for (unsigned int current_index : current_atom_order) {
1720-
store_single_row(row_data + current_index*strides[0], task_num_cols, strides[1], bytes_per_float, bytes_per_float, normalization.method, task_stats);
1726+
store_single_row(row_data + current_index*strides[0], task_num_cols, strides[1], bytes_per_float, bytes_per_float, normalization.method, do_clipping, task_stats);
17211727
}
17221728
}
17231729
else if (task_levels[task_index] == FeatureLevel::NODEPAIR) {
17241730
const size_t n = current_atom_order.size();
17251731
assert(end_offset - begin_offset == n*n);
17261732
for (unsigned int current_index0 : current_atom_order) {
17271733
for (unsigned int current_index1 : current_atom_order) {
1728-
store_single_row(row_data + (current_index0*n + current_index1)*strides[0], task_num_cols, strides[1], bytes_per_float, bytes_per_float, normalization.method, task_stats);
1734+
store_single_row(row_data + (current_index0*n + current_index1)*strides[0], task_num_cols, strides[1], bytes_per_float, bytes_per_float, normalization.method, do_clipping, task_stats);
17291735
}
17301736
}
17311737
}
17321738
else {
17331739
assert(task_levels[task_index] == FeatureLevel::EDGE);
17341740
for (unsigned int current_index : current_bond_order) {
1735-
store_single_row(row_data + current_index*strides[0], task_num_cols, strides[1], bytes_per_float, bytes_per_float, normalization.method, task_stats);
1741+
store_single_row(row_data + current_index*strides[0], task_num_cols, strides[1], bytes_per_float, bytes_per_float, normalization.method, do_clipping, task_stats);
17361742
}
17371743
}
17381744
}

0 commit comments

Comments
 (0)