Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(gpu): enable large number of samples in pbs tbc #2068

Merged
merged 1 commit into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,55 @@ mul_ggsw_glwe_in_fourier_domain(double2 *fft, double2 *join_buffer,

synchronize_threads_in_block();
}
// This is a temporary function to have a working version of the tbc-PBS
// with large integers.
// TO DO: we should erase this function when the other flavors of the PBS also
// work with the number of samples in the first block dimension.
template <typename G, class params>
__device__ void mul_ggsw_glwe_in_fourier_domain_tbc(
double2 *fft, double2 *join_buffer,
const double2 *__restrict__ bootstrapping_key, int iteration, G &group,
bool support_dsm = false) {
const uint32_t polynomial_size = params::degree;
const uint32_t glwe_dimension = gridDim.y - 1;
const uint32_t level_count = gridDim.z;

// The first product is used to initialize level_join_buffer
auto this_block_rank = get_this_block_rank<G>(group, support_dsm);

// Continues multiplying fft by every polynomial in that particular bsk level
// Each y-block accumulates in a different polynomial at each iteration
auto bsk_slice = get_ith_mask_kth_block(
bootstrapping_key, iteration, blockIdx.y, blockIdx.z, polynomial_size,
glwe_dimension, level_count);
for (int j = 0; j < glwe_dimension + 1; j++) {
int idx = (j + this_block_rank) % (glwe_dimension + 1);

auto bsk_poly = bsk_slice + idx * polynomial_size / 2;
auto buffer_slice = get_join_buffer_element<G>(blockIdx.z, idx, group,
join_buffer, polynomial_size,
glwe_dimension, support_dsm);

polynomial_product_accumulate_in_fourier_domain<params, double2>(
buffer_slice, fft, bsk_poly, j == 0);
group.sync();
}

// -----------------------------------------------------------------
// All blocks are synchronized here; after this sync, level_join_buffer has
// the values needed from every other block

// accumulate rest of the products into fft buffer
for (int l = 0; l < level_count; l++) {
auto cur_src_acc = get_join_buffer_element<G>(l, blockIdx.y, group,
join_buffer, polynomial_size,
glwe_dimension, support_dsm);

polynomial_accumulate_in_fourier_domain<params>(fft, cur_src_acc, l == 0);
}

synchronize_threads_in_block();
}

template <typename Torus>
void execute_pbs_async(cudaStream_t const *streams, uint32_t const *gpu_indexes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ __global__ void __launch_bounds__(params::degree / params::opt)
if (support_dsm)
selected_memory += sizeof(Torus) * polynomial_size;
} else {
int block_index = blockIdx.x + blockIdx.y * gridDim.x +
blockIdx.z * gridDim.x * gridDim.y;
int block_index = blockIdx.z + blockIdx.y * gridDim.z +
blockIdx.x * gridDim.z * gridDim.y;
selected_memory = &device_mem[block_index * device_memory_size_per_block];
}

Expand All @@ -65,25 +65,25 @@ __global__ void __launch_bounds__(params::degree / params::opt)
accumulator_fft += sizeof(double2) * (polynomial_size / 2);
}

// The third dimension of the block is used to determine on which ciphertext
// The first dimension of the block is used to determine on which ciphertext
// this block is operating, in the case of batch bootstraps
const Torus *block_lwe_array_in =
&lwe_array_in[lwe_input_indexes[blockIdx.z] * (lwe_dimension + 1)];
&lwe_array_in[lwe_input_indexes[blockIdx.x] * (lwe_dimension + 1)];

const Torus *block_lut_vector =
&lut_vector[lut_vector_indexes[blockIdx.z] * params::degree *
&lut_vector[lut_vector_indexes[blockIdx.x] * params::degree *
(glwe_dimension + 1)];

double2 *block_join_buffer =
&join_buffer[blockIdx.z * level_count * (glwe_dimension + 1) *
&join_buffer[blockIdx.x * level_count * (glwe_dimension + 1) *
params::degree / 2];

Torus *global_accumulator_slice =
&global_accumulator[(blockIdx.y + blockIdx.z * (glwe_dimension + 1)) *
&global_accumulator[(blockIdx.y + blockIdx.x * (glwe_dimension + 1)) *
params::degree];

const double2 *keybundle =
&keybundle_array[blockIdx.z * keybundle_size_per_input];
&keybundle_array[blockIdx.x * keybundle_size_per_input];

if (lwe_offset == 0) {
// Put "b" in [0, 2N[
Expand Down Expand Up @@ -113,12 +113,12 @@ __global__ void __launch_bounds__(params::degree / params::opt)
// accumulator decomposed at level 0, 1 at 1, etc.)
GadgetMatrix<Torus, params> gadget_acc(base_log, level_count,
accumulator_rotated);
gadget_acc.decompose_and_compress_level(accumulator_fft, blockIdx.x);
gadget_acc.decompose_and_compress_level(accumulator_fft, blockIdx.z);
NSMFFT_direct<HalfDegree<params>>(accumulator_fft);
synchronize_threads_in_block();

// Perform G^-1(ACC) * GGSW -> GLWE
mul_ggsw_glwe_in_fourier_domain<cluster_group, params>(
mul_ggsw_glwe_in_fourier_domain_tbc<cluster_group, params>(
accumulator_fft, block_join_buffer, keybundle, i, cluster, support_dsm);
NSMFFT_inverse<HalfDegree<params>>(accumulator_fft);
synchronize_threads_in_block();
Expand All @@ -128,10 +128,10 @@ __global__ void __launch_bounds__(params::degree / params::opt)

auto accumulator = accumulator_rotated;

if (blockIdx.x == 0) {
if (blockIdx.z == 0) {
if (lwe_offset + lwe_chunk_size >= (lwe_dimension / grouping_factor)) {
auto block_lwe_array_out =
&lwe_array_out[lwe_output_indexes[blockIdx.z] *
&lwe_array_out[lwe_output_indexes[blockIdx.x] *
(glwe_dimension * polynomial_size + 1) +
blockIdx.y * polynomial_size];

Expand All @@ -145,9 +145,9 @@ __global__ void __launch_bounds__(params::degree / params::opt)
for (int i = 1; i < num_many_lut; i++) {
auto next_lwe_array_out =
lwe_array_out +
(i * gridDim.z * (glwe_dimension * polynomial_size + 1));
(i * gridDim.x * (glwe_dimension * polynomial_size + 1));
auto next_block_lwe_array_out =
&next_lwe_array_out[lwe_output_indexes[blockIdx.z] *
&next_lwe_array_out[lwe_output_indexes[blockIdx.x] *
(glwe_dimension * polynomial_size + 1) +
blockIdx.y * polynomial_size];

Expand All @@ -162,9 +162,9 @@ __global__ void __launch_bounds__(params::degree / params::opt)

auto next_lwe_array_out =
lwe_array_out +
(i * gridDim.z * (glwe_dimension * polynomial_size + 1));
(i * gridDim.x * (glwe_dimension * polynomial_size + 1));
auto next_block_lwe_array_out =
&next_lwe_array_out[lwe_output_indexes[blockIdx.z] *
&next_lwe_array_out[lwe_output_indexes[blockIdx.x] *
(glwe_dimension * polynomial_size + 1) +
blockIdx.y * polynomial_size];

Expand Down Expand Up @@ -334,7 +334,7 @@ __host__ void execute_tbc_external_product_loop(
auto global_accumulator = buffer->global_accumulator;
auto buffer_fft = buffer->global_join_buffer;

dim3 grid_accumulate(level_count, glwe_dimension + 1, num_samples);
dim3 grid_accumulate(num_samples, glwe_dimension + 1, level_count);
dim3 thds(polynomial_size / params::opt, 1, 1);

cudaLaunchConfig_t config = {0};
Expand All @@ -346,9 +346,9 @@ __host__ void execute_tbc_external_product_loop(

cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeClusterDimension;
attribute[0].val.clusterDim.x = level_count; // Cluster size in X-dimension
attribute[0].val.clusterDim.x = 1;
attribute[0].val.clusterDim.y = (glwe_dimension + 1);
attribute[0].val.clusterDim.z = 1;
attribute[0].val.clusterDim.z = level_count; // Cluster size in Z-dimension
config.attrs = attribute;
config.numAttrs = 1;
config.stream = stream;
Expand Down Expand Up @@ -463,7 +463,7 @@ __host__ bool supports_thread_block_clusters_on_multibit_programmable_bootstrap(

int cluster_size;

dim3 grid_accumulate(level_count, glwe_dimension + 1, num_samples);
dim3 grid_accumulate(num_samples, glwe_dimension + 1, level_count);
dim3 thds(polynomial_size / params::opt, 1, 1);

cudaLaunchConfig_t config = {0};
Expand Down
Loading