Skip to content

Commit

Permalink
feat: update noise reduction to take input noise into account
Browse files Browse the repository at this point in the history
  • Loading branch information
mayeul-zama committed Feb 13, 2025
1 parent 4305f8d commit 53a1f35
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 41 deletions.
2 changes: 1 addition & 1 deletion tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ publish = false
[dev-dependencies]
tfhe = { path = "../tfhe" }
tfhe-versionable = { path = "../utils/tfhe-versionable" }
tfhe-backward-compat-data = { git = "https://github.com/zama-ai/tfhe-backward-compat-data.git", branch = "v0.6", default-features = false, features = [
tfhe-backward-compat-data = { git = "https://github.com/zama-ai/tfhe-backward-compat-data.git", branch = "v0.7", default-features = false, features = [
"load",
] }
ron = "0.8"
Expand Down
4 changes: 3 additions & 1 deletion tests/backward_compatibility/shortint.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::path::Path;
use tfhe::core_crypto::prelude::{
LweCiphertextCount, NoiseEstimationMeasureBound, RSigmaFactor, TUniform,
LweCiphertextCount, NoiseEstimationMeasureBound, RSigmaFactor, TUniform, Variance,
};
use tfhe::shortint::parameters::ModulusSwitchNoiseReductionParams;
use tfhe_backward_compat_data::load::{
Expand Down Expand Up @@ -50,11 +50,13 @@ pub fn load_params(test_params: &TestParameterSet) -> ClassicPBSParameters {
modulus_switch_zeros_count,
ms_bound,
ms_r_sigma_factor,
ms_input_variance,
}| {
ModulusSwitchNoiseReductionParams {
modulus_switch_zeros_count: LweCiphertextCount(*modulus_switch_zeros_count),
ms_bound: NoiseEstimationMeasureBound(*ms_bound),
ms_r_sigma_factor: RSigmaFactor(*ms_r_sigma_factor),
ms_input_variance: Variance(*ms_input_variance),
}
},
);
Expand Down
2 changes: 2 additions & 0 deletions tfhe/benches/core_crypto/modulus_switch_noise_reduction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ fn modulus_switch_noise_reduction(c: &mut Criterion) {
let bound = NoiseEstimationMeasureBound((1_u64 << (64 - 1 - 4 - 1)) as f64);
let r_sigma_factor = RSigmaFactor(14.658999256586121);
let log_modulus = PolynomialSize(2048).to_blind_rotation_input_modulus_log();
let input_variance = Variance(0.);

for count in [10, 50, 100, 1_000, 10_000, 100_000] {
let mut boxed_seeder = new_seeder();
Expand Down Expand Up @@ -71,6 +72,7 @@ fn modulus_switch_noise_reduction(c: &mut Criterion) {
&encryptions_of_zero,
r_sigma_factor,
bound,
input_variance,
log_modulus,
);

Expand Down
20 changes: 14 additions & 6 deletions tfhe/src/c_api/shortint/parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ pub use crate::core_crypto::commons::parameters::{
DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize,
};
use crate::core_crypto::commons::parameters::{NoiseEstimationMeasureBound, RSigmaFactor};
use crate::core_crypto::prelude::LweCiphertextCount;
use crate::core_crypto::prelude::{LweCiphertextCount, Variance};
pub use crate::shortint::parameters::compact_public_key_only::p_fail_2_minus_64::ks_pbs::V0_11_PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
pub use crate::shortint::parameters::key_switching::p_fail_2_minus_64::ks_pbs::V0_11_PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use crate::shortint::parameters::ModulusSwitchNoiseReductionParams as RustModulusSwitchNoiseReductionParams;
Expand Down Expand Up @@ -50,6 +50,7 @@ pub struct ModulusSwitchNoiseReductionParams {
pub modulus_switch_zeros_count: u32,
pub ms_bound: f64,
pub ms_r_sigma_factor: f64,
pub ms_input_variance: f64,
}

#[repr(C)]
Expand All @@ -67,6 +68,7 @@ impl ModulusSwitchNoiseReductionParamsOption {
modulus_switch_zeros_count: 0,
ms_bound: 0.0,
ms_r_sigma_factor: 0.0,
ms_input_variance: 0.0,
},
}
}
Expand Down Expand Up @@ -96,12 +98,17 @@ pub unsafe extern "C" fn modulus_switch_noise_reduction_params_option_some(

impl From<ModulusSwitchNoiseReductionParams> for RustModulusSwitchNoiseReductionParams {
fn from(value: ModulusSwitchNoiseReductionParams) -> Self {
let ModulusSwitchNoiseReductionParams {
modulus_switch_zeros_count,
ms_bound,
ms_r_sigma_factor,
ms_input_variance,
} = value;
Self {
modulus_switch_zeros_count: LweCiphertextCount(
value.modulus_switch_zeros_count as usize,
),
ms_bound: NoiseEstimationMeasureBound(value.ms_bound),
ms_r_sigma_factor: RSigmaFactor(value.ms_r_sigma_factor),
modulus_switch_zeros_count: LweCiphertextCount(modulus_switch_zeros_count as usize),
ms_bound: NoiseEstimationMeasureBound(ms_bound),
ms_r_sigma_factor: RSigmaFactor(ms_r_sigma_factor),
ms_input_variance: Variance(ms_input_variance),
}
}
}
Expand All @@ -127,6 +134,7 @@ impl RustModulusSwitchNoiseReductionParams {
modulus_switch_zeros_count: self.modulus_switch_zeros_count.0 as u32,
ms_bound: self.ms_bound.0,
ms_r_sigma_factor: self.ms_r_sigma_factor.0,
ms_input_variance: self.ms_input_variance.0,
}
}
}
Expand Down
17 changes: 15 additions & 2 deletions tfhe/src/core_crypto/algorithms/modulus_switch_noise_reduction.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use super::lwe_ciphertext_add_assign;
use crate::core_crypto::commons::dispersion::{ModularVariance, Variance};
use crate::core_crypto::commons::numeric::CastInto;
use crate::core_crypto::commons::parameters::{NoiseEstimationMeasureBound, RSigmaFactor};
use crate::core_crypto::commons::traits::{Container, ContainerMut, UnsignedInteger};
use crate::core_crypto::entities::{LweCiphertext, LweCiphertextList};
use crate::core_crypto::fft_impl::common::modulus_switch;
use crate::core_crypto::prelude::{CiphertextModulusLog, ContiguousEntityContainer};
use crate::core_crypto::prelude::{
CiphertextModulusLog, ContiguousEntityContainer, DispersionParameter,
};
use itertools::Itertools;

/// Only works on power of 2 moduli
Expand Down Expand Up @@ -67,6 +70,7 @@ fn measure_modulus_switch_noise_expectancy_variance_for_binary_key<Scalar: Unsig

pub fn measure_modulus_switch_noise_estimation_for_binary_key<Scalar: UnsignedInteger>(
r_sigma_factor: RSigmaFactor,
input_variance: ModularVariance,
log_modulus: CiphertextModulusLog,
masks: impl Iterator<Item = Scalar>,
body: Scalar,
Expand All @@ -76,7 +80,7 @@ pub fn measure_modulus_switch_noise_estimation_for_binary_key<Scalar: UnsignedIn
variance,
} = measure_modulus_switch_noise_expectancy_variance_for_binary_key(masks, body, log_modulus);

let std_dev = variance.sqrt();
let std_dev = (variance + input_variance.value).sqrt();

expectancy.abs() + std_dev * r_sigma_factor.0
}
Expand All @@ -97,6 +101,7 @@ pub fn choose_candidate_to_improve_modulus_switch_noise_for_binary_key<Scalar, C
encryptions_of_zero: &LweCiphertextList<C2>,
r_sigma_factor: RSigmaFactor,
bound: NoiseEstimationMeasureBound,
input_variance: Variance,
log_modulus: CiphertextModulusLog,
) -> CandidateResult
where
Expand Down Expand Up @@ -124,12 +129,17 @@ where
"Expected at least one encryption of zero"
);

let modulus = lwe.ciphertext_modulus().raw_modulus_float();

let input_variance = input_variance.get_modular_variance(modulus);

let mask = lwe.get_mask();

let mask = mask.as_ref();

let base_measure = measure_modulus_switch_noise_estimation_for_binary_key(
r_sigma_factor,
input_variance,
log_modulus,
mask.iter().copied(),
*lwe.get_body().data,
Expand Down Expand Up @@ -159,6 +169,7 @@ where

let measure = measure_modulus_switch_noise_estimation_for_binary_key(
r_sigma_factor,
input_variance,
log_modulus,
mask_sum,
body_sum,
Expand Down Expand Up @@ -187,6 +198,7 @@ pub fn improve_lwe_ciphertext_modulus_switch_noise_for_binary_key<Scalar, C1, C2
encryptions_of_zero: &LweCiphertextList<C2>,
r_sigma_factor: RSigmaFactor,
bound: NoiseEstimationMeasureBound,
input_variance: Variance,
log_modulus: CiphertextModulusLog,
) where
Scalar: UnsignedInteger,
Expand All @@ -198,6 +210,7 @@ pub fn improve_lwe_ciphertext_modulus_switch_noise_for_binary_key<Scalar, C1, C2
encryptions_of_zero,
r_sigma_factor,
bound,
input_variance,
log_modulus,
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ struct MsNoiseReductionTestParams {
pub modulus_switch_zeros_count: LweCiphertextCount,
pub bound: NoiseEstimationMeasureBound,
pub r_sigma_factor: RSigmaFactor,
pub input_variance: Variance,
pub log_modulus: CiphertextModulusLog,
pub expected_individual_check_p_success: f64,
pub expected_variance_improved: Variance,
Expand All @@ -30,15 +31,16 @@ struct MsNoiseReductionTestParams {

const TEST_PARAM: MsNoiseReductionTestParams = MsNoiseReductionTestParams {
lwe_dimension: LweDimension(918),
lwe_noise_distribution: DynamicDistribution::new_t_uniform(46),
lwe_noise_distribution: DynamicDistribution::new_t_uniform(45),
ciphertext_modulus: CiphertextModulus::new_native(),
modulus_switch_zeros_count: LweCiphertextCount(1452),
bound: NoiseEstimationMeasureBound((1_u64 << (64 - 1 - 4 - 1)) as f64),
r_sigma_factor: RSigmaFactor(14.658999256586121),
modulus_switch_zeros_count: LweCiphertextCount(1449),
bound: NoiseEstimationMeasureBound(288230376151711744_f64),
r_sigma_factor: RSigmaFactor(13.179852282053789f64),
log_modulus: PolynomialSize(2048).to_blind_rotation_input_modulus_log(),
expected_individual_check_p_success: 0.059282589,
expected_variance_improved: Variance(4.834651119161795e32 - 9.68570987092478e+31),
expected_individual_check_p_success: 0.060923874,
expected_variance_improved: Variance(1.40546154228955e-6),
target_upper_bound_p_all_fail_log2: -130.,
input_variance: Variance(2.63039184094559e-7f64),
};

thread_local! {
Expand Down Expand Up @@ -66,8 +68,13 @@ fn improve_modulus_switch_noise_test_individual_check_p_success(
expected_individual_check_p_success,
expected_variance_improved: _,
target_upper_bound_p_all_fail_log2,
input_variance,
} = params;

let modulus = ciphertext_modulus.raw_modulus_float();

let input_variance = input_variance.get_modular_variance(modulus);

let number_loops = 100_000;

let mut rsc = TestResources::new();
Expand Down Expand Up @@ -132,6 +139,7 @@ fn improve_modulus_switch_noise_test_individual_check_p_success(

let measure = measure_modulus_switch_noise_estimation_for_binary_key(
r_sigma_factor,
input_variance,
log_modulus,
mask_sum,
body_sum,
Expand Down Expand Up @@ -232,6 +240,7 @@ fn improve_modulus_switch_noise_test_average_number_checks(params: MsNoiseReduct
expected_individual_check_p_success,
expected_variance_improved: _,
target_upper_bound_p_all_fail_log2: _,
input_variance,
} = params;

let expected_average_number_checks = 1. / expected_individual_check_p_success;
Expand Down Expand Up @@ -280,6 +289,7 @@ fn improve_modulus_switch_noise_test_average_number_checks(params: MsNoiseReduct
&encryptions_of_zero,
r_sigma_factor,
bound,
input_variance,
log_modulus,
) {
CandidateResult::SatisfiyingBound(candidate) => candidate,
Expand Down Expand Up @@ -357,6 +367,7 @@ fn check_noise_improve_modulus_switch_noise(
expected_individual_check_p_success: _,
expected_variance_improved,
target_upper_bound_p_all_fail_log2: _,
input_variance,
} = ms_noise_reduction_test_params;

let number_loops = 100_000;
Expand Down Expand Up @@ -414,6 +425,7 @@ fn check_noise_improve_modulus_switch_noise(
&encryptions_of_zero,
r_sigma_factor,
bound,
input_variance,
log_modulus,
);

Expand Down Expand Up @@ -467,7 +479,9 @@ fn check_noise_improve_modulus_switch_noise(
"Expected {expected_base_variance}, got {base_variance}",
);

let expected_variance_improved = expected_variance_improved.0;
let expected_variance_improved = Variance(expected_variance_improved.0 - input_variance.0)
.get_modular_variance(2_f64.powi(64))
.value;

assert!(
check_both_ratio_under(variance_improved, expected_variance_improved, 1.03_f64),
Expand Down
4 changes: 3 additions & 1 deletion tfhe/src/shortint/parameters/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::core_crypto::fft_impl::fft64::crypto::bootstrap::BootstrapKeyConforma
use crate::core_crypto::prelude::{
GlweCiphertextConformanceParameters, KeyswitchKeyConformanceParams, LweCiphertextCount,
LweCiphertextListParameters, LweCiphertextParameters, MsDecompressionType,
NoiseEstimationMeasureBound, RSigmaFactor,
NoiseEstimationMeasureBound, RSigmaFactor, Variance,
};
use crate::shortint::backward_compatibility::parameters::*;
#[cfg(feature = "zk-pok")]
Expand Down Expand Up @@ -441,6 +441,7 @@ impl PBSParameters {

#[derive(Serialize, Copy, Clone, Deserialize, Debug, PartialEq, Versionize)]
#[versionize(ShortintParameterSetInnerVersions)]
#[allow(clippy::large_enum_variant)]
pub(crate) enum ShortintParameterSetInner {
PBSOnly(PBSParameters),
WopbsOnly(WopbsParameters),
Expand Down Expand Up @@ -838,4 +839,5 @@ pub struct ModulusSwitchNoiseReductionParams {
pub modulus_switch_zeros_count: LweCiphertextCount,
pub ms_bound: NoiseEstimationMeasureBound,
pub ms_r_sigma_factor: RSigmaFactor,
pub ms_input_variance: Variance,
}
Loading

0 comments on commit 53a1f35

Please sign in to comment.