Skip to content

Commit

Permalink
feat(witness): implements the module witness generation
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexandreBelling committed Feb 25, 2025
1 parent 9ed1365 commit 565ca7c
Show file tree
Hide file tree
Showing 5 changed files with 320 additions and 2 deletions.
23 changes: 23 additions & 0 deletions prover/maths/common/smartvectors/smartvectors.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package smartvectors

import (
"errors"
"fmt"
"math/rand"

Expand Down Expand Up @@ -183,6 +184,28 @@ func Density(v SmartVector) int {
}
}

// PaddingOrientationOf returns an integer indicating the orientation of the
// padding of a column. '0' indicates an unresolved orientation. '1' indicates
// that the columns if right-padded and '-1' indicates that it is left-padded.
//
// The function returns an error if the vector is not a padded-circular window.
func PaddingOrientationOf(v SmartVector) (int, error) {

switch w := v.(type) {
case *PaddedCircularWindow:
if w.offset == 0 {
return 1, nil
}
if w.offset+len(w.window) == w.totLen {
return -1, nil
}
default:
return 0, errors.New("vector is not a padded-circular window")
}

return 0, nil
}

// Window returns the effective window of the vector,
// if the vector is Padded with zeroes it return the window.
// Namely, the part without zero pads.
Expand Down
73 changes: 72 additions & 1 deletion prover/protocol/distributed/experiment/distribute.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,86 @@ import (
"github.com/consensys/linea-monorepo/prover/protocol/wizard"
)

// DistributedWizard represents a wizard protocol that has undergone a
// distributed compilation process.
type DistributedWizard struct {

// ModuleNames is the list of the names of the modules that compose
// the distributed protocol.
ModuleNames []ModuleName

// LPPs is the list of the LPP parts for every modules
LPPs []*ModuleLPP

// GLs is the list of the GL parts for every modules
GLs []*ModuleGL

// Bootstrapper is the original compiledIOP precompiled with a few
// preparation steps.
Bootstrapper *wizard.CompiledIOP

// Disc is the [ModuleDiscoverer] used to delimitate the scope for
// each module.
Disc ModuleDiscoverer
}

// Distribute returns a [DistributedWizard] from a [wizard.CompiledIOP]. It
// takes ownership of the input [wizard.CompiledIOP]. And uses disc to design
// the scope of each module.
func Distribute(comp *wizard.CompiledIOP, disc ModuleDiscoverer) DistributedWizard {

distributedWizard := DistributedWizard{
Bootstrapper: precompileInitialWizard(comp),
}

disc.Analyze(distributedWizard.Bootstrapper)
distributedWizard.ModuleNames = disc.ModuleList()

for _, moduleName := range distributedWizard.ModuleNames {

moduleFilter := moduleFilter{
Disc: disc,
Module: moduleName,
}

filteredModuleInputs := moduleFilter.FilterCompiledIOP(
distributedWizard.Bootstrapper,
)

distributedWizard.LPPs = append(
distributedWizard.LPPs,
BuildModuleLPP(&filteredModuleInputs),
)

distributedWizard.GLs = append(
distributedWizard.GLs,
BuildModuleGL(&filteredModuleInputs),
)
}

return distributedWizard
}

// CompileModules applies the compilation steps to each modules identically.
func (dist *DistributedWizard) CompileModules(compilers ...func(*wizard.CompiledIOP)) {
for i := range dist.ModuleNames {
for _, compile := range compilers {
compile(dist.LPPs[i].Wiop)
compile(dist.GLs[i].Wiop)
}
}
}

// precompileInitialWizard pre-compiles the initial wizard protocol by applying all the
// compilation steps needing to be run before the splitting phase. Its role is to
// ensure that all of the queries that cannot be processed by the splitting phase
// are removed from the compiled IOP.
func precompileInitialWizard(comp *wizard.CompiledIOP) {
func precompileInitialWizard(comp *wizard.CompiledIOP) *wizard.CompiledIOP {
mimc.CompileMiMC(comp)
specialqueries.RangeProof(comp)
specialqueries.CompileFixedPermutations(comp)
logderivativesum.LookupIntoLogDerivativeSum(comp)
permutation.CompileIntoGdProduct(comp)
horner.ProjectionToHorner(comp)
return comp
}
18 changes: 18 additions & 0 deletions prover/protocol/distributed/experiment/module_gl.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,24 @@ type ModuleGLAssignGL struct {
*ModuleGL
}

// BuildModuleGL builds a [ModuleGL] from scratch from a [FilteredModuleInputs].
// The function works by creating a define function that will call [NewModuleGL]
// / and then it calls [wizard.Compile] without passing compilers.
func BuildModuleGL(moduleInput *FilteredModuleInputs) *ModuleGL {

var (
moduleGL *ModuleGL
defineFunc = func(b *wizard.Builder) {
moduleGL = NewModuleGL(b, moduleInput)
}
// Since the NewModuleGL contains a pointer to the compiled IOP already
// there is no need to use the one returned by [wizard.Compile].
_ = wizard.Compile(defineFunc)
)

return moduleGL
}

// NewModuleGL declares and constructs a new ModuleGL from a [wizard.Builder]
// and a [FilteredModuleInput]. The function performs all the necessary
// declarations to build the GL part of the module and returns the constructed
Expand Down
18 changes: 18 additions & 0 deletions prover/protocol/distributed/experiment/module_lpp.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,24 @@ type CheckNxHash struct {
skipped bool
}

// BuildModuleLPP builds a [ModuleLPP] from scratch from a [FilteredModuleInputs].
// The function works by creating a define function that will call [NewModuleLPP]
// / and then it calls [wizard.Compile] without passing compilers.
func BuildModuleLPP(moduleInput *FilteredModuleInputs) *ModuleLPP {

var (
moduleLPP *ModuleLPP
defineFunc = func(b *wizard.Builder) {
moduleLPP = NewModuleLPP(b, moduleInput)
}
// Since the NewModuleLPP contains a pointer to the compiled IOP already
// there is no need to use the one returned by [wizard.Compile].
_ = wizard.Compile(defineFunc)
)

return moduleLPP
}

// NewModuleLPP declares and constructs a new [ModuleLPP] from a [wizard.Builder]
// and a [FilteredModuleInput]. The function performs all the necessary
// declarations to build the LPP part of the module and returns the constructed
Expand Down
190 changes: 189 additions & 1 deletion prover/protocol/distributed/experiment/module_witness.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@ package experiment
import (
"github.com/consensys/linea-monorepo/prover/maths/common/smartvectors"
"github.com/consensys/linea-monorepo/prover/maths/field"
"github.com/consensys/linea-monorepo/prover/protocol/column"
"github.com/consensys/linea-monorepo/prover/protocol/ifaces"
"github.com/consensys/linea-monorepo/prover/protocol/wizard"
"github.com/consensys/linea-monorepo/prover/utils"
)

// ModuleWitness is a structure collecting the witness of a module. And
// stores all the informations that are necessary to build the witness.
type ModuleWitness struct {
// ModuleName indicates the name of the module
ModuleName string
ModuleName ModuleName
// IsLPP indicates if the current instance of [ModuleWitness] is for
// an LPP segment. In the contrary case, it is understood to be for
// a GL segment.
Expand All @@ -29,3 +32,188 @@ type ModuleWitness struct {
// as in the [FilteredModuleInputs.HornerArgs]
N0Values []int
}

// SegmentRuntime scans a [wizard.ProverRuntime] and returns a list of
// [ModuleWitness] that contains the witness for each segment of each
// module.
func SegmentRuntime(runtime *wizard.ProverRuntime, distributedWizard *DistributedWizard) (witnessesGL, witnessesLPP []*ModuleWitness) {

for i := range distributedWizard.ModuleNames {
wGL, wLPP := SegmentModule(runtime, distributedWizard.GLs[i], distributedWizard.LPPs[i])
witnessesGL = append(witnessesGL, wGL...)
witnessesLPP = append(witnessesLPP, wLPP...)
}

return witnessesGL, witnessesLPP
}

// SegmentModule produces the list of the [ModuleWitness] for a given module
func SegmentModule(runtime *wizard.ProverRuntime, moduleGL *ModuleGL, moduleLPP *ModuleLPP) (witnessesGL, witnessesLPP []*ModuleWitness) {

var (
fmi = moduleGL.definitionInput
cols = runtime.Spec.Columns.AllKeys()
nbSegmentModule = NbSegmentOfModule(runtime, fmi.Disc, fmi.ModuleName)
n0 = make([]int, len(fmi.HornerArgs))
receivedValuesGlobal = make([]field.Element, len(moduleGL.ReceivedValuesGlobalAccs))
)

witnessesLPP = make([]*ModuleWitness, nbSegmentModule)
witnessesGL = make([]*ModuleWitness, nbSegmentModule)

for moduleIndex := range witnessesLPP {

moduleWitnessGL := &ModuleWitness{
ModuleName: fmi.ModuleName,
ModuleIndex: moduleIndex,
IsFirst: moduleIndex == 0,
IsLast: moduleIndex == nbSegmentModule-1,
Columns: make(map[ifaces.ColID]smartvectors.SmartVector),
ReceivedValuesGlobal: receivedValuesGlobal,
}

moduleWitnessLPP := &ModuleWitness{
ModuleName: fmi.ModuleName,
ModuleIndex: moduleIndex,
IsFirst: moduleIndex == 0,
IsLast: moduleIndex == nbSegmentModule-1,
IsLPP: true,
Columns: make(map[ifaces.ColID]smartvectors.SmartVector),
N0Values: n0,
}

for _, col := range cols {

col := runtime.Spec.Columns.GetHandle(col)

if ModuleOfColumn(fmi.Disc, col) != fmi.ModuleName {
continue
}

segment := SegmentOfColumn(runtime, fmi.Disc, col, moduleIndex, nbSegmentModule)
moduleWitnessGL.Columns[col.GetColID()] = segment

if _, ok := fmi.ColumnsLPPSet[col.GetColID()]; ok {
moduleWitnessLPP.Columns[col.GetColID()] = segment
}
}

witnessesGL[moduleIndex] = moduleWitnessGL
witnessesLPP[moduleIndex] = moduleWitnessLPP

n0 = moduleWitnessLPP.NextN0s(moduleLPP)
receivedValuesGlobal = moduleWitnessGL.NextReceivedValuesGlobal(moduleGL)
}

return witnessesGL, witnessesLPP
}

// NbSegmentOfModule returns the number of segments for a given module
func NbSegmentOfModule(runtime *wizard.ProverRuntime, disc ModuleDiscoverer, moduleName ModuleName) int {

var (
cols = runtime.Spec.Columns.AllKeys()
nbSegmentModule = -1
)

for _, col := range cols {

col := runtime.Spec.Columns.GetHandle(col)

if ModuleOfColumn(disc, col) != moduleName {
continue
}

var (
newSize = NewSizeOfColumn(disc, col)
assignment = col.GetColAssignment(runtime)
density = smartvectors.Density(assignment)
_, orientErr = smartvectors.PaddingOrientationOf(assignment)
nbSegmentCol = utils.DivCeil(newSize, density)
)

if orientErr != nil {
// the column cannot be taken into account for the segmentation
continue
}

nbSegmentModule = max(nbSegmentModule, nbSegmentCol)
}

if nbSegmentModule == -1 {
utils.Panic("could not resolve the number of segment for module %v", moduleName)
}

return nbSegmentModule
}

// SegmentColumn returns the segment of a given column for given index. The
// function also takes a maxNbSegment value which is useful in case
func SegmentOfColumn(runtime *wizard.ProverRuntime, disc ModuleDiscoverer,
col ifaces.Column, index, totalNbSegment int) smartvectors.SmartVector {

var (
newSize = NewSizeOfColumn(disc, col)
assignment = col.GetColAssignment(runtime)
orientiation, orientErr = smartvectors.PaddingOrientationOf(assignment)
start = index * newSize
end = start + newSize
)

if orientErr != nil {
// If a column is assigned to a plain-vector, then it is assumed to
// be right-padded. The reason for this assumption is that the
// columns from the arithmetization are systematically padded on the
// left while the columns from the prover are all right-padded and the
// sometime they (suboptimally) assigned to plain-vectors.
orientiation = 1
}

if orientiation == -1 {
start += assignment.Len() - totalNbSegment*newSize
end += assignment.Len() - totalNbSegment*newSize
}

return assignment.SubVector(start, end)
}

// NextN0s returns the next value of N0, from the current one and the witness
// of the current module.
func (mw *ModuleWitness) NextN0s(moduleLPP *ModuleLPP) []int {

newN0s := append([]int{}, mw.N0Values...)
args := moduleLPP.Horner.Parts

for i := range newN0s {

sel := mw.Columns[args[i].Selector.GetColID()].IntoRegVecSaveAlloc()

for j := range sel {
if sel[j].IsOne() {
newN0s[i]++
}
}
}

return newN0s
}

// NextReceivedValuesGlobal returns the next value of ReceivedValuesGlobal, from
// the witness of the current module.
func (mw *ModuleWitness) NextReceivedValuesGlobal(moduleGL *ModuleGL) []field.Element {

newReceivedValuesGlobal := make([]field.Element, len(mw.ReceivedValuesGlobal))

for i, loc := range moduleGL.SentValuesGlobal {

var (
col = column.RootParents(loc.Pol)[0]
pos = column.StackOffsets(loc.Pol)
)

pos = utils.PositiveMod(pos, col.Size())
newReceivedValuesGlobal[i] = mw.Columns[col.GetColID()].Get(pos)
}

return newReceivedValuesGlobal
}

0 comments on commit 565ca7c

Please sign in to comment.