From a305a7c2ccfa01868bb380b086d463e180aea316 Mon Sep 17 00:00:00 2001 From: tmontaigu Date: Tue, 19 Nov 2024 14:02:01 +0100 Subject: [PATCH] chore!: use u64 for NoiseLevel Change from usize to u64 for MaxNoiseLevel and NoiseLevel This is an API break as `new` and `get` handle/returns u64 instead of usize This is also a potential serialization break depending on the serializer used (bincode should be fine as it serializes usize as u64) --- tfhe/src/c_api/shortint/parameters.rs | 4 +-- tfhe/src/integer/server_key/mod.rs | 2 +- .../integer/server_key/radix_parallel/add.rs | 3 +- .../server_key/radix_parallel/block_shift.rs | 2 +- .../modulus_switch_compression.rs | 2 +- .../server_key/radix_parallel/scalar_add.rs | 3 +- .../server_key/radix_parallel/scalar_sub.rs | 3 +- .../integer/server_key/radix_parallel/sub.rs | 3 +- tfhe/src/js_on_wasm_api/shortint.rs | 2 +- tfhe/src/shortint/ciphertext/common.rs | 31 ++++++++++--------- .../shortint/list_compression/compression.rs | 2 +- tfhe/src/shortint/server_key/bivariate_pbs.rs | 6 ++-- tfhe/src/shortint/server_key/scalar_mul.rs | 4 +-- .../shortint/server_key/tests/noise_level.rs | 4 +-- tfhe/tests/backward_compatibility/shortint.rs | 2 +- 15 files changed, 39 insertions(+), 34 deletions(-) diff --git a/tfhe/src/c_api/shortint/parameters.rs b/tfhe/src/c_api/shortint/parameters.rs index 24ee4f251e..85c865207e 100644 --- a/tfhe/src/c_api/shortint/parameters.rs +++ b/tfhe/src/c_api/shortint/parameters.rs @@ -62,7 +62,7 @@ impl TryFrom for crate::shortint::ClassicPBSParameters { c_params.modulus_power_of_2_exponent, )?, max_noise_level: crate::shortint::parameters::MaxNoiseLevel::new( - c_params.max_noise_level, + c_params.max_noise_level as u64, ), log2_p_fail: c_params.log2_p_fail, encryption_key_choice: c_params.encryption_key_choice.into(), @@ -113,7 +113,7 @@ impl ShortintPBSParameters { ks_level: rust_params.ks_level.0, message_modulus: rust_params.message_modulus.0, carry_modulus: rust_params.carry_modulus.0, - max_noise_level: rust_params.max_noise_level.get(), + max_noise_level: rust_params.max_noise_level.get() as usize, log2_p_fail: rust_params.log2_p_fail, modulus_power_of_2_exponent: convert_modulus(rust_params.ciphertext_modulus), encryption_key_choice: ShortintEncryptionKeyChoice::convert( diff --git a/tfhe/src/integer/server_key/mod.rs b/tfhe/src/integer/server_key/mod.rs index 376a4594a3..464c8a8620 100644 --- a/tfhe/src/integer/server_key/mod.rs +++ b/tfhe/src/integer/server_key/mod.rs @@ -246,7 +246,7 @@ impl ServerKey { MaxDegree::from_msg_carry_modulus(self.message_modulus(), self.carry_modulus()); let max_sum_to_full_carry = max_degree.get() / degree.get(); - max_sum_to_full_carry.min(self.key.max_noise_level.get()) + max_sum_to_full_carry.min(self.key.max_noise_level.get() as usize) } } diff --git a/tfhe/src/integer/server_key/radix_parallel/add.rs b/tfhe/src/integer/server_key/radix_parallel/add.rs index 71516125f3..2fbaa37731 100644 --- a/tfhe/src/integer/server_key/radix_parallel/add.rs +++ b/tfhe/src/integer/server_key/radix_parallel/add.rs @@ -854,7 +854,8 @@ impl ServerKey { // Just in case we compare with max noise level, but it should always be num_bits_in_blocks // with the parameters we provide - let grouping_size = (num_bits_in_block as usize).min(self.key.max_noise_level.get()); + let grouping_size = + (num_bits_in_block as usize).min(self.key.max_noise_level.get() as usize); let mut output_flag = None; diff --git a/tfhe/src/integer/server_key/radix_parallel/block_shift.rs b/tfhe/src/integer/server_key/radix_parallel/block_shift.rs index ecf714610a..432f03b32f 100644 --- a/tfhe/src/integer/server_key/radix_parallel/block_shift.rs +++ b/tfhe/src/integer/server_key/radix_parallel/block_shift.rs @@ -67,7 +67,7 @@ impl ServerKey { assert!( self.key .max_noise_level - .validate(NoiseLevel::NOMINAL * 3usize) + .validate(NoiseLevel::NOMINAL * 3u64) .is_ok(), "Parameters must support 2 additions before a PBS" ); diff --git a/tfhe/src/integer/server_key/radix_parallel/modulus_switch_compression.rs b/tfhe/src/integer/server_key/radix_parallel/modulus_switch_compression.rs index 7539efc769..715c0ff817 100644 --- a/tfhe/src/integer/server_key/radix_parallel/modulus_switch_compression.rs +++ b/tfhe/src/integer/server_key/radix_parallel/modulus_switch_compression.rs @@ -66,7 +66,7 @@ impl ServerKey { "Compression does not support message_modulus > carry_modulus" ); assert!( - self.key.max_noise_level.get() >= self.message_modulus().0 + 1, + self.key.max_noise_level.get() >= self.message_modulus().0 as u64 + 1, "Compression does not support max_noise_level < message_modulus + 1" ); diff --git a/tfhe/src/integer/server_key/radix_parallel/scalar_add.rs b/tfhe/src/integer/server_key/radix_parallel/scalar_add.rs index c4a200ba70..b14756d30b 100644 --- a/tfhe/src/integer/server_key/radix_parallel/scalar_add.rs +++ b/tfhe/src/integer/server_key/radix_parallel/scalar_add.rs @@ -555,7 +555,8 @@ impl ServerKey { let num_bits_in_block = packed_modulus.ilog2(); // Just in case we compare with max noise level, but it should always be num_bits_in_blocks // with the parameters we provide - let grouping_size = (num_bits_in_block as usize).min(self.key.max_noise_level.get()); + let grouping_size = + (num_bits_in_block as usize).min(self.key.max_noise_level.get() as usize); // In this, we store lookup tables to be used on each 'packing'. // These LUTs will generate an output that tells whether the packing diff --git a/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs b/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs index 30a8a51172..00df516a8d 100644 --- a/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs +++ b/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs @@ -394,7 +394,8 @@ impl ServerKey { let num_bits_in_block = packed_modulus.ilog2(); // Just in case we compare with max noise level, but it should always be num_bits_in_blocks // with the parameters we provide - let grouping_size = (num_bits_in_block as usize).min(self.key.max_noise_level.get()); + let grouping_size = + (num_bits_in_block as usize).min(self.key.max_noise_level.get() as usize); // In this, we store lookup tables to be used on each 'packing'. // These LUTs will generate an output that tells whether the packing diff --git a/tfhe/src/integer/server_key/radix_parallel/sub.rs b/tfhe/src/integer/server_key/radix_parallel/sub.rs index 47d8b74d6e..5fb8ba3928 100644 --- a/tfhe/src/integer/server_key/radix_parallel/sub.rs +++ b/tfhe/src/integer/server_key/radix_parallel/sub.rs @@ -386,7 +386,8 @@ impl ServerKey { // Just in case we compare with max noise level, but it should always be num_bits_in_blocks // with the parameters we provide - let grouping_size = (num_bits_in_block as usize).min(self.key.max_noise_level.get()); + let grouping_size = + (num_bits_in_block as usize).min(self.key.max_noise_level.get() as usize); // Second step let (mut prepared_blocks, resolved_borrows) = { diff --git a/tfhe/src/js_on_wasm_api/shortint.rs b/tfhe/src/js_on_wasm_api/shortint.rs index d25c3b5fb3..d149b19303 100644 --- a/tfhe/src/js_on_wasm_api/shortint.rs +++ b/tfhe/src/js_on_wasm_api/shortint.rs @@ -417,7 +417,7 @@ impl Shortint { ks_level: usize, message_modulus: usize, carry_modulus: usize, - max_noise_level: usize, + max_noise_level: u64, log2_p_fail: f64, modulus_power_of_2_exponent: usize, encryption_key_choice: ShortintEncryptionKeyChoice, diff --git a/tfhe/src/shortint/ciphertext/common.rs b/tfhe/src/shortint/ciphertext/common.rs index 33cca6325f..d48265f853 100644 --- a/tfhe/src/shortint/ciphertext/common.rs +++ b/tfhe/src/shortint/ciphertext/common.rs @@ -23,16 +23,16 @@ impl std::error::Error for NotTrivialCiphertextError {} /// that guarantees the target p-error when doing a PBS on it #[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize, Versionize)] #[versionize(MaxNoiseLevelVersions)] -pub struct MaxNoiseLevel(usize); +pub struct MaxNoiseLevel(u64); impl MaxNoiseLevel { - pub(crate) const UNKNOWN: Self = Self(usize::MAX); + pub(crate) const UNKNOWN: Self = Self(u64::MAX); - pub const fn new(value: usize) -> Self { + pub const fn new(value: u64) -> Self { Self(value) } - pub const fn get(&self) -> usize { + pub const fn get(&self) -> u64 { self.0 } @@ -45,7 +45,8 @@ impl MaxNoiseLevel { msg_modulus: MessageModulus, carry_modulus: CarryModulus, ) -> Self { - Self((carry_modulus.0 * msg_modulus.0 - 1) / (msg_modulus.0 - 1)) + let level = (carry_modulus.0 * msg_modulus.0 - 1) / (msg_modulus.0 - 1); + Self(level as u64) } pub const fn validate(&self, noise_level: NoiseLevel) -> Result<(), CheckError> { @@ -64,17 +65,17 @@ impl MaxNoiseLevel { Debug, PartialEq, Eq, PartialOrd, Ord, Copy, Clone, Serialize, Deserialize, Versionize, )] #[versionize(NoiseLevelVersions)] -pub struct NoiseLevel(usize); +pub struct NoiseLevel(u64); impl NoiseLevel { pub const NOMINAL: Self = Self(1); pub const ZERO: Self = Self(0); // As a safety measure the unknown noise level is set to the max value - pub const UNKNOWN: Self = Self(usize::MAX); + pub const UNKNOWN: Self = Self(u64::MAX); } impl NoiseLevel { - pub fn get(&self) -> usize { + pub fn get(&self) -> u64 { self.0 } } @@ -94,16 +95,16 @@ impl std::ops::Add for NoiseLevel { } } -impl std::ops::MulAssign for NoiseLevel { - fn mul_assign(&mut self, rhs: usize) { +impl std::ops::MulAssign for NoiseLevel { + fn mul_assign(&mut self, rhs: u64) { self.0 = self.0.saturating_mul(rhs); } } -impl std::ops::Mul for NoiseLevel { +impl std::ops::Mul for NoiseLevel { type Output = Self; - fn mul(mut self, rhs: usize) -> Self::Output { + fn mul(mut self, rhs: u64) -> Self::Output { self *= rhs; self @@ -272,14 +273,14 @@ mod tests { let mut rng = thread_rng(); - assert_eq!(NoiseLevel::UNKNOWN.0, usize::MAX); + assert_eq!(NoiseLevel::UNKNOWN.0, u64::MAX); let max_noise_level = NoiseLevel::UNKNOWN; - let random_addend = rng.gen::(); + let random_addend = rng.gen::(); let add = max_noise_level + NoiseLevel(random_addend); assert_eq!(add, NoiseLevel::UNKNOWN); - let random_positive_multiplier = rng.gen_range(1usize..=usize::MAX); + let random_positive_multiplier = rng.gen_range(1u64..=u64::MAX); let mul = max_noise_level * random_positive_multiplier; assert_eq!(mul, NoiseLevel::UNKNOWN); } diff --git a/tfhe/src/shortint/list_compression/compression.rs b/tfhe/src/shortint/list_compression/compression.rs index a38f93461f..9dd5c6cc7f 100644 --- a/tfhe/src/shortint/list_compression/compression.rs +++ b/tfhe/src/shortint/list_compression/compression.rs @@ -88,7 +88,7 @@ impl CompressionKey { let mut ct = ct.clone(); let max_noise_level = - MaxNoiseLevel::new((ct.noise_level() * message_modulus.0).get()); + MaxNoiseLevel::new((ct.noise_level() * message_modulus.0 as u64).get()); unchecked_scalar_mul_assign(&mut ct, message_modulus.0 as u8, max_noise_level); list.extend(ct.ct.as_ref()); diff --git a/tfhe/src/shortint/server_key/bivariate_pbs.rs b/tfhe/src/shortint/server_key/bivariate_pbs.rs index d11b18065b..47cfb0f25d 100644 --- a/tfhe/src/shortint/server_key/bivariate_pbs.rs +++ b/tfhe/src/shortint/server_key/bivariate_pbs.rs @@ -36,7 +36,7 @@ fn ciphertexts_can_be_packed_without_exceeding_space_or_noise( max_degree.validate(final_degree)?; - let final_noise_level = (lhs.noise_level * factor) + rhs.noise_level; + let final_noise_level = (lhs.noise_level * factor as u64) + rhs.noise_level; server_key.max_noise_level.validate(final_noise_level)?; @@ -513,7 +513,7 @@ impl ScalingOperation { impl Ciphertext { fn noise_degree_if_scaled(&self, scale: u8) -> CiphertextNoiseDegree { CiphertextNoiseDegree { - noise_level: self.noise_level() * scale as usize, + noise_level: self.noise_level() * u64::from(scale), degree: self.degree * scale as usize, } } @@ -524,7 +524,7 @@ impl Ciphertext { } = self.noise_degree_if_bootstrapped(); CiphertextNoiseDegree { - noise_level: noise * scale as usize, + noise_level: noise * u64::from(scale), degree: degree * scale as usize, } } diff --git a/tfhe/src/shortint/server_key/scalar_mul.rs b/tfhe/src/shortint/server_key/scalar_mul.rs index 24df9d3f44..7209e3f0e0 100644 --- a/tfhe/src/shortint/server_key/scalar_mul.rs +++ b/tfhe/src/shortint/server_key/scalar_mul.rs @@ -298,7 +298,7 @@ impl ServerKey { self.max_degree.validate(Degree::new(final_degree))?; self.max_noise_level - .validate(ct.noise_level * scalar as usize)?; + .validate(ct.noise_level * u64::from(scalar))?; Ok(()) } @@ -521,7 +521,7 @@ pub(crate) fn unchecked_scalar_mul_assign( scalar: u8, max_noise_level: MaxNoiseLevel, ) { - ct.set_noise_level(ct.noise_level() * scalar as usize, max_noise_level); + ct.set_noise_level(ct.noise_level() * u64::from(scalar), max_noise_level); ct.degree = Degree::new(ct.degree.get() * scalar as usize); match scalar { diff --git a/tfhe/src/shortint/server_key/tests/noise_level.rs b/tfhe/src/shortint/server_key/tests/noise_level.rs index 669b5f703c..c242180ccc 100644 --- a/tfhe/src/shortint/server_key/tests/noise_level.rs +++ b/tfhe/src/shortint/server_key/tests/noise_level.rs @@ -234,7 +234,7 @@ fn test_ct_scalar_op_noise_level_propagation(sk: &ServerKey, ct: &Ciphertext, sc test_fn(&ServerKey::unchecked_scalar_add, &|ct_noise, _| ct_noise); test_fn(&ServerKey::unchecked_scalar_sub, &|ct_noise, _| ct_noise); test_fn(&ServerKey::unchecked_scalar_mul, &|ct_noise, scalar| { - ct_noise * scalar as usize + ct_noise * u64::from(scalar) }); let expected_pbs_noise = if ct.is_trivial() { NoiseLevel::ZERO @@ -281,7 +281,7 @@ fn test_ct_scalar_op_assign_noise_level_propagation(sk: &ServerKey, ct: &Ciphert }); test_fn( &ServerKey::unchecked_scalar_mul_assign, - &|ct_noise, scalar| ct_noise * scalar as usize, + &|ct_noise, scalar| ct_noise * u64::from(scalar), ); let expected_pbs_noise = if ct.is_trivial() { NoiseLevel::ZERO diff --git a/tfhe/tests/backward_compatibility/shortint.rs b/tfhe/tests/backward_compatibility/shortint.rs index 0f267e6d28..dde8823b52 100644 --- a/tfhe/tests/backward_compatibility/shortint.rs +++ b/tfhe/tests/backward_compatibility/shortint.rs @@ -36,7 +36,7 @@ pub fn load_params(test_params: &TestParameterSet) -> ClassicPBSParameters { ks_level: DecompositionLevelCount(test_params.ks_level), message_modulus: MessageModulus(test_params.message_modulus), carry_modulus: CarryModulus(test_params.carry_modulus), - max_noise_level: MaxNoiseLevel::new(test_params.max_noise_level), + max_noise_level: MaxNoiseLevel::new(test_params.max_noise_level as u64), log2_p_fail: test_params.log2_p_fail, ciphertext_modulus: CiphertextModulus::try_new(test_params.ciphertext_modulus).unwrap(), encryption_key_choice: {