Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Variable length encoding #86

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ type handler struct {
cache Cache
store Store
limitCounter httprate.LimitCounter
keyVersion int
keyVersion byte
}

var _ proto.QuotaControl = &handler{}
Expand Down
1 change: 0 additions & 1 deletion proto/access_key_test.go

This file was deleted.

88 changes: 74 additions & 14 deletions proto/internal/encoding/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@ import (
"crypto/rand"
"encoding/binary"
"fmt"
"math"

"github.com/goware/base64"
"github.com/jxskiss/base62"
)

var (
ErrInvalidKeyLength = fmt.Errorf("invalid access key length")
ErrVersionMismatch = fmt.Errorf("version mismatch")
)

type Encoding interface {
Version() int
Version() byte
Encode(projectID uint64, ecosystemID uint64) string
Decode(accessKey string) (projectID uint64, ecosystemID uint64, err error)
}
Expand All @@ -28,7 +30,7 @@ const (
// V0 is the v0 encoding format for project access keys
type V0 struct{}

func (V0) Version() int { return 0 }
func (V0) Version() byte { return 0 }

func (V0) Encode(projectID uint64, ecosystemID uint64) string {
buf := make([]byte, sizeV0)
Expand All @@ -50,47 +52,105 @@ func (V0) Decode(accessKey string) (projectID uint64, ecosystemID uint64, err er

type V1 struct{}

func (V1) Version() int { return 1 }
func (V1) Version() byte { return 1 }

func (V1) Encode(projectID uint64, ecosystemID uint64) string {
func (v V1) Encode(projectID uint64, ecosystemID uint64) string {
buf := make([]byte, sizeV1)
buf[0] = byte(1)
buf[0] = v.Version()
binary.BigEndian.PutUint64(buf[1:], projectID)
rand.Read(buf[9:])
return base64.Base64UrlEncode(buf)
}

func (V1) Decode(accessKey string) (projectID uint64, ecosystemID uint64, err error) {
func (v V1) Decode(accessKey string) (projectID uint64, ecosystemID uint64, err error) {
buf, err := base64.Base64UrlDecode(accessKey)
if err != nil {
return 0, 0, fmt.Errorf("base64 decode: %w", err)
}
if len(buf) != sizeV1 {
return 0, 0, ErrInvalidKeyLength
}
if buf[0] != v.Version() {
return 0, 0, ErrVersionMismatch
}
return binary.BigEndian.Uint64(buf[1:9]), 0, nil
}

type V2 struct{}

func (V2) Version() int { return 2 }
func (V2) Version() byte { return 2 }

func (V2) Encode(projectID uint64, ecosystemID uint64) string {
func (v V2) Encode(projectID uint64, ecosystemID uint64) string {
buf := make([]byte, sizeV2)
buf[0] = byte(2)
binary.BigEndian.PutUint64(buf[1:], projectID)
binary.BigEndian.PutUint64(buf[9:], ecosystemID)
rand.Read(buf[17:])
buf[0] = v.Version()

encodedProjectID := encodeUint64(projectID)
encodedEcosystemID := encodeUint64(ecosystemID)
buf[1] = byte(len(encodedProjectID)) + (byte(len(encodedEcosystemID) << 4))
copy(buf[2:], encodedProjectID)
copy(buf[2+len(encodedProjectID):], encodedEcosystemID)

rand.Read(buf[2+len(encodedProjectID)+len(encodedEcosystemID):])

return base64.Base64UrlEncode(buf)
}

func (V2) Decode(accessKey string) (projectID uint64, ecosystemID uint64, err error) {
func (v V2) Decode(accessKey string) (projectID uint64, ecosystemID uint64, err error) {
buf, err := base64.Base64UrlDecode(accessKey)
if err != nil {
return 0, 0, fmt.Errorf("base64 decode: %w", err)
}
if len(buf) != sizeV2 {
return 0, 0, ErrInvalidKeyLength
}
return binary.BigEndian.Uint64(buf[1:9]), binary.BigEndian.Uint64(buf[9:17]), nil
if buf[0] != v.Version() {
return 0, 0, fmt.Errorf("version mismatch")
}

projectLength := buf[1] & 0x0f
ecosystemLength := buf[1] >> 4

if projectID, err = decodeUint64(buf[2 : 2+projectLength]); err != nil {
return 0, 0, fmt.Errorf("decode projectID: %w", err)
}

if ecosystemID, err = decodeUint64(buf[2+projectLength : 2+projectLength+ecosystemLength]); err != nil {
return 0, 0, fmt.Errorf("decode ecosystemID: %w", err)
}

return projectID, ecosystemID, nil
}

func encodeUint64(n uint64) []byte {
switch {
case n <= math.MaxUint8:
return []byte{byte(n)}
case n <= math.MaxUint16:
buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, uint16(n))
return buf
case n <= math.MaxUint32:
buf := make([]byte, 4)
binary.BigEndian.PutUint32(buf, uint32(n))
return buf
default:
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, uint64(n))
return buf
}
}

func decodeUint64(buf []byte) (uint64, error) {
switch len(buf) {
case 1:
return uint64(buf[0]), nil
case 2:
return uint64(binary.BigEndian.Uint16(buf)), nil
case 4:
return uint64(binary.BigEndian.Uint32(buf)), nil
case 8:
return uint64(binary.BigEndian.Uint64(buf)), nil
default:
return 0, fmt.Errorf("invalid uint64 length")
}
}
11 changes: 7 additions & 4 deletions proto/proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func Ptr[T any](v T) *T {
return &v
}

var supportedEncodings = []encoding.Encoding{
var SupportedEncodings = []encoding.Encoding{
encoding.V0{},
encoding.V1{},
encoding.V2{},
Expand All @@ -25,19 +25,22 @@ var AccessKeyVersion = encoding.V2{}.Version()

func GetProjectID(accessKey string) (projectID, ecosystemID uint64, err error) {
var errs []error
for _, e := range supportedEncodings {
for i := len(SupportedEncodings) - 1; i >= 0; i-- {
e := SupportedEncodings[i]

projectID, ecosystemID, err := e.Decode(accessKey)
if err != nil {
errs = append(errs, fmt.Errorf("decode v%d: %w", e.Version(), err))
continue
}

return projectID, ecosystemID, nil
}
return 0, 0, errors.Join(errs...)
}

func GenerateAccessKey(version int, projectID, ecosystemID uint64) string {
for _, e := range supportedEncodings {
func GenerateAccessKey(version byte, projectID, ecosystemID uint64) string {
for _, e := range SupportedEncodings {
if e.Version() == version {
return e.Encode(projectID, ecosystemID)
}
Expand Down
59 changes: 29 additions & 30 deletions proto/proto_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package proto_test

import (
"fmt"
"math"
"testing"

"github.com/0xsequence/quotacontrol/proto"
Expand All @@ -10,37 +12,34 @@ import (
)

func TestAccessKeyEncoding(t *testing.T) {
t.Run("v0", func(t *testing.T) {
projectID := uint64(12345)
accessKey := proto.GenerateAccessKey(0, projectID, 0)
t.Log("=> k", accessKey)

outID, outecosystemID, err := proto.GetProjectID(accessKey)
require.NoError(t, err)
require.Equal(t, projectID, outID)
require.Equal(t, uint64(0), outecosystemID)
})

t.Run("v1", func(t *testing.T) {
projectID := uint64(12345)
accessKey := proto.GenerateAccessKey(1, projectID, 0)
t.Log("=> k", accessKey)
outID, ecosystemID, err := proto.GetProjectID(accessKey)
require.NoError(t, err)
require.Equal(t, projectID, outID)
require.Equal(t, uint64(0), ecosystemID)
})
t.Run("v1", func(t *testing.T) {
projectID := uint64(12345)
ecosystemID := uint64(54321)
accessKey := proto.GenerateAccessKey(2, projectID, ecosystemID)
t.Log("=> k", accessKey)
inputList := [][]uint64{
{1, 2},
{127, 128},
{math.MaxUint8, math.MaxUint8 + 1},
{math.MaxUint16, math.MaxUint16 + 1},
{math.MaxUint32, math.MaxUint32 + 1},
{math.MaxUint64, 1},
}
for _, e := range proto.SupportedEncodings {
version := e.Version()
t.Run(fmt.Sprintf("v%d", version), func(t *testing.T) {
for _, input := range inputList {
projectID := input[0]
ecosystemID := input[1]
accessKey := proto.GenerateAccessKey(version, projectID, ecosystemID)
t.Logf("=> key: [%d/%d] %s", projectID, ecosystemID, accessKey)

outID, outecosystemID, err := proto.GetProjectID(accessKey)
require.NoError(t, err)
require.Equal(t, projectID, outID)
require.Equal(t, ecosystemID, outecosystemID)
})
outID, outecosystemID, err := proto.GetProjectID(accessKey)
require.NoError(t, err)
require.Equal(t, projectID, outID)
if version < 2 {
require.Equal(t, uint64(0), outecosystemID)
} else {
require.Equal(t, ecosystemID, outecosystemID)
}
}
})
}
}

func TestAccessKeyValidateOrigin(t *testing.T) {
Expand Down
Loading