From 53a1f35d3bda35b4df2f43f950b9db3534b0c1c4 Mon Sep 17 00:00:00 2001 From: "Mayeul@Zama" <69792125+mayeul-zama@users.noreply.github.com> Date: Mon, 10 Feb 2025 18:03:33 +0100 Subject: [PATCH] feat: update noise reduction to take input noise into account --- tests/Cargo.toml | 2 +- tests/backward_compatibility/shortint.rs | 4 +- .../modulus_switch_noise_reduction.rs | 2 + tfhe/src/c_api/shortint/parameters.rs | 20 ++++-- .../modulus_switch_noise_reduction.rs | 17 ++++- .../test/modulus_switch_noise_reduction.rs | 28 ++++++-- tfhe/src/shortint/parameters/mod.rs | 4 +- .../modulus_switch_noise_reduction.rs | 68 ++++++++++++------- 8 files changed, 104 insertions(+), 41 deletions(-) diff --git a/tests/Cargo.toml b/tests/Cargo.toml index cc3c925b8f..6ecac9ad8c 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -7,7 +7,7 @@ publish = false [dev-dependencies] tfhe = { path = "../tfhe" } tfhe-versionable = { path = "../utils/tfhe-versionable" } -tfhe-backward-compat-data = { git = "https://github.com/zama-ai/tfhe-backward-compat-data.git", branch = "v0.6", default-features = false, features = [ +tfhe-backward-compat-data = { git = "https://github.com/zama-ai/tfhe-backward-compat-data.git", branch = "v0.7", default-features = false, features = [ "load", ] } ron = "0.8" diff --git a/tests/backward_compatibility/shortint.rs b/tests/backward_compatibility/shortint.rs index e578abf028..3bd6fa0e15 100644 --- a/tests/backward_compatibility/shortint.rs +++ b/tests/backward_compatibility/shortint.rs @@ -1,6 +1,6 @@ use std::path::Path; use tfhe::core_crypto::prelude::{ - LweCiphertextCount, NoiseEstimationMeasureBound, RSigmaFactor, TUniform, + LweCiphertextCount, NoiseEstimationMeasureBound, RSigmaFactor, TUniform, Variance, }; use tfhe::shortint::parameters::ModulusSwitchNoiseReductionParams; use tfhe_backward_compat_data::load::{ @@ -50,11 +50,13 @@ pub fn load_params(test_params: &TestParameterSet) -> ClassicPBSParameters { modulus_switch_zeros_count, ms_bound, ms_r_sigma_factor, + ms_input_variance, }| { ModulusSwitchNoiseReductionParams { modulus_switch_zeros_count: LweCiphertextCount(*modulus_switch_zeros_count), ms_bound: NoiseEstimationMeasureBound(*ms_bound), ms_r_sigma_factor: RSigmaFactor(*ms_r_sigma_factor), + ms_input_variance: Variance(*ms_input_variance), } }, ); diff --git a/tfhe/benches/core_crypto/modulus_switch_noise_reduction.rs b/tfhe/benches/core_crypto/modulus_switch_noise_reduction.rs index 1ed4375fbe..4a70440b4b 100644 --- a/tfhe/benches/core_crypto/modulus_switch_noise_reduction.rs +++ b/tfhe/benches/core_crypto/modulus_switch_noise_reduction.rs @@ -11,6 +11,7 @@ fn modulus_switch_noise_reduction(c: &mut Criterion) { let bound = NoiseEstimationMeasureBound((1_u64 << (64 - 1 - 4 - 1)) as f64); let r_sigma_factor = RSigmaFactor(14.658999256586121); let log_modulus = PolynomialSize(2048).to_blind_rotation_input_modulus_log(); + let input_variance = Variance(0.); for count in [10, 50, 100, 1_000, 10_000, 100_000] { let mut boxed_seeder = new_seeder(); @@ -71,6 +72,7 @@ fn modulus_switch_noise_reduction(c: &mut Criterion) { &encryptions_of_zero, r_sigma_factor, bound, + input_variance, log_modulus, ); diff --git a/tfhe/src/c_api/shortint/parameters.rs b/tfhe/src/c_api/shortint/parameters.rs index da71c358c4..64c1ac3584 100644 --- a/tfhe/src/c_api/shortint/parameters.rs +++ b/tfhe/src/c_api/shortint/parameters.rs @@ -3,7 +3,7 @@ pub use crate::core_crypto::commons::parameters::{ DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, }; use crate::core_crypto::commons::parameters::{NoiseEstimationMeasureBound, RSigmaFactor}; -use crate::core_crypto::prelude::LweCiphertextCount; +use crate::core_crypto::prelude::{LweCiphertextCount, Variance}; pub use crate::shortint::parameters::compact_public_key_only::p_fail_2_minus_64::ks_pbs::V0_11_PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; pub use crate::shortint::parameters::key_switching::p_fail_2_minus_64::ks_pbs::V0_11_PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; use crate::shortint::parameters::ModulusSwitchNoiseReductionParams as RustModulusSwitchNoiseReductionParams; @@ -50,6 +50,7 @@ pub struct ModulusSwitchNoiseReductionParams { pub modulus_switch_zeros_count: u32, pub ms_bound: f64, pub ms_r_sigma_factor: f64, + pub ms_input_variance: f64, } #[repr(C)] @@ -67,6 +68,7 @@ impl ModulusSwitchNoiseReductionParamsOption { modulus_switch_zeros_count: 0, ms_bound: 0.0, ms_r_sigma_factor: 0.0, + ms_input_variance: 0.0, }, } } @@ -96,12 +98,17 @@ pub unsafe extern "C" fn modulus_switch_noise_reduction_params_option_some( impl From for RustModulusSwitchNoiseReductionParams { fn from(value: ModulusSwitchNoiseReductionParams) -> Self { + let ModulusSwitchNoiseReductionParams { + modulus_switch_zeros_count, + ms_bound, + ms_r_sigma_factor, + ms_input_variance, + } = value; Self { - modulus_switch_zeros_count: LweCiphertextCount( - value.modulus_switch_zeros_count as usize, - ), - ms_bound: NoiseEstimationMeasureBound(value.ms_bound), - ms_r_sigma_factor: RSigmaFactor(value.ms_r_sigma_factor), + modulus_switch_zeros_count: LweCiphertextCount(modulus_switch_zeros_count as usize), + ms_bound: NoiseEstimationMeasureBound(ms_bound), + ms_r_sigma_factor: RSigmaFactor(ms_r_sigma_factor), + ms_input_variance: Variance(ms_input_variance), } } } @@ -127,6 +134,7 @@ impl RustModulusSwitchNoiseReductionParams { modulus_switch_zeros_count: self.modulus_switch_zeros_count.0 as u32, ms_bound: self.ms_bound.0, ms_r_sigma_factor: self.ms_r_sigma_factor.0, + ms_input_variance: self.ms_input_variance.0, } } } diff --git a/tfhe/src/core_crypto/algorithms/modulus_switch_noise_reduction.rs b/tfhe/src/core_crypto/algorithms/modulus_switch_noise_reduction.rs index 36ae2f9352..1075ff18f6 100644 --- a/tfhe/src/core_crypto/algorithms/modulus_switch_noise_reduction.rs +++ b/tfhe/src/core_crypto/algorithms/modulus_switch_noise_reduction.rs @@ -1,10 +1,13 @@ use super::lwe_ciphertext_add_assign; +use crate::core_crypto::commons::dispersion::{ModularVariance, Variance}; use crate::core_crypto::commons::numeric::CastInto; use crate::core_crypto::commons::parameters::{NoiseEstimationMeasureBound, RSigmaFactor}; use crate::core_crypto::commons::traits::{Container, ContainerMut, UnsignedInteger}; use crate::core_crypto::entities::{LweCiphertext, LweCiphertextList}; use crate::core_crypto::fft_impl::common::modulus_switch; -use crate::core_crypto::prelude::{CiphertextModulusLog, ContiguousEntityContainer}; +use crate::core_crypto::prelude::{ + CiphertextModulusLog, ContiguousEntityContainer, DispersionParameter, +}; use itertools::Itertools; /// Only works on power of 2 moduli @@ -67,6 +70,7 @@ fn measure_modulus_switch_noise_expectancy_variance_for_binary_key( r_sigma_factor: RSigmaFactor, + input_variance: ModularVariance, log_modulus: CiphertextModulusLog, masks: impl Iterator, body: Scalar, @@ -76,7 +80,7 @@ pub fn measure_modulus_switch_noise_estimation_for_binary_key, r_sigma_factor: RSigmaFactor, bound: NoiseEstimationMeasureBound, + input_variance: Variance, log_modulus: CiphertextModulusLog, ) -> CandidateResult where @@ -124,12 +129,17 @@ where "Expected at least one encryption of zero" ); + let modulus = lwe.ciphertext_modulus().raw_modulus_float(); + + let input_variance = input_variance.get_modular_variance(modulus); + let mask = lwe.get_mask(); let mask = mask.as_ref(); let base_measure = measure_modulus_switch_noise_estimation_for_binary_key( r_sigma_factor, + input_variance, log_modulus, mask.iter().copied(), *lwe.get_body().data, @@ -159,6 +169,7 @@ where let measure = measure_modulus_switch_noise_estimation_for_binary_key( r_sigma_factor, + input_variance, log_modulus, mask_sum, body_sum, @@ -187,6 +198,7 @@ pub fn improve_lwe_ciphertext_modulus_switch_noise_for_binary_key, r_sigma_factor: RSigmaFactor, bound: NoiseEstimationMeasureBound, + input_variance: Variance, log_modulus: CiphertextModulusLog, ) where Scalar: UnsignedInteger, @@ -198,6 +210,7 @@ pub fn improve_lwe_ciphertext_modulus_switch_noise_for_binary_key candidate, @@ -357,6 +367,7 @@ fn check_noise_improve_modulus_switch_noise( expected_individual_check_p_success: _, expected_variance_improved, target_upper_bound_p_all_fail_log2: _, + input_variance, } = ms_noise_reduction_test_params; let number_loops = 100_000; @@ -414,6 +425,7 @@ fn check_noise_improve_modulus_switch_noise( &encryptions_of_zero, r_sigma_factor, bound, + input_variance, log_modulus, ); @@ -467,7 +479,9 @@ fn check_noise_improve_modulus_switch_noise( "Expected {expected_base_variance}, got {base_variance}", ); - let expected_variance_improved = expected_variance_improved.0; + let expected_variance_improved = Variance(expected_variance_improved.0 - input_variance.0) + .get_modular_variance(2_f64.powi(64)) + .value; assert!( check_both_ratio_under(variance_improved, expected_variance_improved, 1.03_f64), diff --git a/tfhe/src/shortint/parameters/mod.rs b/tfhe/src/shortint/parameters/mod.rs index 2ac5156fd2..a508a9c5de 100644 --- a/tfhe/src/shortint/parameters/mod.rs +++ b/tfhe/src/shortint/parameters/mod.rs @@ -15,7 +15,7 @@ use crate::core_crypto::fft_impl::fft64::crypto::bootstrap::BootstrapKeyConforma use crate::core_crypto::prelude::{ GlweCiphertextConformanceParameters, KeyswitchKeyConformanceParams, LweCiphertextCount, LweCiphertextListParameters, LweCiphertextParameters, MsDecompressionType, - NoiseEstimationMeasureBound, RSigmaFactor, + NoiseEstimationMeasureBound, RSigmaFactor, Variance, }; use crate::shortint::backward_compatibility::parameters::*; #[cfg(feature = "zk-pok")] @@ -441,6 +441,7 @@ impl PBSParameters { #[derive(Serialize, Copy, Clone, Deserialize, Debug, PartialEq, Versionize)] #[versionize(ShortintParameterSetInnerVersions)] +#[allow(clippy::large_enum_variant)] pub(crate) enum ShortintParameterSetInner { PBSOnly(PBSParameters), WopbsOnly(WopbsParameters), @@ -838,4 +839,5 @@ pub struct ModulusSwitchNoiseReductionParams { pub modulus_switch_zeros_count: LweCiphertextCount, pub ms_bound: NoiseEstimationMeasureBound, pub ms_r_sigma_factor: RSigmaFactor, + pub ms_input_variance: Variance, } diff --git a/tfhe/src/shortint/server_key/modulus_switch_noise_reduction.rs b/tfhe/src/shortint/server_key/modulus_switch_noise_reduction.rs index 7a28606b60..f5153fae35 100644 --- a/tfhe/src/shortint/server_key/modulus_switch_noise_reduction.rs +++ b/tfhe/src/shortint/server_key/modulus_switch_noise_reduction.rs @@ -8,7 +8,7 @@ use crate::core_crypto::commons::parameters::{ use crate::core_crypto::commons::traits::*; use crate::core_crypto::entities::*; use crate::core_crypto::prelude::modulus_switch_noise_reduction::improve_lwe_ciphertext_modulus_switch_noise_for_binary_key; -use crate::core_crypto::prelude::CiphertextModulusLog; +use crate::core_crypto::prelude::{CiphertextModulusLog, Variance}; use crate::shortint::backward_compatibility::server_key::modulus_switch_noise_reduction::*; use crate::shortint::engine::ShortintEngine; use crate::shortint::parameters::ModulusSwitchNoiseReductionParams; @@ -56,6 +56,7 @@ pub struct ModulusSwitchNoiseReductionKey { pub modulus_switch_zeros: LweCiphertextListOwned, pub ms_bound: NoiseEstimationMeasureBound, pub ms_r_sigma_factor: RSigmaFactor, + pub ms_input_variance: Variance, } impl ParameterSetConformant for ModulusSwitchNoiseReductionKey { @@ -66,19 +67,26 @@ impl ParameterSetConformant for ModulusSwitchNoiseReductionKey { modulus_switch_zeros, ms_bound, ms_r_sigma_factor, + ms_input_variance, } = self; - *ms_bound == parameter_set.modulus_switch_noise_reduction_params.ms_bound - && *ms_r_sigma_factor - == parameter_set - .modulus_switch_noise_reduction_params - .ms_r_sigma_factor - && modulus_switch_zeros.entity_count() - == parameter_set - .modulus_switch_noise_reduction_params - .modulus_switch_zeros_count - .0 - && modulus_switch_zeros.lwe_size().to_lwe_dimension() == parameter_set.lwe_dimension + let ModulusSwitchNoiseReductionKeyConformanceParameters { + modulus_switch_noise_reduction_params, + lwe_dimension, + } = parameter_set; + + let ModulusSwitchNoiseReductionParams { + modulus_switch_zeros_count: param_modulus_switch_zeros_count, + ms_bound: param_ms_bound, + ms_r_sigma_factor: param_ms_r_sigma_factor, + ms_input_variance: param_ms_input_variance, + } = modulus_switch_noise_reduction_params; + + ms_bound == param_ms_bound + && ms_r_sigma_factor == param_ms_r_sigma_factor + && ms_input_variance == param_ms_input_variance + && modulus_switch_zeros.entity_count() == param_modulus_switch_zeros_count.0 + && modulus_switch_zeros.lwe_size().to_lwe_dimension() == *lwe_dimension } } @@ -95,6 +103,7 @@ impl ModulusSwitchNoiseReductionKey { &self.modulus_switch_zeros, self.ms_r_sigma_factor, self.ms_bound, + self.ms_input_variance, log_modulus, ); } @@ -106,6 +115,7 @@ pub struct CompressedModulusSwitchNoiseReductionKey { pub modulus_switch_zeros: SeededLweCiphertextListOwned, pub ms_bound: NoiseEstimationMeasureBound, pub ms_r_sigma_factor: RSigmaFactor, + pub ms_input_variance: Variance, } impl ParameterSetConformant for CompressedModulusSwitchNoiseReductionKey { @@ -116,19 +126,26 @@ impl ParameterSetConformant for CompressedModulusSwitchNoiseReductionKey { modulus_switch_zeros, ms_bound, ms_r_sigma_factor, + ms_input_variance, } = self; - *ms_bound == parameter_set.modulus_switch_noise_reduction_params.ms_bound - && *ms_r_sigma_factor - == parameter_set - .modulus_switch_noise_reduction_params - .ms_r_sigma_factor - && modulus_switch_zeros.entity_count() - == parameter_set - .modulus_switch_noise_reduction_params - .modulus_switch_zeros_count - .0 - && modulus_switch_zeros.lwe_size().to_lwe_dimension() == parameter_set.lwe_dimension + let ModulusSwitchNoiseReductionKeyConformanceParameters { + modulus_switch_noise_reduction_params, + lwe_dimension, + } = parameter_set; + + let ModulusSwitchNoiseReductionParams { + modulus_switch_zeros_count: param_modulus_switch_zeros_count, + ms_bound: param_ms_bound, + ms_r_sigma_factor: param_ms_r_sigma_factor, + ms_input_variance: param_ms_input_variance, + } = modulus_switch_noise_reduction_params; + + ms_bound == param_ms_bound + && ms_r_sigma_factor == param_ms_r_sigma_factor + && ms_input_variance == param_ms_input_variance + && modulus_switch_zeros.entity_count() == param_modulus_switch_zeros_count.0 + && modulus_switch_zeros.lwe_size().to_lwe_dimension() == *lwe_dimension } } @@ -144,6 +161,7 @@ impl ModulusSwitchNoiseReductionKey { modulus_switch_zeros_count: count, ms_bound, ms_r_sigma_factor, + ms_input_variance, } = modulus_switch_noise_reduction_params; let lwe_size = secret_key.lwe_dimension().to_lwe_size(); @@ -177,6 +195,7 @@ impl ModulusSwitchNoiseReductionKey { modulus_switch_zeros, ms_bound, ms_r_sigma_factor, + ms_input_variance, } } } @@ -194,6 +213,7 @@ impl CompressedModulusSwitchNoiseReductionKey { modulus_switch_zeros_count: count, ms_bound, ms_r_sigma_factor, + ms_input_variance, } = modulus_switch_noise_reduction_params; let lwe_size = secret_key.lwe_dimension().to_lwe_size(); @@ -227,6 +247,7 @@ impl CompressedModulusSwitchNoiseReductionKey { modulus_switch_zeros, ms_bound, ms_r_sigma_factor, + ms_input_variance, } } @@ -238,6 +259,7 @@ impl CompressedModulusSwitchNoiseReductionKey { .decompress_into_lwe_ciphertext_list(), ms_bound: self.ms_bound, ms_r_sigma_factor: self.ms_r_sigma_factor, + ms_input_variance: self.ms_input_variance, } } }