@@ -45,19 +45,19 @@ keyswitch(Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
45
45
const Torus *__restrict__ lwe_input_indexes,
46
46
const Torus *__restrict__ ksk, uint32_t lwe_dimension_in,
47
47
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count) {
48
- const int tid = threadIdx .x + blockIdx .x * blockDim .x ;
48
+ const int tid = threadIdx .x + blockIdx .y * blockDim .x ;
49
49
const int shmem_index = threadIdx .x + threadIdx .y * blockDim .x ;
50
50
51
51
extern __shared__ int8_t sharedmem[];
52
52
Torus *lwe_acc_out = (Torus *)sharedmem;
53
53
auto block_lwe_array_out = get_chunk (
54
- lwe_array_out, lwe_output_indexes[blockIdx .y ], lwe_dimension_out + 1 );
54
+ lwe_array_out, lwe_output_indexes[blockIdx .x ], lwe_dimension_out + 1 );
55
55
56
56
if (tid <= lwe_dimension_out) {
57
57
58
58
Torus local_lwe_out = 0 ;
59
59
auto block_lwe_array_in = get_chunk (
60
- lwe_array_in, lwe_input_indexes[blockIdx .y ], lwe_dimension_in + 1 );
60
+ lwe_array_in, lwe_input_indexes[blockIdx .x ], lwe_dimension_in + 1 );
61
61
62
62
if (tid == lwe_dimension_out && threadIdx .y == 0 ) {
63
63
local_lwe_out = block_lwe_array_in[lwe_dimension_in];
@@ -108,13 +108,19 @@ __host__ void host_keyswitch_lwe_ciphertext_vector(
108
108
cuda_set_device (gpu_index);
109
109
110
110
constexpr int num_threads_y = 32 ;
111
- int num_blocks , num_threads_x;
111
+ int num_blocks_per_sample , num_threads_x;
112
112
113
113
getNumBlocksAndThreads2D (lwe_dimension_out + 1 , 512 , num_threads_y,
114
- num_blocks , num_threads_x);
114
+ num_blocks_per_sample , num_threads_x);
115
115
116
116
int shared_mem = sizeof (Torus) * num_threads_y * num_threads_x;
117
- dim3 grid (num_blocks, num_samples, 1 );
117
+ if (num_blocks_per_sample > 65536 )
118
+ PANIC (" Cuda error (Keyswith): number of blocks per sample is too large" );
119
+
120
+ // In multiplication of large integers (512, 1024, 2048), the number of
121
+ // samples can be larger than 65536, so we need to set it in the first
122
+ // dimension of the grid
123
+ dim3 grid (num_samples, num_blocks_per_sample, 1 );
118
124
dim3 threads (num_threads_x, num_threads_y, 1 );
119
125
120
126
keyswitch<Torus><<<grid, threads, shared_mem, stream>>> (
0 commit comments