Skip to content

Commit 1ff3a90

Browse files
authored
Add Mul<NonNativeFieldVar> for Group (arkworks-rs#134)
1 parent 6164009 commit 1ff3a90

File tree

14 files changed

+392
-294
lines changed

14 files changed

+392
-294
lines changed

.github/workflows/ci.yml

+14-15
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,14 @@ jobs:
109109
- name: Checkout
110110
uses: actions/checkout@v3
111111

112-
- name: Install Rust (${{ matrix.rust }})
112+
- name: Install Rust
113113
uses: dtolnay/rust-toolchain@stable
114114
id: toolchain-thumbv6m
115115
with:
116116
target: thumbv6m-none-eabi
117117
- run: rustup override set ${{steps.toolchain-thumbv6m.outputs.name}}
118118

119-
- name: Install Rust ARM64 (${{ matrix.rust }})
119+
- name: Install Rust ARM64
120120
uses: dtolnay/rust-toolchain@stable
121121
id: toolchain-aarch64
122122
with:
@@ -152,12 +152,12 @@ jobs:
152152
- ed_on_bls12_381
153153
steps:
154154
- name: Checkout curves
155-
uses: actions/checkout@v2
155+
uses: actions/checkout@v4
156156
with:
157-
repository: arkworks-rs/curves
157+
repository: arkworks-rs/algebra
158158

159159
- name: Checkout r1cs-std
160-
uses: actions/checkout@v2
160+
uses: actions/checkout@v4
161161
with:
162162
path: r1cs-std
163163

@@ -166,22 +166,21 @@ jobs:
166166

167167
- name: Patch cargo.toml
168168
run: |
169+
cd curves
169170
if grep -q "\[patch.crates-io\]" Cargo.toml ; then
170171
MATCH=$(awk '/\[patch.crates-io\]/{ print NR; exit }' Cargo.toml);
171172
sed -i "$MATCH,\$d" Cargo.toml
172173
fi
173174
{
174175
echo "[patch.crates-io]"
175176
echo "ark-std = { git = 'https://github.com/arkworks-rs/std' }"
176-
echo "ark-ec = { git = 'https://github.com/arkworks-rs/algebra' }"
177-
echo "ark-ff = { git = 'https://github.com/arkworks-rs/algebra' }"
178-
echo "ark-poly = { git = 'https://github.com/arkworks-rs/algebra' }"
177+
echo "ark-ec = { path = '../ec' }"
178+
echo "ark-ff = { path = '../ff' }"
179+
echo "ark-poly = { path = '../poly' }"
179180
echo "ark-relations = { git = 'https://github.com/arkworks-rs/snark' }"
180-
echo "ark-serialize = { git = 'https://github.com/arkworks-rs/algebra' }"
181-
echo "ark-algebra-bench-templates = { git = 'https://github.com/arkworks-rs/algebra' }"
182-
echo "ark-algebra-test-templates = { git = 'https://github.com/arkworks-rs/algebra' }"
183-
echo "ark-r1cs-std = { path = 'r1cs-std' }"
181+
echo "ark-serialize = { path = '../serialize' }"
182+
echo "ark-algebra-bench-templates = { path = '../bench-templates' }"
183+
echo "ark-algebra-test-templates = { path = '../test-templates' }"
184+
echo "ark-r1cs-std = { path = '../r1cs-std' }"
184185
} >> Cargo.toml
185-
186-
- name: Test on ${{ matrix.curve }}
187-
run: "cd ${{ matrix.curve }} && cargo test --features 'r1cs'"
186+
cd ${{ matrix.curve }} && cargo test --features 'r1cs'

src/bits/uint8.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ impl<ConstraintF: Field> AllocVar<u8, ConstraintF> for UInt8<ConstraintF> {
342342
/// `ConstraintF::MODULUS_BIT_SIZE - 1` chunks and converts each chunk, which is
343343
/// assumed to be little-endian, to its `FpVar<ConstraintF>` representation.
344344
/// This is the gadget counterpart to the `[u8]` implementation of
345-
/// [ToConstraintField](ark_ff::ToConstraintField).
345+
/// [`ToConstraintField`].
346346
impl<ConstraintF: PrimeField> ToConstraintFieldGadget<ConstraintF> for [UInt8<ConstraintF>] {
347347
#[tracing::instrument(target = "r1cs")]
348348
fn to_constraint_field(&self) -> Result<Vec<FpVar<ConstraintF>>, SynthesisError> {

src/fields/nonnative/allocated_field_var.rs

+17-22
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,13 @@ impl<TargetField: PrimeField, BaseField: PrimeField>
5757
optimization_type,
5858
);
5959

60-
let mut base_repr: <TargetField as PrimeField>::BigInt = TargetField::one().into_bigint();
61-
6260
// Convert 2^{(params.bits_per_limb - 1)} into the TargetField and then double
6361
// the base This is because 2^{(params.bits_per_limb)} might indeed be
6462
// larger than the target field's prime.
65-
base_repr.muln((params.bits_per_limb - 1) as u32);
66-
let mut base: TargetField = TargetField::from_bigint(base_repr).unwrap();
67-
base = base + &base;
63+
let base_repr = TargetField::ONE.into_bigint() << (params.bits_per_limb - 1) as u32;
64+
65+
let mut base = TargetField::from_bigint(base_repr).unwrap();
66+
base.double_in_place();
6867

6968
let mut result = TargetField::zero();
7069
let mut power = TargetField::one();
@@ -206,25 +205,21 @@ impl<TargetField: PrimeField, BaseField: PrimeField>
206205
> BaseField::MODULUS_BIT_SIZE as usize - 1)
207206
{
208207
Reducer::reduce(&mut other)?;
209-
surfeit = overhead!(other.num_of_additions_over_normal_form + BaseField::one()) + 1;
208+
surfeit = overhead!(other.num_of_additions_over_normal_form + BaseField::ONE) + 1;
210209
}
211210

212211
// Step 2: construct the padding
213-
let mut pad_non_top_limb_repr: <BaseField as PrimeField>::BigInt =
214-
BaseField::one().into_bigint();
215-
let mut pad_top_limb_repr: <BaseField as PrimeField>::BigInt = pad_non_top_limb_repr;
212+
let mut pad_non_top_limb = BaseField::ONE.into_bigint();
213+
let mut pad_top_limb = pad_non_top_limb;
216214

217-
pad_non_top_limb_repr.muln((surfeit + params.bits_per_limb) as u32);
218-
let pad_non_top_limb = BaseField::from_bigint(pad_non_top_limb_repr).unwrap();
215+
pad_non_top_limb <<= (surfeit + params.bits_per_limb) as u32;
216+
let pad_non_top_limb = BaseField::from_bigint(pad_non_top_limb).unwrap();
219217

220-
pad_top_limb_repr.muln(
221-
(surfeit
222-
+ (TargetField::MODULUS_BIT_SIZE as usize
223-
- params.bits_per_limb * (params.num_limbs - 1))) as u32,
224-
);
225-
let pad_top_limb = BaseField::from_bigint(pad_top_limb_repr).unwrap();
218+
pad_top_limb <<= (surfeit + TargetField::MODULUS_BIT_SIZE as usize
219+
- params.bits_per_limb * (params.num_limbs - 1)) as u32;
220+
let pad_top_limb = BaseField::from_bigint(pad_top_limb).unwrap();
226221

227-
let mut pad_limbs = Vec::new();
222+
let mut pad_limbs = Vec::with_capacity(self.limbs.len());
228223
pad_limbs.push(pad_top_limb);
229224
for _ in 0..self.limbs.len() - 1 {
230225
pad_limbs.push(pad_non_top_limb);
@@ -236,12 +231,12 @@ impl<TargetField: PrimeField, BaseField: PrimeField>
236231
Self::get_limbs_representations(&pad_to_kp_gap, self.get_optimization_type())?;
237232

238233
// Step 4: the result is self + pad + pad_to_kp - other
239-
let mut limbs = Vec::new();
234+
let mut limbs = Vec::with_capacity(self.limbs.len());
240235
for (i, ((this_limb, other_limb), pad_to_kp_limb)) in self
241236
.limbs
242237
.iter()
243-
.zip(other.limbs.iter())
244-
.zip(pad_to_kp_limbs.iter())
238+
.zip(&other.limbs)
239+
.zip(&pad_to_kp_limbs)
245240
.enumerate()
246241
{
247242
if i != 0 {
@@ -341,7 +336,7 @@ impl<TargetField: PrimeField, BaseField: PrimeField>
341336
&cur_bits[cur_bits.len() - params.bits_per_limb..],
342337
); // therefore, the lowest `bits_per_non_top_limb` bits is what we want.
343338
limbs.push(BaseField::from_bigint(cur_mod_r).unwrap());
344-
cur.divn(params.bits_per_limb as u32);
339+
cur >>= params.bits_per_limb as u32;
345340
}
346341

347342
// then we reserve, so that the limbs are ``big limb first''

src/fields/nonnative/reduce.rs

+11-14
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ impl<TargetField: PrimeField, BaseField: PrimeField> Reducer<TargetField, BaseFi
240240
let mut cur = BaseField::one().into_bigint();
241241
for _ in 0..num_limb_in_a_group {
242242
array.push(BaseField::from_bigint(cur).unwrap());
243-
cur.muln(shift_per_limb as u32);
243+
cur <<= shift_per_limb as u32;
244244
}
245245

246246
array
@@ -280,16 +280,13 @@ impl<TargetField: PrimeField, BaseField: PrimeField> Reducer<TargetField, BaseFi
280280
for (group_id, (left_total_limb, right_total_limb, num_limb_in_this_group)) in
281281
groupped_limb_pairs.iter().enumerate()
282282
{
283-
let mut pad_limb_repr: <BaseField as PrimeField>::BigInt =
284-
BaseField::one().into_bigint();
285-
286-
pad_limb_repr.muln(
287-
(surfeit
288-
+ (bits_per_limb - shift_per_limb)
289-
+ shift_per_limb * num_limb_in_this_group
290-
+ 1
291-
+ 1) as u32,
292-
);
283+
let mut pad_limb_repr = BaseField::ONE.into_bigint();
284+
285+
pad_limb_repr <<= (surfeit
286+
+ (bits_per_limb - shift_per_limb)
287+
+ shift_per_limb * num_limb_in_this_group
288+
+ 1
289+
+ 1) as u32;
293290
let pad_limb = BaseField::from_bigint(pad_limb_repr).unwrap();
294291

295292
let left_total_limb_value = left_total_limb.value().unwrap_or_default();
@@ -298,12 +295,12 @@ impl<TargetField: PrimeField, BaseField: PrimeField> Reducer<TargetField, BaseFi
298295
let mut carry_value =
299296
left_total_limb_value + carry_in_value + pad_limb - right_total_limb_value;
300297

301-
let mut carry_repr = carry_value.into_bigint();
302-
carry_repr.divn((shift_per_limb * num_limb_in_this_group) as u32);
298+
let carry_repr =
299+
carry_value.into_bigint() >> (shift_per_limb * num_limb_in_this_group) as u32;
303300

304301
carry_value = BaseField::from_bigint(carry_repr).unwrap();
305302

306-
let carry = FpVar::<BaseField>::new_witness(cs.clone(), || Ok(carry_value))?;
303+
let carry = FpVar::new_witness(cs.clone(), || Ok(carry_value))?;
307304

308305
accumulated_extra += limbs_to_bigint(bits_per_limb, &[pad_limb]);
309306

0 commit comments

Comments
 (0)