Skip to content

Commit

Permalink
feat(gpu): refactor the sample extract entry point so the user can pa…
Browse files Browse the repository at this point in the history
…ss how many LWEs should be extracted per GLWE
  • Loading branch information
pdroalves committed Feb 20, 2025
1 parent 6b21bff commit 6711611
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 27 deletions.
2 changes: 1 addition & 1 deletion backends/tfhe-cuda-backend/cuda/include/ciphertext.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void cuda_convert_lwe_ciphertext_vector_to_cpu_64(void *stream,
void cuda_glwe_sample_extract_64(void *stream, uint32_t gpu_index,
void *lwe_array_out, void const *glwe_array_in,
uint32_t const *nth_array, uint32_t num_nths,
uint32_t glwe_dimension,
uint32_t lwe_per_glwe, uint32_t glwe_dimension,
uint32_t polynomial_size);
}
#endif
16 changes: 8 additions & 8 deletions backends/tfhe-cuda-backend/cuda/src/crypto/ciphertext.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,51 +24,51 @@ void cuda_convert_lwe_ciphertext_vector_to_cpu_64(void *stream,
void cuda_glwe_sample_extract_64(void *stream, uint32_t gpu_index,
void *lwe_array_out, void const *glwe_array_in,
uint32_t const *nth_array, uint32_t num_nths,
uint32_t glwe_dimension,
uint32_t lwe_per_glwe, uint32_t glwe_dimension,
uint32_t polynomial_size) {

switch (polynomial_size) {
case 256:
host_sample_extract<uint64_t, AmortizedDegree<256>>(
static_cast<cudaStream_t>(stream), gpu_index, (uint64_t *)lwe_array_out,
(uint64_t const *)glwe_array_in, (uint32_t const *)nth_array, num_nths,
glwe_dimension);
lwe_per_glwe, glwe_dimension);
break;
case 512:
host_sample_extract<uint64_t, AmortizedDegree<512>>(
static_cast<cudaStream_t>(stream), gpu_index, (uint64_t *)lwe_array_out,
(uint64_t const *)glwe_array_in, (uint32_t const *)nth_array, num_nths,
glwe_dimension);
lwe_per_glwe, glwe_dimension);
break;
case 1024:
host_sample_extract<uint64_t, AmortizedDegree<1024>>(
static_cast<cudaStream_t>(stream), gpu_index, (uint64_t *)lwe_array_out,
(uint64_t const *)glwe_array_in, (uint32_t const *)nth_array, num_nths,
glwe_dimension);
lwe_per_glwe, glwe_dimension);
break;
case 2048:
host_sample_extract<uint64_t, AmortizedDegree<2048>>(
static_cast<cudaStream_t>(stream), gpu_index, (uint64_t *)lwe_array_out,
(uint64_t const *)glwe_array_in, (uint32_t const *)nth_array, num_nths,
glwe_dimension);
lwe_per_glwe, glwe_dimension);
break;
case 4096:
host_sample_extract<uint64_t, AmortizedDegree<4096>>(
static_cast<cudaStream_t>(stream), gpu_index, (uint64_t *)lwe_array_out,
(uint64_t const *)glwe_array_in, (uint32_t const *)nth_array, num_nths,
glwe_dimension);
lwe_per_glwe, glwe_dimension);
break;
case 8192:
host_sample_extract<uint64_t, AmortizedDegree<8192>>(
static_cast<cudaStream_t>(stream), gpu_index, (uint64_t *)lwe_array_out,
(uint64_t const *)glwe_array_in, (uint32_t const *)nth_array, num_nths,
glwe_dimension);
lwe_per_glwe, glwe_dimension);
break;
case 16384:
host_sample_extract<uint64_t, AmortizedDegree<16384>>(
static_cast<cudaStream_t>(stream), gpu_index, (uint64_t *)lwe_array_out,
(uint64_t const *)glwe_array_in, (uint32_t const *)nth_array, num_nths,
glwe_dimension);
lwe_per_glwe, glwe_dimension);
break;
default:
PANIC("Cuda error: unsupported polynomial size. Supported "
Expand Down
17 changes: 9 additions & 8 deletions backends/tfhe-cuda-backend/cuda/src/crypto/ciphertext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void cuda_convert_lwe_ciphertext_vector_to_cpu(cudaStream_t stream,

template <typename Torus, class params>
__global__ void sample_extract(Torus *lwe_array_out, Torus const *glwe_array_in,
uint32_t const *nth_array,
uint32_t const *nth_array, uint32_t lwe_per_glwe,
uint32_t glwe_dimension) {

const int input_id = blockIdx.x;
Expand All @@ -39,7 +39,6 @@ __global__ void sample_extract(Torus *lwe_array_out, Torus const *glwe_array_in,
auto lwe_out = lwe_array_out + input_id * lwe_output_size;

// We assume each GLWE will store the first polynomial_size inputs
uint32_t lwe_per_glwe = params::degree;
auto glwe_in = glwe_array_in + (input_id / lwe_per_glwe) * glwe_input_size;

// nth is ensured to be in [0, lwe_per_glwe)
Expand All @@ -49,18 +48,20 @@ __global__ void sample_extract(Torus *lwe_array_out, Torus const *glwe_array_in,
sample_extract_body<Torus, params>(lwe_out, glwe_in, glwe_dimension, nth);
}

// lwe_per_glwe LWEs will be extracted per GLWE ciphertext, thus we need to have
// enough indexes
template <typename Torus, class params>
__host__ void host_sample_extract(cudaStream_t stream, uint32_t gpu_index,
Torus *lwe_array_out,
Torus const *glwe_array_in,
uint32_t const *nth_array, uint32_t num_nths,
uint32_t glwe_dimension) {
__host__ void
host_sample_extract(cudaStream_t stream, uint32_t gpu_index,
Torus *lwe_array_out, Torus const *glwe_array_in,
uint32_t const *nth_array, uint32_t num_nths,
uint32_t lwe_per_glwe, uint32_t glwe_dimension) {
cuda_set_device(gpu_index);

dim3 grid(num_nths);
dim3 thds(params::degree / params::opt);
sample_extract<Torus, params><<<grid, thds, 0, stream>>>(
lwe_array_out, glwe_array_in, nth_array, glwe_dimension);
lwe_array_out, glwe_array_in, nth_array, lwe_per_glwe, glwe_dimension);
check_cuda_error(cudaGetLastError());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,10 @@ __host__ void host_integer_decompress(
extracted_glwe = max_idx_and_glwe.second;

auto num_lwes = last_idx + 1 - current_idx;
cuda_glwe_sample_extract_64(streams[0], gpu_indexes[0], extracted_lwe,
extracted_glwe, d_indexes_array_chunk, num_lwes,
compression_params.glwe_dimension,
compression_params.polynomial_size);
cuda_glwe_sample_extract_64(
streams[0], gpu_indexes[0], extracted_lwe, extracted_glwe,
d_indexes_array_chunk, num_lwes, compression_params.polynomial_size,
compression_params.glwe_dimension, compression_params.polynomial_size);
d_indexes_array_chunk += num_lwes;
extracted_lwe += num_lwes * lwe_accumulator_size;
current_idx = last_idx;
Expand Down
1 change: 1 addition & 0 deletions backends/tfhe-cuda-backend/src/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ unsafe extern "C" {
glwe_array_in: *const ffi::c_void,
nth_array: *const u32,
num_nths: u32,
lwe_per_glwe: u32,
glwe_dimension: u32,
polynomial_size: u32,
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub unsafe fn cuda_extract_lwe_samples_from_glwe_ciphertext_list_async<Scalar>(
input_glwe_list: &CudaGlweCiphertextList<Scalar>,
output_lwe_list: &mut CudaLweCiphertextList<Scalar>,
vec_nth: &[MonomialDegree],
lwe_per_glwe: u32,
streams: &CudaStreams,
) where
Scalar: UnsignedTorus,
Expand All @@ -29,9 +30,10 @@ pub unsafe fn cuda_extract_lwe_samples_from_glwe_ciphertext_list_async<Scalar>(
Got {in_lwe_dim:?} for input and {out_lwe_dim:?} for output.",
);

// lwe_per_glwe LWEs will be extracted per GLWE ciphertext, thus we need to have enough indexes
assert_eq!(
vec_nth.len(),
input_glwe_list.glwe_ciphertext_count().0 * input_glwe_list.polynomial_size().0,
input_glwe_list.glwe_ciphertext_count().0 * lwe_per_glwe as usize,
"Mismatch between number of nths and number of GLWEs provided.",
);

Expand All @@ -53,6 +55,7 @@ pub unsafe fn cuda_extract_lwe_samples_from_glwe_ciphertext_list_async<Scalar>(
&input_glwe_list.0.d_vec,
&d_nth_array,
vec_nth.len() as u32,
lwe_per_glwe,
input_glwe_list.glwe_dimension(),
input_glwe_list.polynomial_size(),
);
Expand All @@ -66,6 +69,7 @@ pub fn cuda_extract_lwe_samples_from_glwe_ciphertext_list<Scalar>(
input_glwe_list: &CudaGlweCiphertextList<Scalar>,
output_lwe_list: &mut CudaLweCiphertextList<Scalar>,
vec_nth: &[MonomialDegree],
lwe_per_glwe: u32,
streams: &CudaStreams,
) where
Scalar: UnsignedTorus,
Expand All @@ -75,6 +79,7 @@ pub fn cuda_extract_lwe_samples_from_glwe_ciphertext_list<Scalar>(
input_glwe_list,
output_lwe_list,
vec_nth,
lwe_per_glwe,
streams,
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,22 @@ fn glwe_encrypt_sample_extract_decrypt_custom_mod<Scalar: UnsignedTorus + Send +
let input_cuda_glwe_list =
CudaGlweCiphertextList::from_glwe_ciphertext_list(&glwe_list, &streams);

let lwe_per_glwe = 2;
let mut output_cuda_lwe_ciphertext_list = CudaLweCiphertextList::new(
equivalent_lwe_sk.lwe_dimension(),
LweCiphertextCount(msgs.len() * glwe_list.polynomial_size().0),
LweCiphertextCount(msgs.len() * lwe_per_glwe),
ciphertext_modulus,
&streams,
);

let nths = (0..(msgs.len() * glwe_list.polynomial_size().0))
let nths = (0..(msgs.len() * lwe_per_glwe))
.map(|x| MonomialDegree(x % glwe_list.polynomial_size().0))
.collect_vec();

cuda_extract_lwe_samples_from_glwe_ciphertext_list(
&input_cuda_glwe_list,
&mut output_cuda_lwe_ciphertext_list,
nths.as_slice(),
lwe_per_glwe as u32,
&streams,
);

Expand All @@ -107,7 +108,7 @@ fn glwe_encrypt_sample_extract_decrypt_custom_mod<Scalar: UnsignedTorus + Send +
&mut output_plaintext_list,
);

let mut decoded = vec![Scalar::ZERO; plaintext_list.plaintext_count().0];
let mut decoded = vec![Scalar::ZERO; msgs.len() * lwe_per_glwe];

decoded
.iter_mut()
Expand All @@ -116,7 +117,7 @@ fn glwe_encrypt_sample_extract_decrypt_custom_mod<Scalar: UnsignedTorus + Send +

let mut count = msg_modulus;
count = count.wrapping_sub(Scalar::ONE);
for result in decoded.chunks_exact(glwe_list.polynomial_size().0) {
for result in decoded.chunks_exact(lwe_per_glwe) {
assert!(result.iter().all(|&x| x == count));
count = count.wrapping_sub(Scalar::ONE);
}
Expand Down
2 changes: 2 additions & 0 deletions tfhe/src/core_crypto/gpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ pub unsafe fn extract_lwe_samples_from_glwe_ciphertext_list_async<T: UnsignedInt
glwe_array_in: &CudaVec<T>,
nth_array: &CudaVec<u32>,
num_nths: u32,
lwe_per_glwe: u32,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,
) {
Expand All @@ -412,6 +413,7 @@ pub unsafe fn extract_lwe_samples_from_glwe_ciphertext_list_async<T: UnsignedInt
glwe_array_in.as_c_ptr(0),
nth_array.as_c_ptr(0).cast::<u32>(),
num_nths,
lwe_per_glwe,
glwe_dimension.0 as u32,
polynomial_size.0 as u32,
);
Expand Down

0 comments on commit 6711611

Please sign in to comment.