Skip to content

Commit eb116fb

Browse files
committed
Add scalar soft-to-hard conversion helper
Signed-off-by: Melody Ren <melodyr@nvidia.com>
1 parent eb67075 commit eb116fb

6 files changed

Lines changed: 43 additions & 19 deletions

File tree

libs/qec/include/cudaq/qec/decoder.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,16 @@ class decoder
282282
std::vector<std::vector<uint32_t>> D_sparse;
283283
};
284284

285+
/// @brief Convert a single soft probability to a hard 0/1 decision.
286+
/// @param in Soft probability input in range [0.0, 1.0]
287+
/// @param thresh Values >= thresh return true; all others return false.
288+
template <typename t_soft,
289+
typename std::enable_if<std::is_floating_point<t_soft>::value,
290+
int>::type = 0>
291+
constexpr inline bool convert_soft_to_hard(t_soft in, t_soft thresh = 0.5) {
292+
return in >= thresh;
293+
}
294+
285295
/// @brief Convert a vector of soft probabilities to a vector of hard
286296
/// probabilities.
287297
/// @param in Soft probability input vector in range [0.0, 1.0]
@@ -298,7 +308,7 @@ inline void convert_vec_soft_to_hard(const std::vector<t_soft> &in,
298308
t_soft thresh = 0.5) {
299309
out.resize(in.size());
300310
for (std::size_t i = 0; i < in.size(); i++)
301-
out[i] = static_cast<t_hard>(in[i] >= thresh ? 1 : 0);
311+
out[i] = static_cast<t_hard>(convert_soft_to_hard(in[i], thresh));
302312
}
303313

304314
/// @brief Convert a vector of soft probabilities to a tensor<uint8_t> of hard
@@ -326,7 +336,7 @@ inline void convert_vec_soft_to_tensor_hard(const std::vector<t_soft> &in,
326336
"Vector to tensor conversion requires tensor dim == vector length");
327337
auto raw_ptr = out.data();
328338
for (size_t i = 0; i < in.size(); ++i)
329-
raw_ptr[i] = static_cast<t_hard>(in[i] >= thresh ? 1 : 0);
339+
raw_ptr[i] = static_cast<t_hard>(convert_soft_to_hard(in[i], thresh));
330340
}
331341

332342
/// @brief Convert a vector of hard probabilities to a vector of soft
@@ -392,7 +402,7 @@ inline void convert_vec_soft_to_hard(const std::vector<std::vector<t_soft>> &in,
392402
auto &out_row = out[row_index++];
393403
out_row.resize(r.size());
394404
for (std::size_t c = 0; c < r.size(); c++)
395-
out_row[c] = static_cast<t_hard>(r[c] >= thresh ? 1 : 0);
405+
out_row[c] = static_cast<t_hard>(convert_soft_to_hard(r[c], thresh));
396406
}
397407
}
398408

libs/qec/lib/decoders/lut.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,14 +172,12 @@ class multi_error_lut : public decoder {
172172
decoder_result result{false, std::vector<float_t>(block_size, 0.0)};
173173

174174
// Convert syndrome to a string
175-
std::vector<uint8_t> hard_syndrome;
176-
cudaq::qec::convert_vec_soft_to_hard(syndrome, hard_syndrome);
177-
std::string syndrome_str(hard_syndrome.size(), '0');
175+
std::string syndrome_str(syndrome.size(), '0');
178176
int syndrome_weight = 0;
179177
assert(syndrome_str.length() == syndrome_size);
180178
bool anyErrors = false;
181179
for (std::size_t i = 0; i < syndrome_size; i++) {
182-
if (hard_syndrome[i]) {
180+
if (cudaq::qec::convert_soft_to_hard(syndrome[i])) {
183181
syndrome_str[i] = '1';
184182
anyErrors = true;
185183
syndrome_weight++;

libs/qec/lib/decoders/plugins/example/single_error_lut_example.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,11 @@ class single_error_lut_example : public decoder {
4949
decoder_result result{false, std::vector<float_t>(block_size, 0.0)};
5050

5151
// Convert syndrome to a string
52-
std::vector<uint8_t> hard_syndrome;
53-
cudaq::qec::convert_vec_soft_to_hard(syndrome, hard_syndrome);
54-
std::string syndrome_str(hard_syndrome.size(), '0');
52+
std::string syndrome_str(syndrome.size(), '0');
5553
assert(syndrome_str.length() == syndrome_size);
5654
bool anyErrors = false;
5755
for (std::size_t i = 0; i < syndrome_size; i++) {
58-
if (hard_syndrome[i]) {
56+
if (cudaq::qec::convert_soft_to_hard(syndrome[i])) {
5957
syndrome_str[i] = '1';
6058
anyErrors = true;
6159
}

libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,12 +167,10 @@ class pymatching : public decoder {
167167
auto t1 = std::chrono::high_resolution_clock::now();
168168
#endif
169169

170-
std::vector<uint8_t> hard_syndrome;
171-
cudaq::qec::convert_vec_soft_to_hard(syndrome, hard_syndrome);
172170
std::vector<uint64_t> detection_events;
173-
detection_events.reserve(hard_syndrome.size());
174-
for (size_t i = 0; i < hard_syndrome.size(); i++)
175-
if (hard_syndrome[i])
171+
detection_events.reserve(syndrome.size());
172+
for (size_t i = 0; i < syndrome.size(); i++)
173+
if (cudaq::qec::convert_soft_to_hard(syndrome[i]))
176174
detection_events.push_back(i);
177175
#if PERFORM_TIMING
178176
auto t2 = std::chrono::high_resolution_clock::now();

libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -926,7 +926,6 @@ std::vector<decoder_result> trt_decoder::decode_batch_impl(
926926
// Prepare input batch. For float input we preserve soft (raw) values;
927927
// for uint8 we binarize to 0/1.
928928
std::vector<IoType> input_host(impl_->input_size);
929-
std::vector<uint8_t> hard_syndrome;
930929
for (size_t batch_idx = 0; batch_idx < actual_batch; ++batch_idx) {
931930
const auto &syndrome = syndromes[batch_start + batch_idx];
932931
if constexpr (std::is_same_v<IoType, float>) {
@@ -935,10 +934,10 @@ std::vector<decoder_result> trt_decoder::decode_batch_impl(
935934
static_cast<IoType>(syndrome[i]);
936935
}
937936
} else {
938-
cudaq::qec::convert_vec_soft_to_hard(syndrome, hard_syndrome);
939937
for (size_t i = 0; i < syndrome_size_per_sample_; ++i) {
940938
input_host[batch_idx * syndrome_size_per_sample_ + i] =
941-
static_cast<IoType>(hard_syndrome[i]);
939+
static_cast<IoType>(
940+
cudaq::qec::convert_soft_to_hard(syndrome[i]));
942941
}
943942
}
944943
}

libs/qec/unittests/test_decoders.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,27 @@ TEST(DecoderUtils, CovertSoftToHard) {
7575
}
7676
}
7777

78+
TEST(DecoderUtils, ConvertSoftToHardScalar) {
79+
// Default threshold (0.5): the canonical >= contract.
80+
EXPECT_TRUE(cudaq::qec::convert_soft_to_hard(0.6f));
81+
EXPECT_FALSE(cudaq::qec::convert_soft_to_hard(0.4f));
82+
EXPECT_TRUE(cudaq::qec::convert_soft_to_hard(0.5f));
83+
EXPECT_FALSE(cudaq::qec::convert_soft_to_hard(0.499f));
84+
EXPECT_TRUE(cudaq::qec::convert_soft_to_hard(0.501f));
85+
86+
// Double-precision input.
87+
EXPECT_TRUE(cudaq::qec::convert_soft_to_hard(0.5));
88+
EXPECT_FALSE(cudaq::qec::convert_soft_to_hard(0.499));
89+
90+
// Custom threshold.
91+
EXPECT_TRUE(cudaq::qec::convert_soft_to_hard(0.4f, 0.4f));
92+
EXPECT_FALSE(cudaq::qec::convert_soft_to_hard(0.3f, 0.4f));
93+
94+
// Usable in a constant-expression context.
95+
static_assert(cudaq::qec::convert_soft_to_hard(0.5f));
96+
static_assert(!cudaq::qec::convert_soft_to_hard(0.49f));
97+
}
98+
7899
TEST(DecoderUtils, ConvertVecSoftToTensorHard) {
79100
// Generate a million random floats between 0 and 1 using mt19937
80101
std::mt19937_64 gen(13);

0 commit comments

Comments
 (0)