Skip to content

Commit

Permalink
handle case where withdraw address is different from delegator address
Browse files Browse the repository at this point in the history
  • Loading branch information
george-aj committed Sep 4, 2023
1 parent c9749ea commit b6571f3
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 5 deletions.
48 changes: 43 additions & 5 deletions x/compound/client/cli/tx_compound_setting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/temporal-zone/temporal/x/compound/types"
"strconv"
"testing"
"time"

sdkmath "cosmossdk.io/math"
"github.com/cosmos/cosmos-sdk/client/flags"
Expand All @@ -14,6 +15,7 @@ import (
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
"github.com/stretchr/testify/require"

distrcli "github.com/cosmos/cosmos-sdk/x/distribution/client/cli"
"github.com/temporal-zone/temporal/testutil/network"
"github.com/temporal-zone/temporal/x/compound/client/cli"
)
Expand All @@ -33,7 +35,7 @@ func TestCreateCompoundSetting(t *testing.T) {
val := net.Validators[0]
val1 := net.Validators[1]
ctx := val.ClientCtx

tests := []struct {
desc string
valSetting string
Expand All @@ -48,7 +50,7 @@ func TestCreateCompoundSetting(t *testing.T) {
{
desc: "valid 1",
valSetting: fmt.Sprintf("[{\"validatorAddress\":\"%s\",\"percentToCompound\":50}]", val.ValAddress.String()),
amountToRemain: "10utprl",
amountToRemain: "10" + net.Config.BondDenom,
frequency: "111",
compoundValidators: []string{val.ValAddress.String()},

Expand All @@ -65,7 +67,7 @@ func TestCreateCompoundSetting(t *testing.T) {
"[{\"validatorAddress\":\"%s\",\"percentToCompound\":50},{\"validatorAddress\":\"%s\",\"percentToCompound\":50}]",
val.ValAddress.String(),
val1.ValAddress.String()),
amountToRemain: "10utprl",
amountToRemain: "10" + net.Config.BondDenom,
frequency: "111",
compoundValidators: []string{val.ValAddress.String(), val1.ValAddress.String()},

Expand All @@ -82,11 +84,28 @@ func TestCreateCompoundSetting(t *testing.T) {
"[{\"validatorAddress\":\"%s\",\"percentToCompound\":50},{\"validatorAddress\":\"%s\",\"percentToCompound\":50}]",
val.ValAddress.String(),
val.ValAddress.String()),
amountToRemain: "10utprl",
amountToRemain: "10" + net.Config.BondDenom,
frequency: "111",
compoundValidators: []string{val.ValAddress.String(), val.ValAddress.String()},
code: 18,

args: []string{
fmt.Sprintf("--%s=%s", flags.FlagFrom, val.Address.String()),
fmt.Sprintf("--%s=true", flags.FlagSkipConfirmation),
fmt.Sprintf("--%s=%s", flags.FlagBroadcastMode, flags.BroadcastSync),
fmt.Sprintf("--%s=%s", flags.FlagFees, sdk.NewCoins(sdk.NewCoin(net.Config.BondDenom, sdkmath.NewInt(10))).String()),
},
},
{
desc: "withdraw",
valSetting: fmt.Sprintf(
"[{\"validatorAddress\":\"%s\",\"percentToCompound\":50},{\"validatorAddress\":\"%s\",\"percentToCompound\":50}]",
val.ValAddress.String(),
val1.ValAddress.String()),
amountToRemain: "10" + net.Config.BondDenom,
frequency: "5",
compoundValidators: []string{val.ValAddress.String(), val1.ValAddress.String()},

args: []string{
fmt.Sprintf("--%s=%s", flags.FlagFrom, val.Address.String()),
fmt.Sprintf("--%s=true", flags.FlagSkipConfirmation),
Expand Down Expand Up @@ -124,7 +143,7 @@ func TestCreateCompoundSetting(t *testing.T) {
require.NoError(t, ctx.Codec.UnmarshalJSON(out.Bytes(), &resp))
require.NoError(t, clitestutil.CheckTxCode(net, ctx, resp.TxHash, tc.code))

if tc.desc != "invalid" {
if tc.desc != "invalid" && tc.desc != "withdraw" {
args = append([]string{val.Address.String()}, fmt.Sprintf("--%s=json", tmcli.OutputFlag))
out, err = clitestutil.ExecTestCLICmd(ctx, cli.CmdShowCompoundSetting(), args)
if tc.err != nil {
Expand All @@ -142,6 +161,25 @@ func TestCreateCompoundSetting(t *testing.T) {
require.Equal(t, compoundSetting.GetCompoundSetting().ValidatorSetting[i].ValidatorAddress, tc.compoundValidators[i])
}
}

if tc.desc == "withdraw" {
args = append([]string{val1.Address.String()}, tc.args...)
out, err = clitestutil.ExecTestCLICmd(ctx, distrcli.NewSetWithdrawAddrCmd(), args)
if tc.err != nil {
require.ErrorIs(t, err, tc.err)
return
}
require.NoError(t, err)

require.NoError(t, net.WaitForNextBlock())

var resp sdk.TxResponse
require.NoError(t, ctx.Codec.UnmarshalJSON(out.Bytes(), &resp))
require.NoError(t, clitestutil.CheckTxCode(net, ctx, resp.TxHash, 0))

_, err = net.WaitForHeightWithTimeout(15+ctx.Height, time.Second*45)
require.NoError(t, err)
}
})
}
}
Expand Down
11 changes: 11 additions & 0 deletions x/compound/keeper/compound.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ func (k Keeper) Compound(ctx sdk.Context, cs compTypes.CompoundSetting) error {
// Handle any leftover amount if 100% of rewards are to be compounded by adding any leftover amount to their first validator
compoundActions = k.HandleLeftOverAmount(compoundActions, totalCompoundPercent, amountToCompound)

// If withdraw address is different change it temporarily
withdrawAddr := k.distrKeeper.GetDelegatorWithdrawAddr(ctx, address)
if !withdrawAddr.Equals(address) {
k.distrKeeper.SetDelegatorWithdrawAddr(ctx, address, address)
}

// Claim all staking rewards, there is an edge case where if multiple validators worth of rewards are being
// compounded to a single validator and the compounding amount is greater than the sum of the staking reward being
// claimed on the delegate and the wallet balance, a panic will occur as the network will try to delegate more than
Expand All @@ -99,6 +105,11 @@ func (k Keeper) Compound(ctx sdk.Context, cs compTypes.CompoundSetting) error {
}
}

// Change withdraw address back to what it was
if !withdrawAddr.Equals(address) {
k.distrKeeper.SetDelegatorWithdrawAddr(ctx, address, withdrawAddr)
}

// Execute all CompoundActions
for _, compoundAction := range compoundActions {
err := Delegate(ctx, k, compoundAction, address)
Expand Down
2 changes: 2 additions & 0 deletions x/compound/types/expected_keepers.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ type DistrKeeper interface {
CalculateDelegationRewards(ctx sdk.Context, val stakingTypes.ValidatorI, del stakingTypes.DelegationI, endingPeriod uint64) (rewards sdk.DecCoins)
IncrementValidatorPeriod(ctx sdk.Context, val stakingTypes.ValidatorI) uint64
WithdrawDelegationRewards(ctx sdk.Context, delAddr sdk.AccAddress, valAddr sdk.ValAddress) (sdk.Coins, error)
GetDelegatorWithdrawAddr(ctx sdk.Context, delAddr sdk.AccAddress) sdk.AccAddress
SetDelegatorWithdrawAddr(ctx sdk.Context, delAddr, withdrawAddr sdk.AccAddress)
}

type StakingKeeper interface {
Expand Down

0 comments on commit b6571f3

Please sign in to comment.