Skip to content

Commit 6711611

Browse files
committed
feat(gpu): refactor the sample extract entry point so the user can pass how many LWEs should be extracted per GLWE
1 parent 6b21bff commit 6711611

File tree

8 files changed

+37
-27
lines changed

8 files changed

+37
-27
lines changed

backends/tfhe-cuda-backend/cuda/include/ciphertext.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ void cuda_convert_lwe_ciphertext_vector_to_cpu_64(void *stream,
1818
void cuda_glwe_sample_extract_64(void *stream, uint32_t gpu_index,
1919
void *lwe_array_out, void const *glwe_array_in,
2020
uint32_t const *nth_array, uint32_t num_nths,
21-
uint32_t glwe_dimension,
21+
uint32_t lwe_per_glwe, uint32_t glwe_dimension,
2222
uint32_t polynomial_size);
2323
}
2424
#endif

backends/tfhe-cuda-backend/cuda/src/crypto/ciphertext.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,51 +24,51 @@ void cuda_convert_lwe_ciphertext_vector_to_cpu_64(void *stream,
2424
void cuda_glwe_sample_extract_64(void *stream, uint32_t gpu_index,
2525
void *lwe_array_out, void const *glwe_array_in,
2626
uint32_t const *nth_array, uint32_t num_nths,
27-
uint32_t glwe_dimension,
27+
uint32_t lwe_per_glwe, uint32_t glwe_dimension,
2828
uint32_t polynomial_size) {
2929

3030
switch (polynomial_size) {
3131
case 256:
3232
host_sample_extract<uint64_t, AmortizedDegree<256>>(
3333
static_cast<cudaStream_t>(stream), gpu_index, (uint64_t *)lwe_array_out,
3434
(uint64_t const *)glwe_array_in, (uint32_t const *)nth_array, num_nths,
35-
glwe_dimension);
35+
lwe_per_glwe, glwe_dimension);
3636
break;
3737
case 512:
3838
host_sample_extract<uint64_t, AmortizedDegree<512>>(
3939
static_cast<cudaStream_t>(stream), gpu_index, (uint64_t *)lwe_array_out,
4040
(uint64_t const *)glwe_array_in, (uint32_t const *)nth_array, num_nths,
41-
glwe_dimension);
41+
lwe_per_glwe, glwe_dimension);
4242
break;
4343
case 1024:
4444
host_sample_extract<uint64_t, AmortizedDegree<1024>>(
4545
static_cast<cudaStream_t>(stream), gpu_index, (uint64_t *)lwe_array_out,
4646
(uint64_t const *)glwe_array_in, (uint32_t const *)nth_array, num_nths,
47-
glwe_dimension);
47+
lwe_per_glwe, glwe_dimension);
4848
break;
4949
case 2048:
5050
host_sample_extract<uint64_t, AmortizedDegree<2048>>(
5151
static_cast<cudaStream_t>(stream), gpu_index, (uint64_t *)lwe_array_out,
5252
(uint64_t const *)glwe_array_in, (uint32_t const *)nth_array, num_nths,
53-
glwe_dimension);
53+
lwe_per_glwe, glwe_dimension);
5454
break;
5555
case 4096:
5656
host_sample_extract<uint64_t, AmortizedDegree<4096>>(
5757
static_cast<cudaStream_t>(stream), gpu_index, (uint64_t *)lwe_array_out,
5858
(uint64_t const *)glwe_array_in, (uint32_t const *)nth_array, num_nths,
59-
glwe_dimension);
59+
lwe_per_glwe, glwe_dimension);
6060
break;
6161
case 8192:
6262
host_sample_extract<uint64_t, AmortizedDegree<8192>>(
6363
static_cast<cudaStream_t>(stream), gpu_index, (uint64_t *)lwe_array_out,
6464
(uint64_t const *)glwe_array_in, (uint32_t const *)nth_array, num_nths,
65-
glwe_dimension);
65+
lwe_per_glwe, glwe_dimension);
6666
break;
6767
case 16384:
6868
host_sample_extract<uint64_t, AmortizedDegree<16384>>(
6969
static_cast<cudaStream_t>(stream), gpu_index, (uint64_t *)lwe_array_out,
7070
(uint64_t const *)glwe_array_in, (uint32_t const *)nth_array, num_nths,
71-
glwe_dimension);
71+
lwe_per_glwe, glwe_dimension);
7272
break;
7373
default:
7474
PANIC("Cuda error: unsupported polynomial size. Supported "

backends/tfhe-cuda-backend/cuda/src/crypto/ciphertext.cuh

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ void cuda_convert_lwe_ciphertext_vector_to_cpu(cudaStream_t stream,
2828

2929
template <typename Torus, class params>
3030
__global__ void sample_extract(Torus *lwe_array_out, Torus const *glwe_array_in,
31-
uint32_t const *nth_array,
31+
uint32_t const *nth_array, uint32_t lwe_per_glwe,
3232
uint32_t glwe_dimension) {
3333

3434
const int input_id = blockIdx.x;
@@ -39,7 +39,6 @@ __global__ void sample_extract(Torus *lwe_array_out, Torus const *glwe_array_in,
3939
auto lwe_out = lwe_array_out + input_id * lwe_output_size;
4040

4141
// We assume each GLWE will store the first polynomial_size inputs
42-
uint32_t lwe_per_glwe = params::degree;
4342
auto glwe_in = glwe_array_in + (input_id / lwe_per_glwe) * glwe_input_size;
4443

4544
// nth is ensured to be in [0, lwe_per_glwe)
@@ -49,18 +48,20 @@ __global__ void sample_extract(Torus *lwe_array_out, Torus const *glwe_array_in,
4948
sample_extract_body<Torus, params>(lwe_out, glwe_in, glwe_dimension, nth);
5049
}
5150

51+
// lwe_per_glwe LWEs will be extracted per GLWE ciphertext, thus we need to have
52+
// enough indexes
5253
template <typename Torus, class params>
53-
__host__ void host_sample_extract(cudaStream_t stream, uint32_t gpu_index,
54-
Torus *lwe_array_out,
55-
Torus const *glwe_array_in,
56-
uint32_t const *nth_array, uint32_t num_nths,
57-
uint32_t glwe_dimension) {
54+
__host__ void
55+
host_sample_extract(cudaStream_t stream, uint32_t gpu_index,
56+
Torus *lwe_array_out, Torus const *glwe_array_in,
57+
uint32_t const *nth_array, uint32_t num_nths,
58+
uint32_t lwe_per_glwe, uint32_t glwe_dimension) {
5859
cuda_set_device(gpu_index);
5960

6061
dim3 grid(num_nths);
6162
dim3 thds(params::degree / params::opt);
6263
sample_extract<Torus, params><<<grid, thds, 0, stream>>>(
63-
lwe_array_out, glwe_array_in, nth_array, glwe_dimension);
64+
lwe_array_out, glwe_array_in, nth_array, lwe_per_glwe, glwe_dimension);
6465
check_cuda_error(cudaGetLastError());
6566
}
6667

backends/tfhe-cuda-backend/cuda/src/integer/compression/compression.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,10 @@ __host__ void host_integer_decompress(
279279
extracted_glwe = max_idx_and_glwe.second;
280280

281281
auto num_lwes = last_idx + 1 - current_idx;
282-
cuda_glwe_sample_extract_64(streams[0], gpu_indexes[0], extracted_lwe,
283-
extracted_glwe, d_indexes_array_chunk, num_lwes,
284-
compression_params.glwe_dimension,
285-
compression_params.polynomial_size);
282+
cuda_glwe_sample_extract_64(
283+
streams[0], gpu_indexes[0], extracted_lwe, extracted_glwe,
284+
d_indexes_array_chunk, num_lwes, compression_params.polynomial_size,
285+
compression_params.glwe_dimension, compression_params.polynomial_size);
286286
d_indexes_array_chunk += num_lwes;
287287
extracted_lwe += num_lwes * lwe_accumulator_size;
288288
current_idx = last_idx;

backends/tfhe-cuda-backend/src/bindings.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ unsafe extern "C" {
3030
glwe_array_in: *const ffi::c_void,
3131
nth_array: *const u32,
3232
num_nths: u32,
33+
lwe_per_glwe: u32,
3334
glwe_dimension: u32,
3435
polynomial_size: u32,
3536
);

tfhe/src/core_crypto/gpu/algorithms/glwe_sample_extraction.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub unsafe fn cuda_extract_lwe_samples_from_glwe_ciphertext_list_async<Scalar>(
1313
input_glwe_list: &CudaGlweCiphertextList<Scalar>,
1414
output_lwe_list: &mut CudaLweCiphertextList<Scalar>,
1515
vec_nth: &[MonomialDegree],
16+
lwe_per_glwe: u32,
1617
streams: &CudaStreams,
1718
) where
1819
Scalar: UnsignedTorus,
@@ -29,9 +30,10 @@ pub unsafe fn cuda_extract_lwe_samples_from_glwe_ciphertext_list_async<Scalar>(
2930
Got {in_lwe_dim:?} for input and {out_lwe_dim:?} for output.",
3031
);
3132

33+
// lwe_per_glwe LWEs will be extracted per GLWE ciphertext, thus we need to have enough indexes
3234
assert_eq!(
3335
vec_nth.len(),
34-
input_glwe_list.glwe_ciphertext_count().0 * input_glwe_list.polynomial_size().0,
36+
input_glwe_list.glwe_ciphertext_count().0 * lwe_per_glwe as usize,
3537
"Mismatch between number of nths and number of GLWEs provided.",
3638
);
3739

@@ -53,6 +55,7 @@ pub unsafe fn cuda_extract_lwe_samples_from_glwe_ciphertext_list_async<Scalar>(
5355
&input_glwe_list.0.d_vec,
5456
&d_nth_array,
5557
vec_nth.len() as u32,
58+
lwe_per_glwe,
5659
input_glwe_list.glwe_dimension(),
5760
input_glwe_list.polynomial_size(),
5861
);
@@ -66,6 +69,7 @@ pub fn cuda_extract_lwe_samples_from_glwe_ciphertext_list<Scalar>(
6669
input_glwe_list: &CudaGlweCiphertextList<Scalar>,
6770
output_lwe_list: &mut CudaLweCiphertextList<Scalar>,
6871
vec_nth: &[MonomialDegree],
72+
lwe_per_glwe: u32,
6973
streams: &CudaStreams,
7074
) where
7175
Scalar: UnsignedTorus,
@@ -75,6 +79,7 @@ pub fn cuda_extract_lwe_samples_from_glwe_ciphertext_list<Scalar>(
7579
input_glwe_list,
7680
output_lwe_list,
7781
vec_nth,
82+
lwe_per_glwe,
7883
streams,
7984
);
8085
}

tfhe/src/core_crypto/gpu/algorithms/test/glwe_sample_extraction.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,21 +75,22 @@ fn glwe_encrypt_sample_extract_decrypt_custom_mod<Scalar: UnsignedTorus + Send +
7575
let input_cuda_glwe_list =
7676
CudaGlweCiphertextList::from_glwe_ciphertext_list(&glwe_list, &streams);
7777

78+
let lwe_per_glwe = 2;
7879
let mut output_cuda_lwe_ciphertext_list = CudaLweCiphertextList::new(
7980
equivalent_lwe_sk.lwe_dimension(),
80-
LweCiphertextCount(msgs.len() * glwe_list.polynomial_size().0),
81+
LweCiphertextCount(msgs.len() * lwe_per_glwe),
8182
ciphertext_modulus,
8283
&streams,
8384
);
8485

85-
let nths = (0..(msgs.len() * glwe_list.polynomial_size().0))
86+
let nths = (0..(msgs.len() * lwe_per_glwe))
8687
.map(|x| MonomialDegree(x % glwe_list.polynomial_size().0))
8788
.collect_vec();
88-
8989
cuda_extract_lwe_samples_from_glwe_ciphertext_list(
9090
&input_cuda_glwe_list,
9191
&mut output_cuda_lwe_ciphertext_list,
9292
nths.as_slice(),
93+
lwe_per_glwe as u32,
9394
&streams,
9495
);
9596

@@ -107,7 +108,7 @@ fn glwe_encrypt_sample_extract_decrypt_custom_mod<Scalar: UnsignedTorus + Send +
107108
&mut output_plaintext_list,
108109
);
109110

110-
let mut decoded = vec![Scalar::ZERO; plaintext_list.plaintext_count().0];
111+
let mut decoded = vec![Scalar::ZERO; msgs.len() * lwe_per_glwe];
111112

112113
decoded
113114
.iter_mut()
@@ -116,7 +117,7 @@ fn glwe_encrypt_sample_extract_decrypt_custom_mod<Scalar: UnsignedTorus + Send +
116117

117118
let mut count = msg_modulus;
118119
count = count.wrapping_sub(Scalar::ONE);
119-
for result in decoded.chunks_exact(glwe_list.polynomial_size().0) {
120+
for result in decoded.chunks_exact(lwe_per_glwe) {
120121
assert!(result.iter().all(|&x| x == count));
121122
count = count.wrapping_sub(Scalar::ONE);
122123
}

tfhe/src/core_crypto/gpu/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@ pub unsafe fn extract_lwe_samples_from_glwe_ciphertext_list_async<T: UnsignedInt
402402
glwe_array_in: &CudaVec<T>,
403403
nth_array: &CudaVec<u32>,
404404
num_nths: u32,
405+
lwe_per_glwe: u32,
405406
glwe_dimension: GlweDimension,
406407
polynomial_size: PolynomialSize,
407408
) {
@@ -412,6 +413,7 @@ pub unsafe fn extract_lwe_samples_from_glwe_ciphertext_list_async<T: UnsignedInt
412413
glwe_array_in.as_c_ptr(0),
413414
nth_array.as_c_ptr(0).cast::<u32>(),
414415
num_nths,
416+
lwe_per_glwe,
415417
glwe_dimension.0 as u32,
416418
polynomial_size.0 as u32,
417419
);

0 commit comments

Comments
 (0)