Skip to content

Commit

Permalink
clean(prover): cleans the file structure and the wizard package
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexandreBelling committed Feb 26, 2025
1 parent 565ca7c commit d978b01
Show file tree
Hide file tree
Showing 45 changed files with 400 additions and 4,233 deletions.
26 changes: 19 additions & 7 deletions prover/crypto/fiatshamir/fiatshamir.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/consensys/linea-monorepo/prover/maths/common/smartvectors"
"github.com/consensys/linea-monorepo/prover/maths/field"
"github.com/consensys/linea-monorepo/prover/utils"
"golang.org/x/crypto/blake2b"
)

// State holds a Fiat-Shamir state. The Fiat-Shamir state can be updated by
Expand Down Expand Up @@ -122,16 +123,27 @@ func (fs *State) RandomField() field.Element {

// RandomField generates and returns a single field element from the seed and the given name.
func (fs *State) RandomFieldFromSeed(seed field.Element, name string) field.Element {
challBytes := []byte(name)
seedBytes := seed.Bytes()
challBytes = append(challBytes, seedBytes[:]...)

var res field.Element
res.SetBytes(challBytes)
// The first step encodes the 'name' into a single field element. The
// seed is then obtained by calling the compression function over the
// encoded name and the
nameBytes := []byte(name)
hasher, _ := blake2b.New256(nil)
hasher.Write(nameBytes)
nameBytes = hasher.Sum(nil)

// The seed is then obtained by calling the compression function over
// the seed and the encoded name.
oldState := fs.State()
defer fs.SetState(oldState)

fs.SetState([]field.Element{seed})
fs.hasher.Write(nameBytes)
challBytes := fs.hasher.Sum(nil)
res := new(field.Element).SetBytes(challBytes)

// increase the counter by one
fs.NumCoinGenerated++
return res
return *res
}

// RandomManyIntegers returns a list of challenge small integers. That is, a
Expand Down
28 changes: 27 additions & 1 deletion prover/crypto/fiatshamir/snark.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/consensys/linea-monorepo/prover/crypto/mimc/gkrmimc"
"github.com/consensys/linea-monorepo/prover/maths/field"
"github.com/consensys/linea-monorepo/prover/utils"
"golang.org/x/crypto/blake2b"
)

// GnarkFiatShamir mirrors [State] in a gnark circuit. It provides analogous
Expand Down Expand Up @@ -67,7 +68,8 @@ func (fs *GnarkFiatShamir) SetState(state []frontend.Variable) {
}
}

// State mutates the fiat-shamir state of
// State mutates returns the state of the fiat-shamir hasher. The
// function will also updates its own state with unprocessed inputs.
func (fs *GnarkFiatShamir) State() []frontend.Variable {

switch hsh := fs.hasher.(type) {
Expand Down Expand Up @@ -163,7 +165,31 @@ func (fs *GnarkFiatShamir) RandomManyIntegers(num, upperBound int) []frontend.Va

fs.safeguardUpdate()
}
}

// RandomFieldFromSeed generates a new field element from the given seed
// and a name. The 'fs' is left unchanged by the call (aside from the
// underlying [frontend.API]).
func (fs *GnarkFiatShamir) RandomFieldFromSeed(seed frontend.Variable, name string) frontend.Variable {

// The first step encodes the 'name' into a single field element. The
// seed is then obtained by calling the compression function over the
// encoded name and the
nameBytes := []byte(name)
hasher, _ := blake2b.New256(nil)
hasher.Write(nameBytes)
nameBytes = hasher.Sum(nil)
nameField := new(field.Element).SetBytes(nameBytes)

// The seed is then obtained by calling the compression function over
// the seed and the encoded name.
oldState := fs.State()
defer fs.SetState(oldState)

fs.SetState([]frontend.Variable{seed})
fs.hasher.Write(nameField)

return fs.hasher.Sum()
}

// safeguardUpdate updates the state as a safeguard by appending a field element
Expand Down
4 changes: 2 additions & 2 deletions prover/example/test_cases/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"github.com/consensys/linea-monorepo/prover/protocol/compiler/globalcs"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/innerproduct"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/localcs"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/lookup"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/logderivativesum"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/permutation"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/specialqueries"
"github.com/consensys/linea-monorepo/prover/protocol/compiler/splitter"
Expand Down Expand Up @@ -61,7 +61,7 @@ var (
ALL_SPECIALS = compilationSuite{
specialqueries.RangeProof,
specialqueries.CompileFixedPermutations,
lookup.CompileLogDerivative,
logderivativesum.CompileLookups,
permutation.CompileViaGrandProduct,
innerproduct.Compile,
}
Expand Down
22 changes: 17 additions & 5 deletions prover/protocol/coin/coin.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"strconv"

"github.com/consensys/gnark/frontend"
"github.com/consensys/linea-monorepo/prover/crypto/fiatshamir"
"github.com/consensys/linea-monorepo/prover/maths/field"
"github.com/consensys/linea-monorepo/prover/utils"
Expand Down Expand Up @@ -90,17 +91,28 @@ func (t *Type) UnmarshalJSON(b []byte) error {
/*
Sample a random coin, according to its `spec`
*/
func (info *Info) Sample(fs *fiatshamir.State, seed ...field.Element) interface{} {
func (info *Info) Sample(fs *fiatshamir.State, seed field.Element) interface{} {
switch info.Type {
case Field:
return fs.RandomField()
case IntegerVec:
return fs.RandomManyIntegers(info.Size, info.UpperBound)
case FieldFromSeed:
if len(seed) == 0 {
panic("expected a SEED as the input")
}
return fs.RandomFieldFromSeed(seed[0], string(info.Name))
return fs.RandomFieldFromSeed(seed, string(info.Name))
}
panic("Unreachable")
}

// SampleGnark samples a random coin in a gnark circuit. The seed can optionally be
// passed by the caller is used for [FieldFromSeed] coins. The function returns
func (info *Info) SampleGnark(fs *fiatshamir.GnarkFiatShamir, seed frontend.Variable) interface{} {
switch info.Type {
case Field:
return fs.RandomField()
case IntegerVec:
return fs.RandomManyIntegers(info.Size, info.UpperBound)
case FieldFromSeed:
return fs.RandomFieldFromSeed(seed, string(info.Name))
}
panic("Unreachable")
}
Expand Down
34 changes: 24 additions & 10 deletions prover/protocol/compiler/fullrecursion/circuit.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/linea-monorepo/prover/crypto/fiatshamir"
"github.com/consensys/linea-monorepo/prover/crypto/mimc/gkrmimc"
"github.com/consensys/linea-monorepo/prover/protocol/coin"
"github.com/consensys/linea-monorepo/prover/protocol/query"
"github.com/consensys/linea-monorepo/prover/protocol/wizard"
)
Expand Down Expand Up @@ -47,6 +46,12 @@ func allocateGnarkCircuit(comp *wizard.CompiledIOP, ctx *fullRecursionCtx) *gnar
wizardVerifier.AllocInnerProduct(qInfo.ID, qInfo)
case query.LocalOpening:
wizardVerifier.AllocLocalOpening(qInfo.ID, qInfo)
case query.LogDerivativeSum:
wizardVerifier.AllocLogDerivativeSum(qInfo.ID, qInfo)
case query.GrandProduct:
wizardVerifier.AllocGrandProduct(qInfo.ID, qInfo)
case *query.Horner:
wizardVerifier.AllocHorner(qInfo.ID, qInfo)
}
}
}
Expand Down Expand Up @@ -145,18 +150,18 @@ func (c *gnarkCircuit) generateAllRandomCoins(api frontend.API) {
}
}

for _, fsHook := range ctx.PreSamplingFsHooks[currRound] {
fsHook.RunGnark(api, w)
}

seed := w.FS.State()[0]

for _, info := range ctx.Coins[currRound] {
switch info.Type {
case coin.Field:
value := w.FS.RandomField()
w.Coins.InsertNew(info.Name, value)
case coin.IntegerVec:
value := w.FS.RandomManyIntegers(info.Size, info.UpperBound)
w.Coins.InsertNew(info.Name, value)
}
value := info.SampleGnark(w.FS, seed)
w.Coins.InsertNew(info.Name, value)
}

for _, fsHook := range ctx.FsHooks[currRound] {
for _, fsHook := range ctx.PostSamplingFsHooks[currRound] {
fsHook.RunGnark(api, w)
}

Expand Down Expand Up @@ -194,6 +199,15 @@ func AssignGnarkCircuit(ctx *fullRecursionCtx, comp *wizard.CompiledIOP, run *wi
case query.LocalOpening:
params := run.GetLocalPointEvalParams(qInfo.ID)
wizardVerifier.AssignLocalOpening(qInfo.ID, params)
case query.LogDerivativeSum:
params := run.GetLogDerivSumParams(qInfo.ID)
wizardVerifier.AssignLogDerivativeSum(qInfo.ID, params)
case query.GrandProduct:
params := run.GetGrandProductParams(qInfo.ID)
wizardVerifier.AssignGrandProduct(qInfo.ID, params)
case *query.Horner:
params := run.GetHornerParams(qInfo.ID)
wizardVerifier.AssignHorner(qInfo.ID, params)
}
}
}
Expand Down
28 changes: 22 additions & 6 deletions prover/protocol/compiler/fullrecursion/full_recursion.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func FullRecursion(withoutGkr bool) func(comp *wizard.CompiledIOP) {
)
}

comp.FiatShamirHooks.AppendToInner(ctx.LastRound, &ResetFsActions{fullRecursionCtx: *ctx})
comp.FiatShamirHooksPostSampling.AppendToInner(ctx.LastRound, &ResetFsActions{fullRecursionCtx: *ctx})
comp.RegisterProverAction(ctx.LastRound, CircuitAssignment(*ctx))
comp.RegisterProverAction(ctx.LastRound, ReplacementAssignment(*ctx))
comp.RegisterProverAction(ctx.PlonkInWizard.PI.Round(), LocalOpeningAssignment(*ctx))
Expand All @@ -84,7 +84,8 @@ type fullRecursionCtx struct {
Columns [][]ifaces.Column
VerifierActions [][]wizard.VerifierAction
Coins [][]coin.Info
FsHooks [][]wizard.VerifierAction
PostSamplingFsHooks [][]wizard.VerifierAction
PreSamplingFsHooks [][]wizard.VerifierAction
PlonkInWizard struct {
ProverAction plonk.PlonkInWizardProverAction
PI ifaces.Column
Expand Down Expand Up @@ -115,7 +116,8 @@ func captureCtx(comp *wizard.CompiledIOP) *fullRecursionCtx {
ctx.Columns = append(ctx.Columns, []ifaces.Column{})
ctx.VerifierActions = append(ctx.VerifierActions, []wizard.VerifierAction{})
ctx.Coins = append(ctx.Coins, []coin.Info{})
ctx.FsHooks = append(ctx.FsHooks, []wizard.VerifierAction{})
ctx.PostSamplingFsHooks = append(ctx.PostSamplingFsHooks, []wizard.VerifierAction{})
ctx.PreSamplingFsHooks = append(ctx.PreSamplingFsHooks, []wizard.VerifierAction{})

for _, colName := range comp.Columns.AllKeysAt(round) {

Expand Down Expand Up @@ -176,16 +178,30 @@ func captureCtx(comp *wizard.CompiledIOP) *fullRecursionCtx {
va.Skip()
}

if comp.FiatShamirHooks.Len() > round {
resetFs := comp.FiatShamirHooks.Inner()[round]
if comp.FiatShamirHooksPreSampling.Len() > round {
resetFs := comp.FiatShamirHooksPreSampling.Inner()[round]
for i := range resetFs {

fsHook := resetFs[i]
if fsHook.IsSkipped() {
continue
}

ctx.FsHooks[round] = append(ctx.VerifierActions[round], fsHook)
ctx.PreSamplingFsHooks[round] = append(ctx.PreSamplingFsHooks[round], fsHook)
fsHook.Skip()
}
}

if comp.FiatShamirHooksPostSampling.Len() > round {
resetFs := comp.FiatShamirHooksPostSampling.Inner()[round]
for i := range resetFs {

fsHook := resetFs[i]
if fsHook.IsSkipped() {
continue
}

ctx.PostSamplingFsHooks[round] = append(ctx.PostSamplingFsHooks[round], fsHook)
fsHook.Skip()
}
}
Expand Down
Loading

0 comments on commit d978b01

Please sign in to comment.