Skip to content

Commit 6e732ec

Browse files
feat(gpu): enable if then else for boolean ciphertexts in hlapi
1 parent 0809eb9 commit 6e732ec

File tree

1 file changed

+14
-6
lines changed
  • tfhe/src/high_level_api/booleans

1 file changed

+14
-6
lines changed

tfhe/src/high_level_api/booleans/base.rs

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -255,20 +255,28 @@ impl<Id: FheIntId> IfThenElse<FheInt<Id>> for FheBool {
255255
impl IfThenElse<Self> for FheBool {
256256
fn if_then_else(&self, ct_then: &Self, ct_else: &Self) -> Self {
257257
let ct_condition = self;
258-
global_state::with_internal_keys(|key| match key {
258+
let (ciphertext, tag) = global_state::with_internal_keys(|key| match key {
259259
InternalServerKey::Cpu(key) => {
260260
let new_ct = key.pbs_key().if_then_else_parallelized(
261261
&ct_condition.ciphertext.on_cpu(),
262262
&*ct_then.ciphertext.on_cpu(),
263263
&*ct_else.ciphertext.on_cpu(),
264264
);
265-
Self::new(new_ct, key.tag.clone())
265+
(InnerBoolean::Cpu(new_ct), key.tag.clone())
266266
}
267267
#[cfg(feature = "gpu")]
268-
InternalServerKey::Cuda(_) => {
269-
panic!("Cuda devices do not support signed integers")
270-
}
271-
})
268+
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
269+
let inner = cuda_key.key.key.if_then_else(
270+
&CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)),
271+
&*ct_then.ciphertext.on_gpu(streams),
272+
&*ct_else.ciphertext.on_gpu(streams),
273+
streams,
274+
);
275+
let boolean_inner = CudaBooleanBlock(inner);
276+
(InnerBoolean::Cuda(boolean_inner), cuda_key.tag.clone())
277+
}),
278+
});
279+
Self::new(ciphertext, tag)
272280
}
273281
}
274282

0 commit comments

Comments
 (0)