Skip to content

Commit

Permalink
Use raw bytes for SRS points to save half storage (Layr-Labs#238)
Browse files Browse the repository at this point in the history
  • Loading branch information
jianoaix authored Feb 4, 2024
1 parent 25f6f0a commit 832c156
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 55 deletions.
Binary file modified inabox/resources/kzg/g1.point
Binary file not shown.
Binary file modified inabox/resources/kzg/g1.point.300000
Binary file not shown.
Binary file modified inabox/resources/kzg/g2.point
Binary file not shown.
Binary file modified inabox/resources/kzg/g2.point.300000
Binary file not shown.
7 changes: 4 additions & 3 deletions pkg/encoding/kzgEncoder/precomputeSRS.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"sync"
"time"

"github.com/Layr-Labs/eigenda/pkg/encoding/utils"
kzg "github.com/Layr-Labs/eigenda/pkg/kzg"
bls "github.com/Layr-Labs/eigenda/pkg/kzg/bn254"
)
Expand Down Expand Up @@ -226,7 +227,7 @@ func (p *SRSTable) TableReaderThreads(filePath string, dimE, l uint64, numWorker
}()

// 2 due to circular FFT mul
subTableSize := dimE * 2 * 64
subTableSize := dimE * 2 * utils.G1PointBytes
totalSubTableSize := subTableSize * l

if numWorker > l {
Expand All @@ -236,7 +237,7 @@ func (p *SRSTable) TableReaderThreads(filePath string, dimE, l uint64, numWorker
reader := bufio.NewReaderSize(g1f, int(totalSubTableSize+l))
buf := make([]byte, totalSubTableSize+l)
if _, err := io.ReadFull(reader, buf); err != nil {
log.Println("TableReaderThreads.ERR.1", err)
log.Println("TableReaderThreads.ERR.1", err, "file path:", filePath)
return nil, err
}

Expand Down Expand Up @@ -280,7 +281,7 @@ func (p *SRSTable) readWorker(
for b := range jobChan {
slicePoints := make([]bls.G1Point, dimE*2)
for i := uint64(0); i < dimE*2; i++ {
g1 := buf[b.start+i*64 : b.start+(i+1)*64]
g1 := buf[b.start+i*utils.G1PointBytes : b.start+(i+1)*utils.G1PointBytes]
err := slicePoints[i].UnmarshalText(g1[:])
if err != nil {
log.Printf("Error. From %v to %v. %v", b.start, b.end, err)
Expand Down
68 changes: 34 additions & 34 deletions pkg/encoding/utils/pointsIO.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package utils

import (
"bufio"
"io"
"log"
"os"
"sync"
Expand All @@ -10,11 +11,31 @@ import (
bls "github.com/Layr-Labs/eigenda/pkg/kzg/bn254"
)

const (
// Num of bytes per G1 point in serialized format in file.
G1PointBytes = 32
// Num of bytes per G2 point in serialized format in file.
G2PointBytes = 64
)

type EncodeParams struct {
NumNodeE uint64
ChunkLenE uint64
}

// ReadDesiredBytes reads exactly numBytesToRead bytes from the reader and returns
// the result.
func ReadDesiredBytes(reader *bufio.Reader, numBytesToRead uint64) ([]byte, error) {
buf := make([]byte, numBytesToRead)
_, err := io.ReadFull(reader, buf)
// Note that ReadFull() guarantees the bytes read is len(buf) IFF err is nil.
// See https://pkg.go.dev/io#ReadFull.
if err != nil {
return nil, err
}
return buf, nil
}

func ReadG1Points(filepath string, n uint64, numWorker uint64) ([]bls.G1Point, error) {
g1f, err := os.Open(filepath)
if err != nil {
Expand All @@ -30,28 +51,21 @@ func ReadG1Points(filepath string, n uint64, numWorker uint64) ([]bls.G1Point, e
}()

startTimer := time.Now()
g1r := bufio.NewReaderSize(g1f, int(n*64))
g1r := bufio.NewReaderSize(g1f, int(n*G1PointBytes))

if n < numWorker {
numWorker = n
}

buf, _, err := g1r.ReadLine()
buf, err := ReadDesiredBytes(g1r, n*G1PointBytes)
if err != nil {
return nil, err
}

if uint64(len(buf)) < 64*n {
log.Printf("Error. Insufficient G1 points. Only contains %v. Requesting %v\n", len(buf)/64, n)
log.Println()
log.Println("ReadG1Points.ERR.1", err)
return nil, err
}

// measure reading time
t := time.Now()
elapsed := t.Sub(startTimer)
log.Printf(" Reading G1 points (%v bytes) takes %v\n", (n * 64), elapsed)
log.Printf(" Reading G1 points (%v bytes) takes %v\n", (n * G1PointBytes), elapsed)
startTimer = time.Now()

s1Outs := make([]bls.G1Point, n)
Expand All @@ -73,7 +87,7 @@ func ReadG1Points(filepath string, n uint64, numWorker uint64) ([]bls.G1Point, e
}
//fmt.Printf("worker %v start %v end %v. size %v\n", i, start, end, end - start)
//todo: handle error?
go readG1Worker(buf, s1Outs, start, end, 64, &wg)
go readG1Worker(buf, s1Outs, start, end, G1PointBytes, &wg)
}
wg.Wait()

Expand Down Expand Up @@ -101,9 +115,9 @@ func ReadG1PointSection(filepath string, from, to uint64, numWorker uint64) ([]b
n := to - from

startTimer := time.Now()
g1r := bufio.NewReaderSize(g1f, int(to*64))
g1r := bufio.NewReaderSize(g1f, int(to*G1PointBytes))

_, err = g1f.Seek(int64(from*64), 0)
_, err = g1f.Seek(int64(from)*G1PointBytes, 0)
if err != nil {
return nil, err
}
Expand All @@ -112,22 +126,15 @@ func ReadG1PointSection(filepath string, from, to uint64, numWorker uint64) ([]b
numWorker = n
}

buf, _, err := g1r.ReadLine()
buf, err := ReadDesiredBytes(g1r, n*G1PointBytes)
if err != nil {
return nil, err
}

if uint64(len(buf)) < 64*n {
log.Printf("Error. Insufficient G1 points. Only contains %v. Requesting %v\n", len(buf)/64, n)
log.Println()
log.Println("ReadG1PointSection.ERR.1", err)
return nil, err
}

// measure reading time
t := time.Now()
elapsed := t.Sub(startTimer)
log.Printf(" Reading G1 points (%v bytes) takes %v\n", (n * 64), elapsed)
log.Printf(" Reading G1 points (%v bytes) takes %v\n", (n * G1PointBytes), elapsed)
startTimer = time.Now()

s1Outs := make([]bls.G1Point, n)
Expand All @@ -148,7 +155,7 @@ func ReadG1PointSection(filepath string, from, to uint64, numWorker uint64) ([]b
end = (i + 1) * size
}
//todo: handle error?
go readG1Worker(buf, s1Outs, start, end, 64, &wg)
go readG1Worker(buf, s1Outs, start, end, G1PointBytes, &wg)
}
wg.Wait()

Expand Down Expand Up @@ -210,28 +217,21 @@ func ReadG2Points(filepath string, n uint64, numWorker uint64) ([]bls.G2Point, e
}()

startTimer := time.Now()
g1r := bufio.NewReaderSize(g1f, int(n*128))
g1r := bufio.NewReaderSize(g1f, int(n*G2PointBytes))

if n < numWorker {
numWorker = n
}

buf, _, err := g1r.ReadLine()
buf, err := ReadDesiredBytes(g1r, n*G2PointBytes)
if err != nil {
return nil, err
}

if uint64(len(buf)) < 128*n {
log.Printf("Error. Insufficient G1 points. Only contains %v. Requesting %v\n", len(buf)/128, n)
log.Println()
log.Println("ReadG2Points.ERR.1", err)
return nil, err
}

// measure reading time
t := time.Now()
elapsed := t.Sub(startTimer)
log.Printf(" Reading G2 points (%v bytes) takes %v\n", (n * 128), elapsed)
log.Printf(" Reading G2 points (%v bytes) takes %v\n", (n * G2PointBytes), elapsed)

startTimer = time.Now()

Expand All @@ -252,7 +252,7 @@ func ReadG2Points(filepath string, n uint64, numWorker uint64) ([]bls.G2Point, e
end = (i + 1) * size
}
//todo: handle error?
go readG2Worker(buf, s2Outs, start, end, 128, &wg)
go readG2Worker(buf, s2Outs, start, end, G2PointBytes, &wg)
}
wg.Wait()

Expand Down
17 changes: 4 additions & 13 deletions pkg/kzg/bn254/bn254_all.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,19 @@
package bn254

import (
"encoding/hex"
"errors"
)

func (p *G1Point) MarshalText() []byte {
return []byte(hex.EncodeToString(ToCompressedG1(p)))
return ToCompressedG1(p)
}

// UnmarshalText decodes hex formatted text (no 0x prefix) into a G1Point
func (p *G1Point) UnmarshalText(text []byte) error {
if p == nil {
return errors.New("cannot decode into nil G1Point")
}
data, err := hex.DecodeString(string(text))
if err != nil {
return err
}
d, err := FromCompressedG1(data)
d, err := FromCompressedG1(text)
if err != nil {
return err
}
Expand All @@ -52,19 +47,15 @@ func (p *G1Point) UnmarshalText(text []byte) error {

// MarshalText encodes G2Point into hex formatted text (no 0x prefix)
func (p *G2Point) MarshalText() []byte {
return []byte(hex.EncodeToString(ToCompressedG2(p)))
return ToCompressedG2(p)
}

// UnmarshalText decodes hex formatted text (no 0x prefix) into a G2Point
func (p *G2Point) UnmarshalText(text []byte) error {
if p == nil {
return errors.New("cannot decode into nil G2Point")
}
data, err := hex.DecodeString(string(text))
if err != nil {
return err
}
d, err := FromCompressedG2(data)
d, err := FromCompressedG2(text)
if err != nil {
return err
}
Expand Down
15 changes: 10 additions & 5 deletions pkg/kzg/bn254/bn254_all_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ func TestPointG1Marshalling(t *testing.T) {

bytes := point.MarshalText()

// The point is serialized to raw bytes and it should be 32 bytes.
assert.Equal(t, len(bytes), 32)

var anotherPoint G1Point
err := anotherPoint.UnmarshalText(bytes)
require.Nil(t, err)
Expand All @@ -52,10 +55,10 @@ func TestPointG1Marshalling_InvalidG1(t *testing.T) {

g1 = new(G1Point)
err = g1.UnmarshalText([]byte("G"))
assert.EqualError(t, err, "encoding/hex: invalid byte: U+0047 'G'")
assert.EqualError(t, err, "short buffer")

err = g1.UnmarshalText([]byte("8000000000000000000000000000000000000000000000000000000000000099"))
assert.EqualError(t, err, "invalid compressed coordinate: square root doesn't exist")
assert.EqualError(t, err, "invalid fp.Element encoding")
}

func TestPointG2Marshalling(t *testing.T) {
Expand All @@ -65,6 +68,8 @@ func TestPointG2Marshalling(t *testing.T) {
MulG2(&point, &GenG2, &x)

bytes := point.MarshalText()
// The point is serialized to raw bytes and it should be 64 bytes (2x the G1).
assert.Equal(t, len(bytes), 64)

var anotherPoint G2Point
err := anotherPoint.UnmarshalText(bytes)
Expand All @@ -80,12 +85,12 @@ func TestPointG2Marshalling_InvalidG2(t *testing.T) {

g2 = new(G2Point)
err = g2.UnmarshalText([]byte("G"))
assert.EqualError(t, err, "encoding/hex: invalid byte: U+0047 'G'")
assert.EqualError(t, err, "short buffer")

err = g2.UnmarshalText([]byte("898e9393920d483a7260bfb731fb5d25f1aa493335a9e71297e485b7aef312c21800deef121f1e76426a00665e5c4479674322d4f75edadd46debd5cd992f6ed"))
assert.EqualError(t, err, "invalid point: subgroup check failed")
assert.EqualError(t, err, "invalid fp.Element encoding")

err = g2.UnmarshalText([]byte("998e9393920d483a7260bfb731fb5d25f1aa493335a9e71297e485b7aef312c21800deef121f1e76426a00665e5c4479674322d4f75edadd46debd5cd992ffff"))
assert.EqualError(t, err, "invalid compressed coordinate: square root doesn't exist")
assert.EqualError(t, err, "invalid fp.Element encoding")

}

0 comments on commit 832c156

Please sign in to comment.