Skip to content

Commit bfd3773

Browse files
committed
chore(gpu): refactor arithmetic scalar shift
1 parent a7c9357 commit bfd3773

File tree

8 files changed

+159
-34
lines changed

8 files changed

+159
-34
lines changed

backends/tfhe-cuda-backend/cuda/include/integer/integer.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@ void scratch_cuda_integer_radix_arithmetic_scalar_shift_kb_64(
177177

178178
void cuda_integer_radix_arithmetic_scalar_shift_kb_64_inplace(
179179
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
180-
void *lwe_array, uint32_t shift, int8_t *mem_ptr, void *const *bsks,
181-
void *const *ksks, uint32_t num_blocks);
180+
CudaRadixCiphertextFFI *lwe_array, uint32_t shift, int8_t *mem_ptr,
181+
void *const *bsks, void *const *ksks);
182182

183183
void cleanup_cuda_integer_radix_logical_scalar_shift(
184184
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,

backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2877,7 +2877,7 @@ template <typename Torus> struct int_arithmetic_scalar_shift_buffer {
28772877

28782878
SHIFT_OR_ROTATE_TYPE shift_type;
28792879

2880-
Torus *tmp_rotated;
2880+
CudaRadixCiphertextFFI *tmp_rotated;
28812881

28822882
cudaStream_t *local_streams_1;
28832883
cudaStream_t *local_streams_2;
@@ -2909,13 +2909,10 @@ template <typename Torus> struct int_arithmetic_scalar_shift_buffer {
29092909
uint32_t big_lwe_size = params.big_lwe_dimension + 1;
29102910
uint32_t big_lwe_size_bytes = big_lwe_size * sizeof(Torus);
29112911

2912-
tmp_rotated = (Torus *)cuda_malloc_async((num_radix_blocks + 3) *
2913-
big_lwe_size_bytes,
2914-
streams[0], gpu_indexes[0]);
2915-
2916-
cuda_memset_async(tmp_rotated, 0,
2917-
(num_radix_blocks + 3) * big_lwe_size_bytes, streams[0],
2918-
gpu_indexes[0]);
2912+
tmp_rotated = new CudaRadixCiphertextFFI;
2913+
create_zero_radix_ciphertext_async<Torus>(
2914+
streams[0], gpu_indexes[0], tmp_rotated, num_radix_blocks + 3,
2915+
params.big_lwe_dimension);
29192916

29202917
uint32_t num_bits_in_block = (uint32_t)std::log2(params.message_modulus);
29212918

@@ -3051,7 +3048,8 @@ template <typename Torus> struct int_arithmetic_scalar_shift_buffer {
30513048
lut_buffers_bivariate.clear();
30523049
lut_buffers_univariate.clear();
30533050

3054-
cuda_drop_async(tmp_rotated, streams[0], gpu_indexes[0]);
3051+
release_radix_ciphertext(streams[0], gpu_indexes[0], tmp_rotated);
3052+
delete tmp_rotated;
30553053
}
30563054
};
30573055

backends/tfhe-cuda-backend/cuda/src/integer/abs.cuh

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ __host__ void legacy_host_integer_abs_kb_async(
4848
cuda_memcpy_async_gpu_to_gpu(mask, ct, num_blocks * big_lwe_size_bytes,
4949
streams[0], gpu_indexes[0]);
5050

51-
host_integer_radix_arithmetic_scalar_shift_kb_inplace<Torus>(
51+
legacy_host_integer_radix_arithmetic_scalar_shift_kb_inplace<Torus>(
5252
streams, gpu_indexes, gpu_count, mask, num_bits_in_ciphertext - 1,
5353
mem_ptr->arithmetic_scalar_shift_mem, bsks, ksks, num_blocks);
5454
legacy_host_addition<Torus>(streams[0], gpu_indexes[0], ct, mask, ct,
@@ -84,9 +84,8 @@ host_integer_abs_kb(cudaStream_t const *streams, uint32_t const *gpu_indexes,
8484
copy_radix_ciphertext_async<Torus>(streams[0], gpu_indexes[0], mask, ct);
8585

8686
host_integer_radix_arithmetic_scalar_shift_kb_inplace<Torus>(
87-
streams, gpu_indexes, gpu_count, (Torus *)(mask->ptr),
88-
num_bits_in_ciphertext - 1, mem_ptr->arithmetic_scalar_shift_mem, bsks,
89-
ksks, ct->num_radix_blocks);
87+
streams, gpu_indexes, gpu_count, mask, num_bits_in_ciphertext - 1,
88+
mem_ptr->arithmetic_scalar_shift_mem, bsks, ksks);
9089
host_addition<Torus>(streams[0], gpu_indexes[0], ct, mask, ct,
9190
ct->num_radix_blocks);
9291

backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cu

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,13 @@ void scratch_cuda_integer_radix_arithmetic_scalar_shift_kb_64(
6464
/// zeros as would be done in the logical shift.
6565
void cuda_integer_radix_arithmetic_scalar_shift_kb_64_inplace(
6666
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
67-
void *lwe_array, uint32_t shift, int8_t *mem_ptr, void *const *bsks,
68-
void *const *ksks, uint32_t num_blocks) {
67+
CudaRadixCiphertextFFI *lwe_array, uint32_t shift, int8_t *mem_ptr,
68+
void *const *bsks, void *const *ksks) {
6969

7070
host_integer_radix_arithmetic_scalar_shift_kb_inplace<uint64_t>(
71-
(cudaStream_t *)(streams), gpu_indexes, gpu_count,
72-
static_cast<uint64_t *>(lwe_array), shift,
71+
(cudaStream_t *)(streams), gpu_indexes, gpu_count, lwe_array, shift,
7372
(int_arithmetic_scalar_shift_buffer<uint64_t> *)mem_ptr, bsks,
74-
(uint64_t **)(ksks), num_blocks);
73+
(uint64_t **)(ksks));
7574
}
7675

7776
void cleanup_cuda_integer_radix_logical_scalar_shift(

backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cuh

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ __host__ void scratch_cuda_integer_radix_arithmetic_scalar_shift_kb(
224224
}
225225

226226
template <typename Torus>
227-
__host__ void host_integer_radix_arithmetic_scalar_shift_kb_inplace(
227+
__host__ void legacy_host_integer_radix_arithmetic_scalar_shift_kb_inplace(
228228
cudaStream_t const *streams, uint32_t const *gpu_indexes,
229229
uint32_t gpu_count, Torus *lwe_array, uint32_t shift,
230230
int_arithmetic_scalar_shift_buffer<Torus> *mem, void *const *bsks,
@@ -248,7 +248,7 @@ __host__ void host_integer_radix_arithmetic_scalar_shift_kb_inplace(
248248
size_t rotations = std::min(shift / num_bits_in_block, (size_t)num_blocks);
249249
size_t shift_within_block = shift % num_bits_in_block;
250250

251-
Torus *rotated_buffer = mem->tmp_rotated;
251+
Torus *rotated_buffer = (Torus *)mem->tmp_rotated->ptr;
252252
Torus *padding_block = &rotated_buffer[(num_blocks + 1) * big_lwe_size];
253253
Torus *last_block_copy = &padding_block[big_lwe_size];
254254

@@ -339,4 +339,119 @@ __host__ void host_integer_radix_arithmetic_scalar_shift_kb_inplace(
339339
}
340340
}
341341

342+
template <typename Torus>
343+
__host__ void host_integer_radix_arithmetic_scalar_shift_kb_inplace(
344+
cudaStream_t const *streams, uint32_t const *gpu_indexes,
345+
uint32_t gpu_count, CudaRadixCiphertextFFI *lwe_array, uint32_t shift,
346+
int_arithmetic_scalar_shift_buffer<Torus> *mem, void *const *bsks,
347+
Torus *const *ksks) {
348+
349+
auto num_blocks = lwe_array->num_radix_blocks;
350+
auto params = mem->params;
351+
auto message_modulus = params.message_modulus;
352+
353+
size_t num_bits_in_block = (size_t)log2_int(message_modulus);
354+
size_t total_num_bits = num_bits_in_block * num_blocks;
355+
shift = shift % total_num_bits;
356+
357+
if (shift == 0) {
358+
return;
359+
}
360+
size_t rotations = std::min(shift / num_bits_in_block, (size_t)num_blocks);
361+
size_t shift_within_block = shift % num_bits_in_block;
362+
363+
CudaRadixCiphertextFFI padding_block;
364+
as_radix_ciphertext_slice<Torus>(&padding_block, mem->tmp_rotated,
365+
num_blocks + 1, num_blocks + 2);
366+
CudaRadixCiphertextFFI last_block_copy;
367+
as_radix_ciphertext_slice<Torus>(&last_block_copy, mem->tmp_rotated,
368+
num_blocks + 2, num_blocks + 3);
369+
370+
if (mem->shift_type == RIGHT_SHIFT) {
371+
host_radix_blocks_rotate_left<Torus>(streams, gpu_indexes, gpu_count,
372+
mem->tmp_rotated, lwe_array, rotations,
373+
num_blocks);
374+
copy_radix_ciphertext_slice_async<Torus>(streams[0], gpu_indexes[0],
375+
lwe_array, 0, num_blocks,
376+
mem->tmp_rotated, 0, num_blocks);
377+
378+
if (num_bits_in_block == 1) {
379+
// if there is only 1 bit in the msg part, it means shift_within block is
380+
// 0 thus only rotations is required.
381+
382+
// We still need to pad with the value of the sign bit.
383+
// And here since a block only has 1 bit of message
384+
// we can optimize things by not doing the pbs to extract this sign bit
385+
for (uint i = 0; i < num_blocks; i++) {
386+
copy_radix_ciphertext_slice_async<Torus>(
387+
streams[0], gpu_indexes[0], mem->tmp_rotated,
388+
num_blocks - rotations + i, num_blocks - rotations + i + 1,
389+
mem->tmp_rotated, num_blocks - rotations - 1,
390+
num_blocks - rotations);
391+
}
392+
return;
393+
}
394+
395+
if (num_blocks != rotations) {
396+
// In the arithmetic shift case we have to pad with the value of the sign
397+
// bit. This creates the need for a different shifting lut than in the
398+
// logical shift case. We also need another PBS to create the padding
399+
// block.
400+
CudaRadixCiphertextFFI last_block;
401+
as_radix_ciphertext_slice<Torus>(&last_block, lwe_array,
402+
num_blocks - rotations - 1,
403+
num_blocks - rotations);
404+
copy_radix_ciphertext_slice_async<Torus>(
405+
streams[0], gpu_indexes[0], &last_block_copy, 0, 1, mem->tmp_rotated,
406+
num_blocks - rotations - 1, num_blocks - rotations);
407+
if (shift_within_block != 0) {
408+
auto partial_current_blocks = lwe_array;
409+
CudaRadixCiphertextFFI partial_next_blocks;
410+
as_radix_ciphertext_slice<Torus>(&partial_next_blocks, mem->tmp_rotated,
411+
1, mem->tmp_rotated->num_radix_blocks);
412+
size_t partial_block_count = num_blocks - rotations;
413+
auto lut_bivariate = mem->lut_buffers_bivariate[shift_within_block - 1];
414+
415+
integer_radix_apply_bivariate_lookup_table_kb<Torus>(
416+
streams, gpu_indexes, gpu_count, partial_current_blocks,
417+
partial_current_blocks, &partial_next_blocks, bsks, ksks,
418+
lut_bivariate, partial_block_count,
419+
lut_bivariate->params.message_modulus);
420+
}
421+
// Since our CPU threads will be working on different streams we shall
422+
// assert the work in the main stream is completed
423+
for (uint j = 0; j < gpu_count; j++) {
424+
cuda_synchronize_stream(streams[j], gpu_indexes[j]);
425+
}
426+
auto lut_univariate_padding_block =
427+
mem->lut_buffers_univariate[num_bits_in_block - 1];
428+
integer_radix_apply_univariate_lookup_table_kb<Torus>(
429+
mem->local_streams_1, gpu_indexes, gpu_count, &padding_block,
430+
&last_block_copy, bsks, ksks, lut_univariate_padding_block, 1);
431+
// Replace blocks 'pulled' from the left with the correct padding
432+
// block
433+
for (uint i = 0; i < rotations; i++) {
434+
copy_radix_ciphertext_slice_async<Torus>(
435+
mem->local_streams_1[0], gpu_indexes[0], lwe_array,
436+
num_blocks - rotations + i, num_blocks - rotations + i + 1,
437+
&padding_block, 0, 1);
438+
}
439+
if (shift_within_block != 0) {
440+
auto lut_univariate_shift_last_block =
441+
mem->lut_buffers_univariate[shift_within_block - 1];
442+
integer_radix_apply_univariate_lookup_table_kb<Torus>(
443+
mem->local_streams_2, gpu_indexes, gpu_count, &last_block,
444+
&last_block_copy, bsks, ksks, lut_univariate_shift_last_block, 1);
445+
}
446+
for (uint j = 0; j < mem->active_gpu_count; j++) {
447+
cuda_synchronize_stream(mem->local_streams_1[j], gpu_indexes[j]);
448+
cuda_synchronize_stream(mem->local_streams_2[j], gpu_indexes[j]);
449+
}
450+
}
451+
} else {
452+
PANIC("Cuda error (scalar shift): left scalar shift is never of the "
453+
"arithmetic type")
454+
}
455+
}
456+
342457
#endif // CUDA_SCALAR_SHIFT_CUH

backends/tfhe-cuda-backend/src/bindings.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,12 +467,11 @@ unsafe extern "C" {
467467
streams: *const *mut ffi::c_void,
468468
gpu_indexes: *const u32,
469469
gpu_count: u32,
470-
lwe_array: *mut ffi::c_void,
470+
lwe_array: *mut CudaRadixCiphertextFFI,
471471
shift: u32,
472472
mem_ptr: *mut i8,
473473
bsks: *const *mut ffi::c_void,
474474
ksks: *const *mut ffi::c_void,
475-
num_blocks: u32,
476475
);
477476
}
478477
unsafe extern "C" {

tfhe/src/integer/gpu/mod.rs

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1671,7 +1671,7 @@ pub unsafe fn unchecked_scalar_arithmetic_right_shift_integer_radix_kb_assign_as
16711671
B: Numeric,
16721672
>(
16731673
streams: &CudaStreams,
1674-
radix_lwe_left: &mut CudaVec<T>,
1674+
radix_lwe_left: &mut CudaRadixCiphertext,
16751675
shift: u32,
16761676
bootstrapping_key: &CudaVec<B>,
16771677
keyswitch_key: &CudaVec<T>,
@@ -1685,13 +1685,12 @@ pub unsafe fn unchecked_scalar_arithmetic_right_shift_integer_radix_kb_assign_as
16851685
ks_base_log: DecompositionBaseLog,
16861686
pbs_level: DecompositionLevelCount,
16871687
pbs_base_log: DecompositionBaseLog,
1688-
num_blocks: u32,
16891688
pbs_type: PBSType,
16901689
grouping_factor: LweBskGroupingFactor,
16911690
) {
16921691
assert_eq!(
16931692
streams.gpu_indexes[0],
1694-
radix_lwe_left.gpu_index(0),
1693+
radix_lwe_left.d_blocks.0.d_vec.gpu_index(0),
16951694
"GPU error: all data should reside on the same GPU."
16961695
);
16971696
assert_eq!(
@@ -1705,6 +1704,24 @@ pub unsafe fn unchecked_scalar_arithmetic_right_shift_integer_radix_kb_assign_as
17051704
"GPU error: all data should reside on the same GPU."
17061705
);
17071706
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
1707+
let mut radix_lwe_left_degrees = radix_lwe_left
1708+
.info
1709+
.blocks
1710+
.iter()
1711+
.map(|b| b.degree.0)
1712+
.collect();
1713+
let mut radix_lwe_left_noise_levels = radix_lwe_left
1714+
.info
1715+
.blocks
1716+
.iter()
1717+
.map(|b| b.noise_level.0)
1718+
.collect();
1719+
let mut cuda_ffi_radix_lwe_left = prepare_cuda_radix_ffi(
1720+
radix_lwe_left,
1721+
&mut radix_lwe_left_degrees,
1722+
&mut radix_lwe_left_noise_levels,
1723+
);
1724+
17081725
scratch_cuda_integer_radix_arithmetic_scalar_shift_kb_64(
17091726
streams.ptr.as_ptr(),
17101727
streams.gpu_indexes_ptr(),
@@ -1719,7 +1736,7 @@ pub unsafe fn unchecked_scalar_arithmetic_right_shift_integer_radix_kb_assign_as
17191736
pbs_level.0 as u32,
17201737
pbs_base_log.0 as u32,
17211738
grouping_factor.0 as u32,
1722-
num_blocks,
1739+
radix_lwe_left.d_blocks.lwe_ciphertext_count().0 as u32,
17231740
message_modulus.0 as u32,
17241741
carry_modulus.0 as u32,
17251742
pbs_type as u32,
@@ -1730,19 +1747,19 @@ pub unsafe fn unchecked_scalar_arithmetic_right_shift_integer_radix_kb_assign_as
17301747
streams.ptr.as_ptr(),
17311748
streams.gpu_indexes_ptr(),
17321749
streams.len() as u32,
1733-
radix_lwe_left.as_mut_c_ptr(0),
1750+
&mut cuda_ffi_radix_lwe_left,
17341751
shift,
17351752
mem_ptr,
17361753
bootstrapping_key.ptr.as_ptr(),
17371754
keyswitch_key.ptr.as_ptr(),
1738-
num_blocks,
17391755
);
17401756
cleanup_cuda_integer_radix_arithmetic_scalar_shift(
17411757
streams.ptr.as_ptr(),
17421758
streams.gpu_indexes_ptr(),
17431759
streams.len() as u32,
17441760
std::ptr::addr_of_mut!(mem_ptr),
17451761
);
1762+
update_noise_degree(radix_lwe_left, &cuda_ffi_radix_lwe_left);
17461763
}
17471764

17481765
#[allow(clippy::too_many_arguments)]

tfhe/src/integer/gpu/server_key/radix/scalar_shift.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ impl CudaServerKey {
194194
CudaBootstrappingKey::Classic(d_bsk) => {
195195
unchecked_scalar_arithmetic_right_shift_integer_radix_kb_assign_async(
196196
streams,
197-
&mut ct.as_mut().d_blocks.0.d_vec,
197+
ct.as_mut(),
198198
u32::cast_from(shift),
199199
&d_bsk.d_vec,
200200
&self.key_switching_key.d_vec,
@@ -212,15 +212,14 @@ impl CudaServerKey {
212212
self.key_switching_key.decomposition_base_log(),
213213
d_bsk.decomp_level_count,
214214
d_bsk.decomp_base_log,
215-
lwe_ciphertext_count.0 as u32,
216215
PBSType::Classical,
217216
LweBskGroupingFactor(0),
218217
);
219218
}
220219
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
221220
unchecked_scalar_arithmetic_right_shift_integer_radix_kb_assign_async(
222221
streams,
223-
&mut ct.as_mut().d_blocks.0.d_vec,
222+
ct.as_mut(),
224223
u32::cast_from(shift),
225224
&d_multibit_bsk.d_vec,
226225
&self.key_switching_key.d_vec,
@@ -238,7 +237,6 @@ impl CudaServerKey {
238237
self.key_switching_key.decomposition_base_log(),
239238
d_multibit_bsk.decomp_level_count,
240239
d_multibit_bsk.decomp_base_log,
241-
lwe_ciphertext_count.0 as u32,
242240
PBSType::MultiBit,
243241
d_multibit_bsk.grouping_factor,
244242
);

0 commit comments

Comments
 (0)