From 565ca7cf4ad83423c144d78560bcc78351d84e89 Mon Sep 17 00:00:00 2001 From: AlexandreBelling Date: Wed, 26 Feb 2025 00:09:42 +0100 Subject: [PATCH] feat(witness): implements the module witness generation --- .../maths/common/smartvectors/smartvectors.go | 23 +++ .../distributed/experiment/distribute.go | 73 ++++++- .../distributed/experiment/module_gl.go | 18 ++ .../distributed/experiment/module_lpp.go | 18 ++ .../distributed/experiment/module_witness.go | 190 +++++++++++++++++- 5 files changed, 320 insertions(+), 2 deletions(-) diff --git a/prover/maths/common/smartvectors/smartvectors.go b/prover/maths/common/smartvectors/smartvectors.go index 43e33db10..6a4a0d540 100644 --- a/prover/maths/common/smartvectors/smartvectors.go +++ b/prover/maths/common/smartvectors/smartvectors.go @@ -1,6 +1,7 @@ package smartvectors import ( + "errors" "fmt" "math/rand" @@ -183,6 +184,28 @@ func Density(v SmartVector) int { } } +// PaddingOrientationOf returns an integer indicating the orientation of the +// padding of a column. '0' indicates an unresolved orientation. '1' indicates +// that the columns if right-padded and '-1' indicates that it is left-padded. +// +// The function returns an error if the vector is not a padded-circular window. +func PaddingOrientationOf(v SmartVector) (int, error) { + + switch w := v.(type) { + case *PaddedCircularWindow: + if w.offset == 0 { + return 1, nil + } + if w.offset+len(w.window) == w.totLen { + return -1, nil + } + default: + return 0, errors.New("vector is not a padded-circular window") + } + + return 0, nil +} + // Window returns the effective window of the vector, // if the vector is Padded with zeroes it return the window. // Namely, the part without zero pads. diff --git a/prover/protocol/distributed/experiment/distribute.go b/prover/protocol/distributed/experiment/distribute.go index 7cdc6ad7f..5b02ab4cf 100644 --- a/prover/protocol/distributed/experiment/distribute.go +++ b/prover/protocol/distributed/experiment/distribute.go @@ -9,15 +9,86 @@ import ( "github.com/consensys/linea-monorepo/prover/protocol/wizard" ) +// DistributedWizard represents a wizard protocol that has undergone a +// distributed compilation process. +type DistributedWizard struct { + + // ModuleNames is the list of the names of the modules that compose + // the distributed protocol. + ModuleNames []ModuleName + + // LPPs is the list of the LPP parts for every modules + LPPs []*ModuleLPP + + // GLs is the list of the GL parts for every modules + GLs []*ModuleGL + + // Bootstrapper is the original compiledIOP precompiled with a few + // preparation steps. + Bootstrapper *wizard.CompiledIOP + + // Disc is the [ModuleDiscoverer] used to delimitate the scope for + // each module. + Disc ModuleDiscoverer +} + +// Distribute returns a [DistributedWizard] from a [wizard.CompiledIOP]. It +// takes ownership of the input [wizard.CompiledIOP]. And uses disc to design +// the scope of each module. +func Distribute(comp *wizard.CompiledIOP, disc ModuleDiscoverer) DistributedWizard { + + distributedWizard := DistributedWizard{ + Bootstrapper: precompileInitialWizard(comp), + } + + disc.Analyze(distributedWizard.Bootstrapper) + distributedWizard.ModuleNames = disc.ModuleList() + + for _, moduleName := range distributedWizard.ModuleNames { + + moduleFilter := moduleFilter{ + Disc: disc, + Module: moduleName, + } + + filteredModuleInputs := moduleFilter.FilterCompiledIOP( + distributedWizard.Bootstrapper, + ) + + distributedWizard.LPPs = append( + distributedWizard.LPPs, + BuildModuleLPP(&filteredModuleInputs), + ) + + distributedWizard.GLs = append( + distributedWizard.GLs, + BuildModuleGL(&filteredModuleInputs), + ) + } + + return distributedWizard +} + +// CompileModules applies the compilation steps to each modules identically. +func (dist *DistributedWizard) CompileModules(compilers ...func(*wizard.CompiledIOP)) { + for i := range dist.ModuleNames { + for _, compile := range compilers { + compile(dist.LPPs[i].Wiop) + compile(dist.GLs[i].Wiop) + } + } +} + // precompileInitialWizard pre-compiles the initial wizard protocol by applying all the // compilation steps needing to be run before the splitting phase. Its role is to // ensure that all of the queries that cannot be processed by the splitting phase // are removed from the compiled IOP. -func precompileInitialWizard(comp *wizard.CompiledIOP) { +func precompileInitialWizard(comp *wizard.CompiledIOP) *wizard.CompiledIOP { mimc.CompileMiMC(comp) specialqueries.RangeProof(comp) specialqueries.CompileFixedPermutations(comp) logderivativesum.LookupIntoLogDerivativeSum(comp) permutation.CompileIntoGdProduct(comp) horner.ProjectionToHorner(comp) + return comp } diff --git a/prover/protocol/distributed/experiment/module_gl.go b/prover/protocol/distributed/experiment/module_gl.go index d262b7b00..de54dd362 100644 --- a/prover/protocol/distributed/experiment/module_gl.go +++ b/prover/protocol/distributed/experiment/module_gl.go @@ -114,6 +114,24 @@ type ModuleGLAssignGL struct { *ModuleGL } +// BuildModuleGL builds a [ModuleGL] from scratch from a [FilteredModuleInputs]. +// The function works by creating a define function that will call [NewModuleGL] +// / and then it calls [wizard.Compile] without passing compilers. +func BuildModuleGL(moduleInput *FilteredModuleInputs) *ModuleGL { + + var ( + moduleGL *ModuleGL + defineFunc = func(b *wizard.Builder) { + moduleGL = NewModuleGL(b, moduleInput) + } + // Since the NewModuleGL contains a pointer to the compiled IOP already + // there is no need to use the one returned by [wizard.Compile]. + _ = wizard.Compile(defineFunc) + ) + + return moduleGL +} + // NewModuleGL declares and constructs a new ModuleGL from a [wizard.Builder] // and a [FilteredModuleInput]. The function performs all the necessary // declarations to build the GL part of the module and returns the constructed diff --git a/prover/protocol/distributed/experiment/module_lpp.go b/prover/protocol/distributed/experiment/module_lpp.go index 8714f0b8b..1ab2811d6 100644 --- a/prover/protocol/distributed/experiment/module_lpp.go +++ b/prover/protocol/distributed/experiment/module_lpp.go @@ -55,6 +55,24 @@ type CheckNxHash struct { skipped bool } +// BuildModuleLPP builds a [ModuleLPP] from scratch from a [FilteredModuleInputs]. +// The function works by creating a define function that will call [NewModuleLPP] +// / and then it calls [wizard.Compile] without passing compilers. +func BuildModuleLPP(moduleInput *FilteredModuleInputs) *ModuleLPP { + + var ( + moduleLPP *ModuleLPP + defineFunc = func(b *wizard.Builder) { + moduleLPP = NewModuleLPP(b, moduleInput) + } + // Since the NewModuleLPP contains a pointer to the compiled IOP already + // there is no need to use the one returned by [wizard.Compile]. + _ = wizard.Compile(defineFunc) + ) + + return moduleLPP +} + // NewModuleLPP declares and constructs a new [ModuleLPP] from a [wizard.Builder] // and a [FilteredModuleInput]. The function performs all the necessary // declarations to build the LPP part of the module and returns the constructed diff --git a/prover/protocol/distributed/experiment/module_witness.go b/prover/protocol/distributed/experiment/module_witness.go index 99466ca08..704e1be46 100644 --- a/prover/protocol/distributed/experiment/module_witness.go +++ b/prover/protocol/distributed/experiment/module_witness.go @@ -3,14 +3,17 @@ package experiment import ( "github.com/consensys/linea-monorepo/prover/maths/common/smartvectors" "github.com/consensys/linea-monorepo/prover/maths/field" + "github.com/consensys/linea-monorepo/prover/protocol/column" "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "github.com/consensys/linea-monorepo/prover/protocol/wizard" + "github.com/consensys/linea-monorepo/prover/utils" ) // ModuleWitness is a structure collecting the witness of a module. And // stores all the informations that are necessary to build the witness. type ModuleWitness struct { // ModuleName indicates the name of the module - ModuleName string + ModuleName ModuleName // IsLPP indicates if the current instance of [ModuleWitness] is for // an LPP segment. In the contrary case, it is understood to be for // a GL segment. @@ -29,3 +32,188 @@ type ModuleWitness struct { // as in the [FilteredModuleInputs.HornerArgs] N0Values []int } + +// SegmentRuntime scans a [wizard.ProverRuntime] and returns a list of +// [ModuleWitness] that contains the witness for each segment of each +// module. +func SegmentRuntime(runtime *wizard.ProverRuntime, distributedWizard *DistributedWizard) (witnessesGL, witnessesLPP []*ModuleWitness) { + + for i := range distributedWizard.ModuleNames { + wGL, wLPP := SegmentModule(runtime, distributedWizard.GLs[i], distributedWizard.LPPs[i]) + witnessesGL = append(witnessesGL, wGL...) + witnessesLPP = append(witnessesLPP, wLPP...) + } + + return witnessesGL, witnessesLPP +} + +// SegmentModule produces the list of the [ModuleWitness] for a given module +func SegmentModule(runtime *wizard.ProverRuntime, moduleGL *ModuleGL, moduleLPP *ModuleLPP) (witnessesGL, witnessesLPP []*ModuleWitness) { + + var ( + fmi = moduleGL.definitionInput + cols = runtime.Spec.Columns.AllKeys() + nbSegmentModule = NbSegmentOfModule(runtime, fmi.Disc, fmi.ModuleName) + n0 = make([]int, len(fmi.HornerArgs)) + receivedValuesGlobal = make([]field.Element, len(moduleGL.ReceivedValuesGlobalAccs)) + ) + + witnessesLPP = make([]*ModuleWitness, nbSegmentModule) + witnessesGL = make([]*ModuleWitness, nbSegmentModule) + + for moduleIndex := range witnessesLPP { + + moduleWitnessGL := &ModuleWitness{ + ModuleName: fmi.ModuleName, + ModuleIndex: moduleIndex, + IsFirst: moduleIndex == 0, + IsLast: moduleIndex == nbSegmentModule-1, + Columns: make(map[ifaces.ColID]smartvectors.SmartVector), + ReceivedValuesGlobal: receivedValuesGlobal, + } + + moduleWitnessLPP := &ModuleWitness{ + ModuleName: fmi.ModuleName, + ModuleIndex: moduleIndex, + IsFirst: moduleIndex == 0, + IsLast: moduleIndex == nbSegmentModule-1, + IsLPP: true, + Columns: make(map[ifaces.ColID]smartvectors.SmartVector), + N0Values: n0, + } + + for _, col := range cols { + + col := runtime.Spec.Columns.GetHandle(col) + + if ModuleOfColumn(fmi.Disc, col) != fmi.ModuleName { + continue + } + + segment := SegmentOfColumn(runtime, fmi.Disc, col, moduleIndex, nbSegmentModule) + moduleWitnessGL.Columns[col.GetColID()] = segment + + if _, ok := fmi.ColumnsLPPSet[col.GetColID()]; ok { + moduleWitnessLPP.Columns[col.GetColID()] = segment + } + } + + witnessesGL[moduleIndex] = moduleWitnessGL + witnessesLPP[moduleIndex] = moduleWitnessLPP + + n0 = moduleWitnessLPP.NextN0s(moduleLPP) + receivedValuesGlobal = moduleWitnessGL.NextReceivedValuesGlobal(moduleGL) + } + + return witnessesGL, witnessesLPP +} + +// NbSegmentOfModule returns the number of segments for a given module +func NbSegmentOfModule(runtime *wizard.ProverRuntime, disc ModuleDiscoverer, moduleName ModuleName) int { + + var ( + cols = runtime.Spec.Columns.AllKeys() + nbSegmentModule = -1 + ) + + for _, col := range cols { + + col := runtime.Spec.Columns.GetHandle(col) + + if ModuleOfColumn(disc, col) != moduleName { + continue + } + + var ( + newSize = NewSizeOfColumn(disc, col) + assignment = col.GetColAssignment(runtime) + density = smartvectors.Density(assignment) + _, orientErr = smartvectors.PaddingOrientationOf(assignment) + nbSegmentCol = utils.DivCeil(newSize, density) + ) + + if orientErr != nil { + // the column cannot be taken into account for the segmentation + continue + } + + nbSegmentModule = max(nbSegmentModule, nbSegmentCol) + } + + if nbSegmentModule == -1 { + utils.Panic("could not resolve the number of segment for module %v", moduleName) + } + + return nbSegmentModule +} + +// SegmentColumn returns the segment of a given column for given index. The +// function also takes a maxNbSegment value which is useful in case +func SegmentOfColumn(runtime *wizard.ProverRuntime, disc ModuleDiscoverer, + col ifaces.Column, index, totalNbSegment int) smartvectors.SmartVector { + + var ( + newSize = NewSizeOfColumn(disc, col) + assignment = col.GetColAssignment(runtime) + orientiation, orientErr = smartvectors.PaddingOrientationOf(assignment) + start = index * newSize + end = start + newSize + ) + + if orientErr != nil { + // If a column is assigned to a plain-vector, then it is assumed to + // be right-padded. The reason for this assumption is that the + // columns from the arithmetization are systematically padded on the + // left while the columns from the prover are all right-padded and the + // sometime they (suboptimally) assigned to plain-vectors. + orientiation = 1 + } + + if orientiation == -1 { + start += assignment.Len() - totalNbSegment*newSize + end += assignment.Len() - totalNbSegment*newSize + } + + return assignment.SubVector(start, end) +} + +// NextN0s returns the next value of N0, from the current one and the witness +// of the current module. +func (mw *ModuleWitness) NextN0s(moduleLPP *ModuleLPP) []int { + + newN0s := append([]int{}, mw.N0Values...) + args := moduleLPP.Horner.Parts + + for i := range newN0s { + + sel := mw.Columns[args[i].Selector.GetColID()].IntoRegVecSaveAlloc() + + for j := range sel { + if sel[j].IsOne() { + newN0s[i]++ + } + } + } + + return newN0s +} + +// NextReceivedValuesGlobal returns the next value of ReceivedValuesGlobal, from +// the witness of the current module. +func (mw *ModuleWitness) NextReceivedValuesGlobal(moduleGL *ModuleGL) []field.Element { + + newReceivedValuesGlobal := make([]field.Element, len(mw.ReceivedValuesGlobal)) + + for i, loc := range moduleGL.SentValuesGlobal { + + var ( + col = column.RootParents(loc.Pol)[0] + pos = column.StackOffsets(loc.Pol) + ) + + pos = utils.PositiveMod(pos, col.Size()) + newReceivedValuesGlobal[i] = mw.Columns[col.GetColID()].Get(pos) + } + + return newReceivedValuesGlobal +}