@@ -45,19 +45,19 @@ keyswitch(Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
4545 const Torus *__restrict__ lwe_input_indexes,
4646 const Torus *__restrict__ ksk, uint32_t lwe_dimension_in,
4747 uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count) {
48- const int tid = threadIdx .x + blockIdx .x * blockDim .x ;
48+ const int tid = threadIdx .x + blockIdx .y * blockDim .x ;
4949 const int shmem_index = threadIdx .x + threadIdx .y * blockDim .x ;
5050
5151 extern __shared__ int8_t sharedmem[];
5252 Torus *lwe_acc_out = (Torus *)sharedmem;
5353 auto block_lwe_array_out = get_chunk (
54- lwe_array_out, lwe_output_indexes[blockIdx .y ], lwe_dimension_out + 1 );
54+ lwe_array_out, lwe_output_indexes[blockIdx .x ], lwe_dimension_out + 1 );
5555
5656 if (tid <= lwe_dimension_out) {
5757
5858 Torus local_lwe_out = 0 ;
5959 auto block_lwe_array_in = get_chunk (
60- lwe_array_in, lwe_input_indexes[blockIdx .y ], lwe_dimension_in + 1 );
60+ lwe_array_in, lwe_input_indexes[blockIdx .x ], lwe_dimension_in + 1 );
6161
6262 if (tid == lwe_dimension_out && threadIdx .y == 0 ) {
6363 local_lwe_out = block_lwe_array_in[lwe_dimension_in];
@@ -108,13 +108,19 @@ __host__ void host_keyswitch_lwe_ciphertext_vector(
108108 cuda_set_device (gpu_index);
109109
110110 constexpr int num_threads_y = 32 ;
111- int num_blocks , num_threads_x;
111+ int num_blocks_per_sample , num_threads_x;
112112
113113 getNumBlocksAndThreads2D (lwe_dimension_out + 1 , 512 , num_threads_y,
114- num_blocks , num_threads_x);
114+ num_blocks_per_sample , num_threads_x);
115115
116116 int shared_mem = sizeof (Torus) * num_threads_y * num_threads_x;
117- dim3 grid (num_blocks, num_samples, 1 );
117+ if (num_blocks_per_sample > 65536 )
118+ PANIC (" Cuda error (Keyswith): number of blocks per sample is too large" );
119+
120+ // In multiplication of large integers (512, 1024, 2048), the number of
121+ // samples can be larger than 65536, so we need to set it in the first
122+ // dimension of the grid
123+ dim3 grid (num_samples, num_blocks_per_sample, 1 );
118124 dim3 threads (num_threads_x, num_threads_y, 1 );
119125
120126 keyswitch<Torus><<<grid, threads, shared_mem, stream>>> (
0 commit comments