Skip to content

Commit 7807162

Browse files
committed
fix(gpu): enforce tighter bounds on compression output
1 parent 0809eb9 commit 7807162

File tree

3 files changed

+56
-71
lines changed

3 files changed

+56
-71
lines changed

backends/tfhe-cuda-backend/cuda/src/integer/compression/compression.cuh

Lines changed: 46 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,26 @@
1414

1515
template <typename Torus>
1616
__global__ void pack(Torus *array_out, Torus *array_in, uint32_t log_modulus,
17-
uint32_t num_coeffs, uint32_t in_len, uint32_t out_len) {
18-
auto nbits = sizeof(Torus) * 8;
17+
uint32_t num_glwes, uint32_t in_len, uint32_t out_len) {
1918
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
2019

21-
auto glwe_index = tid / out_len;
22-
auto i = tid % out_len;
23-
auto chunk_array_in = array_in + glwe_index * in_len;
24-
auto chunk_array_out = array_out + glwe_index * out_len;
20+
if (tid < num_glwes * out_len) {
21+
auto NBITS = sizeof(Torus) * 8;
22+
auto glwe_index = tid / out_len;
23+
auto i = tid % out_len;
24+
auto chunk_array_in = array_in + glwe_index * in_len;
25+
auto chunk_array_out = array_out + glwe_index * out_len;
2526

26-
if (tid < num_coeffs) {
27-
28-
auto k = nbits * i / log_modulus;
27+
auto k = NBITS * i / log_modulus;
2928
auto j = k;
3029

31-
auto start_shift = i * nbits - j * log_modulus;
30+
auto start_shift = i * NBITS - j * log_modulus;
3231

3332
auto value = chunk_array_in[j] >> start_shift;
3433
j++;
3534

36-
while (j * log_modulus < ((i + 1) * nbits) && j < in_len) {
37-
auto shift = j * log_modulus - i * nbits;
35+
while (j * log_modulus < ((i + 1) * NBITS) && j < in_len) {
36+
auto shift = j * log_modulus - i * NBITS;
3837
value |= chunk_array_in[j] << shift;
3938
j++;
4039
}
@@ -51,30 +50,30 @@ __host__ void host_pack(cudaStream_t stream, uint32_t gpu_index,
5150
PANIC("Cuda error: Input and output must be different");
5251

5352
cuda_set_device(gpu_index);
53+
auto NBITS = sizeof(Torus) * 8;
5454
auto compression_params = mem_ptr->compression_params;
55-
5655
auto log_modulus = mem_ptr->storage_log_modulus;
57-
// [0..num_glwes-1) GLWEs
58-
auto in_len = (compression_params.glwe_dimension + 1) *
59-
compression_params.polynomial_size;
60-
auto number_bits_to_pack = in_len * log_modulus;
61-
auto nbits = sizeof(Torus) * 8;
62-
// number_bits_to_pack.div_ceil(Scalar::BITS)
63-
auto out_len = (number_bits_to_pack + nbits - 1) / nbits;
6456

65-
// Last GLWE
66-
number_bits_to_pack = in_len * log_modulus;
67-
auto last_out_len = (number_bits_to_pack + nbits - 1) / nbits;
57+
auto glwe_ciphertext_size = (compression_params.glwe_dimension + 1) *
58+
compression_params.polynomial_size;
59+
auto glwe_mask_size =
60+
compression_params.glwe_dimension * compression_params.polynomial_size;
6861

69-
auto num_coeffs = (num_glwes - 1) * out_len + last_out_len;
62+
auto uncompressed_len = num_glwes * glwe_mask_size + num_lwes;
63+
auto number_bits_to_pack = uncompressed_len * log_modulus;
7064

71-
int num_blocks = 0, num_threads = 0;
72-
getNumBlocksAndThreads(num_coeffs, 1024, num_blocks, num_threads);
65+
// equivalent to number_bits_to_pack.div_ceil(Scalar::BITS)
66+
auto compressed_len = (number_bits_to_pack + NBITS - 1) / NBITS;
7367

68+
// Kernel settings
69+
int num_blocks = 0, num_threads = 0;
70+
getNumBlocksAndThreads(num_glwes * compressed_len, 1024, num_blocks,
71+
num_threads);
7472
dim3 grid(num_blocks);
7573
dim3 threads(num_threads);
7674
pack<Torus><<<grid, threads, 0, stream>>>(array_out, array_in, log_modulus,
77-
num_coeffs, in_len, out_len);
75+
num_glwes, uncompressed_len,
76+
compressed_len);
7877
check_cuda_error(cudaGetLastError());
7978
}
8079

@@ -144,7 +143,7 @@ template <typename Torus>
144143
__global__ void extract(Torus *glwe_array_out, Torus const *array_in,
145144
uint32_t index, uint32_t log_modulus,
146145
uint32_t input_len, uint32_t initial_out_len) {
147-
auto nbits = sizeof(Torus) * 8;
146+
auto NBITS = sizeof(Torus) * 8;
148147

149148
auto i = threadIdx.x + blockIdx.x * blockDim.x;
150149
auto chunk_array_in = array_in + index * input_len;
@@ -154,10 +153,10 @@ __global__ void extract(Torus *glwe_array_out, Torus const *array_in,
154153
auto start = i * log_modulus;
155154
auto end = (i + 1) * log_modulus;
156155

157-
auto start_block = start / nbits;
158-
auto start_remainder = start % nbits;
156+
auto start_block = start / NBITS;
157+
auto start_remainder = start % NBITS;
159158

160-
auto end_block_inclusive = (end - 1) / nbits;
159+
auto end_block_inclusive = (end - 1) / NBITS;
161160

162161
Torus unpacked_i;
163162
if (start_block == end_block_inclusive) {
@@ -166,13 +165,13 @@ __global__ void extract(Torus *glwe_array_out, Torus const *array_in,
166165
} else {
167166
auto first_part = chunk_array_in[start_block] >> start_remainder;
168167
auto second_part = chunk_array_in[start_block + 1]
169-
<< (nbits - start_remainder);
168+
<< (NBITS - start_remainder);
170169

171170
unpacked_i = (first_part | second_part) & mask;
172171
}
173172

174173
// Extract
175-
glwe_array_out[i] = unpacked_i << (nbits - log_modulus);
174+
glwe_array_out[i] = unpacked_i << (NBITS - log_modulus);
176175
}
177176
}
178177

@@ -186,38 +185,38 @@ __host__ void host_extract(cudaStream_t stream, uint32_t gpu_index,
186185
PANIC("Cuda error: Input and output must be different");
187186

188187
cuda_set_device(gpu_index);
189-
188+
auto NBITS = sizeof(Torus) * 8;
190189
auto compression_params = mem_ptr->compression_params;
191-
192190
auto log_modulus = mem_ptr->storage_log_modulus;
193191

194192
uint32_t body_count =
195193
std::min(mem_ptr->body_count, compression_params.polynomial_size);
196-
auto initial_out_len =
194+
// num_glwes = 1 in this case
195+
auto uncompressed_len =
197196
compression_params.glwe_dimension * compression_params.polynomial_size +
198197
body_count;
199198

200-
auto compressed_glwe_accumulator_size =
201-
(compression_params.glwe_dimension + 1) *
202-
compression_params.polynomial_size;
203-
auto number_bits_to_unpack = compressed_glwe_accumulator_size * log_modulus;
204-
auto nbits = sizeof(Torus) * 8;
199+
auto glwe_ciphertext_size = (compression_params.glwe_dimension + 1) *
200+
compression_params.polynomial_size;
201+
auto number_bits_to_unpack = uncompressed_len * log_modulus;
205202
// number_bits_to_unpack.div_ceil(Scalar::BITS)
206-
auto input_len = (number_bits_to_unpack + nbits - 1) / nbits;
203+
auto compressed_len = (number_bits_to_unpack + NBITS - 1) / NBITS;
207204

208205
// We assure the tail of the glwe is zeroed
209-
auto zeroed_slice = glwe_array_out + initial_out_len;
206+
auto zeroed_slice = glwe_array_out + uncompressed_len;
210207
cuda_memset_async(zeroed_slice, 0,
211208
(compression_params.polynomial_size - body_count) *
212209
sizeof(Torus),
213210
stream, gpu_index);
211+
212+
// Kernel settings
214213
int num_blocks = 0, num_threads = 0;
215-
getNumBlocksAndThreads(initial_out_len, 128, num_blocks, num_threads);
214+
getNumBlocksAndThreads(uncompressed_len, 128, num_blocks, num_threads);
216215
dim3 grid(num_blocks);
217216
dim3 threads(num_threads);
218-
extract<Torus><<<grid, threads, 0, stream>>>(glwe_array_out, array_in,
219-
glwe_index, log_modulus,
220-
input_len, initial_out_len);
217+
extract<Torus><<<grid, threads, 0, stream>>>(
218+
glwe_array_out, array_in, glwe_index, log_modulus, compressed_len,
219+
uncompressed_len);
221220
check_cuda_error(cudaGetLastError());
222221
}
223222

tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::core_crypto::entities::packed_integers::PackedIntegers;
22
use crate::core_crypto::gpu::vec::{CudaVec, GpuIndex};
33
use crate::core_crypto::gpu::CudaStreams;
44
use crate::core_crypto::prelude::compressed_modulus_switched_glwe_ciphertext::CompressedModulusSwitchedGlweCiphertext;
5-
use crate::core_crypto::prelude::{glwe_ciphertext_size, CiphertextCount, LweCiphertextCount};
5+
use crate::core_crypto::prelude::{CiphertextCount, LweCiphertextCount};
66
use crate::integer::ciphertext::{CompressedCiphertextList, DataKind};
77
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
88
use crate::integer::gpu::ciphertext::{
@@ -343,25 +343,11 @@ impl CompressedCiphertextList {
343343
let message_modulus = self.packed_list.message_modulus;
344344
let carry_modulus = self.packed_list.carry_modulus;
345345

346-
let mut flat_cpu_data = modulus_switched_glwe_ciphertext_list
346+
let flat_cpu_data = modulus_switched_glwe_ciphertext_list
347347
.iter()
348348
.flat_map(|ct| ct.packed_integers.packed_coeffs.clone())
349349
.collect_vec();
350350

351-
let glwe_ciphertext_count = self.packed_list.modulus_switched_glwe_ciphertext_list.len();
352-
let glwe_size = self.packed_list.modulus_switched_glwe_ciphertext_list[0]
353-
.glwe_dimension()
354-
.to_glwe_size();
355-
let polynomial_size =
356-
self.packed_list.modulus_switched_glwe_ciphertext_list[0].polynomial_size();
357-
358-
// FIXME: have a more precise memory handling, this is too long and should be "just" the
359-
// original flat_cpu_data.len()
360-
let unpacked_glwe_ciphertext_flat_len =
361-
glwe_ciphertext_count * glwe_ciphertext_size(glwe_size, polynomial_size);
362-
363-
flat_cpu_data.resize(unpacked_glwe_ciphertext_flat_len, 0u64);
364-
365351
let flat_gpu_data = unsafe {
366352
let v = CudaVec::from_cpu_async(flat_cpu_data.as_slice(), streams, 0);
367353
streams.synchronize();

tfhe/src/integer/gpu/list_compression/server_keys.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
33
use crate::core_crypto::gpu::vec::CudaVec;
44
use crate::core_crypto::gpu::CudaStreams;
55
use crate::core_crypto::prelude::{
6-
glwe_ciphertext_size, CiphertextModulus, CiphertextModulusLog, GlweCiphertextCount,
7-
LweCiphertextCount, PolynomialSize,
6+
glwe_ciphertext_size, glwe_mask_size, CiphertextModulus, CiphertextModulusLog,
7+
GlweCiphertextCount, LweCiphertextCount, PolynomialSize,
88
};
99
use crate::integer::ciphertext::DataKind;
1010
use crate::integer::compression_keys::CompressionKey;
@@ -173,15 +173,15 @@ impl CudaCompressionKey {
173173
.sum();
174174

175175
let num_glwes = num_lwes.div_ceil(self.lwe_per_glwe.0);
176-
let glwe_ciphertext_size =
177-
glwe_ciphertext_size(compressed_glwe_size, compressed_polynomial_size);
176+
let glwe_mask_size = glwe_mask_size(
177+
compressed_glwe_size.to_glwe_dimension(),
178+
compressed_polynomial_size,
179+
);
178180
// The number of u64 (both mask and bodies)
179-
// FIXME: have a more precise memory handling, this is too long and should be
180-
// num_glwes * glwe_mask_size + num_lwes
181-
let uncompressed_len = num_glwes * glwe_ciphertext_size;
181+
let uncompressed_len = num_glwes * glwe_mask_size + num_lwes;
182182
let number_bits_to_pack = uncompressed_len * self.storage_log_modulus.0;
183183
let compressed_len = number_bits_to_pack.div_ceil(u64::BITS as usize);
184-
let mut packed_glwe_list = CudaVec::new(compressed_len, streams, 0);
184+
let mut packed_glwe_list = CudaVec::new(num_glwes * compressed_len, streams, 0);
185185

186186
unsafe {
187187
let input_lwes = Self::flatten_async(ciphertexts, streams);

0 commit comments

Comments
 (0)