Skip to content

Commit

Permalink
refactor(core): rename GgswPerLweMultiBitBskElement MultiBitPowerSetSize
Browse files Browse the repository at this point in the history
  • Loading branch information
mayeul-zama committed May 13, 2024
1 parent e09ea6c commit 3c3b4ca
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

template <typename Torus, class params>
__device__ Torus calculates_monomial_degree(Torus *lwe_array_group,
uint32_t ggsw_idx,
uint32_t power_set_index,
uint32_t grouping_factor) {
Torus x = 0;
for (int i = 0; i < grouping_factor; i++) {
uint32_t mask_position = grouping_factor - (i + 1);
int selection_bit = (ggsw_idx >> mask_position) & 1;
int selection_bit = (power_set_index >> mask_position) & 1;
x += selection_bit * lwe_array_group[i];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ use rayon::prelude::*;
/// &mut encryption_generator,
/// );
///
/// let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
/// let multi_bit_power_set_size = grouping_factor.multi_bit_power_set_size();
///
/// for (mut ggsw_group, input_key_elements) in bsk.chunks_exact(ggsw_per_multi_bit_element.0).zip(
/// for (mut ggsw_group, input_key_elements) in bsk.chunks_exact(multi_bit_power_set_size.0).zip(
/// input_lwe_secret_key
/// .as_ref()
/// .chunks_exact(grouping_factor.0),
Expand Down Expand Up @@ -144,10 +144,10 @@ pub fn generate_lwe_multi_bit_bootstrap_key<
let output_glwe_size = output.glwe_size();
let output_polynomial_size = output.polynomial_size();
let output_grouping_factor = output.grouping_factor();
let ggsw_per_multi_bit_element = output_grouping_factor.ggsw_per_multi_bit_element();
let multi_bit_power_set_size = output_grouping_factor.multi_bit_power_set_size();

for ((mut ggsw_group, input_key_elements), mut loop_generator) in output
.chunks_exact_mut(ggsw_per_multi_bit_element.0)
.chunks_exact_mut(multi_bit_power_set_size.0)
.zip(
input_lwe_secret_key
.as_ref()
Expand Down Expand Up @@ -292,9 +292,9 @@ where
///
/// par_convert_standard_lwe_multi_bit_bootstrap_key_to_fourier(&bsk, &mut multi_bit_bsk);
///
/// let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
/// let multi_bit_power_set_size = grouping_factor.multi_bit_power_set_size();
///
/// for (mut ggsw_group, input_key_elements) in bsk.chunks_exact(ggsw_per_multi_bit_element.0).zip(
/// for (mut ggsw_group, input_key_elements) in bsk.chunks_exact(multi_bit_power_set_size.0).zip(
/// input_lwe_secret_key
/// .as_ref()
/// .chunks_exact(grouping_factor.0),
Expand Down Expand Up @@ -371,11 +371,11 @@ pub fn par_generate_lwe_multi_bit_bootstrap_key<
let output_glwe_size = output.glwe_size();
let output_polynomial_size = output.polynomial_size();
let output_grouping_factor = output.grouping_factor();
let ggsw_per_multi_bit_element = output_grouping_factor.ggsw_per_multi_bit_element();
let multi_bit_power_set_size = output_grouping_factor.multi_bit_power_set_size();

output
.par_iter_mut()
.chunks(ggsw_per_multi_bit_element.0)
.chunks(multi_bit_power_set_size.0)
.zip(
input_lwe_secret_key
.as_ref()
Expand Down Expand Up @@ -583,10 +583,10 @@ pub fn generate_seeded_lwe_multi_bit_bootstrap_key<
let output_glwe_size = output.glwe_size();
let output_polynomial_size = output.polynomial_size();
let output_grouping_factor = output.grouping_factor();
let ggsw_per_multi_bit_element = output_grouping_factor.ggsw_per_multi_bit_element();
let multi_bit_power_set_size = output_grouping_factor.multi_bit_power_set_size();

for ((mut ggsw_group, input_key_elements), mut loop_generator) in output
.chunks_exact_mut(ggsw_per_multi_bit_element.0)
.chunks_exact_mut(multi_bit_power_set_size.0)
.zip(
input_lwe_secret_key
.as_ref()
Expand Down Expand Up @@ -741,11 +741,11 @@ pub fn par_generate_seeded_lwe_multi_bit_bootstrap_key<
let output_glwe_size = output.glwe_size();
let output_polynomial_size = output.polynomial_size();
let output_grouping_factor = output.grouping_factor();
let ggsw_per_multi_bit_element = output_grouping_factor.ggsw_per_multi_bit_element();
let multi_bit_power_set_size = output_grouping_factor.multi_bit_power_set_size();

output
.par_iter_mut()
.chunks(ggsw_per_multi_bit_element.0)
.chunks(multi_bit_power_set_size.0)
.zip(
input_lwe_secret_key
.as_ref()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ where
Scalar: UnsignedInteger + CastInto<usize> + CastFrom<usize>,
{
// Start at 1, the first ggsw is not rotated
(1..grouping_factor.ggsw_per_multi_bit_element().0).map(move |power_set_index| {
(1..grouping_factor.multi_bit_power_set_size().0).map(move |power_set_index| {
let mut monomial_degree = Scalar::ZERO;
for (&mask_element, selection_bit) in lwe_mask_elements
.iter()
Expand All @@ -49,7 +49,7 @@ pub(crate) fn selection_bit(
grouping_factor: LweBskGroupingFactor,
power_set_index: usize,
) -> impl Iterator<Item = usize> {
debug_assert!(power_set_index < grouping_factor.ggsw_per_multi_bit_element().0);
debug_assert!(power_set_index < grouping_factor.multi_bit_power_set_size().0);

(0..grouping_factor.0).map(move |mask_idx| {
let mask_position = grouping_factor.0 - (mask_idx + 1);
Expand Down Expand Up @@ -417,7 +417,7 @@ pub fn multi_bit_non_deterministic_blind_rotate_assign<Scalar, OutputCont, KeyCo
let ggsw_vec: Vec<_> = multi_bit_bsk.ggsw_iter().collect();

let grouping_factor = multi_bit_bsk.grouping_factor();
let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
let multi_bit_power_set_size = grouping_factor.multi_bit_power_set_size();

let input_lwe_dimension = multi_bit_bsk.input_lwe_dimension();

Expand Down Expand Up @@ -473,8 +473,8 @@ pub fn multi_bit_non_deterministic_blind_rotate_assign<Scalar, OutputCont, KeyCo
let switched_degrees =
switched_modulus_input.switched_modulus_input_mask_per_group(work_index);

let ggsw_group = &ggsw_vec[work_index * ggsw_per_multi_bit_element.0
..(work_index + 1) * ggsw_per_multi_bit_element.0];
let ggsw_group = &ggsw_vec[work_index * multi_bit_power_set_size.0
..(work_index + 1) * multi_bit_power_set_size.0];

let mut ready_for_consumer = ready_for_consumer_lock.lock().unwrap();

Expand Down Expand Up @@ -648,7 +648,7 @@ pub fn multi_bit_deterministic_blind_rotate_assign<Scalar, OutputCont, KeyCont>(
let ggsw_vec: Vec<_> = multi_bit_bsk.ggsw_iter().collect();

let grouping_factor = multi_bit_bsk.grouping_factor();
let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
let multi_bit_power_set_size = grouping_factor.multi_bit_power_set_size();

let input_lwe_dimension = multi_bit_bsk.input_lwe_dimension();

Expand Down Expand Up @@ -699,8 +699,8 @@ pub fn multi_bit_deterministic_blind_rotate_assign<Scalar, OutputCont, KeyCont>(
let switched_degrees =
switched_modulus_input.switched_modulus_input_mask_per_group(work_index);

let ggsw_group = &ggsw_vec[work_index * ggsw_per_multi_bit_element.0
..(work_index + 1) * ggsw_per_multi_bit_element.0];
let ggsw_group = &ggsw_vec[work_index * multi_bit_power_set_size.0
..(work_index + 1) * multi_bit_power_set_size.0];

let mut ready_for_consumer = ready_for_consumer_lock.lock().unwrap();

Expand Down Expand Up @@ -1236,7 +1236,7 @@ pub fn std_multi_bit_non_deterministic_blind_rotate_assign<Scalar, OutputCont, K
let ggsw_vec: Vec<_> = multi_bit_bsk.iter().collect();

let grouping_factor = multi_bit_bsk.grouping_factor();
let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
let multi_bit_power_set_size = grouping_factor.multi_bit_power_set_size();

let input_lwe_dimension = multi_bit_bsk.input_lwe_dimension();

Expand Down Expand Up @@ -1312,8 +1312,8 @@ pub fn std_multi_bit_non_deterministic_blind_rotate_assign<Scalar, OutputCont, K
let switched_degrees =
switched_modulus_input.switched_modulus_input_mask_per_group(work_index);

let ggsw_group = &ggsw_vec[work_index * ggsw_per_multi_bit_element.0
..(work_index + 1) * ggsw_per_multi_bit_element.0];
let ggsw_group = &ggsw_vec[work_index * multi_bit_power_set_size.0
..(work_index + 1) * multi_bit_power_set_size.0];

let mut ready_for_consumer = ready_for_consumer_lock.lock().unwrap();

Expand Down Expand Up @@ -1502,7 +1502,7 @@ pub fn std_multi_bit_deterministic_blind_rotate_assign<Scalar, OutputCont, KeyCo
let ggsw_vec: Vec<_> = multi_bit_bsk.iter().collect();

let grouping_factor = multi_bit_bsk.grouping_factor();
let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
let multi_bit_power_set_size = grouping_factor.multi_bit_power_set_size();

let input_lwe_dimension = multi_bit_bsk.input_lwe_dimension();

Expand Down Expand Up @@ -1573,8 +1573,8 @@ pub fn std_multi_bit_deterministic_blind_rotate_assign<Scalar, OutputCont, KeyCo
let switched_degrees =
switched_modulus_input.switched_modulus_input_mask_per_group(work_index);

let ggsw_group = &ggsw_vec[work_index * ggsw_per_multi_bit_element.0
..(work_index + 1) * ggsw_per_multi_bit_element.0];
let ggsw_group = &ggsw_vec[work_index * multi_bit_power_set_size.0
..(work_index + 1) * multi_bit_power_set_size.0];

let mut ready_for_consumer = ready_for_consumer_lock.lock().unwrap();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub fn decompress_seeded_lwe_multi_bit_bootstrap_key_with_existing_generator<
let output_glwe_size = output_bsk.glwe_size();
let output_polynomial_size = output_bsk.polynomial_size();
let output_grouping_factor = output_bsk.grouping_factor();
let ggsw_per_multi_bit_element = output_grouping_factor.ggsw_per_multi_bit_element();
let multi_bit_power_set_size = output_grouping_factor.multi_bit_power_set_size();

let gen_iter = generator
.fork_multi_bit_bsk_to_ggsw_group::<Scalar>(
Expand All @@ -51,8 +51,8 @@ pub fn decompress_seeded_lwe_multi_bit_bootstrap_key_with_existing_generator<
.unwrap();

for ((mut output_ggsw_group, input_ggsw_group), mut loop_generator) in output_bsk
.chunks_exact_mut(ggsw_per_multi_bit_element.0)
.zip(input_bsk.chunks_exact(ggsw_per_multi_bit_element.0))
.chunks_exact_mut(multi_bit_power_set_size.0)
.zip(input_bsk.chunks_exact(multi_bit_power_set_size.0))
.zip(gen_iter)
{
let gen_iter = loop_generator
Expand Down Expand Up @@ -136,7 +136,7 @@ pub fn par_decompress_seeded_lwe_multi_bit_bootstrap_key_with_existing_generator
let output_glwe_size = output_bsk.glwe_size();
let output_polynomial_size = output_bsk.polynomial_size();
let output_grouping_factor = output_bsk.grouping_factor();
let ggsw_per_multi_bit_element = output_grouping_factor.ggsw_per_multi_bit_element();
let multi_bit_power_set_size = output_grouping_factor.multi_bit_power_set_size();

let gen_iter = generator
.par_fork_multi_bit_bsk_to_ggsw_group::<Scalar>(
Expand All @@ -149,8 +149,8 @@ pub fn par_decompress_seeded_lwe_multi_bit_bootstrap_key_with_existing_generator
.unwrap();

output_bsk
.par_chunks_exact_mut(ggsw_per_multi_bit_element.0)
.zip(input_bsk.par_chunks_exact(ggsw_per_multi_bit_element.0))
.par_chunks_exact_mut(multi_bit_power_set_size.0)
.zip(input_bsk.par_chunks_exact(multi_bit_power_set_size.0))
.zip(gen_iter)
.for_each(
|((mut output_ggsw_group, input_ggsw_group), mut loop_generator)| {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ impl<G: ByteRandomGenerator> MaskRandomGenerator<G> {
polynomial_size: PolynomialSize,
grouping_factor: LweBskGroupingFactor,
) -> Result<impl Iterator<Item = Self>, ForkError> {
let ggsw_count = grouping_factor.ggsw_per_multi_bit_element();
let ggsw_count = grouping_factor.multi_bit_power_set_size();
let mask_bytes = mask_elements_per_ggsw(level, glwe_size, polynomial_size)
.to_mask_byte_count(mask_bytes_per_coef::<T>());
self.try_fork(ggsw_count.0, mask_bytes)
Expand Down Expand Up @@ -254,7 +254,7 @@ impl<G: ParallelByteRandomGenerator> MaskRandomGenerator<G> {
polynomial_size: PolynomialSize,
grouping_factor: LweBskGroupingFactor,
) -> Result<impl IndexedParallelIterator<Item = Self>, ForkError> {
let ggsw_count = grouping_factor.ggsw_per_multi_bit_element();
let ggsw_count = grouping_factor.multi_bit_power_set_size();
let mask_bytes = mask_elements_per_ggsw(level, glwe_size, polynomial_size)
.to_mask_byte_count(mask_bytes_per_coef::<T>());
self.par_try_fork(ggsw_count.0, mask_bytes)
Expand Down Expand Up @@ -419,7 +419,7 @@ fn mask_elements_per_multi_bit_bsk_ggsw_group(
grouping_factor: LweBskGroupingFactor,
) -> MaskElementCount {
MaskElementCount(
grouping_factor.ggsw_per_multi_bit_element().0
grouping_factor.multi_bit_power_set_size().0
* mask_elements_per_ggsw(level, glwe_size, poly_size).0,
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ impl<G: ByteRandomGenerator> NoiseRandomGenerator<G> {
polynomial_size: PolynomialSize,
grouping_factor: LweBskGroupingFactor,
) -> Result<impl Iterator<Item = Self>, ForkError> {
let ggsw_count = grouping_factor.ggsw_per_multi_bit_element();
let ggsw_count = grouping_factor.multi_bit_power_set_size();
let noise_bytes = noise_elements_per_ggsw(level, glwe_size, polynomial_size)
.to_noise_byte_count(noise_bytes_per_coef());
self.try_fork(ggsw_count.0, noise_bytes)
Expand Down Expand Up @@ -326,7 +326,7 @@ impl<G: ParallelByteRandomGenerator> NoiseRandomGenerator<G> {
polynomial_size: PolynomialSize,
grouping_factor: LweBskGroupingFactor,
) -> Result<impl IndexedParallelIterator<Item = Self>, ForkError> {
let ggsw_count = grouping_factor.ggsw_per_multi_bit_element();
let ggsw_count = grouping_factor.multi_bit_power_set_size();
let noise_bytes = noise_elements_per_ggsw(level, glwe_size, polynomial_size)
.to_noise_byte_count(noise_bytes_per_coef());
self.par_try_fork(ggsw_count.0, noise_bytes)
Expand Down Expand Up @@ -493,7 +493,7 @@ fn noise_elements_per_multi_bit_bsk_ggsw_group(
grouping_factor: LweBskGroupingFactor,
) -> NoiseElementCount {
NoiseElementCount(
grouping_factor.ggsw_per_multi_bit_element().0
grouping_factor.multi_bit_power_set_size().0
* noise_elements_per_ggsw(level, glwe_size, poly_size).0,
)
}
Expand Down
6 changes: 3 additions & 3 deletions tfhe/src/core_crypto/commons/parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,14 @@ pub struct ThreadCount(pub usize);
pub struct LweBskGroupingFactor(pub usize);

impl LweBskGroupingFactor {
pub fn ggsw_per_multi_bit_element(&self) -> GgswPerLweMultiBitBskElement {
GgswPerLweMultiBitBskElement(1 << self.0)
pub fn multi_bit_power_set_size(&self) -> MultiBitPowerSetSize {
MultiBitPowerSetSize(1 << self.0)
}
}

/// The number of GGSW ciphertexts required per multi_bit BSK element
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)]
pub struct GgswPerLweMultiBitBskElement(pub usize);
pub struct MultiBitPowerSetSize(pub usize);

#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)]
pub enum EncryptionKeyChoice {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,10 @@ impl<Scalar: UnsignedInteger + CastInto<usize> + CastFrom<usize>>
let mut diffs = vec![];

for lwe_mask_elements in input_lwe_mask.as_ref().chunks_exact(grouping_factor.0) {
for ggsw_idx in 1..grouping_factor.ggsw_per_multi_bit_element().0 {
for power_set_index in 1..grouping_factor.multi_bit_power_set_size().0 {
// We need to store the diff sums of more than one element as we store the
// individual modulus_switched elements
if ggsw_idx.count_ones() == 1 {
if power_set_index.count_ones() == 1 {
continue;
}

Expand All @@ -232,7 +232,7 @@ impl<Scalar: UnsignedInteger + CastInto<usize> + CastFrom<usize>>

for (&mask_element, selection_bit) in lwe_mask_elements
.iter()
.zip_eq(selection_bit(grouping_factor, ggsw_idx))
.zip_eq(selection_bit(grouping_factor, power_set_index))
{
let selection_bit: Scalar = Scalar::cast_from(selection_bit);

Expand Down Expand Up @@ -337,17 +337,17 @@ impl<Scalar: UnsignedInteger + CastInto<usize> + CastFrom<usize>>
let mut switched_modulus_input_mask_per_group: Vec<usize> = vec![];

for lwe_mask_elements in masks.chunks_exact(self.grouping_factor.0) {
for ggsw_idx in 1..self.grouping_factor.ggsw_per_multi_bit_element().0 {
for power_set_index in 1..self.grouping_factor.multi_bit_power_set_size().0 {
let mut monomial_degree = 0;
for (&mask_element, selection_bit) in lwe_mask_elements
.iter()
.zip_eq(selection_bit(self.grouping_factor, ggsw_idx))
.zip_eq(selection_bit(self.grouping_factor, power_set_index))
{
monomial_degree =
monomial_degree.wrapping_add(selection_bit.wrapping_mul(mask_element));
}

if ggsw_idx.count_ones() != 1 {
if power_set_index.count_ones() != 1 {
let diff = diffs(diff_index);

diff_index += 1;
Expand Down Expand Up @@ -387,9 +387,9 @@ impl MultiBitModulusSwitchedCt for FromCompressionMultiBitModulusSwitchedCt {
&self,
index: usize,
) -> impl Iterator<Item = usize> + '_ {
let ggsw_per_multi_bit_element = self.grouping_factor.ggsw_per_multi_bit_element();
let multi_bit_power_set_size = self.grouping_factor.multi_bit_power_set_size();

let chunk_size = ggsw_per_multi_bit_element.0 - 1;
let chunk_size = multi_bit_power_set_size.0 - 1;

self.switched_modulus_input_mask_per_group[index * chunk_size..(index + 1) * chunk_size]
.iter()
Expand Down
Loading

0 comments on commit 3c3b4ca

Please sign in to comment.