@@ -255,20 +255,28 @@ impl<Id: FheIntId> IfThenElse<FheInt<Id>> for FheBool {
255
255
impl IfThenElse < Self > for FheBool {
256
256
fn if_then_else ( & self , ct_then : & Self , ct_else : & Self ) -> Self {
257
257
let ct_condition = self ;
258
- global_state:: with_internal_keys ( |key| match key {
258
+ let ( ciphertext , tag ) = global_state:: with_internal_keys ( |key| match key {
259
259
InternalServerKey :: Cpu ( key) => {
260
260
let new_ct = key. pbs_key ( ) . if_then_else_parallelized (
261
261
& ct_condition. ciphertext . on_cpu ( ) ,
262
262
& * ct_then. ciphertext . on_cpu ( ) ,
263
263
& * ct_else. ciphertext . on_cpu ( ) ,
264
264
) ;
265
- Self :: new ( new_ct, key. tag . clone ( ) )
265
+ ( InnerBoolean :: Cpu ( new_ct) , key. tag . clone ( ) )
266
266
}
267
267
#[ cfg( feature = "gpu" ) ]
268
- InternalServerKey :: Cuda ( _) => {
269
- panic ! ( "Cuda devices do not support signed integers" )
270
- }
271
- } )
268
+ InternalServerKey :: Cuda ( cuda_key) => with_thread_local_cuda_streams ( |streams| {
269
+ let inner = cuda_key. key . key . if_then_else (
270
+ & CudaBooleanBlock ( self . ciphertext . on_gpu ( streams) . duplicate ( streams) ) ,
271
+ & * ct_then. ciphertext . on_gpu ( streams) ,
272
+ & * ct_else. ciphertext . on_gpu ( streams) ,
273
+ streams,
274
+ ) ;
275
+ let boolean_inner = CudaBooleanBlock ( inner) ;
276
+ ( InnerBoolean :: Cuda ( boolean_inner) , cuda_key. tag . clone ( ) )
277
+ } ) ,
278
+ } ) ;
279
+ Self :: new ( ciphertext, tag)
272
280
}
273
281
}
274
282
0 commit comments