Skip to content

Commit

Permalink
chore(gpu): add scalar div and signed scalar div to hl api
Browse files Browse the repository at this point in the history
Also add overflowing sub to hl
  • Loading branch information
agnesLeroy committed Sep 19, 2024
1 parent 48315dc commit 24088fd
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 25 deletions.
6 changes: 2 additions & 4 deletions tfhe/src/core_crypto/gpu/entities/lwe_ciphertext_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,12 @@ impl<T: UnsignedInteger> CudaLweCiphertextList<T> {

// Copy to the GPU
let h_input = h_ct.as_view().into_container();
let mut d_vec = CudaVec::new(
let mut d_vec = CudaVec::new_async(
lwe_dimension.to_lwe_size().0 * lwe_ciphertext_count.0,
streams,
0,
);
unsafe {
d_vec.copy_from_cpu_async(h_input.as_ref(), streams, 0);
}
d_vec.copy_from_cpu_async(h_input.as_ref(), streams, 0);
let cuda_lwe_list = CudaLweList {
d_vec,
lwe_ciphertext_count,
Expand Down
36 changes: 27 additions & 9 deletions tfhe/src/high_level_api/integers/signed/scalar_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,6 @@ where
// DivRem is a bit special as it returns a tuple of quotient and remainder
macro_rules! generic_integer_impl_scalar_div_rem {
(
key_method: $key_method:ident,
// A 'list' of tuple, where the first element is the concrete Fhe type
// e.g (FheUint8 and the rest is scalar types (u8, u16, etc)
fhe_and_scalar_type: $(
Expand Down Expand Up @@ -393,15 +392,24 @@ macro_rules! generic_integer_impl_scalar_div_rem {
InternalServerKey::Cpu(cpu_key) => {
let (q, r) = cpu_key
.pbs_key()
.$key_method(&*self.ciphertext.on_cpu(), rhs);
.signed_scalar_div_rem_parallelized(&*self.ciphertext.on_cpu(), rhs);
(
<$concrete_type>::new(q, cpu_key.tag.clone()),
<$concrete_type>::new(r, cpu_key.tag.clone())
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices does not support div rem yet")
InternalServerKey::Cuda(cuda_key) => {
let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| {
cuda_key.key.signed_scalar_div_rem(
&*self.ciphertext.on_gpu(), rhs, streams
)
});
let (q, r) = (RadixCiphertext::Cuda(inner_q), RadixCiphertext::Cuda(inner_r));
(
<$concrete_type>::new(q, cuda_key.tag.clone()),
<$concrete_type>::new(r, cuda_key.tag.clone())
)
}
})
}
Expand All @@ -410,8 +418,8 @@ macro_rules! generic_integer_impl_scalar_div_rem {
)* // Closing first repeating pattern
};
}

generic_integer_impl_scalar_div_rem!(
key_method: signed_scalar_div_rem_parallelized,
fhe_and_scalar_type:
(super::FheInt2, i8),
(super::FheInt4, i8),
Expand Down Expand Up @@ -826,8 +834,13 @@ generic_integer_impl_scalar_operation!(
RadixCiphertext::Cpu(inner_result)
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Div '/' with clear value is not yet supported by Cuda devices")
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
cuda_key.key.signed_scalar_div(
&lhs.ciphertext.on_gpu(), rhs, streams
)
});
RadixCiphertext::Cuda(inner_result)
}
})
}
Expand Down Expand Up @@ -859,8 +872,13 @@ generic_integer_impl_scalar_operation!(
RadixCiphertext::Cpu(inner_result)
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Rem '%' with clear value is not yet supported by Cuda devices")
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
cuda_key.key.signed_scalar_rem(
&lhs.ciphertext.on_gpu(), rhs, streams
)
});
RadixCiphertext::Cuda(inner_result)
}
})
}
Expand Down
14 changes: 11 additions & 3 deletions tfhe/src/high_level_api/integers/unsigned/overflowing_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,17 @@ where
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support overflowing_sub yet");
}
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let inner_result = cuda_key.key.unsigned_overflowing_sub(
&self.ciphertext.on_gpu(),
&other.ciphertext.on_gpu(),
streams,
);
(
FheUint::<Id>::new(inner_result.0, cuda_key.tag.clone()),
FheBool::new(inner_result.1, cuda_key.tag.clone()),
)
}),
})
}
}
Expand Down
35 changes: 26 additions & 9 deletions tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,6 @@ where
// DivRem is a bit special as it returns a tuple of quotient and remainder
macro_rules! generic_integer_impl_scalar_div_rem {
(
key_method: $key_method:ident,
// A 'list' of tuple, where the first element is the concrete Fhe type
// e.g (FheUint8 and the rest is scalar types (u8, u16, etc)
fhe_and_scalar_type: $(
Expand All @@ -473,15 +472,24 @@ macro_rules! generic_integer_impl_scalar_div_rem {
global_state::with_internal_keys(|key| {
match key {
InternalServerKey::Cpu(cpu_key) => {
let (q, r) = cpu_key.pbs_key().$key_method(&*self.ciphertext.on_cpu(), rhs);
let (q, r) = cpu_key.pbs_key().scalar_div_rem_parallelized(&*self.ciphertext.on_cpu(), rhs);
(
<$concrete_type>::new(q, cpu_key.tag.clone()),
<$concrete_type>::new(r, cpu_key.tag.clone())
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support div_rem yet");
InternalServerKey::Cuda(cuda_key) => {
let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| {
cuda_key.key.scalar_div_rem(
&*self.ciphertext.on_gpu(), rhs, streams
)
});
let (q, r) = (RadixCiphertext::Cuda(inner_q), RadixCiphertext::Cuda(inner_r));
(
<$concrete_type>::new(q, cuda_key.tag.clone()),
<$concrete_type>::new(r, cuda_key.tag.clone())
)
}
}
})
Expand All @@ -492,7 +500,6 @@ macro_rules! generic_integer_impl_scalar_div_rem {
};
}
generic_integer_impl_scalar_div_rem!(
key_method: scalar_div_rem_parallelized,
fhe_and_scalar_type:
(super::FheUint2, u8),
(super::FheUint4, u8),
Expand Down Expand Up @@ -978,8 +985,13 @@ generic_integer_impl_scalar_operation!(
RadixCiphertext::Cpu(inner_result)
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Div '/' with clear value is not yet supported by Cuda devices")
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
cuda_key.key.scalar_div(
&lhs.ciphertext.on_gpu(), rhs, streams
)
});
RadixCiphertext::Cuda(inner_result)
}
})
}
Expand Down Expand Up @@ -1014,8 +1026,13 @@ generic_integer_impl_scalar_operation!(
RadixCiphertext::Cpu(inner_result)
},
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Rem '%' with clear value is not yet supported by Cuda devices")
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_streams(|streams| {
cuda_key.key.scalar_rem(
&lhs.ciphertext.on_gpu(), rhs, streams
)
});
RadixCiphertext::Cuda(inner_result)
}
})
}
Expand Down

0 comments on commit 24088fd

Please sign in to comment.