Skip to content

Commit 0192b97

Browse files
committed
halo2_poseidon: Refactor code so it compiles in its new crate
1 parent de7219d commit 0192b97

File tree

5 files changed

+124
-58
lines changed

5 files changed

+124
-58
lines changed

halo2_gadgets/src/poseidon.rs

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
//! The Poseidon algebraic hash function.
22
3-
use std::convert::TryInto;
43
use std::fmt;
54
use std::marker::PhantomData;
65

@@ -148,13 +147,7 @@ impl<
148147
pub fn new(chip: PoseidonChip, mut layouter: impl Layouter<F>) -> Result<Self, Error> {
149148
chip.initial_state(&mut layouter).map(|state| Sponge {
150149
chip,
151-
mode: Absorbing(
152-
(0..RATE)
153-
.map(|_| None)
154-
.collect::<Vec<_>>()
155-
.try_into()
156-
.unwrap(),
157-
),
150+
mode: Absorbing::init_empty(),
158151
state,
159152
_marker: PhantomData::default(),
160153
})
@@ -166,12 +159,10 @@ impl<
166159
mut layouter: impl Layouter<F>,
167160
value: PaddedWord<F>,
168161
) -> Result<(), Error> {
169-
for entry in self.mode.0.iter_mut() {
170-
if entry.is_none() {
171-
*entry = Some(value);
172-
return Ok(());
173-
}
174-
}
162+
let value = match self.mode.absorb(value) {
163+
Ok(()) => return Ok(()),
164+
Err(value) => value,
165+
};
175166

176167
// We've already absorbed as many elements as we can
177168
let _ = poseidon_sponge(
@@ -180,7 +171,8 @@ impl<
180171
&mut self.state,
181172
Some(&self.mode),
182173
)?;
183-
self.mode = Absorbing::init_with(value);
174+
self.mode = Absorbing::init_empty();
175+
self.mode.absorb(value).expect("state is not full");
184176

185177
Ok(())
186178
}
@@ -220,10 +212,8 @@ impl<
220212
/// Squeezes an element from the sponge.
221213
pub fn squeeze(&mut self, mut layouter: impl Layouter<F>) -> Result<AssignedCell<F, F>, Error> {
222214
loop {
223-
for entry in self.mode.0.iter_mut() {
224-
if let Some(inner) = entry.take() {
225-
return Ok(inner.into());
226-
}
215+
if let Ok(value) = self.mode.squeeze() {
216+
return Ok(value.into());
227217
}
228218

229219
// We've already squeezed out all available elements

halo2_gadgets/src/poseidon/pow5.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -340,19 +340,20 @@ impl<
340340
let initial_state = initial_state?;
341341

342342
// Load the input into this region.
343-
let load_input_word = |i: usize| {
344-
let (cell, value) = match input.0[i].clone() {
343+
let load_input_word = |(i, input_word): (usize, &Option<PaddedWord<F>>)| {
344+
let (cell, value) = match input_word {
345345
Some(PaddedWord::Message(word)) => (word.cell(), word.value().copied()),
346346
Some(PaddedWord::Padding(padding_value)) => {
347+
let value = Value::known(padding_value.clone());
347348
let cell = region
348349
.assign_fixed(
349350
|| format!("load pad_{}", i),
350351
config.rc_b[i],
351352
1,
352-
|| Value::known(padding_value),
353+
|| value.clone(),
353354
)?
354355
.cell();
355-
(cell, Value::known(padding_value))
356+
(cell, value)
356357
}
357358
_ => panic!("Input is not padded"),
358359
};
@@ -366,7 +367,12 @@ impl<
366367

367368
Ok(StateWord(var))
368369
};
369-
let input: Result<Vec<_>, Error> = (0..RATE).map(load_input_word).collect();
370+
let input: Result<Vec<_>, Error> = input
371+
.expose_inner()
372+
.into_iter()
373+
.enumerate()
374+
.map(load_input_word)
375+
.collect();
370376
let input = input?;
371377

372378
// Constrain the output.
@@ -394,14 +400,8 @@ impl<
394400
}
395401

396402
fn get_output(state: &State<Self::Word, WIDTH>) -> Squeezing<Self::Word, RATE> {
397-
Squeezing(
398-
state[..RATE]
399-
.iter()
400-
.map(|word| Some(word.clone()))
401-
.collect::<Vec<_>>()
402-
.try_into()
403-
.unwrap(),
404-
)
403+
let vals = state[..RATE].iter().cloned().collect::<Vec<_>>();
404+
Squeezing::init_full(vals.try_into().expect("correct length"))
405405
}
406406
}
407407

@@ -687,7 +687,7 @@ mod tests {
687687
.try_into()
688688
.unwrap();
689689
let (round_constants, mds, _) = S::constants();
690-
poseidon::permute::<_, S, WIDTH, RATE>(
690+
poseidon::test_only_permute::<_, S, WIDTH, RATE>(
691691
&mut expected_final_state,
692692
&mds,
693693
&round_constants,

halo2_poseidon/src/lib.rs

Lines changed: 83 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@ pub(crate) mod fq;
1212
pub(crate) mod grain;
1313
pub(crate) mod mds;
1414

15-
#[cfg(test)]
16-
pub(crate) mod test_vectors;
15+
#[cfg(any(test, feature = "test-dependencies"))]
16+
pub mod test_vectors;
1717

1818
mod p128pow5t3;
1919
pub use p128pow5t3::P128Pow5T3;
2020

2121
use grain::SboxType;
2222

2323
/// The type used to hold permutation state.
24-
pub(crate) type State<F, const T: usize> = [F; T];
24+
pub type State<F, const T: usize> = [F; T];
2525

2626
/// The type used to hold sponge rate.
2727
pub(crate) type SpongeRate<F, const RATE: usize> = [Option<F>; RATE];
@@ -83,6 +83,18 @@ pub fn generate_constants<
8383
(round_constants, mds, mds_inv)
8484
}
8585

86+
/// Runs the Poseidon permutation on the given state.
87+
///
88+
/// Exposed for testing purposes only.
89+
#[cfg(feature = "test-dependencies")]
90+
pub fn test_only_permute<F: Field, S: Spec<F, T, RATE>, const T: usize, const RATE: usize>(
91+
state: &mut State<F, T>,
92+
mds: &Mds<F, T>,
93+
round_constants: &[[F; T]],
94+
) {
95+
permute::<F, S, T, RATE>(state, mds, round_constants);
96+
}
97+
8698
/// Runs the Poseidon permutation on the given state.
8799
pub(crate) fn permute<F: Field, S: Spec<F, T, RATE>, const T: usize, const RATE: usize>(
88100
state: &mut State<F, T>,
@@ -176,16 +188,82 @@ impl<F, const RATE: usize> SpongeMode for Squeezing<F, RATE> {}
176188

177189
impl<F: fmt::Debug, const RATE: usize> Absorbing<F, RATE> {
178190
pub(crate) fn init_with(val: F) -> Self {
191+
let mut state = Self::init_empty();
192+
state.absorb(val).expect("state is not full");
193+
state
194+
}
195+
196+
/// Initializes an empty sponge in the absorbing state.
197+
pub fn init_empty() -> Self {
179198
Self(
180-
iter::once(Some(val))
181-
.chain((1..RATE).map(|_| None))
199+
(0..RATE)
200+
.map(|_| None)
182201
.collect::<Vec<_>>()
183202
.try_into()
184203
.unwrap(),
185204
)
186205
}
187206
}
188207

208+
impl<F, const RATE: usize> Absorbing<F, RATE> {
209+
/// Attempts to absorb a value into the sponge state.
210+
///
211+
/// Returns the value if it was not absorbed because the sponge is full.
212+
pub fn absorb(&mut self, value: F) -> Result<(), F> {
213+
for entry in self.0.iter_mut() {
214+
if entry.is_none() {
215+
*entry = Some(value);
216+
return Ok(());
217+
}
218+
}
219+
// Sponge is full.
220+
Err(value)
221+
}
222+
223+
/// Exposes the inner state of the sponge.
224+
///
225+
/// This is a low-level API, requiring a detailed understanding of this specific
226+
/// Poseidon implementation to use correctly and securely. It is exposed for use by
227+
/// the circuit implementation in `halo2_gadgets`, and may be removed from the public
228+
/// API if refactoring enables the circuit implementation to move into this crate.
229+
pub fn expose_inner(&self) -> &SpongeRate<F, RATE> {
230+
&self.0
231+
}
232+
}
233+
234+
impl<F: fmt::Debug, const RATE: usize> Squeezing<F, RATE> {
235+
/// Initializes a full sponge in the squeezing state.
236+
///
237+
/// This is a low-level API, requiring a detailed understanding of this specific
238+
/// Poseidon implementation to use correctly and securely. It is exposed for use by
239+
/// the circuit implementation in `halo2_gadgets`, and may be removed from the public
240+
/// API if refactoring enables the circuit implementation to move into this crate.
241+
pub fn init_full(vals: [F; RATE]) -> Self {
242+
Self(
243+
vals.into_iter()
244+
.map(Some)
245+
.collect::<Vec<_>>()
246+
.try_into()
247+
.unwrap(),
248+
)
249+
}
250+
}
251+
252+
impl<F, const RATE: usize> Squeezing<F, RATE> {
253+
/// Attempts to squeeze a value from the sponge state.
254+
///
255+
/// Returns an error if the sponge is empty.
256+
pub fn squeeze(&mut self) -> Result<F, ()> {
257+
for entry in self.0.iter_mut() {
258+
if let Some(inner) = entry.take() {
259+
return Ok(inner);
260+
}
261+
}
262+
// Sponge is empty.
263+
Err(())
264+
}
265+
}
266+
189267
/// A Poseidon sponge.
190268
pub(crate) struct Sponge<
191269
F: Field,

halo2_poseidon/src/p128pow5t3.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use halo2_proofs::arithmetic::Field;
1+
use ff::Field;
22
use pasta_curves::{pallas::Base as Fp, vesta::Base as Fq};
33

44
use super::{Mds, Spec};
@@ -73,9 +73,7 @@ mod tests {
7373
super::{fp, fq},
7474
Fp, Fq,
7575
};
76-
use crate::poseidon::primitives::{
77-
generate_constants, permute, ConstantLength, Hash, Mds, Spec,
78-
};
76+
use crate::{generate_constants, permute, ConstantLength, Hash, Mds, Spec};
7977

8078
/// The same Poseidon specification as poseidon::P128Pow5T3, but constructed
8179
/// such that its constants will be generated at runtime.
@@ -257,7 +255,7 @@ mod tests {
257255
{
258256
let (round_constants, mds, _) = super::P128Pow5T3::constants();
259257

260-
for tv in crate::poseidon::primitives::test_vectors::fp::permute() {
258+
for tv in crate::test_vectors::fp::permute() {
261259
let mut state = [
262260
Fp::from_repr(tv.initial_state[0]).unwrap(),
263261
Fp::from_repr(tv.initial_state[1]).unwrap(),
@@ -275,7 +273,7 @@ mod tests {
275273
{
276274
let (round_constants, mds, _) = super::P128Pow5T3::constants();
277275

278-
for tv in crate::poseidon::primitives::test_vectors::fq::permute() {
276+
for tv in crate::test_vectors::fq::permute() {
279277
let mut state = [
280278
Fq::from_repr(tv.initial_state[0]).unwrap(),
281279
Fq::from_repr(tv.initial_state[1]).unwrap(),
@@ -293,7 +291,7 @@ mod tests {
293291

294292
#[test]
295293
fn hash_test_vectors() {
296-
for tv in crate::poseidon::primitives::test_vectors::fp::hash() {
294+
for tv in crate::test_vectors::fp::hash() {
297295
let message = [
298296
Fp::from_repr(tv.input[0]).unwrap(),
299297
Fp::from_repr(tv.input[1]).unwrap(),
@@ -305,7 +303,7 @@ mod tests {
305303
assert_eq!(result.to_repr(), tv.output);
306304
}
307305

308-
for tv in crate::poseidon::primitives::test_vectors::fq::hash() {
306+
for tv in crate::test_vectors::fq::hash() {
309307
let message = [
310308
Fq::from_repr(tv.input[0]).unwrap(),
311309
Fq::from_repr(tv.input[1]).unwrap(),

halo2_poseidon/src/test_vectors.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
//! Test vectors for [`OrchardNullifier`].
22
3-
pub(crate) struct PermuteTestVector {
4-
pub(crate) initial_state: [[u8; 32]; 3],
5-
pub(crate) final_state: [[u8; 32]; 3],
3+
pub struct PermuteTestVector {
4+
pub initial_state: [[u8; 32]; 3],
5+
pub final_state: [[u8; 32]; 3],
66
}
77

8-
pub(crate) struct HashTestVector {
9-
pub(crate) input: [[u8; 32]; 2],
10-
pub(crate) output: [u8; 32],
8+
pub struct HashTestVector {
9+
pub input: [[u8; 32]; 2],
10+
pub output: [u8; 32],
1111
}
1212

13-
pub(crate) mod fp {
13+
pub mod fp {
1414
use super::*;
1515

16-
pub(crate) fn permute() -> Vec<PermuteTestVector> {
16+
pub fn permute() -> Vec<PermuteTestVector> {
1717
use PermuteTestVector as TestVector;
1818

1919
// From https://github.com/zcash-hackworks/zcash-test-vectors/blob/master/orchard_poseidon/permute/fp.py
@@ -417,7 +417,7 @@ pub(crate) mod fp {
417417
]
418418
}
419419

420-
pub(crate) fn hash() -> Vec<HashTestVector> {
420+
pub fn hash() -> Vec<HashTestVector> {
421421
use HashTestVector as TestVector;
422422

423423
// From https://github.com/zcash-hackworks/zcash-test-vectors/blob/master/orchard_poseidon/hash/fp.py
@@ -635,10 +635,10 @@ pub(crate) mod fp {
635635
}
636636
}
637637

638-
pub(crate) mod fq {
638+
pub mod fq {
639639
use super::*;
640640

641-
pub(crate) fn permute() -> Vec<PermuteTestVector> {
641+
pub fn permute() -> Vec<PermuteTestVector> {
642642
use PermuteTestVector as TestVector;
643643

644644
// From https://github.com/zcash-hackworks/zcash-test-vectors/blob/master/orchard_poseidon/permute/fq.py
@@ -1042,7 +1042,7 @@ pub(crate) mod fq {
10421042
]
10431043
}
10441044

1045-
pub(crate) fn hash() -> Vec<HashTestVector> {
1045+
pub fn hash() -> Vec<HashTestVector> {
10461046
use HashTestVector as TestVector;
10471047

10481048
// From https://github.com/zcash-hackworks/zcash-test-vectors/blob/master/orchard_poseidon/hash/fq.py

0 commit comments

Comments
 (0)