Skip to content

Commit

Permalink
fix(gpu): enable larger number of samples in the keyswitch
Browse files Browse the repository at this point in the history
  • Loading branch information
guillermo-oyarzun committed Feb 14, 2025
1 parent 0d3d23d commit 6d0f664
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,19 @@ keyswitch(Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
const Torus *__restrict__ lwe_input_indexes,
const Torus *__restrict__ ksk, uint32_t lwe_dimension_in,
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count) {
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int tid = threadIdx.x + blockIdx.y * blockDim.x;
const int shmem_index = threadIdx.x + threadIdx.y * blockDim.x;

extern __shared__ int8_t sharedmem[];
Torus *lwe_acc_out = (Torus *)sharedmem;
auto block_lwe_array_out = get_chunk(
lwe_array_out, lwe_output_indexes[blockIdx.y], lwe_dimension_out + 1);
lwe_array_out, lwe_output_indexes[blockIdx.x], lwe_dimension_out + 1);

if (tid <= lwe_dimension_out) {

Torus local_lwe_out = 0;
auto block_lwe_array_in = get_chunk(
lwe_array_in, lwe_input_indexes[blockIdx.y], lwe_dimension_in + 1);
lwe_array_in, lwe_input_indexes[blockIdx.x], lwe_dimension_in + 1);

if (tid == lwe_dimension_out && threadIdx.y == 0) {
local_lwe_out = block_lwe_array_in[lwe_dimension_in];
Expand Down Expand Up @@ -108,13 +108,19 @@ __host__ void host_keyswitch_lwe_ciphertext_vector(
cuda_set_device(gpu_index);

constexpr int num_threads_y = 32;
int num_blocks, num_threads_x;
int num_blocks_per_sample, num_threads_x;

getNumBlocksAndThreads2D(lwe_dimension_out + 1, 512, num_threads_y,
num_blocks, num_threads_x);
num_blocks_per_sample, num_threads_x);

int shared_mem = sizeof(Torus) * num_threads_y * num_threads_x;
dim3 grid(num_blocks, num_samples, 1);
if (num_blocks_per_sample > 65536)
PANIC("Cuda error (Keyswith): number of blocks per sample is too large");

// In multiplication of large integers (512, 1024, 2048), the number of
// samples can be larger than 65536, so we need to set it in the first
// dimension of the grid
dim3 grid(num_samples, num_blocks_per_sample, 1);
dim3 threads(num_threads_x, num_threads_y, 1);

keyswitch<Torus><<<grid, threads, shared_mem, stream>>>(
Expand Down

0 comments on commit 6d0f664

Please sign in to comment.