Skip to content

Commit 24088fd

Browse files
committed
chore(gpu): add scalar div and signed scalar div to hl api
Also add overflowing sub to hl
1 parent 48315dc commit 24088fd

File tree

4 files changed

+66
-25
lines changed

4 files changed

+66
-25
lines changed

tfhe/src/core_crypto/gpu/entities/lwe_ciphertext_list.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,12 @@ impl<T: UnsignedInteger> CudaLweCiphertextList<T> {
6161

6262
// Copy to the GPU
6363
let h_input = h_ct.as_view().into_container();
64-
let mut d_vec = CudaVec::new(
64+
let mut d_vec = CudaVec::new_async(
6565
lwe_dimension.to_lwe_size().0 * lwe_ciphertext_count.0,
6666
streams,
6767
0,
6868
);
69-
unsafe {
70-
d_vec.copy_from_cpu_async(h_input.as_ref(), streams, 0);
71-
}
69+
d_vec.copy_from_cpu_async(h_input.as_ref(), streams, 0);
7270
let cuda_lwe_list = CudaLweList {
7371
d_vec,
7472
lwe_ciphertext_count,

tfhe/src/high_level_api/integers/signed/scalar_ops.rs

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,6 @@ where
365365
// DivRem is a bit special as it returns a tuple of quotient and remainder
366366
macro_rules! generic_integer_impl_scalar_div_rem {
367367
(
368-
key_method: $key_method:ident,
369368
// A 'list' of tuple, where the first element is the concrete Fhe type
370369
// e.g (FheUint8 and the rest is scalar types (u8, u16, etc)
371370
fhe_and_scalar_type: $(
@@ -393,15 +392,24 @@ macro_rules! generic_integer_impl_scalar_div_rem {
393392
InternalServerKey::Cpu(cpu_key) => {
394393
let (q, r) = cpu_key
395394
.pbs_key()
396-
.$key_method(&*self.ciphertext.on_cpu(), rhs);
395+
.signed_scalar_div_rem_parallelized(&*self.ciphertext.on_cpu(), rhs);
397396
(
398397
<$concrete_type>::new(q, cpu_key.tag.clone()),
399398
<$concrete_type>::new(r, cpu_key.tag.clone())
400399
)
401400
}
402401
#[cfg(feature = "gpu")]
403-
InternalServerKey::Cuda(_) => {
404-
panic!("Cuda devices does not support div rem yet")
402+
InternalServerKey::Cuda(cuda_key) => {
403+
let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| {
404+
cuda_key.key.signed_scalar_div_rem(
405+
&*self.ciphertext.on_gpu(), rhs, streams
406+
)
407+
});
408+
let (q, r) = (RadixCiphertext::Cuda(inner_q), RadixCiphertext::Cuda(inner_r));
409+
(
410+
<$concrete_type>::new(q, cuda_key.tag.clone()),
411+
<$concrete_type>::new(r, cuda_key.tag.clone())
412+
)
405413
}
406414
})
407415
}
@@ -410,8 +418,8 @@ macro_rules! generic_integer_impl_scalar_div_rem {
410418
)* // Closing first repeating pattern
411419
};
412420
}
421+
413422
generic_integer_impl_scalar_div_rem!(
414-
key_method: signed_scalar_div_rem_parallelized,
415423
fhe_and_scalar_type:
416424
(super::FheInt2, i8),
417425
(super::FheInt4, i8),
@@ -826,8 +834,13 @@ generic_integer_impl_scalar_operation!(
826834
RadixCiphertext::Cpu(inner_result)
827835
},
828836
#[cfg(feature = "gpu")]
829-
InternalServerKey::Cuda(_) => {
830-
panic!("Div '/' with clear value is not yet supported by Cuda devices")
837+
InternalServerKey::Cuda(cuda_key) => {
838+
let inner_result = with_thread_local_cuda_streams(|streams| {
839+
cuda_key.key.signed_scalar_div(
840+
&lhs.ciphertext.on_gpu(), rhs, streams
841+
)
842+
});
843+
RadixCiphertext::Cuda(inner_result)
831844
}
832845
})
833846
}
@@ -859,8 +872,13 @@ generic_integer_impl_scalar_operation!(
859872
RadixCiphertext::Cpu(inner_result)
860873
},
861874
#[cfg(feature = "gpu")]
862-
InternalServerKey::Cuda(_) => {
863-
panic!("Rem '%' with clear value is not yet supported by Cuda devices")
875+
InternalServerKey::Cuda(cuda_key) => {
876+
let inner_result = with_thread_local_cuda_streams(|streams| {
877+
cuda_key.key.signed_scalar_rem(
878+
&lhs.ciphertext.on_gpu(), rhs, streams
879+
)
880+
});
881+
RadixCiphertext::Cuda(inner_result)
864882
}
865883
})
866884
}

tfhe/src/high_level_api/integers/unsigned/overflowing_ops.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,17 @@ where
285285
)
286286
}
287287
#[cfg(feature = "gpu")]
288-
InternalServerKey::Cuda(_) => {
289-
panic!("Cuda devices do not support overflowing_sub yet");
290-
}
288+
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
289+
let inner_result = cuda_key.key.unsigned_overflowing_sub(
290+
&self.ciphertext.on_gpu(),
291+
&other.ciphertext.on_gpu(),
292+
streams,
293+
);
294+
(
295+
FheUint::<Id>::new(inner_result.0, cuda_key.tag.clone()),
296+
FheBool::new(inner_result.1, cuda_key.tag.clone()),
297+
)
298+
}),
291299
})
292300
}
293301
}

tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,6 @@ where
446446
// DivRem is a bit special as it returns a tuple of quotient and remainder
447447
macro_rules! generic_integer_impl_scalar_div_rem {
448448
(
449-
key_method: $key_method:ident,
450449
// A 'list' of tuple, where the first element is the concrete Fhe type
451450
// e.g (FheUint8 and the rest is scalar types (u8, u16, etc)
452451
fhe_and_scalar_type: $(
@@ -473,15 +472,24 @@ macro_rules! generic_integer_impl_scalar_div_rem {
473472
global_state::with_internal_keys(|key| {
474473
match key {
475474
InternalServerKey::Cpu(cpu_key) => {
476-
let (q, r) = cpu_key.pbs_key().$key_method(&*self.ciphertext.on_cpu(), rhs);
475+
let (q, r) = cpu_key.pbs_key().scalar_div_rem_parallelized(&*self.ciphertext.on_cpu(), rhs);
477476
(
478477
<$concrete_type>::new(q, cpu_key.tag.clone()),
479478
<$concrete_type>::new(r, cpu_key.tag.clone())
480479
)
481480
}
482481
#[cfg(feature = "gpu")]
483-
InternalServerKey::Cuda(_) => {
484-
panic!("Cuda devices do not support div_rem yet");
482+
InternalServerKey::Cuda(cuda_key) => {
483+
let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| {
484+
cuda_key.key.scalar_div_rem(
485+
&*self.ciphertext.on_gpu(), rhs, streams
486+
)
487+
});
488+
let (q, r) = (RadixCiphertext::Cuda(inner_q), RadixCiphertext::Cuda(inner_r));
489+
(
490+
<$concrete_type>::new(q, cuda_key.tag.clone()),
491+
<$concrete_type>::new(r, cuda_key.tag.clone())
492+
)
485493
}
486494
}
487495
})
@@ -492,7 +500,6 @@ macro_rules! generic_integer_impl_scalar_div_rem {
492500
};
493501
}
494502
generic_integer_impl_scalar_div_rem!(
495-
key_method: scalar_div_rem_parallelized,
496503
fhe_and_scalar_type:
497504
(super::FheUint2, u8),
498505
(super::FheUint4, u8),
@@ -978,8 +985,13 @@ generic_integer_impl_scalar_operation!(
978985
RadixCiphertext::Cpu(inner_result)
979986
},
980987
#[cfg(feature = "gpu")]
981-
InternalServerKey::Cuda(_) => {
982-
panic!("Div '/' with clear value is not yet supported by Cuda devices")
988+
InternalServerKey::Cuda(cuda_key) => {
989+
let inner_result = with_thread_local_cuda_streams(|streams| {
990+
cuda_key.key.scalar_div(
991+
&lhs.ciphertext.on_gpu(), rhs, streams
992+
)
993+
});
994+
RadixCiphertext::Cuda(inner_result)
983995
}
984996
})
985997
}
@@ -1014,8 +1026,13 @@ generic_integer_impl_scalar_operation!(
10141026
RadixCiphertext::Cpu(inner_result)
10151027
},
10161028
#[cfg(feature = "gpu")]
1017-
InternalServerKey::Cuda(_) => {
1018-
panic!("Rem '%' with clear value is not yet supported by Cuda devices")
1029+
InternalServerKey::Cuda(cuda_key) => {
1030+
let inner_result = with_thread_local_cuda_streams(|streams| {
1031+
cuda_key.key.scalar_rem(
1032+
&lhs.ciphertext.on_gpu(), rhs, streams
1033+
)
1034+
});
1035+
RadixCiphertext::Cuda(inner_result)
10191036
}
10201037
})
10211038
}

0 commit comments

Comments
 (0)