Skip to content

Commit 6d0f664

Browse files
fix(gpu): enable larger number of samples in the keyswitch
1 parent 0d3d23d commit 6d0f664

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,19 @@ keyswitch(Torus *lwe_array_out, const Torus *__restrict__ lwe_output_indexes,
4545
const Torus *__restrict__ lwe_input_indexes,
4646
const Torus *__restrict__ ksk, uint32_t lwe_dimension_in,
4747
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count) {
48-
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
48+
const int tid = threadIdx.x + blockIdx.y * blockDim.x;
4949
const int shmem_index = threadIdx.x + threadIdx.y * blockDim.x;
5050

5151
extern __shared__ int8_t sharedmem[];
5252
Torus *lwe_acc_out = (Torus *)sharedmem;
5353
auto block_lwe_array_out = get_chunk(
54-
lwe_array_out, lwe_output_indexes[blockIdx.y], lwe_dimension_out + 1);
54+
lwe_array_out, lwe_output_indexes[blockIdx.x], lwe_dimension_out + 1);
5555

5656
if (tid <= lwe_dimension_out) {
5757

5858
Torus local_lwe_out = 0;
5959
auto block_lwe_array_in = get_chunk(
60-
lwe_array_in, lwe_input_indexes[blockIdx.y], lwe_dimension_in + 1);
60+
lwe_array_in, lwe_input_indexes[blockIdx.x], lwe_dimension_in + 1);
6161

6262
if (tid == lwe_dimension_out && threadIdx.y == 0) {
6363
local_lwe_out = block_lwe_array_in[lwe_dimension_in];
@@ -108,13 +108,19 @@ __host__ void host_keyswitch_lwe_ciphertext_vector(
108108
cuda_set_device(gpu_index);
109109

110110
constexpr int num_threads_y = 32;
111-
int num_blocks, num_threads_x;
111+
int num_blocks_per_sample, num_threads_x;
112112

113113
getNumBlocksAndThreads2D(lwe_dimension_out + 1, 512, num_threads_y,
114-
num_blocks, num_threads_x);
114+
num_blocks_per_sample, num_threads_x);
115115

116116
int shared_mem = sizeof(Torus) * num_threads_y * num_threads_x;
117-
dim3 grid(num_blocks, num_samples, 1);
117+
if (num_blocks_per_sample > 65536)
118+
PANIC("Cuda error (Keyswith): number of blocks per sample is too large");
119+
120+
// In multiplication of large integers (512, 1024, 2048), the number of
121+
// samples can be larger than 65536, so we need to set it in the first
122+
// dimension of the grid
123+
dim3 grid(num_samples, num_blocks_per_sample, 1);
118124
dim3 threads(num_threads_x, num_threads_y, 1);
119125

120126
keyswitch<Torus><<<grid, threads, shared_mem, stream>>>(

0 commit comments

Comments
 (0)