From f0fbb416f48ea3e87419746281c0a48ddde4228e Mon Sep 17 00:00:00 2001 From: Tony Stark Date: Thu, 9 Jan 2025 15:25:22 +0530 Subject: [PATCH] fix: validate node for plan in before session creation in subscription module --- x/subscription/expected/keeper.go | 1 + x/subscription/keeper/alias.go | 4 ++++ x/subscription/keeper/msg_handler.go | 4 ++++ x/subscription/types/errors.go | 18 ++++++++++++------ 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/x/subscription/expected/keeper.go b/x/subscription/expected/keeper.go index ea9e91d1..6b489c52 100644 --- a/x/subscription/expected/keeper.go +++ b/x/subscription/expected/keeper.go @@ -35,6 +35,7 @@ type ProviderKeeper interface { } type NodeKeeper interface { + HasNodeForPlan(ctx sdk.Context, id uint64, addr base.NodeAddress) bool GetNode(ctx sdk.Context, addr base.NodeAddress) (nodetypes.Node, bool) } diff --git a/x/subscription/keeper/alias.go b/x/subscription/keeper/alias.go index 7b91f5a7..6ef1a4ba 100644 --- a/x/subscription/keeper/alias.go +++ b/x/subscription/keeper/alias.go @@ -36,6 +36,10 @@ func (k *Keeper) GetNode(ctx sdk.Context, addr base.NodeAddress) (nodetypes.Node return k.node.GetNode(ctx, addr) } +func (k *Keeper) HasNodeForPlan(ctx sdk.Context, id uint64, addr base.NodeAddress) bool { + return k.node.HasNodeForPlan(ctx, id, addr) +} + func (k *Keeper) QuotePriceFunc(ctx sdk.Context, price sdk.DecCoin) (sdk.Coin, error) { return k.oracle.GetQuotePrice(ctx, price) } diff --git a/x/subscription/keeper/msg_handler.go b/x/subscription/keeper/msg_handler.go index 2813be85..6055638a 100644 --- a/x/subscription/keeper/msg_handler.go +++ b/x/subscription/keeper/msg_handler.go @@ -401,6 +401,10 @@ func (k *Keeper) HandleMsgStartSession(ctx sdk.Context, msg *v3.MsgStartSessionR return nil, types.NewErrorInvalidNodeStatus(nodeAddr, node.Status) } + if !k.HasNodeForPlan(ctx, subscription.PlanID, nodeAddr) { + return nil, types.NewErrorNodeForPlanNotFound(subscription.PlanID, nodeAddr) + } + accAddr, err := sdk.AccAddressFromBech32(msg.From) if err != nil { return nil, err diff --git a/x/subscription/types/errors.go b/x/subscription/types/errors.go index 4aa6e6f4..00f4e3e6 100644 --- a/x/subscription/types/errors.go +++ b/x/subscription/types/errors.go @@ -20,12 +20,13 @@ var ( ErrInvalidRenewalPolicy = sdkerrors.Register(ModuleName, 206, "invalid renewal policy") ErrInvalidSessionStatus = sdkerrors.Register(ModuleName, 207, "invalid session status") ErrInvalidSubscriptionStatus = sdkerrors.Register(ModuleName, 208, "invalid subscription status") - ErrNodeNotFound = sdkerrors.Register(ModuleName, 209, "node not found") - ErrPlanNotFound = sdkerrors.Register(ModuleName, 210, "plan not found") - ErrPriceNotFound = sdkerrors.Register(ModuleName, 211, "price not found") - ErrSessionNotFound = sdkerrors.Register(ModuleName, 212, "session not found") - ErrSubscriptionNotFound = sdkerrors.Register(ModuleName, 213, "subscription not found") - ErrUnauthorized = sdkerrors.Register(ModuleName, 214, "unauthorized") + ErrNodeForPlanNotFound = sdkerrors.Register(ModuleName, 209, "node for plan not found") + ErrNodeNotFound = sdkerrors.Register(ModuleName, 210, "node not found") + ErrPlanNotFound = sdkerrors.Register(ModuleName, 211, "plan not found") + ErrPriceNotFound = sdkerrors.Register(ModuleName, 212, "price not found") + ErrSessionNotFound = sdkerrors.Register(ModuleName, 213, "session not found") + ErrSubscriptionNotFound = sdkerrors.Register(ModuleName, 214, "subscription not found") + ErrUnauthorized = sdkerrors.Register(ModuleName, 215, "unauthorized") ) // NewErrorAllocationNotFound returns an error indicating that the specified allocation does not exist. @@ -63,6 +64,11 @@ func NewErrorInvalidSubscriptionStatus(id uint64, status v1base.Status) error { return sdkerrors.Wrapf(ErrInvalidSubscriptionStatus, "invalid status %s for subscription %d", status, id) } +// NewErrorNodeForPlanNotFound returns an error indicating that the specified node does not exist for the plan. +func NewErrorNodeForPlanNotFound(id uint64, addr base.NodeAddress) error { + return sdkerrors.Wrapf(ErrNodeForPlanNotFound, "node %s for plan %d does not exist", addr, id) +} + // NewErrorNodeNotFound returns an error indicating that the specified node does not exist. func NewErrorNodeNotFound(addr base.NodeAddress) error { return sdkerrors.Wrapf(ErrNodeNotFound, "node %s does not exist", addr)