Skip to content

Commit

Permalink
Merge pull request #7007 from onflow/illia-malachyn/7006-fix-subscrip…
Browse files Browse the repository at this point in the history
…tion-id-validation

Generate random string with maxLen symbols instead of UUID
  • Loading branch information
peterargue authored Feb 13, 2025
2 parents 77f4db5 + 28d1c9f commit 4f1be89
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 52 deletions.
11 changes: 8 additions & 3 deletions engine/access/rest/websockets/subscription_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package websockets
import (
"fmt"

"github.com/google/uuid"
randutils "github.com/onflow/flow-go/utils/rand"
)

const maxLen = 20
Expand All @@ -17,12 +17,17 @@ type SubscriptionID struct {
}

// NewSubscriptionID creates a new SubscriptionID based on the provided input.
// - If the input `id` is empty, a new UUID is generated and returned.
// - If the input `id` is empty, a random ID is generated and returned.
// - If the input `id` is non-empty, it is validated and returned if no errors.
func NewSubscriptionID(id string) (SubscriptionID, error) {
if len(id) == 0 {
randomString, err := randutils.GenerateRandomString(maxLen)
if err != nil {
return SubscriptionID{}, fmt.Errorf("could not generate subscription ID: %w", err)
}

return SubscriptionID{
id: uuid.New().String(),
id: randomString,
}, nil
}

Expand Down
3 changes: 1 addition & 2 deletions engine/access/rest/websockets/subscription_id_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"fmt"
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)

Expand All @@ -14,7 +13,7 @@ func TestNewSubscriptionID(t *testing.T) {

assert.NoError(t, err)
assert.NotEmpty(t, subscriptionID.id)
assert.NoError(t, uuid.Validate(subscriptionID.id), "Generated ID should be a valid UUID")
assert.Len(t, subscriptionID.id, maxLen)
})

t.Run("should return valid SubscriptionID when input ID is valid", func(t *testing.T) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
p2pmsg "github.com/onflow/flow-go/network/p2p/message"
mockp2p "github.com/onflow/flow-go/network/p2p/mock"
p2ptest "github.com/onflow/flow-go/network/p2p/test"
"github.com/onflow/flow-go/utils/rand"
"github.com/onflow/flow-go/utils/unittest"
)

Expand Down Expand Up @@ -116,12 +117,13 @@ func TestValidationInspector_InvalidTopicId_Detection(t *testing.T) {
// create unknown topic
unknownTopic := channels.Topic(fmt.Sprintf("%s/%s", p2ptest.GossipSubTopicIdFixture(), sporkID))
// create malformed topic
malformedTopic := channels.Topic(unittest.RandomStringFixture(t, 100))
malformedTopic, err := rand.GenerateRandomString(100)
require.NoError(t, err)
// a topics spork ID is considered invalid if it does not match the current spork ID
invalidSporkIDTopic := channels.Topic(fmt.Sprintf("%s/%s", channels.PushBlocks, unittest.IdentifierFixture()))

// set topic oracle to return list with all topics to avoid hasSubscription failures and force topic validation
topicProvider.UpdateTopics([]string{unknownTopic.String(), malformedTopic.String(), invalidSporkIDTopic.String()})
topicProvider.UpdateTopics([]string{unknownTopic.String(), malformedTopic, invalidSporkIDTopic.String()})

validationInspector.Start(signalerCtx)
nodes := []p2p.LibP2PNode{victimNode, spammer.SpammerNode}
Expand All @@ -131,15 +133,15 @@ func TestValidationInspector_InvalidTopicId_Detection(t *testing.T) {

// prepare to spam - generate control messages
graftCtlMsgsWithUnknownTopic := spammer.GenerateCtlMessages(int(controlMessageCount), p2ptest.WithGraft(messageCount, unknownTopic.String()))
graftCtlMsgsWithMalformedTopic := spammer.GenerateCtlMessages(int(controlMessageCount), p2ptest.WithGraft(messageCount, malformedTopic.String()))
graftCtlMsgsWithMalformedTopic := spammer.GenerateCtlMessages(int(controlMessageCount), p2ptest.WithGraft(messageCount, malformedTopic))
graftCtlMsgsInvalidSporkIDTopic := spammer.GenerateCtlMessages(int(controlMessageCount), p2ptest.WithGraft(messageCount, invalidSporkIDTopic.String()))

pruneCtlMsgsWithUnknownTopic := spammer.GenerateCtlMessages(int(controlMessageCount), p2ptest.WithPrune(messageCount, unknownTopic.String()))
pruneCtlMsgsWithMalformedTopic := spammer.GenerateCtlMessages(int(controlMessageCount), p2ptest.WithPrune(messageCount, malformedTopic.String()))
pruneCtlMsgsWithMalformedTopic := spammer.GenerateCtlMessages(int(controlMessageCount), p2ptest.WithPrune(messageCount, malformedTopic))
pruneCtlMsgsInvalidSporkIDTopic := spammer.GenerateCtlMessages(int(controlMessageCount), p2ptest.WithPrune(messageCount, invalidSporkIDTopic.String()))

iHaveCtlMsgsWithUnknownTopic := spammer.GenerateCtlMessages(int(controlMessageCount), p2ptest.WithIHave(messageCount, 1000, unknownTopic.String()))
iHaveCtlMsgsWithMalformedTopic := spammer.GenerateCtlMessages(int(controlMessageCount), p2ptest.WithIHave(messageCount, 1000, malformedTopic.String()))
iHaveCtlMsgsWithMalformedTopic := spammer.GenerateCtlMessages(int(controlMessageCount), p2ptest.WithIHave(messageCount, 1000, malformedTopic))
iHaveCtlMsgsInvalidSporkIDTopic := spammer.GenerateCtlMessages(int(controlMessageCount), p2ptest.WithIHave(messageCount, 1000, invalidSporkIDTopic.String()))

// spam the victim peer with invalid graft messages
Expand Down Expand Up @@ -849,7 +851,8 @@ func TestValidationInspector_InspectRpcPublishMessages(t *testing.T) {
// create unknown topic
unknownTopic := channels.Topic(fmt.Sprintf("%s/%s", p2ptest.GossipSubTopicIdFixture(), sporkID)).String()
// create malformed topic
malformedTopic := channels.Topic(unittest.RandomStringFixture(t, 100)).String()
malformedTopic, err := rand.GenerateRandomString(100)
require.NoError(t, err)
// a topics spork ID is considered invalid if it does not match the current spork ID
invalidSporkIDTopic := channels.Topic(fmt.Sprintf("%s/%s", channels.PushBlocks, unittest.IdentifierFixture())).String()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
p2pmsg "github.com/onflow/flow-go/network/p2p/message"
mockp2p "github.com/onflow/flow-go/network/p2p/mock"
p2ptest "github.com/onflow/flow-go/network/p2p/test"
randutils "github.com/onflow/flow-go/utils/rand"
"github.com/onflow/flow-go/utils/unittest"
)

Expand Down Expand Up @@ -1711,12 +1712,15 @@ func TestControlMessageValidationInspector_InspectionConfigToggle(t *testing.T)
rpcTracker.On("WasIHaveRPCSent", mock.AnythingOfType("string")).Return(true).Maybe()
inspector.Start(signalerCtx)

topic, err := randutils.GenerateRandomString(100)
require.NoError(t, err)

rpc := unittest.P2PRPCFixture(
unittest.WithGrafts(unittest.P2PRPCGraftFixtures(unittest.IdentifierListFixture(numOfMsgs).Strings()...)...),
unittest.WithPrunes(unittest.P2PRPCPruneFixtures(unittest.IdentifierListFixture(numOfMsgs).Strings()...)...),
unittest.WithIHaves(unittest.P2PRPCIHaveFixtures(numOfMsgs, unittest.IdentifierListFixture(numOfMsgs).Strings()...)...),
unittest.WithIWants(unittest.P2PRPCIWantFixtures(numOfMsgs, numOfMsgs)...),
unittest.WithPubsubMessages(unittest.GossipSubMessageFixtures(numOfMsgs, unittest.RandomStringFixture(t, 100), unittest.WithFrom(unittest.PeerIdFixture(t)))...),
unittest.WithPubsubMessages(unittest.GossipSubMessageFixtures(numOfMsgs, topic, unittest.WithFrom(unittest.PeerIdFixture(t)))...),
)

from := unittest.PeerIdFixture(t)
Expand All @@ -1740,7 +1744,8 @@ func invalidTopics(t *testing.T, sporkID flow.Identifier) (string, string, strin
// create unknown topic
unknownTopic := channels.Topic(fmt.Sprintf("%s/%s", unittest.IdentifierFixture(), sporkID)).String()
// create malformed topic
malformedTopic := channels.Topic(unittest.RandomStringFixture(t, 100)).String()
malformedTopic, err := randutils.GenerateRandomString(100)
require.NoError(t, err)
// a topics spork ID is considered invalid if it does not match the current spork ID
invalidSporkIDTopic := channels.Topic(fmt.Sprintf("%s/%s", channels.PushBlocks, unittest.IdentifierFixture())).String()
return unknownTopic, malformedTopic, invalidSporkIDTopic
Expand Down
7 changes: 5 additions & 2 deletions network/p2p/test/fixtures.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import (
"github.com/onflow/flow-go/network/p2p/utils"
validator "github.com/onflow/flow-go/network/validator/pubsub"
"github.com/onflow/flow-go/utils/logging"
randutils "github.com/onflow/flow-go/utils/rand"
"github.com/onflow/flow-go/utils/unittest"
)

Expand Down Expand Up @@ -863,7 +864,8 @@ func GossipSubRpcFixture(t *testing.T, msgCnt int, opts ...GossipSubCtrlOption)
subscriptions := make([]*pb.RPC_SubOpts, numSubscriptions)
for i := 0; i < numSubscriptions; i++ {
subscribe := rand.Intn(2) == 1
topicID := unittest.RandomStringFixture(t, topicIdSize)
topicID, err := randutils.GenerateRandomString(topicIdSize)
require.NoError(t, err)
subscriptions[i] = &pb.RPC_SubOpts{
Subscribe: &subscribe,
Topicid: &topicID,
Expand Down Expand Up @@ -1030,7 +1032,8 @@ func GossipSubMessageIdsFixture(count int) []string {
// Note: the message is not signed.
func GossipSubMessageFixture(t *testing.T) *pb.Message {
byteSize := 100
topic := unittest.RandomStringFixture(t, byteSize)
topic, err := randutils.GenerateRandomString(byteSize)
require.NoError(t, err)
return &pb.Message{
From: unittest.RandomBytes(byteSize),
Data: unittest.RandomBytes(byteSize),
Expand Down
5 changes: 4 additions & 1 deletion network/p2p/test/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import (

pb "github.com/libp2p/go-libp2p-pubsub/pb"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/stretchr/testify/require"

"github.com/onflow/flow-go/utils/rand"
"github.com/onflow/flow-go/utils/unittest"
)

Expand Down Expand Up @@ -47,7 +49,8 @@ func WithoutSignerId() func(*pb.Message) {
// Returns:
// *pb.Message: pubsub message
func PubsubMessageFixture(t *testing.T, opts ...func(*pb.Message)) *pb.Message {
topic := unittest.RandomStringFixture(t, 10)
topic, err := rand.GenerateRandomString(10)
require.NoError(t, err)

m := &pb.Message{
Data: unittest.RandomByteSlice(t, 100),
Expand Down
32 changes: 32 additions & 0 deletions utils/rand/rand.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ package rand

import (
"crypto/rand"
"encoding/base64"
"encoding/binary"
"fmt"
)
Expand Down Expand Up @@ -167,3 +168,34 @@ func Samples(n uint, m uint, swap func(i, j uint)) error {
}
return nil
}

// GenerateRandomString generates a cryptographically secure random string of size n.
// n must be > 0
func GenerateRandomString(length int) (string, error) {
if length <= 0 {
return "", fmt.Errorf("length should greater than 0, got %d", length)
}

// The base64 encoding uses 64 different characters to represent data in
// strings, which makes it possible to represent 6 bits of data with each
// character (as 2^6 is 64). This means that every 3 bytes (24 bits) of
// input data will be represented by 4 characters (4 * 6 bits) in the
// base64 encoding. Consequently, base64 encoding increases the size of
// the data by approximately 1/3 compared to the original input data.
//
// 1. (n+3) / 4 - This calculates how many groups of 4 characters are needed
// in the base64 encoded output to represent at least 'n' characters.
// The +3 ensures rounding up, as integer division truncates the result.
//
// 2. ... * 3 - Each group of 4 base64 characters represents 3 bytes
// of input data. This multiplication calculates the number of bytes
// needed to produce the required length of the base64 string.
byteSlice := make([]byte, (length+3)/4*3)
_, err := rand.Read(byteSlice)
if err != nil {
return "", fmt.Errorf("failed to generate random string: %w", err)
}

encodedString := base64.URLEncoding.EncodeToString(byteSlice)
return encodedString[:length], nil
}
10 changes: 10 additions & 0 deletions utils/rand/rand_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,13 @@ func TestSamples(t *testing.T) {
assert.Equal(t, constant, fullSlice)
})
}

func TestRandomString(t *testing.T) {
t.Run("basic random string", func(t *testing.T) {
length := 32
str, err := GenerateRandomString(length)
require.NoError(t, err)
t.Logf("string: %s", str)
require.Equal(t, length, len(str))
})
}
36 changes: 0 additions & 36 deletions utils/unittest/strings.go

This file was deleted.

0 comments on commit 4f1be89

Please sign in to comment.