Skip to content

Commit 8294769

Browse files
committed
chore: Export msgpack ext encode/decode functions
Fixes #420.
1 parent 10b3c8f commit 8294769

File tree

6 files changed

+78
-60
lines changed

6 files changed

+78
-60
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release.
1212

1313
### Changed
1414

15+
- Made MessagePack extension encoding and decoding functions exportable,
16+
allowing users to reuse the logic for custom extensions (#421).
17+
1518
### Fixed
1619

1720
## [v2.2.0] - 2024-12-16

arrow/arrow.go

+17-15
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ import (
77
"github.com/vmihailenco/msgpack/v5"
88
)
99

10-
// Arrow MessagePack extension type.
11-
const arrowExtId = 8
10+
// ExtID represents the Arrow MessagePack extension type identifier.
11+
const ExtID = 8
1212

1313
// Arrow struct wraps a raw arrow data buffer.
1414
type Arrow struct {
@@ -26,31 +26,33 @@ func (a Arrow) Raw() []byte {
2626
return a.data
2727
}
2828

29-
func arrowDecoder(d *msgpack.Decoder, v reflect.Value, extLen int) error {
29+
// EncodeExt encodes an Arrow into a MessagePack extension.
30+
func EncodeExt(_ *msgpack.Encoder, v reflect.Value) ([]byte, error) {
31+
arr, ok := v.Interface().(Arrow)
32+
if !ok {
33+
return []byte{}, fmt.Errorf("encode: not an Arrow type")
34+
}
35+
return arr.data, nil
36+
}
37+
38+
// DecodeExt decodes a MessagePack extension into an Arrow.
39+
func DecodeExt(d *msgpack.Decoder, v reflect.Value, extLen int) error {
3040
arrow := Arrow{
3141
data: make([]byte, extLen),
3242
}
3343
n, err := d.Buffered().Read(arrow.data)
3444
if err != nil {
35-
return fmt.Errorf("arrowDecoder: can't read bytes on Arrow decode: %w", err)
45+
return fmt.Errorf("decode: can't read bytes on Arrow decode: %w", err)
3646
}
3747
if n < extLen || n != len(arrow.data) {
38-
return fmt.Errorf("arrowDecoder: unexpected end of stream after %d Arrow bytes", n)
48+
return fmt.Errorf("decode: unexpected end of stream after %d Arrow bytes", n)
3949
}
4050

4151
v.Set(reflect.ValueOf(arrow))
4252
return nil
4353
}
4454

45-
func arrowEncoder(e *msgpack.Encoder, v reflect.Value) ([]byte, error) {
46-
arr, ok := v.Interface().(Arrow)
47-
if !ok {
48-
return []byte{}, fmt.Errorf("arrowEncoder: not an Arrow type")
49-
}
50-
return arr.data, nil
51-
}
52-
5355
func init() {
54-
msgpack.RegisterExtDecoder(arrowExtId, Arrow{}, arrowDecoder)
55-
msgpack.RegisterExtEncoder(arrowExtId, Arrow{}, arrowEncoder)
56+
msgpack.RegisterExtEncoder(ExtID, Arrow{}, EncodeExt)
57+
msgpack.RegisterExtDecoder(ExtID, Arrow{}, DecodeExt)
5658
}

datetime/datetime.go

+10-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Package with support of Tarantool's datetime data type.
1+
// Package datetime provides support for Tarantool's datetime data type.
22
//
33
// Datetime data type supported in Tarantool since 2.10.
44
//
@@ -34,9 +34,10 @@ import (
3434
// * [optional] all the other fields (nsec, tzoffset, tzindex) if any of them
3535
// were having not 0 value. They are packed naturally in little-endian order;
3636

37-
// Datetime external type. Supported since Tarantool 2.10. See more details in
37+
// ExtID represents the Datetime MessagePack extension type identifier.
38+
// Supported since Tarantool 2.10. See more details in
3839
// issue https://github.com/tarantool/tarantool/issues/5946.
39-
const datetimeExtID = 4
40+
const ExtID = 4
4041

4142
// datetime structure keeps a number of seconds and nanoseconds since Unix Epoch.
4243
// Time is normalized by UTC, so time-zone offset is informative only.
@@ -242,7 +243,8 @@ func (d *Datetime) ToTime() time.Time {
242243
return d.time
243244
}
244245

245-
func datetimeEncoder(e *msgpack.Encoder, v reflect.Value) ([]byte, error) {
246+
// EncodeExt encodes a Datetime into a MessagePack extension.
247+
func EncodeExt(_ *msgpack.Encoder, v reflect.Value) ([]byte, error) {
246248
dtime := v.Interface().(Datetime)
247249
tm := dtime.ToTime()
248250

@@ -275,7 +277,8 @@ func datetimeEncoder(e *msgpack.Encoder, v reflect.Value) ([]byte, error) {
275277
return buf, nil
276278
}
277279

278-
func datetimeDecoder(d *msgpack.Decoder, v reflect.Value, extLen int) error {
280+
// DecodeExt decodes a MessagePack extension into a Datetime.
281+
func DecodeExt(d *msgpack.Decoder, v reflect.Value, extLen int) error {
279282
if extLen != maxSize && extLen != secondsSize {
280283
return fmt.Errorf("invalid data length: got %d, wanted %d or %d",
281284
extLen, secondsSize, maxSize)
@@ -333,6 +336,6 @@ func datetimeDecoder(d *msgpack.Decoder, v reflect.Value, extLen int) error {
333336
}
334337

335338
func init() {
336-
msgpack.RegisterExtDecoder(datetimeExtID, Datetime{}, datetimeDecoder)
337-
msgpack.RegisterExtEncoder(datetimeExtID, Datetime{}, datetimeEncoder)
339+
msgpack.RegisterExtEncoder(ExtID, Datetime{}, EncodeExt)
340+
msgpack.RegisterExtDecoder(ExtID, Datetime{}, DecodeExt)
338341
}

datetime/interval.go

+21-16
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ import (
88
"github.com/vmihailenco/msgpack/v5"
99
)
1010

11-
const interval_extId = 6
11+
// IntervalExtID represents the Interval MessagePack extension type identifier.
12+
const IntervalExtID = 6
1213

1314
const (
1415
fieldYear = 0
@@ -74,7 +75,7 @@ func encodeIntervalValue(e *msgpack.Encoder, typ uint64, value int64) (err error
7475
if err == nil {
7576
if value > 0 {
7677
err = e.EncodeUint(uint64(value))
77-
} else if value < 0 {
78+
} else {
7879
err = e.EncodeInt(value)
7980
}
8081
}
@@ -181,20 +182,24 @@ func decodeInterval(d *msgpack.Decoder, v reflect.Value) (err error) {
181182
return nil
182183
}
183184

184-
func init() {
185-
msgpack.RegisterExtEncoder(interval_extId, Interval{},
186-
func(e *msgpack.Encoder, v reflect.Value) (ret []byte, err error) {
187-
var b bytes.Buffer
185+
// EncodeIntervalExt encodes an Interval into a MessagePack extension.
186+
func EncodeIntervalExt(_ *msgpack.Encoder, v reflect.Value) (ret []byte, err error) {
187+
var b bytes.Buffer
188188

189-
enc := msgpack.NewEncoder(&b)
190-
if err = encodeInterval(enc, v); err == nil {
191-
ret = b.Bytes()
192-
}
189+
enc := msgpack.NewEncoder(&b)
190+
if err = encodeInterval(enc, v); err == nil {
191+
ret = b.Bytes()
192+
}
193193

194-
return
195-
})
196-
msgpack.RegisterExtDecoder(interval_extId, Interval{},
197-
func(d *msgpack.Decoder, v reflect.Value, extLen int) error {
198-
return decodeInterval(d, v)
199-
})
194+
return
195+
}
196+
197+
// DecodeIntervalExt decodes a MessagePack extension into an Interval.
198+
func DecodeIntervalExt(d *msgpack.Decoder, v reflect.Value, _ int) error {
199+
return decodeInterval(d, v)
200+
}
201+
202+
func init() {
203+
msgpack.RegisterExtEncoder(IntervalExtID, Interval{}, EncodeIntervalExt)
204+
msgpack.RegisterExtDecoder(IntervalExtID, Interval{}, DecodeIntervalExt)
200205
}

decimal/decimal.go

+11-10
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Package decimal with support of Tarantool's decimal data type.
1+
// Package decimal provides support for Tarantool's decimal data type.
22
//
33
// Decimal data type supported in Tarantool since 2.2.
44
//
@@ -37,11 +37,10 @@ import (
3737
// - Tarantool module decimal:
3838
// https://www.tarantool.io/en/doc/latest/reference/reference_lua/decimal/
3939

40-
const (
41-
// Decimal external type.
42-
decimalExtID = 1
43-
decimalPrecision = 38
44-
)
40+
// ExtID represents the Decimal MessagePack extension type identifier.
41+
const ExtID = 1
42+
43+
const decimalPrecision = 38
4544

4645
var (
4746
one decimal.Decimal = decimal.NewFromInt(1)
@@ -71,7 +70,8 @@ func MakeDecimalFromString(src string) (Decimal, error) {
7170
return result, nil
7271
}
7372

74-
func decimalEncoder(e *msgpack.Encoder, v reflect.Value) ([]byte, error) {
73+
// EncodeExt encodes a Decimal into a MessagePack extension.
74+
func EncodeExt(_ *msgpack.Encoder, v reflect.Value) ([]byte, error) {
7575
dec := v.Interface().(Decimal)
7676
if dec.GreaterThan(maxSupportedDecimal) {
7777
return nil,
@@ -94,7 +94,8 @@ func decimalEncoder(e *msgpack.Encoder, v reflect.Value) ([]byte, error) {
9494
return bcdBuf, nil
9595
}
9696

97-
func decimalDecoder(d *msgpack.Decoder, v reflect.Value, extLen int) error {
97+
// DecodeExt decodes a MessagePack extension into a Decimal.
98+
func DecodeExt(d *msgpack.Decoder, v reflect.Value, extLen int) error {
9899
b := make([]byte, extLen)
99100
n, err := d.Buffered().Read(b)
100101
if err != nil {
@@ -131,6 +132,6 @@ func decimalDecoder(d *msgpack.Decoder, v reflect.Value, extLen int) error {
131132
}
132133

133134
func init() {
134-
msgpack.RegisterExtDecoder(decimalExtID, Decimal{}, decimalDecoder)
135-
msgpack.RegisterExtEncoder(decimalExtID, Decimal{}, decimalEncoder)
135+
msgpack.RegisterExtEncoder(ExtID, Decimal{}, EncodeExt)
136+
msgpack.RegisterExtDecoder(ExtID, Decimal{}, DecodeExt)
136137
}

uuid/uuid.go

+16-12
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Package with support of Tarantool's UUID data type.
1+
// Package uuid provides support for Tarantool's UUID data type.
22
//
33
// UUID data type supported in Tarantool since 2.4.1.
44
//
@@ -24,8 +24,8 @@ import (
2424
"github.com/vmihailenco/msgpack/v5"
2525
)
2626

27-
// UUID external type.
28-
const uuid_extID = 2
27+
// ExtID represents the UUID MessagePack extension type identifier.
28+
const ExtID = 2
2929

3030
func encodeUUID(e *msgpack.Encoder, v reflect.Value) error {
3131
id := v.Interface().(uuid.UUID)
@@ -64,15 +64,19 @@ func decodeUUID(d *msgpack.Decoder, v reflect.Value) error {
6464
return nil
6565
}
6666

67+
// EncodeExt encodes a UUID into a MessagePack extension.
68+
func EncodeExt(_ *msgpack.Encoder, v reflect.Value) ([]byte, error) {
69+
u := v.Interface().(uuid.UUID)
70+
return u.MarshalBinary()
71+
}
72+
73+
// DecodeExt decodes a MessagePack extension into a UUID.
74+
func DecodeExt(d *msgpack.Decoder, v reflect.Value, _ int) error {
75+
return decodeUUID(d, v)
76+
}
77+
6778
func init() {
6879
msgpack.Register(reflect.TypeOf((*uuid.UUID)(nil)).Elem(), encodeUUID, decodeUUID)
69-
msgpack.RegisterExtEncoder(uuid_extID, uuid.UUID{},
70-
func(e *msgpack.Encoder, v reflect.Value) ([]byte, error) {
71-
uuid := v.Interface().(uuid.UUID)
72-
return uuid.MarshalBinary()
73-
})
74-
msgpack.RegisterExtDecoder(uuid_extID, uuid.UUID{},
75-
func(d *msgpack.Decoder, v reflect.Value, extLen int) error {
76-
return decodeUUID(d, v)
77-
})
80+
msgpack.RegisterExtEncoder(ExtID, uuid.UUID{}, EncodeExt)
81+
msgpack.RegisterExtDecoder(ExtID, uuid.UUID{}, DecodeExt)
7882
}

0 commit comments

Comments
 (0)