Skip to content

Commit

Permalink
fix(gpu): enable large number of samples in pbs tbc
Browse files Browse the repository at this point in the history
  • Loading branch information
guillermo-oyarzun committed Feb 17, 2025
1 parent 1b7f7a5 commit 380a1d1
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 20 deletions.
49 changes: 49 additions & 0 deletions backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,55 @@ mul_ggsw_glwe_in_fourier_domain(double2 *fft, double2 *join_buffer,

synchronize_threads_in_block();
}
// This is a temporary function to have a working version of the tbc-PBS
// with large integers.
// TO DO: we should erase this function when the other flavors of the PBS also
// work with the number of samples in the first block dimension.
template <typename G, class params>
__device__ void mul_ggsw_glwe_in_fourier_domain_tbc(
double2 *fft, double2 *join_buffer,
const double2 *__restrict__ bootstrapping_key, int iteration, G &group,
bool support_dsm = false) {
const uint32_t polynomial_size = params::degree;
const uint32_t glwe_dimension = gridDim.y - 1;
const uint32_t level_count = gridDim.z;

// The first product is used to initialize level_join_buffer
auto this_block_rank = get_this_block_rank<G>(group, support_dsm);

// Continues multiplying fft by every polynomial in that particular bsk level
// Each y-block accumulates in a different polynomial at each iteration
auto bsk_slice = get_ith_mask_kth_block(
bootstrapping_key, iteration, blockIdx.y, blockIdx.z, polynomial_size,
glwe_dimension, level_count);
for (int j = 0; j < glwe_dimension + 1; j++) {
int idx = (j + this_block_rank) % (glwe_dimension + 1);

auto bsk_poly = bsk_slice + idx * polynomial_size / 2;
auto buffer_slice = get_join_buffer_element<G>(blockIdx.z, idx, group,
join_buffer, polynomial_size,
glwe_dimension, support_dsm);

polynomial_product_accumulate_in_fourier_domain<params, double2>(
buffer_slice, fft, bsk_poly, j == 0);
group.sync();
}

// -----------------------------------------------------------------
// All blocks are synchronized here; after this sync, level_join_buffer has
// the values needed from every other block

// accumulate rest of the products into fft buffer
for (int l = 0; l < level_count; l++) {
auto cur_src_acc = get_join_buffer_element<G>(l, blockIdx.y, group,
join_buffer, polynomial_size,
glwe_dimension, support_dsm);

polynomial_accumulate_in_fourier_domain<params>(fft, cur_src_acc, l == 0);
}

synchronize_threads_in_block();
}

template <typename Torus>
void execute_pbs_async(cudaStream_t const *streams, uint32_t const *gpu_indexes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ __global__ void __launch_bounds__(params::degree / params::opt)
if (support_dsm)
selected_memory += sizeof(Torus) * polynomial_size;
} else {
int block_index = blockIdx.x + blockIdx.y * gridDim.x +
blockIdx.z * gridDim.x * gridDim.y;
int block_index = blockIdx.z + blockIdx.y * gridDim.z +
blockIdx.x * gridDim.z * gridDim.y;
selected_memory = &device_mem[block_index * device_memory_size_per_block];
}

Expand All @@ -65,25 +65,25 @@ __global__ void __launch_bounds__(params::degree / params::opt)
accumulator_fft += sizeof(double2) * (polynomial_size / 2);
}

// The third dimension of the block is used to determine on which ciphertext
// The first dimension of the block is used to determine on which ciphertext
// this block is operating, in the case of batch bootstraps
const Torus *block_lwe_array_in =
&lwe_array_in[lwe_input_indexes[blockIdx.z] * (lwe_dimension + 1)];
&lwe_array_in[lwe_input_indexes[blockIdx.x] * (lwe_dimension + 1)];

const Torus *block_lut_vector =
&lut_vector[lut_vector_indexes[blockIdx.z] * params::degree *
&lut_vector[lut_vector_indexes[blockIdx.x] * params::degree *
(glwe_dimension + 1)];

double2 *block_join_buffer =
&join_buffer[blockIdx.z * level_count * (glwe_dimension + 1) *
&join_buffer[blockIdx.x * level_count * (glwe_dimension + 1) *
params::degree / 2];

Torus *global_accumulator_slice =
&global_accumulator[(blockIdx.y + blockIdx.z * (glwe_dimension + 1)) *
&global_accumulator[(blockIdx.y + blockIdx.x * (glwe_dimension + 1)) *
params::degree];

const double2 *keybundle =
&keybundle_array[blockIdx.z * keybundle_size_per_input];
&keybundle_array[blockIdx.x * keybundle_size_per_input];

if (lwe_offset == 0) {
// Put "b" in [0, 2N[
Expand Down Expand Up @@ -113,12 +113,12 @@ __global__ void __launch_bounds__(params::degree / params::opt)
// accumulator decomposed at level 0, 1 at 1, etc.)
GadgetMatrix<Torus, params> gadget_acc(base_log, level_count,
accumulator_rotated);
gadget_acc.decompose_and_compress_level(accumulator_fft, blockIdx.x);
gadget_acc.decompose_and_compress_level(accumulator_fft, blockIdx.z);
NSMFFT_direct<HalfDegree<params>>(accumulator_fft);
synchronize_threads_in_block();

// Perform G^-1(ACC) * GGSW -> GLWE
mul_ggsw_glwe_in_fourier_domain<cluster_group, params>(
mul_ggsw_glwe_in_fourier_domain_tbc<cluster_group, params>(
accumulator_fft, block_join_buffer, keybundle, i, cluster, support_dsm);
NSMFFT_inverse<HalfDegree<params>>(accumulator_fft);
synchronize_threads_in_block();
Expand All @@ -128,10 +128,10 @@ __global__ void __launch_bounds__(params::degree / params::opt)

auto accumulator = accumulator_rotated;

if (blockIdx.x == 0) {
if (blockIdx.z == 0) {
if (lwe_offset + lwe_chunk_size >= (lwe_dimension / grouping_factor)) {
auto block_lwe_array_out =
&lwe_array_out[lwe_output_indexes[blockIdx.z] *
&lwe_array_out[lwe_output_indexes[blockIdx.x] *
(glwe_dimension * polynomial_size + 1) +
blockIdx.y * polynomial_size];

Expand All @@ -145,9 +145,9 @@ __global__ void __launch_bounds__(params::degree / params::opt)
for (int i = 1; i < num_many_lut; i++) {
auto next_lwe_array_out =
lwe_array_out +
(i * gridDim.z * (glwe_dimension * polynomial_size + 1));
(i * gridDim.x * (glwe_dimension * polynomial_size + 1));
auto next_block_lwe_array_out =
&next_lwe_array_out[lwe_output_indexes[blockIdx.z] *
&next_lwe_array_out[lwe_output_indexes[blockIdx.x] *
(glwe_dimension * polynomial_size + 1) +
blockIdx.y * polynomial_size];

Expand All @@ -162,9 +162,9 @@ __global__ void __launch_bounds__(params::degree / params::opt)

auto next_lwe_array_out =
lwe_array_out +
(i * gridDim.z * (glwe_dimension * polynomial_size + 1));
(i * gridDim.x * (glwe_dimension * polynomial_size + 1));
auto next_block_lwe_array_out =
&next_lwe_array_out[lwe_output_indexes[blockIdx.z] *
&next_lwe_array_out[lwe_output_indexes[blockIdx.x] *
(glwe_dimension * polynomial_size + 1) +
blockIdx.y * polynomial_size];

Expand Down Expand Up @@ -334,7 +334,7 @@ __host__ void execute_tbc_external_product_loop(
auto global_accumulator = buffer->global_accumulator;
auto buffer_fft = buffer->global_join_buffer;

dim3 grid_accumulate(level_count, glwe_dimension + 1, num_samples);
dim3 grid_accumulate(num_samples, glwe_dimension + 1, level_count);
dim3 thds(polynomial_size / params::opt, 1, 1);

cudaLaunchConfig_t config = {0};
Expand All @@ -346,9 +346,9 @@ __host__ void execute_tbc_external_product_loop(

cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeClusterDimension;
attribute[0].val.clusterDim.x = level_count; // Cluster size in X-dimension
attribute[0].val.clusterDim.x = 1;
attribute[0].val.clusterDim.y = (glwe_dimension + 1);
attribute[0].val.clusterDim.z = 1;
attribute[0].val.clusterDim.z = level_count; // Cluster size in Z-dimension
config.attrs = attribute;
config.numAttrs = 1;
config.stream = stream;
Expand Down Expand Up @@ -463,7 +463,7 @@ __host__ bool supports_thread_block_clusters_on_multibit_programmable_bootstrap(

int cluster_size;

dim3 grid_accumulate(level_count, glwe_dimension + 1, num_samples);
dim3 grid_accumulate(num_samples, glwe_dimension + 1, level_count);
dim3 thds(polynomial_size / params::opt, 1, 1);

cudaLaunchConfig_t config = {0};
Expand Down

0 comments on commit 380a1d1

Please sign in to comment.