Skip to content

Commit 56f31a2

Browse files
committed
chore: double instead of mul by W
1 parent 8505543 commit 56f31a2

File tree

2 files changed

+19
-18
lines changed

2 files changed

+19
-18
lines changed

arith/src/extension_field/baby_bear_ext3.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,9 @@ impl ExtensionField for BabyBearExt3 {
168168

169169
#[inline(always)]
170170
fn mul_by_x(&self) -> Self {
171-
let w = BabyBear::from(Self::W);
171+
// Note: W = 2
172172
Self {
173-
v: [self.v[2] * w, self.v[0], self.v[1]],
173+
v: [self.v[2].double(), self.v[0], self.v[1]],
174174
}
175175
}
176176
}
@@ -310,22 +310,22 @@ fn sub_internal(a: &BabyBearExt3, b: &BabyBearExt3) -> BabyBearExt3 {
310310
// + {(a0 b2 + a1 b1 + a2 b0)} x^2
311311
#[inline(always)]
312312
fn mul_internal(a: &BabyBearExt3, b: &BabyBearExt3) -> BabyBearExt3 {
313-
let w = BabyBear::new(BabyBearExt3::W);
313+
// Note: W = 2
314314
let a = a.v;
315315
let b = b.v;
316316
let mut res = [BabyBear::default(); 3];
317-
res[0] = a[0] * b[0] + w * (a[1] * b[2] + a[2] * b[1]);
318-
res[1] = (a[0] * b[1] + a[1] * b[0]) + w * a[2] * b[2];
317+
res[0] = a[0] * b[0] + (a[1] * b[2] + a[2] * b[1]).double();
318+
res[1] = (a[0] * b[1] + a[1] * b[0]) + a[2] * b[2].double();
319319
res[2] = a[0] * b[2] + a[1] * b[1] + a[2] * b[0];
320320
BabyBearExt3 { v: res }
321321
}
322322

323323
#[inline(always)]
324324
fn square_internal(a: &[BabyBear; 3]) -> [BabyBear; 3] {
325-
let w = BabyBear::new(BabyBearExt3::W);
325+
// Note: W = 2
326326
let mut res = [BabyBear::default(); 3];
327-
res[0] = a[0].square() + w * (a[1] * a[2]).double();
328-
res[1] = (a[0] * a[1]).double() + w * a[2].square();
327+
res[0] = a[0].square() + (a[1] * a[2]).double().double();
328+
res[1] = (a[0] * a[1]).double() + a[2].square().double();
329329
res[2] = a[0] * a[2].double() + a[1].square();
330330
res
331331
}

arith/src/extension_field/baby_bear_ext3x16.rs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ impl ExtensionField for BabyBearExt3x16 {
134134
#[inline(always)]
135135
fn mul_by_x(&self) -> Self {
136136
Self {
137-
v: [self.v[2] * BabyBearx16::from(Self::W), self.v[0], self.v[1]],
137+
// Note: W = 2
138+
v: [self.v[2].double(), self.v[0], self.v[1]],
138139
}
139140
}
140141
}
@@ -249,10 +250,10 @@ impl Mul<BabyBearExt3> for BabyBearExt3x16 {
249250
// + {(a0 b1 + a1 b0) + w * a2 b2} x
250251
// + {(a0 b2 + a1 b1 + a2 b0)} x^2
251252

252-
let w = BabyBear::from(BabyBearExt3x16::W);
253+
// Note: W = 2
253254
let mut res = [BabyBearx16::ZERO; 3];
254-
res[0] = self.v[0] * rhs.v[0] + (self.v[1] * rhs.v[2] + self.v[2] * rhs.v[1]) * w;
255-
res[1] = self.v[0] * rhs.v[1] + self.v[1] * rhs.v[0] + self.v[2] * rhs.v[2] * w;
255+
res[0] = self.v[0] * rhs.v[0] + (self.v[1] * rhs.v[2] + self.v[2] * rhs.v[1]).double();
256+
res[1] = self.v[0] * rhs.v[1] + self.v[1] * rhs.v[0] + self.v[2] * rhs.v[2].double();
256257
res[2] = self.v[0] * rhs.v[2] + self.v[1] * rhs.v[1] + self.v[2] * rhs.v[0];
257258
Self { v: res }
258259
}
@@ -336,9 +337,9 @@ fn mul_internal(a: &BabyBearExt3x16, b: &BabyBearExt3x16) -> BabyBearExt3x16 {
336337
let a = &a.v;
337338
let b = &b.v;
338339
let mut res = [BabyBearx16::default(); 3];
339-
let w = BabyBear::from(BabyBearExt3x16::W);
340-
res[0] = a[0] * b[0] + (a[1] * b[2] + a[2] * b[1]) * w;
341-
res[1] = (a[0] * b[1] + a[1] * b[0]) + a[2] * b[2] * w;
340+
// Note: W = 2
341+
res[0] = a[0] * b[0] + (a[1] * b[2] + a[2] * b[1]).double();
342+
res[1] = (a[0] * b[1] + a[1] * b[0]) + a[2] * b[2].double();
342343
res[2] = a[0] * b[2] + a[1] * b[1] + a[2] * b[0];
343344

344345
BabyBearExt3x16 { v: res }
@@ -347,9 +348,9 @@ fn mul_internal(a: &BabyBearExt3x16, b: &BabyBearExt3x16) -> BabyBearExt3x16 {
347348
#[inline(always)]
348349
fn square_internal(a: &[BabyBearx16; 3]) -> [BabyBearx16; 3] {
349350
let mut res = [BabyBearx16::default(); 3];
350-
let w = BabyBear::from(BabyBearExt3x16::W);
351-
res[0] = a[0].square() + (a[1] * a[2]).double() * w;
352-
res[1] = (a[0] * a[1]).double() + a[2].square() * w;
351+
// Note: W = 2
352+
res[0] = a[0].square() + (a[1] * a[2]).double().double();
353+
res[1] = (a[0] * a[1]).double() + a[2].square().double();
353354
res[2] = a[0] * a[2].double() + a[1].square();
354355

355356
res

0 commit comments

Comments
 (0)