|
| 1 | +use super::zq::Zq; |
| 2 | +use std::ops::{Add, Mul, Neg, Sub}; |
| 3 | + |
| 4 | +/// Polynomial |
| 5 | +/// |
| 6 | +#[derive(Debug, Clone, PartialEq, Eq)] |
| 7 | +pub struct Poly<const Q: u64> { |
| 8 | + coeffs: Vec<Zq<Q>>, |
| 9 | +} |
| 10 | + |
| 11 | +impl<const Q: u64> Poly<Q> { |
| 12 | + pub fn new(coeffs: impl Into<Vec<Zq<Q>>>) -> Self { |
| 13 | + Poly { |
| 14 | + coeffs: coeffs.into(), |
| 15 | + } |
| 16 | + } |
| 17 | + |
| 18 | + pub fn zero() -> Self { |
| 19 | + Poly { coeffs: Vec::new() } |
| 20 | + } |
| 21 | + |
| 22 | + pub fn one() -> Self { |
| 23 | + Poly { |
| 24 | + coeffs: vec![Zq::one()], |
| 25 | + } |
| 26 | + } |
| 27 | +} |
| 28 | + |
| 29 | +impl<const Q: u64> Add for Poly<Q> { |
| 30 | + type Output = Self; |
| 31 | + |
| 32 | + fn add(self, rhs: Self) -> Self { |
| 33 | + let len = self.coeffs.len().max(rhs.coeffs.len()); |
| 34 | + let mut new_coeffs = vec![Zq::<Q>::zero(); len]; |
| 35 | + for (i, &c) in self.coeffs.iter().enumerate() { |
| 36 | + new_coeffs[i] = new_coeffs[i] + c; |
| 37 | + } |
| 38 | + for (i, &c) in rhs.coeffs.iter().enumerate() { |
| 39 | + new_coeffs[i] = new_coeffs[i] + c; |
| 40 | + } |
| 41 | + |
| 42 | + // trim 0s in the top terms |
| 43 | + while new_coeffs.last() == Some(&Zq::zero()) { |
| 44 | + new_coeffs.pop(); |
| 45 | + } |
| 46 | + Poly { coeffs: new_coeffs } |
| 47 | + } |
| 48 | +} |
| 49 | + |
| 50 | +impl<const Q: u64> Neg for Poly<Q> { |
| 51 | + type Output = Self; |
| 52 | + |
| 53 | + fn neg(self) -> Self { |
| 54 | + Poly { |
| 55 | + coeffs: self.coeffs.iter().map(|&c| -c).collect(), |
| 56 | + } |
| 57 | + } |
| 58 | +} |
| 59 | + |
| 60 | +impl<const Q: u64> Sub for Poly<Q> { |
| 61 | + type Output = Self; |
| 62 | + |
| 63 | + fn sub(self, rhs: Self) -> Self { |
| 64 | + self + (-rhs) |
| 65 | + } |
| 66 | +} |
| 67 | + |
| 68 | +impl<const Q: u64> Mul for Poly<Q> { |
| 69 | + type Output = Self; |
| 70 | + |
| 71 | + #[allow(clippy::needless_range_loop)] |
| 72 | + fn mul(self, rhs: Self) -> Self { |
| 73 | + if self.coeffs.is_empty() || rhs.coeffs.is_empty() { |
| 74 | + return Poly::zero(); |
| 75 | + } |
| 76 | + |
| 77 | + let new_len = self.coeffs.len() + rhs.coeffs.len() - 1; |
| 78 | + let mut new_coeffs = vec![Zq::<Q>::zero(); new_len]; |
| 79 | + for i in 0..self.coeffs.len() { |
| 80 | + for j in 0..rhs.coeffs.len() { |
| 81 | + new_coeffs[i + j] = new_coeffs[i + j] + self.coeffs[i] * rhs.coeffs[j] |
| 82 | + } |
| 83 | + } |
| 84 | + |
| 85 | + // trim 0s in the top terms |
| 86 | + while new_coeffs.last() == Some(&Zq::zero()) { |
| 87 | + new_coeffs.pop(); |
| 88 | + } |
| 89 | + |
| 90 | + Poly { coeffs: new_coeffs } |
| 91 | + } |
| 92 | +} |
| 93 | + |
| 94 | +#[cfg(test)] |
| 95 | +mod tests { |
| 96 | + // import everything in this module |
| 97 | + use super::*; |
| 98 | + |
| 99 | + const Q: u64 = 17; |
| 100 | + type F = Zq<Q>; |
| 101 | + type R = Poly<Q>; |
| 102 | + |
| 103 | + // helper: build poly from raw u64 coefficients |
| 104 | + fn p(coeffs: &[u64]) -> R { |
| 105 | + R::new(coeffs.iter().map(|&c| F::new(c)).collect::<Vec<_>>()) |
| 106 | + } |
| 107 | + |
| 108 | + #[test] |
| 109 | + fn test_new() { |
| 110 | + assert_eq!(R::new(vec![]), R::zero()); |
| 111 | + } |
| 112 | + |
| 113 | + #[test] |
| 114 | + fn test_one() { |
| 115 | + assert_eq!(R::one(), p(&[1])); |
| 116 | + } |
| 117 | + |
| 118 | + #[test] |
| 119 | + fn test_add_same_len() { |
| 120 | + // (3 + 5x) + (2 + 4x) = 5 + 9x |
| 121 | + assert_eq!(p(&[3, 5]) + p(&[2, 4]), p(&[5, 9])); |
| 122 | + } |
| 123 | + |
| 124 | + #[test] |
| 125 | + fn test_add_different_len() { |
| 126 | + // (1 + 2x + 3x^2) + (4 + 5x) = 5 + 7x + 3x^2 |
| 127 | + assert_eq!(p(&[1, 2, 3]) + p(&[4, 5]), p(&[5, 7, 3])); |
| 128 | + } |
| 129 | + |
| 130 | + #[test] |
| 131 | + fn test_add_with_cancellation() { |
| 132 | + // (1 + 16x) + (2 + x) = 3 + 0x = 3 (16 + 1 = 17 ≡ 0 mod 17) |
| 133 | + // trailing zero should be trimmed |
| 134 | + assert_eq!(p(&[1, 16]) + p(&[2, 1]), p(&[3])); |
| 135 | + } |
| 136 | + |
| 137 | + #[test] |
| 138 | + fn test_add_zero() { |
| 139 | + assert_eq!(p(&[3, 5]) + R::zero(), p(&[3, 5])); |
| 140 | + assert_eq!(R::zero() + p(&[3, 5]), p(&[3, 5])); |
| 141 | + } |
| 142 | + |
| 143 | + #[test] |
| 144 | + fn test_neg() { |
| 145 | + // -(3 + 5x) = 14 + 12x (mod 17) |
| 146 | + assert_eq!(-p(&[3, 5]), p(&[14, 12])); |
| 147 | + } |
| 148 | + |
| 149 | + #[test] |
| 150 | + fn test_neg_zero() { |
| 151 | + assert_eq!(-R::zero(), R::zero()); |
| 152 | + } |
| 153 | + |
| 154 | + #[test] |
| 155 | + fn test_sub() { |
| 156 | + // (10 + 3x) - (5 + 7x) = 5 + (-4 mod 17)x = 5 + 13x |
| 157 | + assert_eq!(p(&[10, 3]) - p(&[5, 7]), p(&[5, 13])); |
| 158 | + } |
| 159 | + |
| 160 | + #[test] |
| 161 | + fn test_add_sub_inverse() { |
| 162 | + let a = p(&[3, 5, 7]); |
| 163 | + assert_eq!(a.clone() + (-a), R::zero()); |
| 164 | + } |
| 165 | + |
| 166 | + #[test] |
| 167 | + fn test_mul_basic() { |
| 168 | + // (1 + 2x) * (3 + 4x) = 3 + 4x + 6x + 8x^2 = 3 + 10x + 8x^2 |
| 169 | + assert_eq!(p(&[1, 2]) * p(&[3, 4]), p(&[3, 10, 8])); |
| 170 | + } |
| 171 | + |
| 172 | + #[test] |
| 173 | + fn test_mul_with_mod() { |
| 174 | + // (5 + 4x) * (4 + 3x) = 20 + 15x + 16x + 12x^2 |
| 175 | + // = 3 + 14x + 12x^2 (mod 17) |
| 176 | + assert_eq!(p(&[5, 4]) * p(&[4, 3]), p(&[3, 14, 12])); |
| 177 | + } |
| 178 | + |
| 179 | + #[test] |
| 180 | + fn test_mul_by_zero() { |
| 181 | + assert_eq!(p(&[3, 5]) * R::zero(), R::zero()); |
| 182 | + assert_eq!(R::zero() * R::zero(), R::zero()); |
| 183 | + } |
| 184 | + |
| 185 | + #[test] |
| 186 | + fn test_mul_by_one() { |
| 187 | + assert_eq!(p(&[3, 5, 7]) * R::one(), p(&[3, 5, 7])); |
| 188 | + } |
| 189 | + |
| 190 | + #[test] |
| 191 | + fn test_mul_by_constant() { |
| 192 | + // (1 + 2x + 3x^2) * (2) = 2 + 4x + 6x^2 |
| 193 | + assert_eq!(p(&[1, 2, 3]) * p(&[2]), p(&[2, 4, 6])); |
| 194 | + } |
| 195 | +} |
0 commit comments