diff --git a/prover/protocol/distributed/compiler/global/global.go b/prover/protocol/distributed/compiler/global/global.go index 7f08c1cc1..579497dad 100644 --- a/prover/protocol/distributed/compiler/global/global.go +++ b/prover/protocol/distributed/compiler/global/global.go @@ -36,13 +36,23 @@ func DistributeGlobal(in DistributionInputs) { var ( bInputs = boundaryInputs{ - moduleComp: in.ModuleComp, - numSegments: in.NumSegments, - provider: in.ModuleComp.Columns.GetHandle("PROVIDER"), - receiver: in.ModuleComp.Columns.GetHandle("RECEIVER"), - providerOpenings: []query.LocalOpening{}, - receiverOpenings: []query.LocalOpening{}, + moduleComp: in.ModuleComp, + numSegments: in.NumSegments, + + provider: boundaries{ + boundaryCol: in.ModuleComp.Columns.GetHandle("PROVIDER"), + lastPosOnBoundaryCol: 0, + boundaryOpenings: collection.NewMapping[query.LocalOpening, int](), + }, + receiver: boundaries{ + boundaryCol: in.ModuleComp.Columns.GetHandle("RECEIVER"), + boundaryOpenings: collection.NewMapping[query.LocalOpening, int](), + lastPosOnBoundaryCol: 0, + }, } + + provider = bInputs.provider.boundaryCol + receiver = bInputs.receiver.boundaryCol ) for _, qName := range in.InitialComp.QueriesNoParams.AllUnignoredKeys() { @@ -55,7 +65,7 @@ func DistributeGlobal(in DistributionInputs) { if in.Disc.ExpressionIsInModule(q.Expression, in.ModuleName) { // apply global constraint over the segment. - in.ModuleComp.InsertGlobal(0, + in.ModuleComp.InsertGlobal(constants.RoundGL, q.ID, AdjustExpressionForGlobal(in.ModuleComp, q.Expression, in.NumSegments), ) @@ -73,17 +83,17 @@ func DistributeGlobal(in DistributionInputs) { // get the hash of the provider and the receiver var ( - colOnes = verifiercol.NewConstantCol(field.One(), bInputs.provider.Size()) - mimcHasherProvider = edc.NewMIMCHasher(in.ModuleComp, bInputs.provider, colOnes, "MIMC_HASHER_PROVIDER") - mimicHasherReceiver = edc.NewMIMCHasher(in.ModuleComp, bInputs.receiver, colOnes, "MIMC_HASHER_RECEIVER") + colOnes = verifiercol.NewConstantCol(field.One(), provider.Size()) + mimcHasherProvider = edc.NewMIMCHasher(in.ModuleComp, provider, colOnes, "MIMC_HASHER_PROVIDER") + mimcHasherReceiver = edc.NewMIMCHasher(in.ModuleComp, receiver, colOnes, "MIMC_HASHER_RECEIVER") ) mimcHasherProvider.DefineHasher(in.ModuleComp, "DISTRIBUTED_GLOBAL_QUERY_MIMC_HASHER_PROVIDER") - mimcHasherProvider.DefineHasher(in.ModuleComp, "DISTRIBUTED_GLOBAL_QUERY_MIMC_HASHER_RECEIVER") + mimcHasherReceiver.DefineHasher(in.ModuleComp, "DISTRIBUTED_GLOBAL_QUERY_MIMC_HASHER_RECEIVER") var ( - openingHashProvider = in.ModuleComp.InsertLocalOpening(0, "ACCESSOR_FROM_HASH_PROVIDER", mimcHasherProvider.HashFinal) - openingHashReceiver = in.ModuleComp.InsertLocalOpening(0, "ACCESSOR_FROM_HASH_RECEIVER", mimicHasherReceiver.HashFinal) + openingHashProvider = in.ModuleComp.InsertLocalOpening(constants.RoundGL, "ACCESSOR_FROM_HASH_PROVIDER", mimcHasherProvider.HashFinal) + openingHashReceiver = in.ModuleComp.InsertLocalOpening(constants.RoundGL, "ACCESSOR_FROM_HASH_RECEIVER", mimcHasherReceiver.HashFinal) ) // declare the hash of the provider/receiver as the public inputs. @@ -96,30 +106,37 @@ func DistributeGlobal(in DistributionInputs) { in.ModuleComp.PublicInputs = append(in.ModuleComp.PublicInputs, wizard.PublicInput{ Name: constants.GlobalReceiverPublicInput, - Acc: accessors.NewLocalOpeningAccessor(openingHashReceiver, 0), + Acc: accessors.NewLocalOpeningAccessor(openingHashReceiver, constants.RoundGL), }) - in.ModuleComp.RegisterProverAction(0, &proverActionForBoundaries{ - provider: bInputs.provider, - receiver: bInputs.receiver, - providerOpenings: bInputs.providerOpenings, - receiverOpenings: bInputs.receiverOpenings, - - mimicHasherProvider: *mimcHasherProvider, - mimicHasherReceiver: *mimicHasherReceiver, - hashOpeningProvider: openingHashProvider, - hashOpeningReceiver: openingHashReceiver, + in.ModuleComp.RegisterProverAction(constants.RoundGL, &proverActionForBoundaries{ + provider: boundaryAssignments{ + boundaries: bInputs.provider, + hashOpening: openingHashProvider, + mimcHash: *mimcHasherProvider, + }, + + receiver: boundaryAssignments{ + boundaries: bInputs.receiver, + hashOpening: openingHashReceiver, + mimcHash: *mimcHasherReceiver, + }, }) } type boundaryInputs struct { - moduleComp *wizard.CompiledIOP - numSegments int - provider ifaces.Column - receiver ifaces.Column - providerOpenings, receiverOpenings []query.LocalOpening - segID int + moduleComp *wizard.CompiledIOP + provider boundaries + receiver boundaries + numSegments int + segID int +} + +type boundaries struct { + boundaryCol ifaces.Column + boundaryOpenings collection.Mapping[query.LocalOpening, int] + lastPosOnBoundaryCol int } func AdjustExpressionForGlobal( @@ -168,7 +185,7 @@ func AdjustExpressionForGlobal( if m.T > segSize { - panic("unsupported, since this depends on the segment ID, unless the module discoverer can detect such cases") + panic("unsupported") } translationMap.InsertNew(m.String(), symbolic.NewVariable(metadata)) default: @@ -185,10 +202,9 @@ func BoundariesForProvider(in *boundaryInputs, q query.GlobalConstraint) { var ( board = q.Board() offsetRange = q.MinMaxOffset() - provider = in.provider maxShift = offsetRange.Max colsInExpr = distributed.ListColumnsFromExpr(q.Expression, false) - colsOnProvider = onBoundaries(colsInExpr, maxShift) + colsOnProvider = onBoundaries(colsInExpr, maxShift, &in.provider) numBoundaries = offsetRange.Max - offsetRange.Min size = column.ExprIsOnSameLengthHandles(&board) segSize = size / in.numSegments @@ -204,9 +220,9 @@ func BoundariesForProvider(in *boundaryInputs, q query.GlobalConstraint) { // take it via accessor. var ( index = pos[0] + i - name = ifaces.QueryIDf("%v_%v", "FROM_PROVIDER_AT", index) - loProvider = in.moduleComp.InsertLocalOpening(0, name, column.Shift(provider, index)) - accessorProvider = accessors.NewLocalOpeningAccessor(loProvider, 0) + name = ifaces.QueryIDf("%v_%v_%v", q.ID, "FROM_PROVIDER_AT", index) + loProvider = in.moduleComp.InsertLocalOpening(constants.RoundGL, name, column.Shift(in.provider.boundaryCol, index)) + accessorProvider = accessors.NewLocalOpeningAccessor(loProvider, constants.RoundGL) indexOnCol = segSize - (maxShift - column.StackOffsets(col) - i) nameExpr = ifaces.QueryIDf("%v_%v_%v", "CONSISTENCY_AGAINST_PROVIDER", col.GetColID(), i) colInModule ifaces.Column @@ -219,10 +235,10 @@ func BoundariesForProvider(in *boundaryInputs, q query.GlobalConstraint) { colInModule = in.moduleComp.Columns.GetHandle(col.GetColID()) } - // add the localOpening to the list - in.providerOpenings = append(in.providerOpenings, loProvider) + // add the localOpening to the map + in.provider.boundaryOpenings.InsertNew(loProvider, index) // impose that loProvider = loCol - in.moduleComp.InsertLocal(0, nameExpr, + in.moduleComp.InsertLocal(constants.RoundGL, nameExpr, symbolic.Sub(accessorProvider, column.Shift(colInModule, indexOnCol)), ) @@ -237,15 +253,12 @@ func BoundariesForReceiver(in *boundaryInputs, q query.GlobalConstraint) { var ( offsetRange = q.MinMaxOffset() - receiver = in.receiver maxShift = offsetRange.Max colsInExpr = distributed.ListColumnsFromExpr(q.Expression, false) - colsOnReceiver = onBoundaries(colsInExpr, maxShift) + colsOnReceiver = onBoundaries(colsInExpr, maxShift, &in.receiver) numBoundaries = offsetRange.Max - offsetRange.Min comp = in.moduleComp colInModule ifaces.Column - // list of local openings by the boundary index - allLists = make([][]query.LocalOpening, numBoundaries) ) for i := 0; i < numBoundaries; i++ { @@ -269,12 +282,12 @@ func BoundariesForReceiver(in *boundaryInputs, q query.GlobalConstraint) { // take it via accessor. var ( index = pos[0] + i - name = ifaces.QueryIDf("%v_%v", "FROM_RECEIVER_AT", index) - lo = comp.InsertLocalOpening(0, name, column.Shift(receiver, index)) - accessor = accessors.NewLocalOpeningAccessor(lo, 0) + name = ifaces.QueryIDf("%v_%v_%v", q.ID, "FROM_RECEIVER_AT", index) + lo = comp.InsertLocalOpening(constants.RoundGL, name, column.Shift(in.receiver.boundaryCol, index)) + accessor = accessors.NewLocalOpeningAccessor(lo, constants.RoundGL) ) - // add the localOpening to the list - allLists[i] = append(allLists[i], lo) + // add the localOpening to the map + in.receiver.boundaryOpenings.InsertNew(lo, index) // in.receiverOpenings = append(in.receiverOpenings, lo) // translate the column translationMap.InsertNew(string(col.GetColID()), accessor.AsVariable()) @@ -295,28 +308,19 @@ func BoundariesForReceiver(in *boundaryInputs, q query.GlobalConstraint) { if in.segID != 0 || q.NoBoundCancel { expr := q.Expression.Replay(translationMap) name := ifaces.QueryIDf("%v_%v_%v", "CONSISTENCY_AGAINST_RECEIVER", q.ID, i) - comp.InsertLocal(0, name, expr) + comp.InsertLocal(constants.RoundGL, name, expr) } } - // order receiverOpenings column by column - for i := 0; i < numBoundaries; i++ { - for _, list := range allLists { - if len(list) > i { - in.receiverOpenings = append(in.receiverOpenings, list[i]) - } - } - } - } // it indicates the column list having the provider cells (i.e., // some cells of the columns are needed to be provided to the next segment) -func onBoundaries(colsInExpr []ifaces.Column, maxShift int) collection.Mapping[ifaces.ColID, [2]int] { +func onBoundaries(colsInExpr []ifaces.Column, maxShift int, b *boundaries) collection.Mapping[ifaces.ColID, [2]int] { var ( - ctr = 0 + ctr = b.lastPosOnBoundaryCol colsOnReceiver = collection.NewMapping[ifaces.ColID, [2]int]() ) for _, col := range colsInExpr { @@ -334,6 +338,7 @@ func onBoundaries(colsInExpr []ifaces.Column, maxShift int) collection.Mapping[i } + b.lastPosOnBoundaryCol = ctr return colsOnReceiver } diff --git a/prover/protocol/distributed/compiler/global/global_test.go b/prover/protocol/distributed/compiler/global/global_test.go index e3156affb..2eeeb64d4 100644 --- a/prover/protocol/distributed/compiler/global/global_test.go +++ b/prover/protocol/distributed/compiler/global/global_test.go @@ -34,6 +34,8 @@ func TestDistributedGlobal(t *testing.T) { col1 = b.CompiledIOP.InsertCommit(0, "module.col1", 8) col2 = b.CompiledIOP.InsertCommit(0, "module.col2", 8) col3 = b.CompiledIOP.InsertCommit(0, "module.col3", 8) + + fibonacci = b.CompiledIOP.InsertCommit(0, "module.fibo", 16) ) b.CompiledIOP.InsertGlobal(0, "global0", @@ -44,6 +46,13 @@ func TestDistributedGlobal(t *testing.T) { ), ) + b.CompiledIOP.InsertGlobal(0, "fibonacci", + symbolic.Sub( + fibonacci, + column.Shift(fibonacci, -1), + column.Shift(fibonacci, -2)), + ) + } // initialProver @@ -53,6 +62,8 @@ func TestDistributedGlobal(t *testing.T) { run.AssignColumn("module.col2", smartvectors.ForTest(7, 0, 1, 3, 0, 4, 1, 0)) run.AssignColumn("module.col3", smartvectors.ForTest(2, 14, 0, 2, 3, 0, 10, 0)) + run.AssignColumn("module.fibo", smartvectors.ForTest(1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987)) + } // initial compiledIOP is the parent to all the SegmentModuleComp objects. diff --git a/prover/protocol/distributed/compiler/global/prover.go b/prover/protocol/distributed/compiler/global/prover.go index 8c8d6989d..c0a2dbe9b 100644 --- a/prover/protocol/distributed/compiler/global/prover.go +++ b/prover/protocol/distributed/compiler/global/prover.go @@ -1,45 +1,57 @@ package global import ( - "github.com/consensys/linea-monorepo/prover/protocol/ifaces" + "fmt" + + "github.com/consensys/linea-monorepo/prover/maths/common/vector" "github.com/consensys/linea-monorepo/prover/protocol/query" "github.com/consensys/linea-monorepo/prover/protocol/wizard" edc "github.com/consensys/linea-monorepo/prover/zkevm/prover/publicInput/execution_data_collector" ) +type boundaryAssignments struct { + boundaries boundaries + hashOpening query.LocalOpening + mimcHash edc.MIMCHasher +} type proverActionForBoundaries struct { - provider ifaces.Column - receiver ifaces.Column - providerOpenings []query.LocalOpening - receiverOpenings []query.LocalOpening - - hashOpeningProvider query.LocalOpening - hashOpeningReceiver query.LocalOpening - mimicHasherProvider edc.MIMCHasher - mimicHasherReceiver edc.MIMCHasher + provider boundaryAssignments + receiver boundaryAssignments } // it assigns all the LocalOpening covering the boundaries func (pa proverActionForBoundaries) Run(run *wizard.ProverRuntime) { var ( - providerWit = run.GetColumn(pa.provider.GetColID()).IntoRegVecSaveAlloc() - receiverWit = run.GetColumn(pa.receiver.GetColID()).IntoRegVecSaveAlloc() + provider = pa.provider.boundaries.boundaryCol + receiver = pa.receiver.boundaries.boundaryCol + providerOpenings = pa.provider.boundaries.boundaryOpenings + receiverOpenings = pa.receiver.boundaries.boundaryOpenings + + providerWit = run.GetColumn(provider.GetColID()).IntoRegVecSaveAlloc() + receiverWit = run.GetColumn(receiver.GetColID()).IntoRegVecSaveAlloc() ) - for i := range pa.providerOpenings { + fmt.Printf("provider %v\n", vector.Prettify(providerWit)) + fmt.Printf("receiver %v\n", vector.Prettify(receiverWit)) + + for _, loProvider := range providerOpenings.ListAllKeys() { + index := providerOpenings.MustGet(loProvider) + run.AssignLocalPoint(loProvider.ID, providerWit[index]) + } - run.AssignLocalPoint(pa.providerOpenings[i].ID, providerWit[i]) - run.AssignLocalPoint(pa.receiverOpenings[i].ID, receiverWit[i]) + for _, loReceiver := range receiverOpenings.ListAllKeys() { + index := receiverOpenings.MustGet(loReceiver) + run.AssignLocalPoint(loReceiver.ID, receiverWit[index]) } - pa.mimicHasherProvider.AssignHasher(run) - pa.mimicHasherReceiver.AssignHasher(run) + pa.provider.mimcHash.AssignHasher(run) + pa.receiver.mimcHash.AssignHasher(run) var ( - hashProvider = run.GetColumnAt(pa.mimicHasherProvider.HashFinal.GetColID(), 0) - hashReceiver = run.GetColumnAt(pa.mimicHasherReceiver.HashFinal.GetColID(), 0) + hashProvider = run.GetColumnAt(pa.provider.mimcHash.HashFinal.GetColID(), 0) + hashReceiver = run.GetColumnAt(pa.receiver.mimcHash.HashFinal.GetColID(), 0) ) - run.AssignLocalPoint(pa.hashOpeningProvider.ID, hashProvider) - run.AssignLocalPoint(pa.hashOpeningReceiver.ID, hashReceiver) + run.AssignLocalPoint(pa.provider.hashOpening.ID, hashProvider) + run.AssignLocalPoint(pa.receiver.hashOpening.ID, hashReceiver) } diff --git a/prover/protocol/distributed/conglomeration/conglomeration.go b/prover/protocol/distributed/conglomeration/conglomeration.go index 02a016a8a..bd0a88cf6 100644 --- a/prover/protocol/distributed/conglomeration/conglomeration.go +++ b/prover/protocol/distributed/conglomeration/conglomeration.go @@ -129,6 +129,9 @@ func (ctx *recursionCtx) captureCompPreVortex(tmpl *wizard.CompiledIOP) { _ = tmpl.GetPublicInputAccessor(constants.GrandProductPublicInput) _ = tmpl.GetPublicInputAccessor(constants.GrandSumPublicInput) _ = tmpl.GetPublicInputAccessor(constants.LogDerivativeSumPublicInput) + + _ = tmpl.GetPublicInputAccessor(constants.GlobalProviderPublicInput) + _ = tmpl.GetPublicInputAccessor(constants.GlobalReceiverPublicInput) ) ctx.LastRound = lastRound diff --git a/prover/protocol/distributed/conglomeration/conglomeration_test.go b/prover/protocol/distributed/conglomeration/conglomeration_test.go index d40bc10c6..d154dfb22 100644 --- a/prover/protocol/distributed/conglomeration/conglomeration_test.go +++ b/prover/protocol/distributed/conglomeration/conglomeration_test.go @@ -90,6 +90,9 @@ func TestConglomerationPureVortexSingleRound(t *testing.T) { builder.InsertPublicInput(constants.GrandProductPublicInput, accessors.NewConstant(field.NewElement(1))) builder.InsertPublicInput(constants.GrandSumPublicInput, accessors.NewConstant(field.NewElement(0))) builder.InsertPublicInput(constants.LogDerivativeSumPublicInput, accessors.NewConstant(field.NewElement(0))) + + builder.InsertPublicInput(constants.GlobalProviderPublicInput, accessors.NewConstant(field.NewElement(0))) + builder.InsertPublicInput(constants.GlobalReceiverPublicInput, accessors.NewConstant(field.NewElement(0))) } prover := func(k int) func(run *wizard.ProverRuntime) { @@ -172,6 +175,9 @@ func TestConglomerationPureVortexMultiRound(t *testing.T) { builder.InsertPublicInput(constants.GrandProductPublicInput, accessors.NewConstant(field.NewElement(1))) builder.InsertPublicInput(constants.GrandSumPublicInput, accessors.NewConstant(field.NewElement(0))) builder.InsertPublicInput(constants.LogDerivativeSumPublicInput, accessors.NewConstant(field.NewElement(0))) + + builder.InsertPublicInput(constants.GlobalProviderPublicInput, accessors.NewConstant(field.NewElement(0))) + builder.InsertPublicInput(constants.GlobalReceiverPublicInput, accessors.NewConstant(field.NewElement(0))) } prover := func(k int) func(run *wizard.ProverRuntime) { @@ -232,6 +238,9 @@ func TestConglomerationLookup(t *testing.T) { builder.InsertPublicInput(constants.GrandProductPublicInput, accessors.NewConstant(field.NewElement(1))) builder.InsertPublicInput(constants.GrandSumPublicInput, accessors.NewConstant(field.NewElement(0))) builder.InsertPublicInput(constants.LogDerivativeSumPublicInput, accessors.NewConstant(field.NewElement(0))) + + builder.InsertPublicInput(constants.GlobalProviderPublicInput, accessors.NewConstant(field.NewElement(0))) + builder.InsertPublicInput(constants.GlobalReceiverPublicInput, accessors.NewConstant(field.NewElement(0))) } prover := func(k int) func(run *wizard.ProverRuntime) { diff --git a/prover/protocol/distributed/conglomeration/cross_segment_consistency.go b/prover/protocol/distributed/conglomeration/cross_segment_consistency.go index 911c2ef06..639faef83 100644 --- a/prover/protocol/distributed/conglomeration/cross_segment_consistency.go +++ b/prover/protocol/distributed/conglomeration/cross_segment_consistency.go @@ -30,19 +30,28 @@ func (pir *CrossSegmentCheck) Run(run wizard.Runtime) error { err error ) - for _, ctx := range pir.Ctxs { + for i, ctx := range pir.Ctxs { var ( wrappedRun = &runtimeTranslator{Prefix: ctx.Translator.Prefix, Rt: run} tmpl = ctx.Tmpl + nextTmpl = pir.Ctxs[(i+1)%len(pir.Ctxs)].Tmpl logDerivSum = tmpl.GetPublicInputAccessor(constants.LogDerivativeSumPublicInput).GetVal(wrappedRun) grandProd = tmpl.GetPublicInputAccessor(constants.GrandProductPublicInput).GetVal(wrappedRun) grandSum = tmpl.GetPublicInputAccessor(constants.GrandSumPublicInput).GetVal(wrappedRun) + + providerHash = tmpl.GetPublicInputAccessor(constants.GlobalProviderPublicInput).GetVal(wrappedRun) + nextReceiverHash = nextTmpl.GetPublicInputAccessor(constants.GlobalReceiverPublicInput).GetVal(wrappedRun) ) logDerivSumAcc.Add(&logDerivSumAcc, &logDerivSum) grandSumAcc.Add(&grandSumAcc, &grandSum) grandProductAcc.Mul(&grandProductAcc, &grandProd) + + if providerHash != nextReceiverHash { + err = errors.Join(err, fmt.Errorf("error in crosse checks for distributed global: "+ + "the provider of the current template is different from the receiver of the next template ")) + } } if logDerivSumAcc != field.Zero() { @@ -76,19 +85,26 @@ func (pir *CrossSegmentCheck) RunGnark(api frontend.API, run wizard.GnarkRuntime grandProductAcc = frontend.Variable(1) ) - for _, ctx := range pir.Ctxs { + for i, ctx := range pir.Ctxs { var ( wrappedRun = &gnarkRuntimeTranslator{Prefix: ctx.Translator.Prefix, Rt: run} tmpl = ctx.Tmpl + nextTmpl = pir.Ctxs[(i+1)%len(pir.Ctxs)].Tmpl logDerivSum = tmpl.GetPublicInputAccessor(constants.LogDerivativeSumPublicInput).GetFrontendVariable(api, wrappedRun) grandProd = tmpl.GetPublicInputAccessor(constants.GrandProductPublicInput).GetFrontendVariable(api, wrappedRun) grandSum = tmpl.GetPublicInputAccessor(constants.GrandSumPublicInput).GetFrontendVariable(api, wrappedRun) + + providerHash = tmpl.GetPublicInputAccessor(constants.GlobalProviderPublicInput).GetFrontendVariable(api, wrappedRun) + nextReceiverHash = nextTmpl.GetPublicInputAccessor(constants.GlobalReceiverPublicInput).GetFrontendVariable(api, wrappedRun) ) logDerivSumAcc = api.Add(logDerivSumAcc, logDerivSum) grandSumAcc = api.Add(grandSumAcc, grandSum) grandProductAcc = api.Mul(grandProductAcc, grandProd) + + api.AssertIsEqual(providerHash, nextReceiverHash) + } api.AssertIsEqual(logDerivSumAcc, field.Zero()) diff --git a/prover/protocol/distributed/constants/constant.go b/prover/protocol/distributed/constants/constant.go index 89c87d5c7..96b36ace0 100644 --- a/prover/protocol/distributed/constants/constant.go +++ b/prover/protocol/distributed/constants/constant.go @@ -8,3 +8,8 @@ const ( GlobalProviderPublicInput = "GLOBAL_PROVIDER_PUBLIC_INPUT" GlobalReceiverPublicInput = "GLOBAL_RECEIVER_PUBLIC_INPUT" ) + +const ( + RoundLPP = 0 + RoundGL = 0 +)