Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

eds: make Size method return error #4091

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion share/eds/accessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ var EmptyAccessor = &Rsmt2D{ExtendedDataSquare: share.EmptyEDS()}
// Accessor is an interface for accessing extended data square data.
type Accessor interface {
// Size returns square size of the Accessor.
Size(ctx context.Context) int
Size(ctx context.Context) (int, error)
// DataHash returns data hash of the Accessor.
DataHash(ctx context.Context) (share.DataHash, error)
// AxisRoots returns share.AxisRoots (DataAvailabilityHeader) of the Accessor.
Expand Down
4 changes: 2 additions & 2 deletions share/eds/close_once.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ func (c *closeOnce) Close() error {
return err
}

func (c *closeOnce) Size(ctx context.Context) int {
func (c *closeOnce) Size(ctx context.Context) (int, error) {
if c.closed.Load() {
return 0
return 0, errAccessorClosed
}
return c.f.Size(ctx)
}
Expand Down
4 changes: 2 additions & 2 deletions share/eds/close_once_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ type stubEdsAccessorCloser struct {
closed bool
}

func (s *stubEdsAccessorCloser) Size(context.Context) int {
return 0
func (s *stubEdsAccessorCloser) Size(context.Context) (int, error) {
return 0, nil
}

func (s *stubEdsAccessorCloser) DataHash(context.Context) (share.DataHash, error) {
Expand Down
34 changes: 25 additions & 9 deletions share/eds/proofs_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,17 @@ func WithProofsCache(ac AccessorStreamer) AccessorStreamer {
}
}

func (c *proofsCache) Size(ctx context.Context) int {
func (c *proofsCache) Size(ctx context.Context) (int, error) {
size := c.size.Load()
if size == 0 {
size = int32(c.inner.Size(ctx))
c.size.Store(size)
loaded, err := c.inner.Size(ctx)
if err != nil {
return 0, fmt.Errorf("loading size from inner accessor: %w", err)
}
c.size.Store(int32(loaded))
return loaded, nil
}
return int(size)
return int(size), nil
}

func (c *proofsCache) DataHash(ctx context.Context) (share.DataHash, error) {
Expand Down Expand Up @@ -121,7 +125,11 @@ func (c *proofsCache) Sample(ctx context.Context, idx shwap.SampleCoords) (shwap

// build share proof from proofs cached for given axis
share := ax.shares[shrIdx]
proofs, err := ipld.GetProof(ctx, ax.proofs, ax.root, shrIdx, c.Size(ctx))
size, err := c.Size(ctx)
if err != nil {
return shwap.Sample{}, fmt.Errorf("getting size: %w", err)
}
proofs, err := ipld.GetProof(ctx, ax.proofs, ax.root, shrIdx, size)
if err != nil {
return shwap.Sample{}, fmt.Errorf("building proof from cache: %w", err)
}
Expand Down Expand Up @@ -159,9 +167,13 @@ func (c *proofsCache) axisWithProofs(ctx context.Context, axisType rsmt2d.Axis,
}

// build proofs from Shares and cache them
adder := ipld.NewProofsAdder(c.Size(ctx), true)
size, err := c.Size(ctx)
if err != nil {
return axisWithProofs{}, fmt.Errorf("getting size: %w", err)
}
adder := ipld.NewProofsAdder(size, true)
tree := wrapper.NewErasuredNamespacedMerkleTree(
uint64(c.Size(ctx)/2),
uint64(size/2),
uint(axisIdx),
nmt.NodeVisitor(adder.VisitFn()),
)
Expand Down Expand Up @@ -233,9 +245,13 @@ func (c *proofsCache) RowNamespaceData(
}

func (c *proofsCache) Shares(ctx context.Context) ([]libshare.Share, error) {
odsSize := c.Size(ctx) / 2
size, err := c.Size(ctx)
if err != nil {
return nil, fmt.Errorf("getting size: %w", err)
}
odsSize := size / 2
shares := make([]libshare.Share, 0, odsSize*odsSize)
for i := 0; i < c.Size(ctx)/2; i++ {
for i := 0; i < size/2; i++ {
ax, err := c.AxisHalf(ctx, rsmt2d.Row, i)
if err != nil {
return nil, err
Expand Down
7 changes: 5 additions & 2 deletions share/eds/rsmt2d.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ type Rsmt2D struct {
}

// Size returns the size of the Extended Data Square.
func (eds *Rsmt2D) Size(context.Context) int {
return int(eds.Width())
func (eds *Rsmt2D) Size(context.Context) (int, error) {
if eds.ExtendedDataSquare == nil {
return 0, fmt.Errorf("extended data square is not initialized")
}
return int(eds.Width()), nil
}

// DataHash returns data hash of the Accessor.
Expand Down
52 changes: 46 additions & 6 deletions share/eds/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,7 @@ func testAccessorRowNamespaceData(

// check that the amount of shares in the namespace is equal to the expected amount
require.Equal(t, amount, actualSharesAmount)
}
})
})

t.Run("not included", func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -279,7 +278,7 @@ func testAccessorRowNamespaceData(

err = rowData.Verify(roots, absentNs, i)
require.NoError(t, err)
}
})
})
}

Expand Down Expand Up @@ -472,9 +471,19 @@ func (q quadrantIdx) String() string {
}

func (q quadrantIdx) coordinates(edsSize int) (rowIdx, colIdx int) {
colIdx = edsSize/2*(int(q-1)%2) + 1
rowIdx = edsSize/2*(int(q-1)/2) + 1
return rowIdx, colIdx
half := edsSize / 2
switch q {
case q1:
return rand.IntN(half), rand.IntN(half)
case q2:
return rand.IntN(half), half + rand.IntN(half)
case q3:
return half + rand.IntN(half), rand.IntN(half)
case q4:
return half + rand.IntN(half), half + rand.IntN(half)
default:
panic("invalid quadrant")
}
}

func checkPowerOfTwo(n int) bool {
Expand All @@ -484,3 +493,34 @@ func checkPowerOfTwo(n int) bool {
}
return n&(n-1) == 0
}

func TestAccessorSampling(t *testing.T) {
ctx := context.Background()
acc := NewRandAccessor(t, 8)
defer acc.Close()

size, err := acc.Size(ctx)
require.NoError(t, err)

for squareHalf := 0; squareHalf < 2; squareHalf++ {
for axisType := range []rsmt2d.Axis{rsmt2d.Row, rsmt2d.Col} {
_, err := acc.AxisHalf(ctx, axisType, size/2*(squareHalf))
require.NoError(t, err)
}
}
}

func TestAccessorSamplingQuadrants(t *testing.T) {
ctx := context.Background()
acc := NewRandAccessor(t, 8)
defer acc.Close()

size, err := acc.Size(ctx)
require.NoError(t, err)

for q := range []quadrant{q1, q2, q3, q4} {
rowIdx, colIdx := q.coordinates(size)
_, err := acc.Sample(ctx, shwap.SampleCoords{Row: rowIdx, Col: colIdx})
require.NoError(t, err)
}
}
29 changes: 22 additions & 7 deletions share/eds/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,37 @@ func WithValidation(f Accessor) Accessor {
return &validation{Accessor: f, size: new(atomic.Int32)}
}

func (f validation) Size(ctx context.Context) int {
func (f validation) Size(ctx context.Context) (int, error) {
size := f.size.Load()
if size == 0 {
loaded := f.Accessor.Size(ctx)
loaded, err := f.Accessor.Size(ctx)
if err != nil {
return 0, fmt.Errorf("loading size: %w", err)
}
f.size.Store(int32(loaded))
return loaded
return loaded, nil
}
return int(size)
return int(size), nil
}

func (f validation) Sample(ctx context.Context, idx shwap.SampleCoords) (shwap.Sample, error) {
_, err := shwap.NewSampleID(1, idx, f.Size(ctx))
size, err := f.Size(ctx)
if err != nil {
return shwap.Sample{}, fmt.Errorf("getting size: %w", err)
}
_, err = shwap.NewSampleID(1, idx, size)
if err != nil {
return shwap.Sample{}, fmt.Errorf("sample validation: %w", err)
}
return f.Accessor.Sample(ctx, idx)
}

func (f validation) AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error) {
_, err := shwap.NewRowID(1, axisIdx, f.Size(ctx))
size, err := f.Size(ctx)
if err != nil {
return AxisHalf{}, fmt.Errorf("getting size: %w", err)
}
_, err = shwap.NewRowID(1, axisIdx, size)
if err != nil {
return AxisHalf{}, fmt.Errorf("axis half validation: %w", err)
}
Expand All @@ -55,7 +66,11 @@ func (f validation) RowNamespaceData(
namespace libshare.Namespace,
rowIdx int,
) (shwap.RowNamespaceData, error) {
_, err := shwap.NewRowNamespaceDataID(1, rowIdx, namespace, f.Size(ctx))
size, err := f.Size(ctx)
if err != nil {
return shwap.RowNamespaceData{}, fmt.Errorf("getting size: %w", err)
}
_, err = shwap.NewRowNamespaceDataID(1, rowIdx, namespace, size)
if err != nil {
return shwap.RowNamespaceData{}, fmt.Errorf("row namespace data validation: %w", err)
}
Expand Down
5 changes: 3 additions & 2 deletions store/cache/accessor_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,11 @@ type mockAccessor struct {
m sync.Mutex
data []byte
isClosed bool
size int
}

func (m *mockAccessor) Size(context.Context) int {
panic("implement me")
func (m *mockAccessor) Size(context.Context) (int, error) {
return m.size, nil
}

func (m *mockAccessor) DataHash(context.Context) (share.DataHash, error) {
Expand Down
4 changes: 2 additions & 2 deletions store/cache/noop.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ func (n NoopFile) Reader() (io.Reader, error) {
return noopReader{}, nil
}

func (n NoopFile) Size(context.Context) int {
return 0
func (n NoopFile) Size(context.Context) (int, error) {
return 0, nil
}

func (n NoopFile) DataHash(context.Context) (share.DataHash, error) {
Expand Down
14 changes: 11 additions & 3 deletions store/file/ods.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,11 @@ func OpenODS(path string) (*ODS, error) {
}

// Size returns EDS size stored in file's header.
func (o *ODS) Size(context.Context) int {
return o.size()
func (o *ODS) Size(context.Context) (int, error) {
if o.hdr == nil {
return 0, fmt.Errorf("header is not initialized")
}
return o.size(), nil
}

func (o *ODS) size() int {
Expand Down Expand Up @@ -238,8 +241,13 @@ func (o *ODS) Sample(ctx context.Context, idx shwap.SampleCoords) (shwap.Sample,
// to calculate the sample
rowIdx, colIdx := idx.Row, idx.Col

size, err := o.Size(ctx)
if err != nil {
return shwap.Sample{}, fmt.Errorf("getting size: %w", err)
}

axisType, axisIdx, shrIdx := rsmt2d.Row, rowIdx, colIdx
if colIdx < o.size()/2 && rowIdx >= o.size()/2 {
if colIdx < size/2 && rowIdx >= size/2 {
axisType, axisIdx, shrIdx = rsmt2d.Col, colIdx, rowIdx
}

Expand Down
7 changes: 5 additions & 2 deletions store/file/ods_q4.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (odsq4 *ODSQ4) tryLoadQ4() *q4 {
return q4
}

func (odsq4 *ODSQ4) Size(ctx context.Context) int {
func (odsq4 *ODSQ4) Size(ctx context.Context) (int, error) {
return odsq4.ods.Size(ctx)
}

Expand All @@ -137,7 +137,10 @@ func (odsq4 *ODSQ4) Sample(ctx context.Context, idx shwap.SampleCoords) (shwap.S
}

func (odsq4 *ODSQ4) AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (eds.AxisHalf, error) {
size := odsq4.Size(ctx) // TODO(@Wondertan): Should return error.
size, err := odsq4.Size(ctx)
if err != nil {
return nil, fmt.Errorf("getting size: %w", err)
}

if axisIdx >= size/2 {
// lazy load Q4 file and read axis from it if loaded
Expand Down