Skip to content

Commit

Permalink
Let users pass the CSPRNG
Browse files Browse the repository at this point in the history
  • Loading branch information
dvdplm committed Nov 13, 2024
1 parent a4b4546 commit ddc4730
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 25 deletions.
9 changes: 6 additions & 3 deletions benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,17 @@ fn make_presieved_num<const L: usize>(rng: &mut impl CryptoRngCore) -> Odd<Uint<
fn bench_uniform_sieve(c: &mut Criterion) {
use crypto_primes::uniform_sieve::UniformGeneratePrime;
let mut group = c.benchmark_group("Uniform sieve");
let mut rng = make_rng();
group.bench_function("(U128) Random prime", |b| {
b.iter(|| U128::generate_prime());
b.iter(|| U128::generate_prime_with_rng(&mut rng));
});
let mut rng = make_rng();
group.bench_function("(U1024) Random prime", |b| {
b.iter(|| U1024::generate_prime());
b.iter(|| U1024::generate_prime_with_rng(&mut rng));
});
let mut rng = make_rng();
group.bench_function("(U2048) Random prime", |b| {
b.iter(|| U2048::generate_prime());
b.iter(|| U2048::generate_prime_with_rng(&mut rng));
});
}
fn bench_sieve(c: &mut Criterion) {
Expand Down
65 changes: 43 additions & 22 deletions src/uniform_sieve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crypto_bigint::{
Bounded, Constants, FixedInteger, Integer, Monty, NonZero, Odd, PowBoundedExp, RandomBits, RandomMod, U1024, U128,
U2048, U256, U4096, U512, U64,
};
use rand::thread_rng;
use rand_core::CryptoRngCore;

use crate::hazmat::precomputed::SMALL_PRIMES;
use crate::{hazmat::binary_gcd, is_prime};
Expand All @@ -24,7 +24,7 @@ where
T: Integer + Constants + Bounded + RandomBits + RandomMod + Copy,
{
/// Generate a prime.
fn generate_prime() -> T;
fn generate_prime_with_rng(rng: &mut impl CryptoRngCore) -> T;
// TODO(dp): missing
// generate_prime_with_rng
// generate_safe_prime
Expand All @@ -35,15 +35,15 @@ macro_rules! impl_generate_prime {
($(($name:ident, $m:expr, $lambda_m:expr, $a_max:expr)),+) => {
$(
impl UniformGeneratePrime<$name> for $name {
fn generate_prime() -> $name {
fn generate_prime_with_rng(rng: &mut impl CryptoRngCore) -> $name {
debug_assert!($m.len() == (2*$name::BITS/8) as usize, "expected m to be {} long, instead it's {}", 2*$name::BITS/8, $m.len());
const M: $name = $name::from_be_hex($m);
const LAMBDA_M: $name = $name::from_be_hex($lambda_m);
const A_MAX: $name = $name::from_be_hex($a_max);
let unit = jp06_unitgen(M, LAMBDA_M);
let unit = jp06_unitgen(rng, M, LAMBDA_M);
let a_max = NonZero::new(A_MAX).expect("A_MAX is pre-calculated and known-good");
algorithm2_faster_but_why(unit, M, &a_max)
// algorithm2(unit, M, &a_max)
algorithm2_faster_but_why(rng, unit, M, &a_max)
// algorithm2(rng, unit, M, &a_max)

}
}
Expand Down Expand Up @@ -111,17 +111,15 @@ impl_generate_prime! {
// 4. Output k
// TODO(dp): probably unify this with "algorithm2" yeah?
#[inline(always)]
fn jp06_unitgen<T>(m: T, lambda_m: T) -> T
fn jp06_unitgen<T>(rng: &mut impl CryptoRngCore, m: T, lambda_m: T) -> T
where
T: FixedInteger + RandomMod,
T::Monty: Retrieve<Output = T> + Copy,
<<T as Integer>::Monty as Monty>::Params: Copy,
{
let mut rng = thread_rng();
// 1. sample k in [1, m[
let m_nz = NonZero::new(m - T::ONE).expect("m is a known, pre-calculated non-zero value");
// TODO(dp): can this also be sped up 4x by switching to random_bits + bound-check&retries?
let k = T::random_mod(&mut rng, &m_nz) + T::ONE;
let k = T::random_mod(rng, &m_nz) + T::ONE;

// 2. set `u = 1-(k^lambda_m) mod m`; rewrite as `(1 - k^lambda_m + m) mod m` if k^lambda_m < m, else compute (k^lambda_m) mod m
let prms = <T as Integer>::Monty::new_params_vartime(
Expand All @@ -139,8 +137,7 @@ where
// go to step 2
// }
while u != zero {
// TODO(dp): can this also be sped up 4x by switching to random_bits + bound-check&retries?
let r = T::random_mod(&mut rng, &m_nz) + T::ONE;
let r = T::random_mod(rng, &m_nz) + T::ONE;
let r = T::Monty::new(r, prms);
k = k + r * u;
u = one - k.pow_bounded_exp(&lambda_m, 32);
Expand All @@ -166,15 +163,14 @@ where
// TODO(dp): probably should unify the two functions.
#[allow(unused)]
#[inline(always)]
fn algorithm2<T>(b: T, m: T, a_max: &NonZero<T>) -> T
fn algorithm2<T>(rng: &mut impl CryptoRngCore, b: T, m: T, a_max: &NonZero<T>) -> T
where
T: Integer + Constants + RandomMod + Copy,
{
let mut rng = thread_rng();
let mut a = T::random_mod(&mut rng, a_max);
let mut a = T::random_mod(rng, a_max);
let mut p = a * m + b;
while !is_prime(&p) {
a = T::random_mod(&mut rng, a_max);
a = T::random_mod(rng, a_max);
p = a * m + b;
}
p
Expand All @@ -184,21 +180,20 @@ where
// It seems like `T::random_mod` is a lot slower than `T::random_bits` + bounds check&retries. A regression? Expected?
#[allow(unused)]
#[inline(always)]
fn algorithm2_faster_but_why<T>(b: T, m: T, a_max: &NonZero<T>) -> T
fn algorithm2_faster_but_why<T>(rng: &mut impl CryptoRngCore, b: T, m: T, a_max: &NonZero<T>) -> T
where
T: Integer + Bounded + RandomBits + RandomMod + Copy,
{
let mut rng = thread_rng();
let a_max_bits = T::BITS - m.leading_zeros();
let mut a = T::random_bits(&mut rng, a_max_bits);
let mut a = T::random_bits(rng, a_max_bits);
while a >= a_max.get() {
a = T::random_bits(&mut rng, a_max_bits)
a = T::random_bits(rng, a_max_bits)
}
let mut p = a * m + b;
while !is_prime(&p) {
a = T::random_bits(&mut rng, a_max_bits);
a = T::random_bits(rng, a_max_bits);
while a >= a_max.get() {
a = T::random_bits(&mut rng, a_max_bits)
a = T::random_bits(rng, a_max_bits)
}
p = a * m + b;
}
Expand Down Expand Up @@ -286,6 +281,32 @@ mod tests {
use core::ops::Div;
use crypto_bigint::{U1024, U128, U2048, U256, U512, U64};
use tracing::{debug, info};

#[test_log::test]
fn debug_jp06() {
type T = U64;
let hex_m = "00000000C0CFD797";
let hex_lambda_m = "000000000000D890";
// type T = U1024;
// let hex_m =
// "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000005797D47C51681549D734E4FC4C3EAF7F";
// let hex_lambda_m =
// "0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000002DE3CB9560";
let m = T::from_be_hex(hex_m);
let lambda_m = T::from_be_hex(hex_lambda_m);

let mut count = 0u64;
let mut start = std::time::Instant::now();
loop {
let _unit = jp06_unitgen(m, lambda_m);
count += 1;
if count % 10_000 == 0 {
debug!("10k iters in {:?}", start.elapsed());
start = std::time::Instant::now();
}
}
}

#[test_log::test]
fn debug_algo2() {
type T = U64;
Expand Down

0 comments on commit ddc4730

Please sign in to comment.