Skip to content

Commit

Permalink
chore!: use u64 for NoiseLevel
Browse files Browse the repository at this point in the history
Change from usize to u64 for MaxNoiseLevel and NoiseLevel

This is an API break as `new` and `get` handle/returns u64
instead of usize

This is also a potential serialization break depending on the
serializer used (bincode should be fine as it serializes usize as u64)
  • Loading branch information
tmontaigu committed Nov 20, 2024
1 parent 12e7a71 commit a305a7c
Show file tree
Hide file tree
Showing 15 changed files with 39 additions and 34 deletions.
4 changes: 2 additions & 2 deletions tfhe/src/c_api/shortint/parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl TryFrom<ShortintPBSParameters> for crate::shortint::ClassicPBSParameters {
c_params.modulus_power_of_2_exponent,
)?,
max_noise_level: crate::shortint::parameters::MaxNoiseLevel::new(
c_params.max_noise_level,
c_params.max_noise_level as u64,
),
log2_p_fail: c_params.log2_p_fail,
encryption_key_choice: c_params.encryption_key_choice.into(),
Expand Down Expand Up @@ -113,7 +113,7 @@ impl ShortintPBSParameters {
ks_level: rust_params.ks_level.0,
message_modulus: rust_params.message_modulus.0,
carry_modulus: rust_params.carry_modulus.0,
max_noise_level: rust_params.max_noise_level.get(),
max_noise_level: rust_params.max_noise_level.get() as usize,
log2_p_fail: rust_params.log2_p_fail,
modulus_power_of_2_exponent: convert_modulus(rust_params.ciphertext_modulus),
encryption_key_choice: ShortintEncryptionKeyChoice::convert(
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/integer/server_key/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ impl ServerKey {
MaxDegree::from_msg_carry_modulus(self.message_modulus(), self.carry_modulus());
let max_sum_to_full_carry = max_degree.get() / degree.get();

max_sum_to_full_carry.min(self.key.max_noise_level.get())
max_sum_to_full_carry.min(self.key.max_noise_level.get() as usize)
}
}

Expand Down
3 changes: 2 additions & 1 deletion tfhe/src/integer/server_key/radix_parallel/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,8 @@ impl ServerKey {

// Just in case we compare with max noise level, but it should always be num_bits_in_blocks
// with the parameters we provide
let grouping_size = (num_bits_in_block as usize).min(self.key.max_noise_level.get());
let grouping_size =
(num_bits_in_block as usize).min(self.key.max_noise_level.get() as usize);

let mut output_flag = None;

Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/integer/server_key/radix_parallel/block_shift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl ServerKey {
assert!(
self.key
.max_noise_level
.validate(NoiseLevel::NOMINAL * 3usize)
.validate(NoiseLevel::NOMINAL * 3u64)
.is_ok(),
"Parameters must support 2 additions before a PBS"
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ impl ServerKey {
"Compression does not support message_modulus > carry_modulus"
);
assert!(
self.key.max_noise_level.get() >= self.message_modulus().0 + 1,
self.key.max_noise_level.get() >= self.message_modulus().0 as u64 + 1,
"Compression does not support max_noise_level < message_modulus + 1"
);

Expand Down
3 changes: 2 additions & 1 deletion tfhe/src/integer/server_key/radix_parallel/scalar_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,8 @@ impl ServerKey {
let num_bits_in_block = packed_modulus.ilog2();
// Just in case we compare with max noise level, but it should always be num_bits_in_blocks
// with the parameters we provide
let grouping_size = (num_bits_in_block as usize).min(self.key.max_noise_level.get());
let grouping_size =
(num_bits_in_block as usize).min(self.key.max_noise_level.get() as usize);

// In this, we store lookup tables to be used on each 'packing'.
// These LUTs will generate an output that tells whether the packing
Expand Down
3 changes: 2 additions & 1 deletion tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,8 @@ impl ServerKey {
let num_bits_in_block = packed_modulus.ilog2();
// Just in case we compare with max noise level, but it should always be num_bits_in_blocks
// with the parameters we provide
let grouping_size = (num_bits_in_block as usize).min(self.key.max_noise_level.get());
let grouping_size =
(num_bits_in_block as usize).min(self.key.max_noise_level.get() as usize);

// In this, we store lookup tables to be used on each 'packing'.
// These LUTs will generate an output that tells whether the packing
Expand Down
3 changes: 2 additions & 1 deletion tfhe/src/integer/server_key/radix_parallel/sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,8 @@ impl ServerKey {

// Just in case we compare with max noise level, but it should always be num_bits_in_blocks
// with the parameters we provide
let grouping_size = (num_bits_in_block as usize).min(self.key.max_noise_level.get());
let grouping_size =
(num_bits_in_block as usize).min(self.key.max_noise_level.get() as usize);

// Second step
let (mut prepared_blocks, resolved_borrows) = {
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/js_on_wasm_api/shortint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ impl Shortint {
ks_level: usize,
message_modulus: usize,
carry_modulus: usize,
max_noise_level: usize,
max_noise_level: u64,
log2_p_fail: f64,
modulus_power_of_2_exponent: usize,
encryption_key_choice: ShortintEncryptionKeyChoice,
Expand Down
31 changes: 16 additions & 15 deletions tfhe/src/shortint/ciphertext/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@ impl std::error::Error for NotTrivialCiphertextError {}
/// that guarantees the target p-error when doing a PBS on it
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize, Versionize)]
#[versionize(MaxNoiseLevelVersions)]
pub struct MaxNoiseLevel(usize);
pub struct MaxNoiseLevel(u64);

impl MaxNoiseLevel {
pub(crate) const UNKNOWN: Self = Self(usize::MAX);
pub(crate) const UNKNOWN: Self = Self(u64::MAX);

pub const fn new(value: usize) -> Self {
pub const fn new(value: u64) -> Self {
Self(value)
}

pub const fn get(&self) -> usize {
pub const fn get(&self) -> u64 {
self.0
}

Expand All @@ -45,7 +45,8 @@ impl MaxNoiseLevel {
msg_modulus: MessageModulus,
carry_modulus: CarryModulus,
) -> Self {
Self((carry_modulus.0 * msg_modulus.0 - 1) / (msg_modulus.0 - 1))
let level = (carry_modulus.0 * msg_modulus.0 - 1) / (msg_modulus.0 - 1);
Self(level as u64)
}

pub const fn validate(&self, noise_level: NoiseLevel) -> Result<(), CheckError> {
Expand All @@ -64,17 +65,17 @@ impl MaxNoiseLevel {
Debug, PartialEq, Eq, PartialOrd, Ord, Copy, Clone, Serialize, Deserialize, Versionize,
)]
#[versionize(NoiseLevelVersions)]
pub struct NoiseLevel(usize);
pub struct NoiseLevel(u64);

impl NoiseLevel {
pub const NOMINAL: Self = Self(1);
pub const ZERO: Self = Self(0);
// As a safety measure the unknown noise level is set to the max value
pub const UNKNOWN: Self = Self(usize::MAX);
pub const UNKNOWN: Self = Self(u64::MAX);
}

impl NoiseLevel {
pub fn get(&self) -> usize {
pub fn get(&self) -> u64 {
self.0
}
}
Expand All @@ -94,16 +95,16 @@ impl std::ops::Add for NoiseLevel {
}
}

impl std::ops::MulAssign<usize> for NoiseLevel {
fn mul_assign(&mut self, rhs: usize) {
impl std::ops::MulAssign<u64> for NoiseLevel {
fn mul_assign(&mut self, rhs: u64) {
self.0 = self.0.saturating_mul(rhs);
}
}

impl std::ops::Mul<usize> for NoiseLevel {
impl std::ops::Mul<u64> for NoiseLevel {
type Output = Self;

fn mul(mut self, rhs: usize) -> Self::Output {
fn mul(mut self, rhs: u64) -> Self::Output {
self *= rhs;

self
Expand Down Expand Up @@ -272,14 +273,14 @@ mod tests {

let mut rng = thread_rng();

assert_eq!(NoiseLevel::UNKNOWN.0, usize::MAX);
assert_eq!(NoiseLevel::UNKNOWN.0, u64::MAX);

let max_noise_level = NoiseLevel::UNKNOWN;
let random_addend = rng.gen::<usize>();
let random_addend = rng.gen::<u64>();
let add = max_noise_level + NoiseLevel(random_addend);
assert_eq!(add, NoiseLevel::UNKNOWN);

let random_positive_multiplier = rng.gen_range(1usize..=usize::MAX);
let random_positive_multiplier = rng.gen_range(1u64..=u64::MAX);
let mul = max_noise_level * random_positive_multiplier;
assert_eq!(mul, NoiseLevel::UNKNOWN);
}
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/shortint/list_compression/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ impl CompressionKey {

let mut ct = ct.clone();
let max_noise_level =
MaxNoiseLevel::new((ct.noise_level() * message_modulus.0).get());
MaxNoiseLevel::new((ct.noise_level() * message_modulus.0 as u64).get());
unchecked_scalar_mul_assign(&mut ct, message_modulus.0 as u8, max_noise_level);

list.extend(ct.ct.as_ref());
Expand Down
6 changes: 3 additions & 3 deletions tfhe/src/shortint/server_key/bivariate_pbs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ fn ciphertexts_can_be_packed_without_exceeding_space_or_noise(

max_degree.validate(final_degree)?;

let final_noise_level = (lhs.noise_level * factor) + rhs.noise_level;
let final_noise_level = (lhs.noise_level * factor as u64) + rhs.noise_level;

server_key.max_noise_level.validate(final_noise_level)?;

Expand Down Expand Up @@ -513,7 +513,7 @@ impl ScalingOperation {
impl Ciphertext {
fn noise_degree_if_scaled(&self, scale: u8) -> CiphertextNoiseDegree {
CiphertextNoiseDegree {
noise_level: self.noise_level() * scale as usize,
noise_level: self.noise_level() * u64::from(scale),
degree: self.degree * scale as usize,
}
}
Expand All @@ -524,7 +524,7 @@ impl Ciphertext {
} = self.noise_degree_if_bootstrapped();

CiphertextNoiseDegree {
noise_level: noise * scale as usize,
noise_level: noise * u64::from(scale),
degree: degree * scale as usize,
}
}
Expand Down
4 changes: 2 additions & 2 deletions tfhe/src/shortint/server_key/scalar_mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ impl ServerKey {
self.max_degree.validate(Degree::new(final_degree))?;

self.max_noise_level
.validate(ct.noise_level * scalar as usize)?;
.validate(ct.noise_level * u64::from(scalar))?;

Ok(())
}
Expand Down Expand Up @@ -521,7 +521,7 @@ pub(crate) fn unchecked_scalar_mul_assign(
scalar: u8,
max_noise_level: MaxNoiseLevel,
) {
ct.set_noise_level(ct.noise_level() * scalar as usize, max_noise_level);
ct.set_noise_level(ct.noise_level() * u64::from(scalar), max_noise_level);
ct.degree = Degree::new(ct.degree.get() * scalar as usize);

match scalar {
Expand Down
4 changes: 2 additions & 2 deletions tfhe/src/shortint/server_key/tests/noise_level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ fn test_ct_scalar_op_noise_level_propagation(sk: &ServerKey, ct: &Ciphertext, sc
test_fn(&ServerKey::unchecked_scalar_add, &|ct_noise, _| ct_noise);
test_fn(&ServerKey::unchecked_scalar_sub, &|ct_noise, _| ct_noise);
test_fn(&ServerKey::unchecked_scalar_mul, &|ct_noise, scalar| {
ct_noise * scalar as usize
ct_noise * u64::from(scalar)
});
let expected_pbs_noise = if ct.is_trivial() {
NoiseLevel::ZERO
Expand Down Expand Up @@ -281,7 +281,7 @@ fn test_ct_scalar_op_assign_noise_level_propagation(sk: &ServerKey, ct: &Ciphert
});
test_fn(
&ServerKey::unchecked_scalar_mul_assign,
&|ct_noise, scalar| ct_noise * scalar as usize,
&|ct_noise, scalar| ct_noise * u64::from(scalar),
);
let expected_pbs_noise = if ct.is_trivial() {
NoiseLevel::ZERO
Expand Down
2 changes: 1 addition & 1 deletion tfhe/tests/backward_compatibility/shortint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub fn load_params(test_params: &TestParameterSet) -> ClassicPBSParameters {
ks_level: DecompositionLevelCount(test_params.ks_level),
message_modulus: MessageModulus(test_params.message_modulus),
carry_modulus: CarryModulus(test_params.carry_modulus),
max_noise_level: MaxNoiseLevel::new(test_params.max_noise_level),
max_noise_level: MaxNoiseLevel::new(test_params.max_noise_level as u64),
log2_p_fail: test_params.log2_p_fail,
ciphertext_modulus: CiphertextModulus::try_new(test_params.ciphertext_modulus).unwrap(),
encryption_key_choice: {
Expand Down

0 comments on commit a305a7c

Please sign in to comment.