Skip to content

Commit 57e379e

Browse files
committed
bug fix marshal unmarshal batch data
1 parent 79bd576 commit 57e379e

File tree

3 files changed

+123
-40
lines changed

3 files changed

+123
-40
lines changed

executor/batch/handler.go

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
package batch
22

33
import (
4+
"bytes"
45
"context"
5-
"crypto/sha256"
66
"fmt"
77
"io"
88
"time"
@@ -138,6 +138,8 @@ func (bs *BatchSubmitter) prepareBatch(blockHeight uint64) error {
138138
bs.localBatchInfo.BatchFileSize = 0
139139
bs.localBatchInfo.Start = blockHeight
140140
bs.localBatchInfo.End = 0
141+
142+
bs.batchWriter.Reset(bs.batchFile)
141143
}
142144
return nil
143145
}
@@ -158,7 +160,11 @@ func (bs *BatchSubmitter) finalizeBatch(ctx context.Context, blockHeight uint64)
158160
if err != nil {
159161
return errors.Wrap(err, "failed to write raw commit")
160162
}
161-
fileSize, err := bs.batchFileSize()
163+
err = bs.batchWriter.Close()
164+
if err != nil {
165+
return errors.Wrap(err, "failed to close batch writer")
166+
}
167+
fileSize, err := bs.batchFileSize(false)
162168
if err != nil {
163169
return err
164170
}
@@ -169,7 +175,7 @@ func (bs *BatchSubmitter) finalizeBatch(ctx context.Context, blockHeight uint64)
169175

170176
// TODO: improve this logic to avoid hold all the batch data in memory
171177
chunks := make([][]byte, 0)
172-
for offset := int64(0); ; offset += int64(bs.batchCfg.MaxChunkSize) {
178+
for offset := int64(0); ; {
173179
readLength, err := bs.batchFile.ReadAt(batchBuffer, offset)
174180
if err != nil && err != io.EOF {
175181
return err
@@ -178,15 +184,15 @@ func (bs *BatchSubmitter) finalizeBatch(ctx context.Context, blockHeight uint64)
178184
}
179185

180186
// trim the buffer to the actual read length
181-
chunk := make([]byte, readLength)
182-
copy(chunk, batchBuffer[:readLength])
187+
chunk := bytes.Clone(batchBuffer[:readLength])
183188
chunks = append(chunks, chunk)
184189

185-
checksum := sha256.Sum256(batchBuffer)
190+
checksum := executortypes.GetChecksumFromChunk(chunk)
186191
checksums = append(checksums, checksum[:])
187192
if uint64(readLength) < bs.batchCfg.MaxChunkSize {
188193
break
189194
}
195+
offset += int64(readLength)
190196
}
191197

192198
headerData := executortypes.MarshalBatchDataHeader(
@@ -236,7 +242,7 @@ func (bs *BatchSubmitter) finalizeBatch(ctx context.Context, blockHeight uint64)
236242
}
237243

238244
func (bs *BatchSubmitter) checkBatch(ctx context.Context, blockHeight uint64, latestHeight uint64, blockTime time.Time) error {
239-
fileSize, err := bs.batchFileSize()
245+
fileSize, err := bs.batchFileSize(true)
240246
if err != nil {
241247
return err
242248
}
@@ -263,13 +269,15 @@ func (bs *BatchSubmitter) checkBatch(ctx context.Context, blockHeight uint64, la
263269
return nil
264270
}
265271

266-
func (bs *BatchSubmitter) batchFileSize() (int64, error) {
272+
func (bs *BatchSubmitter) batchFileSize(flush bool) (int64, error) {
267273
if bs.batchFile == nil {
268274
return 0, errors.New("batch file is not initialized")
269275
}
270-
err := bs.batchWriter.Flush()
271-
if err != nil {
272-
return 0, errors.Wrap(err, "failed to flush batch writer")
276+
if flush {
277+
err := bs.batchWriter.Flush()
278+
if err != nil {
279+
return 0, errors.Wrap(err, "failed to flush batch writer")
280+
}
273281
}
274282

275283
info, err := bs.batchFile.Stat()

executor/types/batch.go

Lines changed: 66 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ package types
22

33
import (
44
"context"
5+
"crypto/sha256"
56
"encoding/binary"
6-
"errors"
7+
"fmt"
78
"time"
89

910
btypes "github.com/initia-labs/opinit-bots/node/broadcaster/types"
@@ -39,41 +40,67 @@ const (
3940
BatchDataTypeChunk
4041
)
4142

43+
type BatchDataHeader struct {
44+
Start uint64
45+
End uint64
46+
Checksums [][]byte
47+
}
48+
49+
type BatchDataChunk struct {
50+
Start uint64
51+
End uint64
52+
Index uint64
53+
Length uint64
54+
ChunkData []byte
55+
}
56+
57+
func GetChecksumFromChunk(chunk []byte) [32]byte {
58+
return sha256.Sum256(chunk)
59+
}
60+
4261
func MarshalBatchDataHeader(
4362
start uint64,
4463
end uint64,
4564
checksums [][]byte,
4665
) []byte {
4766
data := make([]byte, 1)
4867
data[0] = byte(BatchDataTypeHeader)
49-
data = binary.AppendUvarint(data, start)
50-
data = binary.AppendUvarint(data, end)
51-
data = binary.AppendUvarint(data, uint64(len(checksums)))
68+
data = binary.BigEndian.AppendUint64(data, start)
69+
data = binary.BigEndian.AppendUint64(data, end)
70+
data = binary.BigEndian.AppendUint64(data, uint64(len(checksums)))
5271
for _, checksum := range checksums {
5372
data = append(data, checksum...)
5473
}
5574
return data
5675
}
5776

58-
func UnmarshalBatchDataHeader(data []byte) (start uint64, end uint64, length uint64, checksums [][]byte, err error) {
77+
func UnmarshalBatchDataHeader(data []byte) (BatchDataHeader, error) {
5978
if len(data) < 25 {
60-
err = errors.New("invalid data length")
61-
return
79+
err := fmt.Errorf("invalid data length: %d, expected > 25", len(data))
80+
return BatchDataHeader{}, err
6281
}
63-
start, _ = binary.Uvarint(data[1:9])
64-
end, _ = binary.Uvarint(data[9:17])
65-
length, _ = binary.Uvarint(data[17:25])
66-
checksums = make([][]byte, 0, length)
67-
68-
if len(data)-25%32 != 0 || (uint64(len(data)-25)/32) != length {
69-
err = errors.New("invalid checksum data")
70-
return
82+
start := binary.BigEndian.Uint64(data[1:9])
83+
end := binary.BigEndian.Uint64(data[9:17])
84+
if start > end {
85+
return BatchDataHeader{}, fmt.Errorf("invalid start: %d, end: %d", start, end)
7186
}
7287

88+
length := binary.BigEndian.Uint64(data[17:25])
89+
if (len(data)-25)%32 != 0 || (uint64(len(data)-25)/32) != length {
90+
err := fmt.Errorf("invalid checksum length: %d, data length: %d", length, len(data)-25)
91+
return BatchDataHeader{}, err
92+
}
93+
94+
checksums := make([][]byte, 0, length)
7395
for i := 25; i < len(data); i += 32 {
7496
checksums = append(checksums, data[i:i+32])
7597
}
76-
return
98+
99+
return BatchDataHeader{
100+
Start: start,
101+
End: end,
102+
Checksums: checksums,
103+
}, nil
77104
}
78105

79106
func MarshalBatchDataChunk(
@@ -85,23 +112,33 @@ func MarshalBatchDataChunk(
85112
) []byte {
86113
data := make([]byte, 1)
87114
data[0] = byte(BatchDataTypeChunk)
88-
data = binary.AppendUvarint(data, start)
89-
data = binary.AppendUvarint(data, end)
90-
data = binary.AppendUvarint(data, index)
91-
data = binary.AppendUvarint(data, length)
115+
data = binary.BigEndian.AppendUint64(data, start)
116+
data = binary.BigEndian.AppendUint64(data, end)
117+
data = binary.BigEndian.AppendUint64(data, index)
118+
data = binary.BigEndian.AppendUint64(data, length)
92119
data = append(data, chunkData...)
93120
return data
94121
}
95122

96-
func UnmarshalBatchDataChunk(data []byte) (start uint64, end uint64, index uint64, length uint64, chunkData []byte, err error) {
123+
func UnmarshalBatchDataChunk(data []byte) (BatchDataChunk, error) {
97124
if len(data) < 33 {
98-
err = errors.New("invalid data length")
99-
return
125+
err := fmt.Errorf("invalid data length: %d, expected > 33", len(data))
126+
return BatchDataChunk{}, err
100127
}
101-
start, _ = binary.Uvarint(data[1:9])
102-
end, _ = binary.Uvarint(data[9:17])
103-
index, _ = binary.Uvarint(data[17:25])
104-
length, _ = binary.Uvarint(data[25:33])
105-
chunkData = data[33:]
106-
return
128+
start := binary.BigEndian.Uint64(data[1:9])
129+
end := binary.BigEndian.Uint64(data[9:17])
130+
if start > end {
131+
return BatchDataChunk{}, fmt.Errorf("invalid start: %d, end: %d", start, end)
132+
}
133+
index := binary.BigEndian.Uint64(data[17:25])
134+
length := binary.BigEndian.Uint64(data[25:33])
135+
chunkData := data[33:]
136+
137+
return BatchDataChunk{
138+
Start: start,
139+
End: end,
140+
Index: index,
141+
Length: length,
142+
ChunkData: chunkData,
143+
}, nil
107144
}

executor/types/batch_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package types
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
)
8+
9+
func TestBatchDataHeader(t *testing.T) {
10+
start := uint64(1)
11+
end := uint64(100)
12+
13+
chunks := [][]byte{
14+
[]byte("chunk1"),
15+
[]byte("chunk2"),
16+
[]byte("chunk3"),
17+
}
18+
19+
checksums := make([][]byte, 0, len(chunks))
20+
for _, chunk := range chunks {
21+
checksum := GetChecksumFromChunk(chunk)
22+
checksums = append(checksums, checksum[:])
23+
}
24+
25+
headerData := MarshalBatchDataHeader(
26+
start,
27+
end,
28+
checksums)
29+
require.Equal(t, 1+8+8+8+3*32, len(headerData))
30+
31+
header, err := UnmarshalBatchDataHeader(headerData)
32+
require.NoError(t, err)
33+
34+
require.Equal(t, start, header.Start)
35+
require.Equal(t, end, header.End)
36+
require.Equal(t, checksums, header.Checksums)
37+
require.Equal(t, len(chunks), len(header.Checksums))
38+
}

0 commit comments

Comments
 (0)