Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

integrate babybear #107

Closed
wants to merge 58 commits into from
Closed
Changes from 1 commit
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
7f13794
feat: BabyBear base field
Sep 16, 2024
cd30560
feat: BabyBearExt4
Sep 16, 2024
13b7184
GKR BabyBear config (no SIMD)
Sep 16, 2024
795b8bc
fix: avx256 feature
Sep 16, 2024
c93b8b8
feat: BabyBear AVX512
Sep 16, 2024
e203f8f
feat: BabyBearExt4 AVX512
Sep 16, 2024
46fcba2
fix: sumcheck_helper poly degree
Sep 17, 2024
6af2318
temp: cargo config
Sep 17, 2024
69bd157
feat: BabyBear config
Sep 17, 2024
0f11134
wip
zhenfeizhang Sep 19, 2024
ef6eb50
wip
zhenfeizhang Sep 19, 2024
0fba45a
wip
zhenfeizhang Sep 19, 2024
04e5cb5
feat: avx256 BabyBear
Sep 19, 2024
5368adb
feat: neon BabyBear
Sep 19, 2024
756edfc
wip
zhenfeizhang Sep 19, 2024
b4e7899
Test vectors for BabyBearExt4
Sep 19, 2024
f104b47
FIX: field for babybearext4x16
Sep 19, 2024
54c2026
feat: Babybear ext3
Sep 16, 2024
dfb0e42
chore: add missing babybear configs
Sep 19, 2024
7271121
update
zhenfeizhang Sep 19, 2024
044d649
fmt
zhenfeizhang Sep 19, 2024
e1fcc02
chore: add babybear to benchmark, correctness test
Sep 19, 2024
6ddbc57
wip
zhenfeizhang Sep 19, 2024
8505543
chore: move p3 imports to workspace
Sep 19, 2024
56f31a2
chore: double instead of mul by W
Sep 19, 2024
bbbc463
chore: remove dummy simd impls
Sep 19, 2024
9172aad
update CI
zhenfeizhang Sep 23, 2024
443119d
ipdate
zhenfeizhang Sep 23, 2024
d23b587
simplify CI
zhenfeizhang Sep 23, 2024
d9dc03c
fix
zhenfeizhang Sep 23, 2024
3fb9815
fix
zhenfeizhang Sep 23, 2024
c8703ce
add arith
zhenfeizhang Sep 23, 2024
e19d26c
reenable bench
zhenfeizhang Sep 23, 2024
940cf17
Merge branch 'zz/fix-ci-avx-target' into HEAD
zhenfeizhang Sep 23, 2024
0bf673b
Merge branch 'zz/fix-ci-avx-target' into zz/field_refactor
zhenfeizhang Sep 23, 2024
05e2b1e
wip
zhenfeizhang Sep 23, 2024
e129e64
almost there
zhenfeizhang Sep 24, 2024
e49820e
clippy and fmt
zhenfeizhang Sep 24, 2024
6064eb7
more fixes
zhenfeizhang Sep 24, 2024
1781036
fmt
zhenfeizhang Sep 24, 2024
8a7011d
fix neon
zhenfeizhang Sep 24, 2024
cc18ea1
Update rust-toolchain
zhenfeizhang Sep 24, 2024
b96f7be
wip
zhenfeizhang Sep 24, 2024
2290c0a
Merge remote-tracking branch 'refs/remotes/origin/zz/field_refactor' …
zhenfeizhang Sep 24, 2024
342445b
wip
zhenfeizhang Sep 24, 2024
404c0af
older nightly version
zhenfeizhang Sep 24, 2024
7e5a14f
fix bench module
zhenfeizhang Sep 24, 2024
eb45863
Merge branch 'main' into zz/field_refactor
zhenfeizhang Sep 24, 2024
3681256
Merge branch 'baby-bear' into zz/field_refactor
zhenfeizhang Sep 24, 2024
35f4987
move babybear to its own crate
zhenfeizhang Sep 24, 2024
8d56926
fmt
zhenfeizhang Sep 24, 2024
b4e46f9
fix neon
zhenfeizhang Sep 24, 2024
b641093
fmt
zhenfeizhang Sep 24, 2024
f8e59e6
fix size bug
zhenfeizhang Sep 24, 2024
9bfbbec
fix CI
zhenfeizhang Sep 24, 2024
59d4645
Merge branch 'main' into zz/integrate_babybear
zhenfeizhang Sep 25, 2024
4115aa0
Update ci.yml
zhenfeizhang Sep 25, 2024
ad99dc7
clippy
zhenfeizhang Sep 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: Babybear ext3
enpsi committed Sep 19, 2024
commit 54c2026766c71b3cc774f2d6c9bef17e9ff73869
8 changes: 6 additions & 2 deletions arith/src/extension_field.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
mod baby_bear_ext;
mod baby_bear_ext3;
mod baby_bear_ext3x16;
mod baby_bear_ext4;
mod baby_bear_ext4x16;
mod fr_ext;
// mod gf2_127;
@@ -8,7 +10,9 @@ mod m31_ext;
mod m31_ext3x16;
use crate::{Field, FieldSerde};

pub use baby_bear_ext::*;
pub use baby_bear_ext3::*;
pub use baby_bear_ext3x16::BabyBearExt3x16;
pub use baby_bear_ext4::*;
pub use baby_bear_ext4x16::BabyBearExt4x16;
// pub use gf2_127::*;
pub use gf2_128::*;
364 changes: 364 additions & 0 deletions arith/src/extension_field/baby_bear_ext3.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,364 @@
use super::ExtensionField;
use crate::{
field_common, BabyBear, Field, FieldSerde, FieldSerdeResult, SimdField, BABYBEAR_MODULUS,
};
use core::{
iter::{Product, Sum},
ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
};

#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
pub struct BabyBearExt3 {
pub v: [BabyBear; 3],
}

field_common!(BabyBearExt3);

impl FieldSerde for BabyBearExt3 {
const SERIALIZED_SIZE: usize = 32 / 8 * 3;

#[inline(always)]
fn serialize_into<W: std::io::Write>(&self, mut writer: W) -> FieldSerdeResult<()> {
self.v[0].serialize_into(&mut writer)?;
self.v[1].serialize_into(&mut writer)?;
self.v[2].serialize_into(&mut writer)?;
Ok(())
}

#[inline(always)]
fn deserialize_from<R: std::io::Read>(mut reader: R) -> FieldSerdeResult<Self> {
Ok(Self {
v: [
BabyBear::deserialize_from(&mut reader)?,
BabyBear::deserialize_from(&mut reader)?,
BabyBear::deserialize_from(&mut reader)?,
],
})
}

#[inline]
fn try_deserialize_from_ecc_format<R: std::io::Read>(mut reader: R) -> FieldSerdeResult<Self> {
let mut buf = [0u8; 32];
reader.read_exact(&mut buf)?;
assert!(
buf.iter().skip(4).all(|&x| x == 0),
"non-zero byte found in witness byte"
);
Ok(Self::from(u32::from_le_bytes(buf[..4].try_into().unwrap())))
}
}

impl Field for BabyBearExt3 {
const NAME: &'static str = "Baby Bear Extension 4";

const SIZE: usize = 32 / 8 * 4;

const FIELD_SIZE: usize = 32 * 4;

const ZERO: Self = Self {
v: [BabyBear::ZERO; 3],
};

const ONE: Self = Self {
v: [BabyBear::ONE, BabyBear::ZERO, BabyBear::ZERO],
};

const INV_2: Self = Self {
v: [BabyBear::INV_2, BabyBear::ZERO, BabyBear::ZERO],
};

fn zero() -> Self {
Self::ZERO
}

fn is_zero(&self) -> bool {
*self == Self::ZERO
}

fn one() -> Self {
Self::ONE
}

fn random_unsafe(mut rng: impl rand::RngCore) -> Self {
Self {
v: [
BabyBear::random_unsafe(&mut rng),
BabyBear::random_unsafe(&mut rng),
BabyBear::random_unsafe(&mut rng),
],
}
}

fn random_bool(rng: impl rand::RngCore) -> Self {
Self {
v: [BabyBear::random_bool(rng), BabyBear::ZERO, BabyBear::ZERO],
}
}

fn exp(&self, exponent: u128) -> Self {
let mut e = exponent;
let mut res = Self::one();
let mut t = *self;
while e != 0 {
let b = e & 1;
if b == 1 {
res *= t;
}
t = t * t;
e >>= 1;
}
res
}

fn inv(&self) -> Option<Self> {
if self.is_zero() {
None
} else {
// TODO: Implement a more efficient inversion
let e = (BABYBEAR_MODULUS as u128).pow(3) - 2;
Some(self.exp(e as u128))
}
}

#[inline(always)]
fn square(&self) -> Self {
Self {
v: square_internal(&self.v),
}
}

fn as_u32_unchecked(&self) -> u32 {
self.v[0].as_u32_unchecked()
}

fn from_uniform_bytes(bytes: &[u8; 32]) -> Self {
let v1 = BabyBear::from(u32::from_be_bytes(bytes[0..4].try_into().unwrap()));
let v2 = BabyBear::from(u32::from_be_bytes(bytes[4..8].try_into().unwrap()));
let v3 = BabyBear::from(u32::from_be_bytes(bytes[8..12].try_into().unwrap()));
Self { v: [v1, v2, v3] }
}
}

impl ExtensionField for BabyBearExt3 {
const DEGREE: usize = 3;

const W: u32 = 2;

const X: Self = Self {
v: [BabyBear::ZERO, BabyBear::ONE, BabyBear::ZERO],
};

type BaseField = BabyBear;

#[inline(always)]
fn mul_by_base_field(&self, base: &Self::BaseField) -> Self {
let mut res = self.v;
res[0] *= base;
res[1] *= base;
res[2] *= base;
Self { v: res }
}

#[inline(always)]
fn add_by_base_field(&self, base: &Self::BaseField) -> Self {
let mut res = self.v;
res[0] += base;
Self { v: res }
}

#[inline(always)]
fn mul_by_x(&self) -> Self {
let w = BabyBear::from(Self::W);
Self {
v: [self.v[2] * w, self.v[0], self.v[1]],
}
}
}

// TODO: Actual SIMD impl
// This is a dummy implementation to satisfy trait bounds
impl SimdField for BabyBearExt3 {
type Scalar = Self;

fn scale(&self, challenge: &Self::Scalar) -> Self {
self * challenge
}

fn pack(base_vec: &[Self::Scalar]) -> Self {
debug_assert!(base_vec.len() == 1);
base_vec[0]
}

fn unpack(&self) -> Vec<Self::Scalar> {
vec![*self]
}

fn pack_size() -> usize {
1
}
}

impl Add<BabyBear> for BabyBearExt3 {
type Output = Self;

fn add(self, rhs: BabyBear) -> Self::Output {
self + BabyBearExt3::from(rhs)
}
}

impl Neg for BabyBearExt3 {
type Output = Self;

fn neg(self) -> Self::Output {
let mut v = self.v;
v[0] = -v[0];
v[1] = -v[1];
v[2] = -v[2];
Self { v }
}
}

impl From<u32> for BabyBearExt3 {
fn from(val: u32) -> Self {
Self {
v: [BabyBear::new(val), BabyBear::ZERO, BabyBear::ZERO],
}
}
}

impl BabyBearExt3 {
#[inline(always)]
pub fn to_base_field(&self) -> BabyBear {
assert!(
self.v[1].is_zero() && self.v[2].is_zero(),
"BabyBearExt3 cannot be converted to base field"
);

self.to_base_field_unsafe()
}

#[inline(always)]
pub fn to_base_field_unsafe(&self) -> BabyBear {
self.v[0]
}

#[inline(always)]
pub fn as_u32_array(&self) -> [u32; 3] {
// Note: as_u32_unchecked converts to canonical form
[
self.v[0].as_u32_unchecked(),
self.v[1].as_u32_unchecked(),
self.v[2].as_u32_unchecked(),
]
}
}

impl From<BabyBear> for BabyBearExt3 {
#[inline(always)]
fn from(val: BabyBear) -> Self {
Self {
v: [val, BabyBear::ZERO, BabyBear::ZERO],
}
}
}

impl From<&BabyBear> for BabyBearExt3 {
#[inline(always)]
fn from(val: &BabyBear) -> Self {
(*val).into()
}
}

impl From<BabyBearExt3> for BabyBear {
#[inline(always)]
fn from(x: BabyBearExt3) -> Self {
x.to_base_field()
}
}

impl From<&BabyBearExt3> for BabyBear {
#[inline(always)]
fn from(x: &BabyBearExt3) -> Self {
x.to_base_field()
}
}

#[inline(always)]
fn add_internal(a: &BabyBearExt3, b: &BabyBearExt3) -> BabyBearExt3 {
let mut vv = a.v;
vv[0] += b.v[0];
vv[1] += b.v[1];
vv[2] += b.v[2];
BabyBearExt3 { v: vv }
}

#[inline(always)]
fn sub_internal(a: &BabyBearExt3, b: &BabyBearExt3) -> BabyBearExt3 {
let mut vv = a.v;
vv[0] -= b.v[0];
vv[1] -= b.v[1];
vv[2] -= b.v[2];
BabyBearExt3 { v: vv }
}

// polynomial mod x^3 - w
//
// (a0 + a1 x + a2 x^2) * (b0 + b1 x + b2 x^2)
// = a0 b0 + (a0 b1 + a1 b0) x + (a0 b2 + a1 b1 + a2 b0) x^2 + (a1 b2 + a2 b1) x^3 + a2 b2 x^4
// = a0 b0 + w * (a1 b2 + a2 b1)
// + {(a0 b1 + a1 b0) + w * a2 b2} x
// + {(a0 b2 + a1 b1 + a2 b0)} x^2
#[inline(always)]
fn mul_internal(a: &BabyBearExt3, b: &BabyBearExt3) -> BabyBearExt3 {
let w = BabyBear::new(BabyBearExt3::W);
let a = a.v;
let b = b.v;
let mut res = [BabyBear::default(); 3];
res[0] = a[0] * b[0] + w * (a[1] * b[2] + a[2] * b[1]);
res[1] = (a[0] * b[1] + a[1] * b[0]) + w * a[2] * b[2];
res[2] = a[0] * b[2] + a[1] * b[1] + a[2] * b[0];
BabyBearExt3 { v: res }
}

#[inline(always)]
fn square_internal(a: &[BabyBear; 3]) -> [BabyBear; 3] {
let w = BabyBear::new(BabyBearExt3::W);
let mut res = [BabyBear::default(); 3];
res[0] = a[0].square() + w * (a[1] * a[2]).double();
res[1] = (a[0] * a[1]).double() + w * a[2].square();
res[2] = a[0] * a[2].double() + a[1].square();
res
}

/// Compare to test vectors generated using SageMath
#[test]
fn test_compare_sage() {
let a = BabyBearExt3 {
v: [BabyBear::new(1), BabyBear::new(2), BabyBear::new(3)],
};
let b = BabyBearExt3 {
v: [BabyBear::new(4), BabyBear::new(5), BabyBear::new(6)],
};
let expected_prod = BabyBearExt3 {
v: [BabyBear::new(58), BabyBear::new(49), BabyBear::new(28)],
};
assert_eq!(a * b, expected_prod);

let a_inv = BabyBearExt3 {
v: [
BabyBear::new(1628709509),
BabyBear::new(1108427305),
BabyBear::new(950080547),
],
};
assert_eq!(a.inv().unwrap(), a_inv);

let a_to_eleven = BabyBearExt3 {
v: [
BabyBear::new(164947539),
BabyBear::new(1313663563),
BabyBear::new(627537568),
],
};
assert_eq!(a.exp(11), a_to_eleven);
}
356 changes: 356 additions & 0 deletions arith/src/extension_field/baby_bear_ext3x16.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,356 @@
use crate::{
field_common, BabyBear, BabyBearExt3, BabyBearx16, ExtensionField, Field, FieldSerde,
FieldSerdeResult, SimdField,
};
use std::{
io::{Read, Write},
iter::{Product, Sum},
ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
};

#[derive(Debug, Clone, Copy, Default, PartialEq)]
pub struct BabyBearExt3x16 {
pub v: [BabyBearx16; 3],
}

field_common!(BabyBearExt3x16);

impl FieldSerde for BabyBearExt3x16 {
const SERIALIZED_SIZE: usize = 512 / 8 * 3;

#[inline(always)]
fn serialize_into<W: Write>(&self, mut writer: W) -> FieldSerdeResult<()> {
self.v[0].serialize_into(&mut writer)?;
self.v[1].serialize_into(&mut writer)?;
self.v[2].serialize_into(&mut writer)
}

#[inline(always)]
fn deserialize_from<R: Read>(mut reader: R) -> FieldSerdeResult<Self> {
Ok(Self {
v: [
BabyBearx16::deserialize_from(&mut reader)?,
BabyBearx16::deserialize_from(&mut reader)?,
BabyBearx16::deserialize_from(&mut reader)?,
],
})
}

#[inline(always)]
fn try_deserialize_from_ecc_format<R: Read>(mut reader: R) -> FieldSerdeResult<Self> {
Ok(Self {
v: [
BabyBearx16::try_deserialize_from_ecc_format(&mut reader)?,
BabyBearx16::zero(),
BabyBearx16::zero(),
],
})
}
}

impl SimdField for BabyBearExt3x16 {
type Scalar = BabyBearExt3;

#[inline]
fn scale(&self, challenge: &Self::Scalar) -> Self {
*self * *challenge
}

#[inline(always)]
fn pack(base_vec: &[Self::Scalar]) -> Self {
debug_assert!(base_vec.len() == Self::pack_size());
let mut v0s = vec![];
let mut v1s = vec![];
let mut v2s = vec![];

for scalar in base_vec {
v0s.push(scalar.v[0]);
v1s.push(scalar.v[1]);
v2s.push(scalar.v[2]);
}

Self {
v: [
BabyBearx16::pack(&v0s),
BabyBearx16::pack(&v1s),
BabyBearx16::pack(&v2s),
],
}
}

#[inline(always)]
fn unpack(&self) -> Vec<Self::Scalar> {
let v0s = self.v[0].unpack();
let v1s = self.v[1].unpack();
let v2s = self.v[2].unpack();

v0s.into_iter()
.zip(v1s)
.zip(v2s)
.map(|((v0, v1), v2)| BabyBearExt3 { v: [v0, v1, v2] })
.collect()
}

#[inline(always)]
fn pack_size() -> usize {
BabyBearx16::pack_size()
}
}

impl From<BabyBearx16> for BabyBearExt3x16 {
#[inline(always)]
fn from(x: BabyBearx16) -> Self {
Self {
v: [x, BabyBearx16::ZERO, BabyBearx16::ZERO],
}
}
}

impl ExtensionField for BabyBearExt3x16 {
const DEGREE: usize = 3;

const W: u32 = 2;

const X: Self = Self {
v: [BabyBearx16::ZERO, BabyBearx16::ONE, BabyBearx16::ZERO],
};

type BaseField = BabyBearx16;

#[inline(always)]
fn mul_by_base_field(&self, base: &Self::BaseField) -> Self {
Self {
v: [self.v[0] * base, self.v[1] * base, self.v[2] * base],
}
}

#[inline(always)]
fn add_by_base_field(&self, base: &Self::BaseField) -> Self {
Self {
v: [self.v[0] + base, self.v[1], self.v[2]],
}
}

#[inline(always)]
fn mul_by_x(&self) -> Self {
Self {
v: [self.v[2] * BabyBearx16::from(Self::W), self.v[0], self.v[1]],
}
}
}

impl From<BabyBearExt3> for BabyBearExt3x16 {
#[inline(always)]
fn from(x: BabyBearExt3) -> Self {
Self {
v: [
BabyBearx16::pack_full(x.v[0]),
BabyBearx16::pack_full(x.v[1]),
BabyBearx16::pack_full(x.v[2]),
],
}
}
}

impl Field for BabyBearExt3x16 {
#[cfg(target_arch = "x86_64")]
const NAME: &'static str = "AVX Vectorized BabyBear Extension 3";

#[cfg(target_arch = "aarch64")]
const NAME: &'static str = "NEON Vectorized BabyBear Extension 3";

const SIZE: usize = 512 / 8 * 3;

const FIELD_SIZE: usize = 32 * 3;

const ZERO: Self = Self {
v: [BabyBearx16::ZERO; 3],
};

const ONE: Self = Self {
v: [BabyBearx16::ONE, BabyBearx16::ZERO, BabyBearx16::ZERO],
};

const INV_2: Self = Self {
v: [BabyBearx16::INV_2, BabyBearx16::ZERO, BabyBearx16::ZERO],
};

#[inline(always)]
fn zero() -> Self {
Self::ZERO
}

#[inline(always)]
fn is_zero(&self) -> bool {
*self == Self::ZERO
}

#[inline(always)]
fn one() -> Self {
Self::ONE
}

#[inline(always)]
fn random_unsafe(mut rng: impl rand::RngCore) -> Self {
Self {
v: [
BabyBearx16::random_unsafe(&mut rng),
BabyBearx16::random_unsafe(&mut rng),
BabyBearx16::random_unsafe(&mut rng),
],
}
}

#[inline(always)]
fn random_bool(mut rng: impl rand::RngCore) -> Self {
Self {
v: [
BabyBearx16::random_bool(&mut rng),
BabyBearx16::random_bool(&mut rng),
BabyBearx16::random_bool(&mut rng),
],
}
}

#[inline(always)]
fn square(&self) -> Self {
Self {
v: square_internal(&self.v),
}
}

fn exp(&self, _: u128) -> Self {
unimplemented!()
}

fn inv(&self) -> Option<Self> {
unimplemented!()
}

fn as_u32_unchecked(&self) -> u32 {
unimplemented!("self is a vector, cannot convert to u32")
}

fn from_uniform_bytes(_: &[u8; 32]) -> Self {
unimplemented!("vec babybear: cannot convert from 32 bytes")
}
}

impl Mul<BabyBearExt3> for BabyBearExt3x16 {
type Output = Self;

#[inline(always)]
fn mul(self, rhs: BabyBearExt3) -> Self {
// polynomial mod x^3 - w
//
// (a0 + a1 x + a2 x^2) * (b0 + b1 x + b2 x^2)
// = a0 b0 + (a0 b1 + a1 b0) x + (a0 b2 + a1 b1 + a2 b0) x^2 + (a1 b2 + a2 b1) x^3 + a2 b2 x^4
// = a0 b0 + w * (a1 b2 + a2 b1)
// + {(a0 b1 + a1 b0) + w * a2 b2} x
// + {(a0 b2 + a1 b1 + a2 b0)} x^2

let w = BabyBear::from(BabyBearExt3x16::W);
let mut res = [BabyBearx16::ZERO; 3];
res[0] = self.v[0] * rhs.v[0] + (self.v[1] * rhs.v[2] + self.v[2] * rhs.v[1]) * w;
res[1] = self.v[0] * rhs.v[1] + self.v[1] * rhs.v[0] + self.v[2] * rhs.v[2] * w;
res[2] = self.v[0] * rhs.v[2] + self.v[1] * rhs.v[1] + self.v[2] * rhs.v[0];
Self { v: res }
}
}

impl Mul<BabyBear> for BabyBearExt3x16 {
type Output = Self;

#[inline(always)]
fn mul(self, rhs: BabyBear) -> Self {
Self {
v: [self.v[0] * rhs, self.v[1] * rhs, self.v[2] * rhs],
}
}
}

impl Add<BabyBear> for BabyBearExt3x16 {
type Output = Self;

#[inline(always)]
fn add(self, rhs: BabyBear) -> Self {
Self {
v: [self.v[0] + rhs, self.v[1], self.v[2]],
}
}
}

impl Neg for BabyBearExt3x16 {
type Output = Self;

#[inline(always)]
fn neg(self) -> Self {
Self {
v: [-self.v[0], -self.v[1], -self.v[2]],
}
}
}

impl From<u32> for BabyBearExt3x16 {
#[inline(always)]
fn from(value: u32) -> Self {
Self {
v: [
BabyBearx16::from(value),
BabyBearx16::ZERO,
BabyBearx16::ZERO,
],
}
}
}

#[inline(always)]
fn add_internal(a: &BabyBearExt3x16, b: &BabyBearExt3x16) -> BabyBearExt3x16 {
let mut vv = a.v;
vv[0] += b.v[0];
vv[1] += b.v[1];
vv[2] += b.v[2];

BabyBearExt3x16 { v: vv }
}

#[inline(always)]
fn sub_internal(a: &BabyBearExt3x16, b: &BabyBearExt3x16) -> BabyBearExt3x16 {
let mut vv = a.v;
vv[0] -= b.v[0];
vv[1] -= b.v[1];
vv[2] -= b.v[2];

BabyBearExt3x16 { v: vv }
}

#[inline(always)]
fn mul_internal(a: &BabyBearExt3x16, b: &BabyBearExt3x16) -> BabyBearExt3x16 {
// polynomial mod x^3 - w
//
// (a0 + a1 x + a2 x^2) * (b0 + b1 x + b2 x^2)
// = a0 b0 + (a0 b1 + a1 b0) x + (a0 b2 + a1 b1 + a2 b0) x^2 + (a1 b2 + a2 b1) x^3 + a2 b2 x^4
// = a0 b0 + w * (a1 b2 + a2 b1)
// + {(a0 b1 + a1 b0) + w * a2 b2} x
// + {(a0 b2 + a1 b1 + a2 b0)} x^2
let a = &a.v;
let b = &b.v;
let mut res = [BabyBearx16::default(); 3];
let w = BabyBear::from(BabyBearExt3x16::W);
res[0] = a[0] * b[0] + (a[1] * b[2] + a[2] * b[1]) * w;
res[1] = (a[0] * b[1] + a[1] * b[0]) + a[2] * b[2] * w;
res[2] = a[0] * b[2] + a[1] * b[1] + a[2] * b[0];

BabyBearExt3x16 { v: res }
}

#[inline(always)]
fn square_internal(a: &[BabyBearx16; 3]) -> [BabyBearx16; 3] {
let mut res = [BabyBearx16::default(); 3];
let w = BabyBear::from(BabyBearExt3x16::W);
res[0] = a[0].square() + (a[1] * a[2]).double() * w;
res[1] = (a[0] * a[1]).double() + a[2].square() * w;
res[2] = a[0] * a[2].double() + a[1].square();

res
}
Original file line number Diff line number Diff line change
@@ -45,6 +45,7 @@ impl FieldSerde for BabyBearExt4 {
buf.iter().skip(4).all(|&x| x == 0),
"non-zero byte found in witness byte"
);
// ? this can only read in a base field element, do we ever need to read in an ext'n field?
Ok(Self::from(u32::from_le_bytes(buf[..4].try_into().unwrap())))
}
}
2 changes: 2 additions & 0 deletions arith/src/field/baby_bear.rs
Original file line number Diff line number Diff line change
@@ -26,6 +26,8 @@ pub use baby_bear_avx::AVXBabyBear;
#[repr(transparent)]
pub struct BabyBear(P3BabyBear);

pub const BABYBEAR_MODULUS: u32 = 0x78000001;

field_common!(BabyBear);

impl BabyBear {
11 changes: 10 additions & 1 deletion arith/src/tests/baby_bear_ext.rs
Original file line number Diff line number Diff line change
@@ -2,11 +2,20 @@ use super::{
extension_field::random_extension_field_tests, field::random_field_tests,
simd_field::random_simd_field_tests,
};
use crate::{BabyBearExt4, BabyBearExt4x16};
use crate::{BabyBearExt3, BabyBearExt3x16, BabyBearExt4, BabyBearExt4x16};

// CMD: RUSTFLAGS="-C target-feature=+avx512f" cargo test --package arith --lib -- tests::baby_bear_ext::test_field --exact --show-output
#[test]
fn test_field() {
// Deg 3
random_field_tests::<BabyBearExt3>("Baby Bear Ext3".to_string());
random_extension_field_tests::<BabyBearExt3>("Baby Bear Ext3".to_string());

random_field_tests::<BabyBearExt3x16>("Simd Baby Bear Ext3".to_string());
random_extension_field_tests::<BabyBearExt3x16>("Simd Baby Bear Ext3".to_string());
random_simd_field_tests::<BabyBearExt3x16>("Simd Baby Bear Ext3".to_string());

// Deg 4
random_field_tests::<BabyBearExt4>("Baby Bear Ext4".to_string());
random_extension_field_tests::<BabyBearExt4>("Baby Bear Ext4".to_string());

11 changes: 11 additions & 0 deletions arith/src/tests/test_vectors.sage
Original file line number Diff line number Diff line change
@@ -6,6 +6,17 @@ p = 2**31 - 2**27 + 1
F = GF(p)
R.<x> = F[]

# Degree 3 extension
K.<a> = F.extension(x^3 - 2)
b = 1 + 2*a + 3*a^2
c = 4 + 5*a + 6*a^2
print("BabyBear Degree 3 Extension")
print(f"b = {b}")
print(f"c = {c}")
print(f"b*c = {b*c}")
print(f"b^(-1) = {b^(-1)}")
print(f"b^(11) = {b^(11)}")

# Degree 4 extension
K.<a> = F.extension(x^4 - 11)
b = 1 + 2*a + 3*a^2 + 4*a^3
2 changes: 1 addition & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@ mod gf2_ext_sha2;
mod m31_ext_keccak;
mod m31_ext_sha2;

pub use baby_bear_keccak::BabyBearConfigKeccak;
pub use baby_bear_keccak::BabyBearExt4ConfigKeccak;
pub use bn254_keccak::BN254ConfigKeccak;
pub use bn254_sha2::BN254ConfigSha2;
pub use gf2_ext_keccak::GF2ExtConfigKeccak;
4 changes: 2 additions & 2 deletions src/config/baby_bear_keccak.rs
Original file line number Diff line number Diff line change
@@ -3,9 +3,9 @@ use crate::Keccak256hasher;
use arith::{BabyBear, BabyBearExt4, BabyBearExt4x16, BabyBearx16, ExtensionField};

#[derive(Debug, Clone, PartialEq, Default)]
pub struct BabyBearConfigKeccak;
pub struct BabyBearExt4ConfigKeccak;

impl GKRConfig for BabyBearConfigKeccak {
impl GKRConfig for BabyBearExt4ConfigKeccak {
type CircuitField = BabyBear;

type ChallengeField = BabyBearExt4;
8 changes: 4 additions & 4 deletions tests/gkr_correctness.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use expander_rs::{utils::*, FieldType};
use expander_rs::{
BN254ConfigKeccak, BN254ConfigSha2, BabyBearConfigKeccak, Circuit, CircuitLayer, Config,
BN254ConfigKeccak, BN254ConfigSha2, BabyBearExt4ConfigKeccak, Circuit, CircuitLayer, Config,
GF2ExtConfigKeccak, GF2ExtConfigSha2, GKRConfig, GKRScheme, GateAdd, GateMul,
M31ExtConfigKeccak, M31ExtConfigSha2, Prover, Verifier,
};
@@ -51,9 +51,9 @@ fn gen_simple_circuit<C: GKRConfig>() -> Circuit<C> {

#[test]
fn test_gkr_correctness() {
test_gkr_correctness_helper::<BabyBearConfigKeccak>(&Config::<BabyBearConfigKeccak>::new(
GKRScheme::Vanilla,
));
test_gkr_correctness_helper::<BabyBearExt4ConfigKeccak>(
&Config::<BabyBearExt4ConfigKeccak>::new(GKRScheme::Vanilla),
);
test_gkr_correctness_helper::<GF2ExtConfigSha2>(&Config::<GF2ExtConfigSha2>::new(
GKRScheme::Vanilla,
));