Skip to content

Commit 48315dc

Browse files
committed
feat(gpu): signed scalar div
1 parent 52b148a commit 48315dc

File tree

10 files changed

+1309
-403
lines changed

10 files changed

+1309
-403
lines changed

tfhe/benches/integer/signed_bench.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1813,6 +1813,12 @@ mod cuda {
18131813
rng_func: default_signed_scalar
18141814
);
18151815

1816+
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
1817+
method_name: unchecked_signed_scalar_div_rem,
1818+
display_name: div_rem,
1819+
rng_func: div_scalar
1820+
);
1821+
18161822
//===========================================
18171823
// Default
18181824
//===========================================
@@ -2035,6 +2041,12 @@ mod cuda {
20352041
rng_func: default_signed_scalar
20362042
);
20372043

2044+
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
2045+
method_name: signed_scalar_div_rem,
2046+
display_name: div_rem,
2047+
rng_func: div_scalar
2048+
);
2049+
20382050
criterion_group!(
20392051
unchecked_cuda_ops,
20402052
cuda_unchecked_add,
@@ -2081,6 +2093,7 @@ mod cuda {
20812093
cuda_unchecked_scalar_le,
20822094
cuda_unchecked_scalar_min,
20832095
cuda_unchecked_scalar_max,
2096+
cuda_unchecked_signed_scalar_div_rem,
20842097
);
20852098

20862099
criterion_group!(
@@ -2146,6 +2159,7 @@ mod cuda {
21462159
cuda_scalar_max,
21472160
cuda_signed_overflowing_scalar_add,
21482161
cuda_signed_overflowing_scalar_sub,
2162+
cuda_signed_scalar_div_rem,
21492163
);
21502164

21512165
fn cuda_bench_server_key_signed_cast_function<F>(

tfhe/src/core_crypto/gpu/entities/lwe_ciphertext_list.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,19 @@ impl<T: UnsignedInteger> CudaLweCiphertextList<T> {
4141
pub fn from_lwe_ciphertext_list<C: Container<Element = T>>(
4242
h_ct: &LweCiphertextList<C>,
4343
streams: &CudaStreams,
44+
) -> Self {
45+
let res = unsafe { Self::from_lwe_ciphertext_list_async(h_ct, streams) };
46+
streams.synchronize();
47+
res
48+
}
49+
50+
/// # Safety
51+
///
52+
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
53+
/// not be dropped until stream is synchronised
54+
pub unsafe fn from_lwe_ciphertext_list_async<C: Container<Element = T>>(
55+
h_ct: &LweCiphertextList<C>,
56+
streams: &CudaStreams,
4457
) -> Self {
4558
let lwe_dimension = h_ct.lwe_size().to_lwe_dimension();
4659
let lwe_ciphertext_count = h_ct.lwe_ciphertext_count();
@@ -56,7 +69,6 @@ impl<T: UnsignedInteger> CudaLweCiphertextList<T> {
5669
unsafe {
5770
d_vec.copy_from_cpu_async(h_input.as_ref(), streams, 0);
5871
}
59-
streams.synchronize();
6072
let cuda_lwe_list = CudaLweList {
6173
d_vec,
6274
lwe_ciphertext_count,

0 commit comments

Comments
 (0)