From fdc4ac0fcde64852f2b092e863020040055155c8 Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Thu, 28 Nov 2024 15:51:14 +0900 Subject: [PATCH 01/24] use common assert --- pool/pool.gno | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pool/pool.gno b/pool/pool.gno index 71f796eb2..a3644a4da 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -30,7 +30,7 @@ func Mint( common.IsHalted() if common.GetLimitCaller() { caller := std.PrevRealm().Addr() - if caller != consts.POSITION_ADDR { + if err := common.PositionOnly(caller); err != nil { panic(addDetailToError( errNoPermission, ufmt.Sprintf("pool.gno__Mint() || only position(%s) can call pool mint(), called from %s", consts.POSITION_ADDR, caller.String()), @@ -82,7 +82,7 @@ func Burn( common.IsHalted() caller := std.PrevRealm().Addr() if common.GetLimitCaller() { - if caller != consts.POSITION_ADDR { + if err := common.PositionOnly(caller); err != nil { panic(addDetailToError( errNoPermission, ufmt.Sprintf("pool.gno__Burn() || only position(%s) can call pool burn(), called from %s", consts.POSITION_ADDR, caller.String()), @@ -135,7 +135,7 @@ func Collect( common.IsHalted() if common.GetLimitCaller() { caller := std.PrevRealm().Addr() - if caller != consts.POSITION_ADDR { + if err := common.PositionOnly(caller); err != nil { panic(addDetailToError( errNoPermission, ufmt.Sprintf("pool.gno__Collect() || only position(%s) can call pool collect(), called from %s", consts.POSITION_ADDR, caller.String()), @@ -196,7 +196,7 @@ func Swap( common.IsHalted() if common.GetLimitCaller() { caller := std.PrevRealm().Addr() - if caller != consts.ROUTER_ADDR { + if err := common.RouterOnly(caller); err != nil { panic(addDetailToError( errNoPermission, ufmt.Sprintf("pool.gno__Swap() || only router(%s) can call pool swap(), called from %s", consts.ROUTER_ADDR, caller.String()), From 9502b9555a2d0a71641279accb3fe0920ae64925 Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Fri, 29 Nov 2024 17:32:50 +0900 Subject: [PATCH 02/24] extract some functions --- pool/pool.gno | 107 +++++++++++++++++++++++++-------------------- pool/pool_test.gno | 78 +++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+), 47 deletions(-) create mode 100644 pool/pool_test.gno diff --git a/pool/pool.gno b/pool/pool.gno index a3644a4da..29198a65e 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -332,38 +332,7 @@ func Swap( // shift tick if we reached the next price if state.sqrtPriceX96.Eq(step.sqrtPriceNextX96) { - // if the tick is initialized, run the tick transition - if step.initialized { - var fee0, fee1 *u256.Uint - - // check for the placeholder value, which we replace with the actual value the first time the swap crosses an initialized tick - if zeroForOne { - fee0 = state.feeGrowthGlobalX128 - fee1 = pool.feeGrowthGlobal1X128 - } else { - fee0 = pool.feeGrowthGlobal0X128 - fee1 = state.feeGrowthGlobalX128 - } - - liquidityNet := pool.tickCross( - step.tickNext, - fee0, - fee1, - ) - - // if we're moving leftward, we interpret liquidityNet as the opposite sign - if zeroForOne { - liquidityNet = i256.Zero().Neg(liquidityNet) - } - - state.liquidity = liquidityMathAddDelta(state.liquidity, liquidityNet) - } - - if zeroForOne { - state.tick = step.tickNext - 1 - } else { - state.tick = step.tickNext - } + tickTransition(step, zeroForOne, state, pool) } else if !(state.sqrtPriceX96.Eq(step.sqrtPriceStartX96)) { // recompute unless we're on a lower tick boundary (i.e. already transitioned ticks), and haven't moved state.tick = common.TickMathGetTickAtSqrtRatio(state.sqrtPriceX96) @@ -408,21 +377,7 @@ func Swap( } // actual swap - if zeroForOne { - // payer > POOL - pool.transferFromAndVerify(payer, consts.POOL_ADDR, pool.token0Path, amount0, true) - - // POOL > recipient - pool.transferAndVerify(recipient, pool.token1Path, amount1, false) - - } else { - // payer > POOL - pool.transferFromAndVerify(payer, consts.POOL_ADDR, pool.token1Path, amount1, false) - - // POOL > recipient - pool.transferAndVerify(recipient, pool.token0Path, amount0, true) - - } + pool.swapTransfers(zeroForOne, payer, recipient, amount0, amount1) prevAddr, prevRealm := getPrev() @@ -448,6 +403,64 @@ func Swap( return amount0.ToString(), amount1.ToString() } +func (pool *Pool) swapTransfers(zeroForOne bool, payer, recipient std.Address, amount0, amount1 *i256.Int) { + var targetTokenPath string + var amount *i256.Int + var isToken0 bool + + switch zeroForOne { + case true: + targetTokenPath = pool.token0Path + amount = amount0 + isToken0 = true + case false: + targetTokenPath = pool.token1Path + amount = amount1 + isToken0 = false + } + + // payer -> POOL -> recipient + pool.transferFromAndVerify(payer, consts.POOL_ADDR, targetTokenPath, amount, isToken0) + pool.transferAndVerify(recipient, targetTokenPath, amount, !isToken0) +} + +// tickTransition handles the transition between ticks during a swap +func tickTransition(step StepComputations, zeroForOne bool, state SwapState, pool *Pool) { + // if the tick is initialized, run the tick transition + if step.initialized { + var fee0, fee1 *u256.Uint + + // check for the placeholder value, which we replace with the actual value + // the first time the swap crosses an initialized tick + if zeroForOne { + fee0 = state.feeGrowthGlobalX128 + fee1 = pool.feeGrowthGlobal1X128 + } else { + fee0 = pool.feeGrowthGlobal0X128 + fee1 = state.feeGrowthGlobalX128 + } + + // if we're moving leftward, we interpret liquidityNet as the opposite sign + liquidityNet := pool.tickCross( + step.tickNext, + fee0, + fee1, + ) + + if zeroForOne { + liquidityNet = i256.Zero().Neg(liquidityNet) + } + + state.liquidity = liquidityMathAddDelta(state.liquidity, liquidityNet) + } + + if zeroForOne { + state.tick = step.tickNext - 1 + } else { + state.tick = step.tickNext + } +} + // SetFeeProtocolByAdmin sets the fee protocol for all pools // Also it will be applied to new created pools func SetFeeProtocolByAdmin( diff --git a/pool/pool_test.gno b/pool/pool_test.gno new file mode 100644 index 000000000..f6f99317b --- /dev/null +++ b/pool/pool_test.gno @@ -0,0 +1,78 @@ +package pool + +import ( + "testing" + + "gno.land/p/demo/uassert" + u256 "gno.land/p/gnoswap/uint256" +) + +func TestSaveProtocolFees(t *testing.T) { + tests := []struct { + name string + pool *Pool + amount0 *u256.Uint + amount1 *u256.Uint + want0 *u256.Uint + want1 *u256.Uint + wantFee0 *u256.Uint + wantFee1 *u256.Uint + }{ + { + name: "normal fee deduction", + pool: &Pool{ + protocolFees: ProtocolFees{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + amount0: u256.NewUint(500), + amount1: u256.NewUint(1000), + want0: u256.NewUint(500), + want1: u256.NewUint(1000), + wantFee0: u256.NewUint(500), + wantFee1: u256.NewUint(1000), + }, + { + name: "exact fee deduction (1 deduction)", + pool: &Pool{ + protocolFees: ProtocolFees{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + amount0: u256.NewUint(1000), + amount1: u256.NewUint(2000), + want0: u256.NewUint(999), + want1: u256.NewUint(1999), + wantFee0: u256.NewUint(1), + wantFee1: u256.NewUint(1), + }, + { + name: "0 fee deduction", + pool: &Pool{ + protocolFees: ProtocolFees{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + amount0: u256.NewUint(0), + amount1: u256.NewUint(0), + want0: u256.NewUint(0), + want1: u256.NewUint(0), + wantFee0: u256.NewUint(1000), + wantFee1: u256.NewUint(2000), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got0, got1 := tt.pool.saveProtocolFees(tt.amount0, tt.amount1) + + uassert.Equal(t, got0.ToString(), tt.want0.ToString()) + uassert.Equal(t, got1.ToString(), tt.want1.ToString()) + uassert.Equal(t, tt.pool.protocolFees.token0.ToString(), tt.wantFee0.ToString()) + uassert.Equal(t, tt.pool.protocolFees.token1.ToString(), tt.wantFee1.ToString()) + }) + } +} From fee611cb58885dd525af0bf7963e01798483f236 Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Mon, 2 Dec 2024 14:11:14 +0900 Subject: [PATCH 03/24] refact, test: transferAndVerify --- pool/pool.gno | 118 +++++++++++++++++++++------------------ pool/pool_test.gno | 134 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 200 insertions(+), 52 deletions(-) diff --git a/pool/pool.gno b/pool/pool.gno index 29198a65e..04b23f69e 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -26,7 +26,7 @@ func Mint( tickUpper int32, _liquidityAmount string, // uint128 positionCaller std.Address, -) (string, string) { // uint256 x2 +) (string, string) { // uint256 x2== "0" common.IsHalted() if common.GetLimitCaller() { caller := std.PrevRealm().Addr() @@ -47,14 +47,8 @@ func Mint( } pool := GetPool(token0Path, token1Path, fee) - _, amount0, amount1 := pool.modifyPosition( - ModifyPositionParams{ - recipient, // owner - tickLower, // tickLower - tickUpper, // tickUpper - i256.FromUint256(liquidityAmount), // liquidityDelta - }, - ) + position := NewModifyPositionParams(recipient, tickLower, tickUpper, i256.FromUint256(liquidityAmount)) + _, amount0, amount1 := pool.modifyPosition(position) if amount0.Gt(i256.Zero()) { pool.transferFromAndVerify(positionCaller, consts.POOL_ADDR, pool.token0Path, amount0, true) @@ -157,6 +151,8 @@ func Collect( )) } + // ------ amount 구하고 transfer하는 로직이 반복 됨. 함수로 만들기 + // Smallest of three: amount0Requested, position.tokensOwed0, pool.balances.token0 amount0 := u256Min(amount0Requested, position.tokensOwed0) amount0 = u256Min(amount0, pool.balances.token0) @@ -175,6 +171,8 @@ func Collect( pool.balances.token1 = new(u256.Uint).Sub(pool.balances.token1, amount1) transferByRegisterCall(pool.token1Path, recipient, amount1.Uint64()) + // ----- + pool.positions[positionKey] = position return amount0.ToString(), amount1.ToString() @@ -204,12 +202,14 @@ func Swap( } } + // --- foo == "0"인지 검사하는 로직이 반복됨. 함수로 추출 if _amountSpecified == "0" { panic(addDetailToError( errInvalidSwapAmount, ufmt.Sprintf("pool.gno__Swap() || amountSpecified == 0"), )) } + // --- amountSpecified := i256.MustFromDecimal(_amountSpecified) sqrtPriceLimitX96 := u256.MustFromDecimal(_sqrtPriceLimitX96) @@ -227,6 +227,7 @@ func Swap( var feeProtocol uint8 var feeGrowthGlobalX128 *u256.Uint + // --- 중복 됨. 함수로 만들면 좋을 듯 if zeroForOne { minSqrtRatio := u256.MustFromDecimal(consts.MIN_SQRT_RATIO) @@ -256,6 +257,7 @@ func Swap( feeProtocol = slot0Start.feeProtocol / 16 feeGrowthGlobalX128 = pool.feeGrowthGlobal1X128 } + // --- pool.slot0.unlocked = false cache := newSwapCache(feeProtocol, pool.liquidity) @@ -263,6 +265,8 @@ func Swap( exactInput := amountSpecified.Gt(i256.Zero()) + // --- 루프 추출해야 함 + // continue swapping as long as we haven't used the entire input/output and haven't reached the price limit swapFee := u256.Zero() for !(state.amountSpecifiedRemaining.IsZero()) && !(state.sqrtPriceX96.Eq(sqrtPriceLimitX96)) { @@ -517,6 +521,7 @@ func setFeeProtocol( ) uint8 { common.IsHalted() + // --- cond 함수로 만들어야 함 fee0Cond := feeProtocol0 == 0 || (feeProtocol0 >= 4 && feeProtocol0 <= 10) fee1Cond := feeProtocol1 == 0 || (feeProtocol1 >= 4 && feeProtocol1 <= 10) if !(fee0Cond && fee1Cond) { @@ -525,6 +530,7 @@ func setFeeProtocol( ufmt.Sprintf("pool.gno__setFeeProtocol() || expected (feeProtocol0(%d) == 0 || (feeProtocol0(%d) >= 4 && feeProtocol0(%d) <= 10)) && (feeProtocol1(%d) == 0 || (feeProtocol1(%d) >= 4 && feeProtocol1(%d) <= 10))", feeProtocol0, feeProtocol0, feeProtocol0, feeProtocol1, feeProtocol1, feeProtocol1), )) } + // --- newFee := feeProtocol0 + (feeProtocol1 << 4) // ( << 4 ) = ( * 16 ) @@ -675,63 +681,71 @@ func (pool *Pool) transferAndVerify( amount *i256.Int, isToken0 bool, ) { - if amount.IsZero() { - return - } - - // must be negative to send token from pool to user - // as point of view from pool, it is negative - if !amount.IsNeg() { + if amount.Sign() != -1 { panic(addDetailToError( errMustBeNegative, ufmt.Sprintf("pool.gno__transferAndVerify() || amount(%s) must be negative", amount.ToString()), )) } - // check pool.balances + absAmount := amount.Abs() + + pool.validatePoolBalance(absAmount, isToken0) // abs + + amountUint64 := checkAmountRange(absAmount) + + // Execute transfer + transferByRegisterCall(tokenPath, to, amountUint64) + + // Update pool balance + pool.updatePoolBalance(absAmount, isToken0) // abs +} + +func (pool *Pool) validatePoolBalance(amount *u256.Uint, isToken0 bool) { if isToken0 { - if pool.balances.token0.Lt(amount.Abs()) { - panic(addDetailToError( - errTransferFailed, - ufmt.Sprintf("pool.gno__transferAndVerify() || pool.balances.token0(%s) >= amount.Abs(%s)", pool.balances.token0.ToString(), amount.Abs().ToString()), - )) - } - } else { - if pool.balances.token1.Lt(amount.Abs()) { + if pool.balances.token0.Lt(amount) { panic(addDetailToError( errTransferFailed, - ufmt.Sprintf("pool.gno__transferAndVerify() || pool.balances.token1(%s) >= amount.Abs(%s)", pool.balances.token1.ToString(), amount.Abs().ToString()), + ufmt.Sprintf("pool.gno__transferAndVerify() || pool.balances.token0(%s) >= amount(%s)", + pool.balances.token0.ToString(), amount.ToString()), )) } + return } + if pool.balances.token1.Lt(amount) { + panic(addDetailToError( + errTransferFailed, + ufmt.Sprintf("pool.gno__transferAndVerify() || pool.balances.token1(%s) >= amount(%s)", + pool.balances.token1.ToString(), amount.ToString()), + )) + } + return +} - amountUint64 := checkAmountRange(amount) - - // try sending - // will panic if following conditions are met: - // - POOL does not have enough balance - // - token is not registered - transferByRegisterCall(tokenPath, to, amountUint64) - - // update pool.balances +func (pool *Pool) updatePoolBalance(amount *u256.Uint, isToken0 bool) { var overflow bool + if isToken0 { - pool.balances.token0, overflow = new(u256.Uint).SubOverflow(pool.balances.token0, amount.Abs()) + pool.balances.token0, overflow = new(u256.Uint).SubOverflow(pool.balances.token0, amount) if overflow { panic(addDetailToError( errTransferFailed, - ufmt.Sprintf("pool.gno__transferAndVerify() || cannot decrease, pool.balances.token0(%s) - amount(%s)", pool.balances.token0.ToString(), amount.Abs().ToString()), - )) - } - } else { - pool.balances.token1, overflow = new(u256.Uint).SubOverflow(pool.balances.token1, amount.Abs()) - if pool.balances.token1.Lt(u256.Zero()) { - panic(addDetailToError( - errTransferFailed, - ufmt.Sprintf("pool.gno__transferAndVerify() || cannot decrease, pool.balances.token1(%s) - amount(%s)", pool.balances.token1.ToString(), amount.Abs().ToString()), + ufmt.Sprintf("pool.gno__transferAndVerify() || cannot decrease, pool.balances.token0(%s) - amount(%s)", + pool.balances.token0.ToString(), amount.ToString()), )) } + return + } + + pool.balances.token1, overflow = new(u256.Uint).SubOverflow(pool.balances.token1, amount) + if pool.balances.token1.Lt(u256.Zero()) { + panic(addDetailToError( + errTransferFailed, + ufmt.Sprintf("pool.gno__transferAndVerify() || cannot decrease, pool.balances.token1(%s) - amount(%s)", + pool.balances.token1.ToString(), amount.ToString()), + )) } + return } func (pool *Pool) transferFromAndVerify( @@ -740,7 +754,9 @@ func (pool *Pool) transferFromAndVerify( amount *i256.Int, isToken0 bool, ) { - amountUint64 := checkAmountRange(amount) + // abs + absAmount := amount.Abs() + amountUint64 := checkAmountRange(absAmount) // try sending // will panic if following conditions are met: @@ -757,16 +773,14 @@ func (pool *Pool) transferFromAndVerify( } } -func checkAmountRange(amount *i256.Int) uint64 { - // check amount is in uint64 range - amountAbs := amount.Abs() - amountUint64, overflow := amountAbs.Uint64WithOverflow() +func checkAmountRange(amount *u256.Uint) uint64 { + res, overflow := amount.Uint64WithOverflow() if overflow { panic(addDetailToError( errOutOfRange, - ufmt.Sprintf("pool.gno__checkAmountRange() || amountAbs(%s) overflows uint64 range", amountAbs.ToString()), + ufmt.Sprintf("pool.gno__checkAmountRange() || amount(%s) overflows uint64 range", amount.ToString()), )) } - return amountUint64 + return res } diff --git a/pool/pool_test.gno b/pool/pool_test.gno index f6f99317b..0c63399f0 100644 --- a/pool/pool_test.gno +++ b/pool/pool_test.gno @@ -1,10 +1,12 @@ package pool import ( + "std" "testing" "gno.land/p/demo/uassert" u256 "gno.land/p/gnoswap/uint256" + i256 "gno.land/p/gnoswap/int256" ) func TestSaveProtocolFees(t *testing.T) { @@ -76,3 +78,135 @@ func TestSaveProtocolFees(t *testing.T) { }) } } + +func TestTransferAndVerify(t *testing.T) { + pool := &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(1000), + }, + } + + t.Run("validatePoolBalance", func(t *testing.T) { + testCases := []struct { + name string + amount *u256.Uint + isToken0 bool + shouldPanic bool + }{ + { + name: "must success for negative amount", + amount: u256.NewUint(500), + isToken0: true, + shouldPanic: false, + }, + { + name: "must panic for insufficient token0 balance", + amount: u256.NewUint(1500), + isToken0: true, + shouldPanic: true, + }, + { + name: "must success for negative amount", + amount: u256.NewUint(500), + isToken0: false, + shouldPanic: false, + }, + { + name: "must panic for insufficient token1 balance", + amount: u256.NewUint(1500), + isToken0: false, + shouldPanic: true, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + defer func() { + r := recover() + if tt.shouldPanic && r == nil { + t.Error("expected panic but no panic") + } + if !tt.shouldPanic && r != nil { + t.Errorf("unexpected panic: %v", r) + } + }() + + pool.validatePoolBalance(tt.amount, tt.isToken0) + }) + } + }) + + t.Run("updatePoolBalance", func(t *testing.T) { + testCases := []struct { + name string + initialBalance *u256.Uint + amount *u256.Uint + isToken0 bool + expectedBalance *u256.Uint + shouldPanic bool + }{ + { + name: "must success for negative amount", + initialBalance: u256.NewUint(1000), + amount: u256.NewUint(300), + isToken0: true, + expectedBalance: u256.NewUint(700), + shouldPanic: false, + }, + { + name: "must panic for overflow", + initialBalance: u256.NewUint(100), + amount: u256.NewUint(200), + isToken0: true, + expectedBalance: nil, + shouldPanic: true, + }, + { + name: "must success for negative amount", + initialBalance: u256.NewUint(1000), + amount: u256.NewUint(300), + isToken0: false, + expectedBalance: u256.NewUint(700), + shouldPanic: false, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + testPool := &Pool{ + balances: Balances{ + token0: tt.initialBalance, + token1: tt.initialBalance, + }, + } + + defer func() { + r := recover() + if tt.shouldPanic && r == nil { + t.Error("expected panic but no panic") + } + if !tt.shouldPanic && r != nil { + t.Errorf("unexpected panic: %v", r) + } + }() + + testPool.updatePoolBalance(tt.amount, tt.isToken0) + if !tt.shouldPanic { + var actualBalance *u256.Uint + if tt.isToken0 { + actualBalance = testPool.balances.token0 + } else { + actualBalance = testPool.balances.token1 + } + + if !actualBalance.Eq(tt.expectedBalance) { + t.Errorf("expected balance: %v, actual balance: %v", + tt.expectedBalance.ToString(), + actualBalance.ToString()) + } + } + }) + } + }) +} From adbafee8ebe8296c758d0d88c984e57c6d3f59fb Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Mon, 2 Dec 2024 14:52:09 +0900 Subject: [PATCH 04/24] fix --- pool/pool.gno | 88 ++++++++++++++++++------------------ pool/pool_test.gno | 109 +++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 150 insertions(+), 47 deletions(-) diff --git a/pool/pool.gno b/pool/pool.gno index 04b23f69e..abb29df64 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -202,14 +202,12 @@ func Swap( } } - // --- foo == "0"인지 검사하는 로직이 반복됨. 함수로 추출 if _amountSpecified == "0" { panic(addDetailToError( errInvalidSwapAmount, ufmt.Sprintf("pool.gno__Swap() || amountSpecified == 0"), )) } - // --- amountSpecified := i256.MustFromDecimal(_amountSpecified) sqrtPriceLimitX96 := u256.MustFromDecimal(_sqrtPriceLimitX96) @@ -493,10 +491,7 @@ func SetFeeProtocolByAdmin( // Only governance contract can execute this function via proposal // Also it will be applied to new created pools // ref: https://docs.gnoswap.io/contracts/pool/pool.gno#setfeeprotocol -func SetFeeProtocol( - feeProtocol0 uint8, - feeProtocol1 uint8, -) { +func SetFeeProtocol(feeProtocol0, feeProtocol1 uint8) { caller := std.PrevRealm().Addr() if err := common.GovernanceOnly(caller); err != nil { panic(err) @@ -515,22 +510,15 @@ func SetFeeProtocol( ) } -func setFeeProtocol( - feeProtocol0 uint8, - feeProtocol1 uint8, -) uint8 { +func setFeeProtocol(feeProtocol0, feeProtocol1 uint8) uint8 { common.IsHalted() - // --- cond 함수로 만들어야 함 - fee0Cond := feeProtocol0 == 0 || (feeProtocol0 >= 4 && feeProtocol0 <= 10) - fee1Cond := feeProtocol1 == 0 || (feeProtocol1 >= 4 && feeProtocol1 <= 10) - if !(fee0Cond && fee1Cond) { + if err := validateFeeProtocol(feeProtocol0, feeProtocol1); err != nil { panic(addDetailToError( - errInvalidProtocolFeePct, + err, ufmt.Sprintf("pool.gno__setFeeProtocol() || expected (feeProtocol0(%d) == 0 || (feeProtocol0(%d) >= 4 && feeProtocol0(%d) <= 10)) && (feeProtocol1(%d) == 0 || (feeProtocol1(%d) >= 4 && feeProtocol1(%d) <= 10))", feeProtocol0, feeProtocol0, feeProtocol0, feeProtocol1, feeProtocol1, feeProtocol1), )) } - // --- newFee := feeProtocol0 + (feeProtocol1 << 4) // ( << 4 ) = ( * 16 ) @@ -545,6 +533,17 @@ func setFeeProtocol( return newFee } +func validateFeeProtocol(feeProtocol0, feeProtocol1 uint8) error { + if !isValidFeeProtocolValue(feeProtocol0) || !isValidFeeProtocolValue(feeProtocol1) { + return errInvalidProtocolFeePct + } + return nil +} + +func isValidFeeProtocolValue(value uint8) bool { + return value == 0 || (value >= 4 && value <= 10) +} + // CollectProtocolByAdmin collects protocol fees for the given pool that accumulated while it was being used for swap // Returns collected amount0, amount1 in string func CollectProtocolByAdmin( @@ -701,6 +700,7 @@ func (pool *Pool) transferAndVerify( pool.updatePoolBalance(absAmount, isToken0) // abs } +// token0, 1은 파라미터로 넘기는게 좋아보임. 이 경우 pool을 넘겨받지 않아도 됨. func (pool *Pool) validatePoolBalance(amount *u256.Uint, isToken0 bool) { if isToken0 { if pool.balances.token0.Lt(amount) { @@ -723,29 +723,32 @@ func (pool *Pool) validatePoolBalance(amount *u256.Uint, isToken0 bool) { } func (pool *Pool) updatePoolBalance(amount *u256.Uint, isToken0 bool) { - var overflow bool - - if isToken0 { - pool.balances.token0, overflow = new(u256.Uint).SubOverflow(pool.balances.token0, amount) - if overflow { - panic(addDetailToError( - errTransferFailed, - ufmt.Sprintf("pool.gno__transferAndVerify() || cannot decrease, pool.balances.token0(%s) - amount(%s)", - pool.balances.token0.ToString(), amount.ToString()), - )) - } - return - } - - pool.balances.token1, overflow = new(u256.Uint).SubOverflow(pool.balances.token1, amount) - if pool.balances.token1.Lt(u256.Zero()) { - panic(addDetailToError( - errTransferFailed, - ufmt.Sprintf("pool.gno__transferAndVerify() || cannot decrease, pool.balances.token1(%s) - amount(%s)", - pool.balances.token1.ToString(), amount.ToString()), - )) - } - return + var overflow bool + var newBalance *u256.Uint + + if isToken0 { + newBalance, overflow = new(u256.Uint).SubOverflow(pool.balances.token0, amount) + if overflow || newBalance.Lt(u256.Zero()) { + panic(addDetailToError( + errTransferFailed, + ufmt.Sprintf("pool.gno__transferAndVerify() || cannot decrease, pool.balances.token0(%s) - amount(%s)", + pool.balances.token0.ToString(), amount.ToString()), + )) + } + pool.balances.token0 = newBalance + return + } + + newBalance, overflow = new(u256.Uint).SubOverflow(pool.balances.token1, amount) + if overflow || newBalance.Lt(u256.Zero()) { + panic(addDetailToError( + errTransferFailed, + ufmt.Sprintf("pool.gno__transferAndVerify() || cannot decrease, pool.balances.token1(%s) - amount(%s)", + pool.balances.token1.ToString(), amount.ToString()), + )) + } + pool.balances.token1 = newBalance + return } func (pool *Pool) transferFromAndVerify( @@ -754,7 +757,6 @@ func (pool *Pool) transferFromAndVerify( amount *i256.Int, isToken0 bool, ) { - // abs absAmount := amount.Abs() amountUint64 := checkAmountRange(absAmount) @@ -765,11 +767,11 @@ func (pool *Pool) transferFromAndVerify( // - token is not registered transferFromByRegisterCall(tokenPath, from, to, amountUint64) - // update pool.balances + // update pool balances if isToken0 { - pool.balances.token0 = new(u256.Uint).Add(pool.balances.token0, amount.Abs()) + pool.balances.token0 = new(u256.Uint).Add(pool.balances.token0, absAmount) } else { - pool.balances.token1 = new(u256.Uint).Add(pool.balances.token1, amount.Abs()) + pool.balances.token1 = new(u256.Uint).Add(pool.balances.token1, absAmount) } } diff --git a/pool/pool_test.gno b/pool/pool_test.gno index 0c63399f0..a03ae3965 100644 --- a/pool/pool_test.gno +++ b/pool/pool_test.gno @@ -80,6 +80,7 @@ func TestSaveProtocolFees(t *testing.T) { } func TestTransferAndVerify(t *testing.T) { + // Setup common test data pool := &Pool{ balances: Balances{ token0: u256.NewUint(1000), @@ -88,7 +89,7 @@ func TestTransferAndVerify(t *testing.T) { } t.Run("validatePoolBalance", func(t *testing.T) { - testCases := []struct { + tests := []struct { name string amount *u256.Uint isToken0 bool @@ -120,7 +121,7 @@ func TestTransferAndVerify(t *testing.T) { }, } - for _, tt := range testCases { + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { defer func() { r := recover() @@ -138,7 +139,7 @@ func TestTransferAndVerify(t *testing.T) { }) t.Run("updatePoolBalance", func(t *testing.T) { - testCases := []struct { + tests := []struct { name string initialBalance *u256.Uint amount *u256.Uint @@ -172,7 +173,7 @@ func TestTransferAndVerify(t *testing.T) { }, } - for _, tt := range testCases { + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { testPool := &Pool{ balances: Balances{ @@ -210,3 +211,103 @@ func TestTransferAndVerify(t *testing.T) { } }) } + +func TestUpdatePoolBalance(t *testing.T) { + tests := []struct { + name string + initialToken0 *u256.Uint + initialToken1 *u256.Uint + amount *u256.Uint + isToken0 bool + expectedToken0 *u256.Uint + expectedToken1 *u256.Uint + shouldPanic bool + }{ + { + name: "normal token0 decrease", + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(300), + isToken0: true, + expectedToken0: u256.NewUint(700), + expectedToken1: u256.NewUint(2000), + shouldPanic: false, + }, + { + name: "normal token1 decrease", + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(500), + isToken0: false, + expectedToken0: u256.NewUint(1000), + expectedToken1: u256.NewUint(1500), + shouldPanic: false, + }, + { + name: "insufficient token0 balance", + initialToken0: u256.NewUint(100), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(200), + isToken0: true, + expectedToken0: nil, + expectedToken1: nil, + shouldPanic: true, + }, + { + name: "insufficient token1 balance", + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(100), + amount: u256.NewUint(200), + isToken0: false, + expectedToken0: nil, + expectedToken1: nil, + shouldPanic: true, + }, + { + name: "0 value handling for token0", + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(0), + isToken0: true, + expectedToken0: u256.NewUint(1000), + expectedToken1: u256.NewUint(2000), + shouldPanic: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pool := &Pool{ + balances: Balances{ + token0: tt.initialToken0, + token1: tt.initialToken1, + }, + } + + defer func() { + r := recover() + if tt.shouldPanic && r == nil { + t.Error("expected panic but no panic") + } + if !tt.shouldPanic && r != nil { + t.Errorf("unexpected panic: %v", r) + } + }() + + pool.updatePoolBalance(tt.amount, tt.isToken0) + + if !tt.shouldPanic { + if !pool.balances.token0.Eq(tt.expectedToken0) { + t.Errorf("token0 balance mismatch. expected: %s, actual: %s", + tt.expectedToken0.ToString(), + pool.balances.token0.ToString()) + } + if !pool.balances.token1.Eq(tt.expectedToken1) { + t.Errorf("token1 balance mismatch. expected: %s, actual: %s", + tt.expectedToken1.ToString(), + pool.balances.token1.ToString()) + } + } + }) + } +} From d842c39e69437d66e07e005e77d86a58b5117326 Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Mon, 2 Dec 2024 19:34:43 +0900 Subject: [PATCH 05/24] refact: transferAndVerify helper --- pool/pool.gno | 158 +++++++++++++++----------------- pool/pool_test.gno | 223 ++++++++++++++------------------------------- 2 files changed, 145 insertions(+), 236 deletions(-) diff --git a/pool/pool.gno b/pool/pool.gno index abb29df64..008bac8cd 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -71,7 +71,7 @@ func Burn( fee uint32, tickLower int32, tickUpper int32, - _liquidityAmount string, // uint128 + liquidityAmount string, // uint128 ) (string, string) { // uint256 x2 common.IsHalted() caller := std.PrevRealm().Addr() @@ -84,25 +84,17 @@ func Burn( } } - liquidityAmount := u256.MustFromDecimal(_liquidityAmount) + liqAmount := u256.MustFromDecimal(liquidityAmount) pool := GetPool(token0Path, token1Path, fee) - position, amount0Int, amount1Int := pool.modifyPosition( // in256 x2 - ModifyPositionParams{ - caller, // msg.sender - tickLower, - tickUpper, - i256.Zero().Neg(i256.FromUint256(liquidityAmount)), - }, - ) + liqDelta := i256.Zero().Neg(i256.FromUint256(liqAmount)) + posParams := NewModifyPositionParams(caller, tickLower, tickUpper, liqDelta) + position, amount0, amount1 := pool.modifyPosition(posParams) - amount0 := amount0Int.Abs() - amount1 := amount1Int.Abs() - - if amount0.Gt(u256.Zero()) || amount1.Gt(u256.Zero()) { - position.tokensOwed0 = new(u256.Uint).Add(position.tokensOwed0, amount0) - position.tokensOwed1 = new(u256.Uint).Add(position.tokensOwed1, amount1) + if amount0.Gt(i256.Zero()) || amount1.Gt(i256.Zero()) { + position.tokensOwed0 = new(u256.Uint).Add(position.tokensOwed0, amount0.Abs()) + position.tokensOwed1 = new(u256.Uint).Add(position.tokensOwed1, amount1.Abs()) } positionKey := positionGetKey(caller, tickLower, tickUpper) @@ -151,8 +143,6 @@ func Collect( )) } - // ------ amount 구하고 transfer하는 로직이 반복 됨. 함수로 만들기 - // Smallest of three: amount0Requested, position.tokensOwed0, pool.balances.token0 amount0 := u256Min(amount0Requested, position.tokensOwed0) amount0 = u256Min(amount0, pool.balances.token0) @@ -171,8 +161,6 @@ func Collect( pool.balances.token1 = new(u256.Uint).Sub(pool.balances.token1, amount1) transferByRegisterCall(pool.token1Path, recipient, amount1.Uint64()) - // ----- - pool.positions[positionKey] = position return amount0.ToString(), amount1.ToString() @@ -215,7 +203,7 @@ func Swap( pool := GetPool(token0Path, token1Path, fee) slot0Start := pool.slot0 - if !(slot0Start.unlocked) { + if !slot0Start.unlocked { panic(addDetailToError( errLockedPool, ufmt.Sprintf("pool.gno__Swap() || slot0Start.unlocked(false) must be unlocked)"), @@ -369,13 +357,10 @@ func Swap( } } - var amount0, amount1 *i256.Int + amount0 := state.amountCalculated + amount1 := i256.Zero().Sub(amountSpecified, state.amountSpecifiedRemaining) if zeroForOne == exactInput { - amount0 = i256.Zero().Sub(amountSpecified, state.amountSpecifiedRemaining) - amount1 = state.amountCalculated - } else { - amount0 = state.amountCalculated - amount1 = i256.Zero().Sub(amountSpecified, state.amountSpecifiedRemaining) + amount0, amount1 = amount1, amount0 } // actual swap @@ -443,11 +428,7 @@ func tickTransition(step StepComputations, zeroForOne bool, state SwapState, poo } // if we're moving leftward, we interpret liquidityNet as the opposite sign - liquidityNet := pool.tickCross( - step.tickNext, - fee0, - fee1, - ) + liquidityNet := pool.tickCross(step.tickNext, fee0, fee1) if zeroForOne { liquidityNet = i256.Zero().Neg(liquidityNet) @@ -689,66 +670,76 @@ func (pool *Pool) transferAndVerify( absAmount := amount.Abs() - pool.validatePoolBalance(absAmount, isToken0) // abs - - amountUint64 := checkAmountRange(absAmount) - - // Execute transfer + token0 := pool.balances.token0 + token1 := pool.balances.token1 + + if err := validatePoolBalance(token0, token1, absAmount, isToken0); err != nil { + panic(err) + } + amountUint64, err := checkAmountRange(absAmount) + if err != nil { + panic(err) + } + transferByRegisterCall(tokenPath, to, amountUint64) - - // Update pool balance - pool.updatePoolBalance(absAmount, isToken0) // abs + + newBalance, err := updatePoolBalance(token0, token1, absAmount, isToken0) + if err != nil { + panic(err) + } + + if isToken0 { + pool.balances.token0 = newBalance + } else { + pool.balances.token1 = newBalance + } } -// token0, 1은 파라미터로 넘기는게 좋아보임. 이 경우 pool을 넘겨받지 않아도 됨. -func (pool *Pool) validatePoolBalance(amount *u256.Uint, isToken0 bool) { +func validatePoolBalance(token0, token1, amount *u256.Uint, isToken0 bool) error { if isToken0 { - if pool.balances.token0.Lt(amount) { - panic(addDetailToError( - errTransferFailed, - ufmt.Sprintf("pool.gno__transferAndVerify() || pool.balances.token0(%s) >= amount(%s)", - pool.balances.token0.ToString(), amount.ToString()), - )) + if token0.Lt(amount) { + return ufmt.Errorf( + "%s || token0(%s) >= amount(%s)", + errTransferFailed.Error(), token0.ToString(), amount.ToString(), + ) } - return + return nil } - if pool.balances.token1.Lt(amount) { - panic(addDetailToError( - errTransferFailed, - ufmt.Sprintf("pool.gno__transferAndVerify() || pool.balances.token1(%s) >= amount(%s)", - pool.balances.token1.ToString(), amount.ToString()), - )) + if token1.Lt(amount) { + return ufmt.Errorf( + "%s || token1(%s) >= amount(%s)", + errTransferFailed.Error(), token1.ToString(), amount.ToString(), + ) } - return + return nil } -func (pool *Pool) updatePoolBalance(amount *u256.Uint, isToken0 bool) { - var overflow bool - var newBalance *u256.Uint +func updatePoolBalance( + token0, token1, amount *u256.Uint, + isToken0 bool, +) (*u256.Uint, error) { + var overflow bool + var newBalance *u256.Uint if isToken0 { - newBalance, overflow = new(u256.Uint).SubOverflow(pool.balances.token0, amount) + newBalance, overflow = new(u256.Uint).SubOverflow(token0, amount) if overflow || newBalance.Lt(u256.Zero()) { - panic(addDetailToError( - errTransferFailed, - ufmt.Sprintf("pool.gno__transferAndVerify() || cannot decrease, pool.balances.token0(%s) - amount(%s)", - pool.balances.token0.ToString(), amount.ToString()), - )) + return nil, ufmt.Errorf( + "%s || cannot decrease, token0(%s) - amount(%s)", + errTransferFailed.Error(), token0.ToString(), amount.ToString(), + ) } - pool.balances.token0 = newBalance - return + return newBalance, nil } - newBalance, overflow = new(u256.Uint).SubOverflow(pool.balances.token1, amount) + newBalance, overflow = new(u256.Uint).SubOverflow(token1, amount) if overflow || newBalance.Lt(u256.Zero()) { - panic(addDetailToError( - errTransferFailed, - ufmt.Sprintf("pool.gno__transferAndVerify() || cannot decrease, pool.balances.token1(%s) - amount(%s)", - pool.balances.token1.ToString(), amount.ToString()), - )) + return nil, ufmt.Errorf( + "%s || cannot decrease, token1(%s) - amount(%s)", + errTransferFailed.Error(), token1.ToString(), amount.ToString(), + ) } - pool.balances.token1 = newBalance - return + return newBalance, nil } func (pool *Pool) transferFromAndVerify( @@ -758,7 +749,10 @@ func (pool *Pool) transferFromAndVerify( isToken0 bool, ) { absAmount := amount.Abs() - amountUint64 := checkAmountRange(absAmount) + amountUint64, err := checkAmountRange(absAmount) + if err != nil { + panic(err) + } // try sending // will panic if following conditions are met: @@ -775,14 +769,14 @@ func (pool *Pool) transferFromAndVerify( } } -func checkAmountRange(amount *u256.Uint) uint64 { +func checkAmountRange(amount *u256.Uint) (uint64, error) { res, overflow := amount.Uint64WithOverflow() if overflow { - panic(addDetailToError( - errOutOfRange, - ufmt.Sprintf("pool.gno__checkAmountRange() || amount(%s) overflows uint64 range", amount.ToString()), - )) + return 0, ufmt.Errorf( + "%s || amount(%s) overflows uint64 range", + errOutOfRange.Error(), amount.ToString(), + ) } - return res + return res, nil } diff --git a/pool/pool_test.gno b/pool/pool_test.gno index a03ae3965..f49863468 100644 --- a/pool/pool_test.gno +++ b/pool/pool_test.gno @@ -93,120 +93,45 @@ func TestTransferAndVerify(t *testing.T) { name string amount *u256.Uint isToken0 bool - shouldPanic bool + expectedError bool }{ { name: "must success for negative amount", amount: u256.NewUint(500), isToken0: true, - shouldPanic: false, + expectedError: false, }, { name: "must panic for insufficient token0 balance", amount: u256.NewUint(1500), isToken0: true, - shouldPanic: true, + expectedError: true, }, { name: "must success for negative amount", amount: u256.NewUint(500), isToken0: false, - shouldPanic: false, + expectedError: false, }, { name: "must panic for insufficient token1 balance", amount: u256.NewUint(1500), isToken0: false, - shouldPanic: true, + expectedError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - defer func() { - r := recover() - if tt.shouldPanic && r == nil { - t.Error("expected panic but no panic") - } - if !tt.shouldPanic && r != nil { - t.Errorf("unexpected panic: %v", r) - } - }() - - pool.validatePoolBalance(tt.amount, tt.isToken0) - }) - } - }) - - t.Run("updatePoolBalance", func(t *testing.T) { - tests := []struct { - name string - initialBalance *u256.Uint - amount *u256.Uint - isToken0 bool - expectedBalance *u256.Uint - shouldPanic bool - }{ - { - name: "must success for negative amount", - initialBalance: u256.NewUint(1000), - amount: u256.NewUint(300), - isToken0: true, - expectedBalance: u256.NewUint(700), - shouldPanic: false, - }, - { - name: "must panic for overflow", - initialBalance: u256.NewUint(100), - amount: u256.NewUint(200), - isToken0: true, - expectedBalance: nil, - shouldPanic: true, - }, - { - name: "must success for negative amount", - initialBalance: u256.NewUint(1000), - amount: u256.NewUint(300), - isToken0: false, - expectedBalance: u256.NewUint(700), - shouldPanic: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - testPool := &Pool{ - balances: Balances{ - token0: tt.initialBalance, - token1: tt.initialBalance, - }, - } - - defer func() { - r := recover() - if tt.shouldPanic && r == nil { - t.Error("expected panic but no panic") - } - if !tt.shouldPanic && r != nil { - t.Errorf("unexpected panic: %v", r) - } - }() - - testPool.updatePoolBalance(tt.amount, tt.isToken0) - if !tt.shouldPanic { - var actualBalance *u256.Uint - if tt.isToken0 { - actualBalance = testPool.balances.token0 - } else { - actualBalance = testPool.balances.token1 - } - - if !actualBalance.Eq(tt.expectedBalance) { - t.Errorf("expected balance: %v, actual balance: %v", - tt.expectedBalance.ToString(), - actualBalance.ToString()) - } - } + token0 := pool.balances.token0 + token1 := pool.balances.token1 + + err := validatePoolBalance(token0, token1, tt.amount, tt.isToken0) + if err != nil { + if !tt.expectedError { + t.Errorf("unexpected error: %v", err) + } + } }) } }) @@ -214,64 +139,58 @@ func TestTransferAndVerify(t *testing.T) { func TestUpdatePoolBalance(t *testing.T) { tests := []struct { - name string - initialToken0 *u256.Uint - initialToken1 *u256.Uint - amount *u256.Uint - isToken0 bool - expectedToken0 *u256.Uint - expectedToken1 *u256.Uint - shouldPanic bool + name string + initialToken0 *u256.Uint + initialToken1 *u256.Uint + amount *u256.Uint + isToken0 bool + expectedBal *u256.Uint + expectErr bool }{ { name: "normal token0 decrease", - initialToken0: u256.NewUint(1000), - initialToken1: u256.NewUint(2000), - amount: u256.NewUint(300), - isToken0: true, - expectedToken0: u256.NewUint(700), - expectedToken1: u256.NewUint(2000), - shouldPanic: false, + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(300), + isToken0: true, + expectedBal: u256.NewUint(700), + expectErr: false, }, { name: "normal token1 decrease", - initialToken0: u256.NewUint(1000), - initialToken1: u256.NewUint(2000), - amount: u256.NewUint(500), - isToken0: false, - expectedToken0: u256.NewUint(1000), - expectedToken1: u256.NewUint(1500), - shouldPanic: false, + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(500), + isToken0: false, + expectedBal: u256.NewUint(1500), + expectErr: false, }, { name: "insufficient token0 balance", - initialToken0: u256.NewUint(100), - initialToken1: u256.NewUint(2000), - amount: u256.NewUint(200), - isToken0: true, - expectedToken0: nil, - expectedToken1: nil, - shouldPanic: true, + initialToken0: u256.NewUint(100), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(200), + isToken0: true, + expectedBal: nil, + expectErr: true, }, { name: "insufficient token1 balance", - initialToken0: u256.NewUint(1000), - initialToken1: u256.NewUint(100), - amount: u256.NewUint(200), - isToken0: false, - expectedToken0: nil, - expectedToken1: nil, - shouldPanic: true, + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(100), + amount: u256.NewUint(200), + isToken0: false, + expectedBal: nil, + expectErr: true, }, { - name: "0 value handling for token0", - initialToken0: u256.NewUint(1000), - initialToken1: u256.NewUint(2000), - amount: u256.NewUint(0), - isToken0: true, - expectedToken0: u256.NewUint(1000), - expectedToken1: u256.NewUint(2000), - shouldPanic: false, + name: "zero value handling", + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(0), + isToken0: true, + expectedBal: u256.NewUint(1000), + expectErr: false, }, } @@ -284,29 +203,25 @@ func TestUpdatePoolBalance(t *testing.T) { }, } - defer func() { - r := recover() - if tt.shouldPanic && r == nil { - t.Error("expected panic but no panic") - } - if !tt.shouldPanic && r != nil { - t.Errorf("unexpected panic: %v", r) - } - }() + newBal, err := updatePoolBalance(tt.initialToken0, tt.initialToken1, tt.amount, tt.isToken0) - pool.updatePoolBalance(tt.amount, tt.isToken0) - - if !tt.shouldPanic { - if !pool.balances.token0.Eq(tt.expectedToken0) { - t.Errorf("token0 balance mismatch. expected: %s, actual: %s", - tt.expectedToken0.ToString(), - pool.balances.token0.ToString()) - } - if !pool.balances.token1.Eq(tt.expectedToken1) { - t.Errorf("token1 balance mismatch. expected: %s, actual: %s", - tt.expectedToken1.ToString(), - pool.balances.token1.ToString()) + if tt.expectErr { + if err == nil { + t.Errorf("%s: expected error but no error", tt.name) } + return + } + if err != nil { + t.Errorf("%s: unexpected error: %v", tt.name, err) + return + } + + if !newBal.Eq(tt.expectedBal) { + t.Errorf("%s: balance mismatch, expected: %s, actual: %s", + tt.name, + tt.expectedBal.ToString(), + newBal.ToString(), + ) } }) } From b62a10096b94e53c289975c614735e88f364fc25 Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Tue, 3 Dec 2024 11:23:27 +0900 Subject: [PATCH 06/24] fix --- pool/pool.gno | 4 ++-- pool/position_modify_test.gno | 19 +++++++++---------- pool/type.gno | 16 +++++++++++++++- 3 files changed, 26 insertions(+), 13 deletions(-) diff --git a/pool/pool.gno b/pool/pool.gno index 008bac8cd..dc1aa1ba7 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -47,7 +47,7 @@ func Mint( } pool := GetPool(token0Path, token1Path, fee) - position := NewModifyPositionParams(recipient, tickLower, tickUpper, i256.FromUint256(liquidityAmount)) + position := newModifyPositionParams(recipient, tickLower, tickUpper, i256.FromUint256(liquidityAmount)) _, amount0, amount1 := pool.modifyPosition(position) if amount0.Gt(i256.Zero()) { @@ -89,7 +89,7 @@ func Burn( pool := GetPool(token0Path, token1Path, fee) liqDelta := i256.Zero().Neg(i256.FromUint256(liqAmount)) - posParams := NewModifyPositionParams(caller, tickLower, tickUpper, liqDelta) + posParams := newModifyPositionParams(caller, tickLower, tickUpper, liqDelta) position, amount0, amount1 := pool.modifyPosition(posParams) if amount0.Gt(i256.Zero()) || amount1.Gt(i256.Zero()) { diff --git a/pool/position_modify_test.gno b/pool/position_modify_test.gno index 93db26666..c8fcb26a4 100644 --- a/pool/position_modify_test.gno +++ b/pool/position_modify_test.gno @@ -54,13 +54,13 @@ func TestModifyPosition(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { sqrtPrice := u256.MustFromDecimal(tt.sqrtPrice) - pool := newPool( + poolParams := newPoolParams( barPath, fooPath, fee500, - feeAmountTickSpacing[fee500], - sqrtPrice, + sqrtPrice.ToString(), ) + pool := newPool(poolParams) params := ModifyPositionParams{ owner: consts.POSITION_ADDR, @@ -84,13 +84,13 @@ func TestModifyPositionEdgeCases(t *testing.T) { sp := u256.MustFromDecimal(sqrtPrice) t.Run("liquidityDelta is zero", func(t *testing.T) { - pool := newPool( + poolParams := newPoolParams( barPath, fooPath, fee500, - feeAmountTickSpacing[fee500], - sp, + sp.ToString(), ) + pool := newPool(poolParams) params := ModifyPositionParams{ owner: consts.POSITION_ADDR, @@ -114,14 +114,13 @@ func TestModifyPositionEdgeCases(t *testing.T) { }) t.Run("liquidityDelta is negative", func(t *testing.T) { - pool := newPool( + poolParams := newPoolParams( barPath, fooPath, fee500, - feeAmountTickSpacing[fee500], - sp, + sp.ToString(), ) - + pool := newPool(poolParams) params := ModifyPositionParams{ owner: consts.POSITION_ADDR, tickLower: -11000, diff --git a/pool/type.gno b/pool/type.gno index f112abf50..0d3a1a642 100644 --- a/pool/type.gno +++ b/pool/type.gno @@ -3,7 +3,7 @@ package pool import ( "std" - "gno.land/r/gnoswap/v2/common" + "gno.land/r/gnoswap/v1/common" i256 "gno.land/p/gnoswap/int256" u256 "gno.land/p/gnoswap/uint256" @@ -66,6 +66,20 @@ type ModifyPositionParams struct { liquidityDelta *i256.Int // any change in liquidity } +func newModifyPositionParams( + owner std.Address, + tickLower int32, + tickUpper int32, + liquidityDelta *i256.Int, +) ModifyPositionParams { + return ModifyPositionParams{ + owner: owner, + tickLower: tickLower, + tickUpper: tickUpper, + liquidityDelta: liquidityDelta, + } +} + type SwapCache struct { feeProtocol uint8 // protocol fee for the input token liquidityStart *u256.Uint // liquidity at the beginning of the swap From 83fe404a017a02cf354bdc41d3c1868529eaec09 Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Tue, 3 Dec 2024 11:25:23 +0900 Subject: [PATCH 07/24] remove unnecessary conversion --- pool/position_update.gno | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pool/position_update.gno b/pool/position_update.gno index fb66853ca..ca788f882 100644 --- a/pool/position_update.gno +++ b/pool/position_update.gno @@ -5,8 +5,8 @@ import ( ) func (pool *Pool) updatePosition(positionParams ModifyPositionParams) PositionInfo { - feeGrowthGlobal0X128 := u256.MustFromDecimal(pool.feeGrowthGlobal0X128.Clone()) - _feeGrowthGlobal1X128 := u256.MustFromDecimal(pool.feeGrowthGlobal1X128.ToString()) + feeGrowthGlobal0X128 := pool.feeGrowthGlobal0X128.Clone() + feeGrowthGlobal1X128 := pool.feeGrowthGlobal1X128.Clone() var flippedLower, flippedUpper bool if !(positionParams.liquidityDelta.IsZero()) { @@ -14,8 +14,8 @@ func (pool *Pool) updatePosition(positionParams ModifyPositionParams) PositionIn positionParams.tickLower, pool.slot0.tick, positionParams.liquidityDelta, - _feeGrowthGlobal0X128, - _feeGrowthGlobal1X128, + feeGrowthGlobal0X128, + feeGrowthGlobal1X128, false, pool.maxLiquidityPerTick, ) @@ -24,8 +24,8 @@ func (pool *Pool) updatePosition(positionParams ModifyPositionParams) PositionIn positionParams.tickUpper, pool.slot0.tick, positionParams.liquidityDelta, - _feeGrowthGlobal0X128, - _feeGrowthGlobal1X128, + feeGrowthGlobal0X128, + feeGrowthGlobal1X128, true, pool.maxLiquidityPerTick, ) @@ -43,8 +43,8 @@ func (pool *Pool) updatePosition(positionParams ModifyPositionParams) PositionIn positionParams.tickLower, positionParams.tickUpper, pool.slot0.tick, - _feeGrowthGlobal0X128, - _feeGrowthGlobal1X128, + feeGrowthGlobal0X128, + feeGrowthGlobal1X128, ) positionKey := positionGetKey(positionParams.owner, positionParams.tickLower, positionParams.tickUpper) From 76c2541df1e1f959a8a1831ecbb8ad8fc31f7509 Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Tue, 3 Dec 2024 15:59:31 +0900 Subject: [PATCH 08/24] reorganize execute flows --- pool/_RPC_dry.gno | 7 +++-- pool/pool.gno | 62 ++++++++++++++++++++++++----------------- pool/token_register.gno | 10 +++++-- pool/type.gno | 32 +++++++++++++++++++-- 4 files changed, 79 insertions(+), 32 deletions(-) diff --git a/pool/_RPC_dry.gno b/pool/_RPC_dry.gno index e589c139a..52b802553 100644 --- a/pool/_RPC_dry.gno +++ b/pool/_RPC_dry.gno @@ -58,9 +58,12 @@ func DrySwap( feeGrowthGlobalX128 = pool.feeGrowthGlobal1X128 } - pool.slot0.unlocked = false + // pool.slot0.unlocked = false + slot0 := pool.slot0 + slot0.unlocked = false + cache := newSwapCache(feeProtocol, pool.liquidity) - state := pool.newSwapState(amountSpecified, feeGrowthGlobalX128, cache.liquidityStart) // TODO: feeGrowthGlobalX128.Clone() or NOT + state := newSwapState(amountSpecified, feeGrowthGlobalX128, cache.liquidityStart, slot0) // TODO: feeGrowthGlobalX128.Clone() or NOT exactInput := amountSpecified.Gt(i256.Zero()) diff --git a/pool/pool.gno b/pool/pool.gno index dc1aa1ba7..e19832b36 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -201,19 +201,21 @@ func Swap( sqrtPriceLimitX96 := u256.MustFromDecimal(_sqrtPriceLimitX96) pool := GetPool(token0Path, token1Path, fee) - slot0Start := pool.slot0 + // 락 관리 + slot0Start := pool.slot0 // 이걸 따로 변수로 만드는게 좋긴한데 이름을 바꿔야 됨 if !slot0Start.unlocked { - panic(addDetailToError( - errLockedPool, - ufmt.Sprintf("pool.gno__Swap() || slot0Start.unlocked(false) must be unlocked)"), - )) + panic(errLockedPool) } + slot0Start.unlocked = false // unlocked라는 변수명이 좋지 않음. lock으로 변 + defer func() { slot0Start.unlocked = true }() + // + var feeProtocol uint8 var feeGrowthGlobalX128 *u256.Uint - // --- 중복 됨. 함수로 만들면 좋을 듯 + // --- 중복 됨. 함수로 만들면 좋을 듯 (validatePriceAndGetFeeParams) if zeroForOne { minSqrtRatio := u256.MustFromDecimal(consts.MIN_SQRT_RATIO) @@ -243,11 +245,15 @@ func Swap( feeProtocol = slot0Start.feeProtocol / 16 feeGrowthGlobalX128 = pool.feeGrowthGlobal1X128 } - // --- - pool.slot0.unlocked = false + + + + + + // ------- cache := newSwapCache(feeProtocol, pool.liquidity) - state := pool.newSwapState(amountSpecified, feeGrowthGlobalX128, cache.liquidityStart) + state := newSwapState(amountSpecified, feeGrowthGlobalX128, cache.liquidityStart, pool.slot0) exactInput := amountSpecified.Gt(i256.Zero()) @@ -256,21 +262,18 @@ func Swap( // continue swapping as long as we haven't used the entire input/output and haven't reached the price limit swapFee := u256.Zero() for !(state.amountSpecifiedRemaining.IsZero()) && !(state.sqrtPriceX96.Eq(sqrtPriceLimitX96)) { + // ----- computeSingleSwapStep var step StepComputations step.sqrtPriceStartX96 = state.sqrtPriceX96 - step.tickNext, step.initialized = pool.tickBitmapNextInitializedTickWithInOneWord( state.tick, pool.tickSpacing, zeroForOne, ) - // ensure that we do not overshoot the min/max tick, as the tick bitmap is not aware of these bounds - if step.tickNext < consts.MIN_TICK { - step.tickNext = consts.MIN_TICK - } else if step.tickNext > consts.MAX_TICK { - step.tickNext = consts.MAX_TICK - } + // prevent overshoot the min/max tick + step.clampTickNext() + // ----- // get the price for the next tick step.sqrtPriceNextX96 = common.TickMathGetSqrtRatioAtTick(step.tickNext) @@ -292,7 +295,10 @@ func Swap( state.amountSpecifiedRemaining, uint64(pool.fee), ) - state.sqrtPriceX96 = u256.MustFromDecimal(_sqrtPriceX96Str) + + // ---- + + // 상태 격리 step.amountIn = u256.MustFromDecimal(_amountInStr) step.amountOut = u256.MustFromDecimal(_amountOutStr) step.feeAmount = u256.MustFromDecimal(_feeAmountStr) @@ -308,24 +314,30 @@ func Swap( // if the protocol fee is on, calculate how much is owed, decrement feeAmount, and increment protocolFee if cache.feeProtocol > 0 { - delta := new(u256.Uint).Div(step.feeAmount, u256.NewUint(uint64(cache.feeProtocol))) - step.feeAmount = new(u256.Uint).Sub(step.feeAmount, delta) - state.protocolFee = new(u256.Uint).Add(state.protocolFee, delta) + delta := step.feeAmount + delta.Div(delta, u256.NewUint(uint64(cache.feeProtocol))) + + step.feeAmount.Sub(step.feeAmount, delta) + state.protocolFee.Add(state.protocolFee, delta) } // update global fee tracker if state.liquidity.Gt(u256.Zero()) { update := u256.MulDiv(step.feeAmount, u256.MustFromDecimal(consts.Q128), state.liquidity) - state.feeGrowthGlobalX128 = new(u256.Uint).Add(state.feeGrowthGlobalX128, update) + state.SetFeeGrowthGlobalX128(new(u256.Uint).Add(state.feeGrowthGlobalX128, update)) } swapFee = new(u256.Uint).Add(swapFee, step.feeAmount) + ////////// + state.SetSqrtPriceX96(_sqrtPriceX96Str) // shift tick if we reached the next price if state.sqrtPriceX96.Eq(step.sqrtPriceNextX96) { tickTransition(step, zeroForOne, state, pool) - } else if !(state.sqrtPriceX96.Eq(step.sqrtPriceStartX96)) { - // recompute unless we're on a lower tick boundary (i.e. already transitioned ticks), and haven't moved - state.tick = common.TickMathGetTickAtSqrtRatio(state.sqrtPriceX96) + continue + } + // recompute unless we're on a lower tick boundary (i.e. already transitioned ticks), and haven't moved + if state.sqrtPriceX96.Neq(step.sqrtPriceStartX96) { + state.SetTick(common.TickMathGetTickAtSqrtRatio(state.sqrtPriceX96)) } } // END LOOP @@ -345,6 +357,7 @@ func Swap( // update fee growth global and, if necessary, protocol fees // overflow is acceptable, protocol has to withdraw before it hits MAX_UINT256 fees + // -- 최종 상태 if zeroForOne { pool.feeGrowthGlobal0X128 = state.feeGrowthGlobalX128 if state.protocolFee.Gt(u256.Zero()) { @@ -386,7 +399,6 @@ func Swap( "internal_sqrtPriceX96", pool.slot0.sqrtPriceX96.ToString(), ) - pool.slot0.unlocked = true return amount0.ToString(), amount1.ToString() } diff --git a/pool/token_register.gno b/pool/token_register.gno index f37cf31ba..a5bc81f92 100644 --- a/pool/token_register.gno +++ b/pool/token_register.gno @@ -61,12 +61,18 @@ func RegisterGRC20Interface(pkgPath string, igrc20 GRC20Interface) { // UnregisterGRC20Interface unregisters a GRC20 token interface func UnregisterGRC20Interface(pkgPath string) { if err := common.SatisfyCond(isUserCall()); err != nil { - panic(err) + panic(addDetailToError( + errNoPermission, + ufmt.Sprintf("token_register.gno__UnregisterGRC20Interface() || unauthorized address(%s) to unregister", std.PrevRealm().Addr()), + )) } caller := std.PrevRealm().Addr() if err := common.TokenRegisterOnly(caller); err != nil { - panic(err) + panic(addDetailToError( + errNoPermission, + ufmt.Sprintf("token_register.gno__UnregisterGRC20Interface() || unauthorized address(%s) to unregister", caller), + )) } pkgPath = handleNative(pkgPath) diff --git a/pool/type.gno b/pool/type.gno index 0d3a1a642..c0074bbb4 100644 --- a/pool/type.gno +++ b/pool/type.gno @@ -4,6 +4,7 @@ import ( "std" "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/consts" i256 "gno.land/p/gnoswap/int256" u256 "gno.land/p/gnoswap/uint256" @@ -105,13 +106,12 @@ type SwapState struct { liquidity *u256.Uint // current liquidity in range } -func (pool *Pool) newSwapState( +func newSwapState( amountSpecifiedRemaining *i256.Int, feeGrowthGlobalX128 *u256.Uint, liquidity *u256.Uint, + slot0 Slot0, ) SwapState { - slot0 := pool.slot0 - return SwapState{ amountSpecifiedRemaining: amountSpecifiedRemaining, amountCalculated: i256.Zero(), @@ -123,6 +123,22 @@ func (pool *Pool) newSwapState( } } +func (s *SwapState) SetSqrtPriceX96(sqrtPriceX96 string) { + s.sqrtPriceX96 = u256.MustFromDecimal(sqrtPriceX96) +} + +func (s *SwapState) SetTick(tick int32) { + s.tick = tick +} + +func (s *SwapState) SetFeeGrowthGlobalX128(feeGrowthGlobalX128 *u256.Uint) { + s.feeGrowthGlobalX128 = feeGrowthGlobalX128 +} + +func (s *SwapState) SetProtocolFee(fee *u256.Uint) { + s.protocolFee = fee +} + type StepComputations struct { sqrtPriceStartX96 *u256.Uint // price at the beginning of the step tickNext int32 // next tick to swap to from the current tick in the swap direction @@ -133,6 +149,16 @@ type StepComputations struct { feeAmount *u256.Uint // how much fee is being paid in this step } +// clampTickNext ensures that `tickNext` stays within the min, max tick boundaries +// as the tick bitmap is not aware of these bounds +func (step *StepComputations) clampTickNext() { + if step.tickNext < consts.MIN_TICK { + step.tickNext = consts.MIN_TICK + } else if step.tickNext > consts.MAX_TICK { + step.tickNext = consts.MAX_TICK + } +} + type PositionInfo struct { liquidity *u256.Uint // amount of liquidity owned by this position From 0b6d9d435993d63712342fb8be4c291af1ac7b88 Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Tue, 3 Dec 2024 18:49:40 +0900 Subject: [PATCH 09/24] refactor: Swap --- pool/pool.gno | 220 +++++++++++++++++++++++++++++--------------------- 1 file changed, 127 insertions(+), 93 deletions(-) diff --git a/pool/pool.gno b/pool/pool.gno index e19832b36..7cc3d9d5b 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -202,20 +202,17 @@ func Swap( pool := GetPool(token0Path, token1Path, fee) - // 락 관리 - slot0Start := pool.slot0 // 이걸 따로 변수로 만드는게 좋긴한데 이름을 바꿔야 됨 + slot0Start := pool.slot0 if !slot0Start.unlocked { panic(errLockedPool) } - slot0Start.unlocked = false // unlocked라는 변수명이 좋지 않음. lock으로 변 + slot0Start.unlocked = false defer func() { slot0Start.unlocked = true }() - // var feeProtocol uint8 var feeGrowthGlobalX128 *u256.Uint - // --- 중복 됨. 함수로 만들면 좋을 듯 (validatePriceAndGetFeeParams) if zeroForOne { minSqrtRatio := u256.MustFromDecimal(consts.MIN_SQRT_RATIO) @@ -246,101 +243,16 @@ func Swap( feeGrowthGlobalX128 = pool.feeGrowthGlobal1X128 } - - - - - - // ------- cache := newSwapCache(feeProtocol, pool.liquidity) state := newSwapState(amountSpecified, feeGrowthGlobalX128, cache.liquidityStart, pool.slot0) exactInput := amountSpecified.Gt(i256.Zero()) - // --- 루프 추출해야 함 - // continue swapping as long as we haven't used the entire input/output and haven't reached the price limit swapFee := u256.Zero() - for !(state.amountSpecifiedRemaining.IsZero()) && !(state.sqrtPriceX96.Eq(sqrtPriceLimitX96)) { - // ----- computeSingleSwapStep - var step StepComputations - step.sqrtPriceStartX96 = state.sqrtPriceX96 - step.tickNext, step.initialized = pool.tickBitmapNextInitializedTickWithInOneWord( - state.tick, - pool.tickSpacing, - zeroForOne, - ) - - // prevent overshoot the min/max tick - step.clampTickNext() - // ----- - - // get the price for the next tick - step.sqrtPriceNextX96 = common.TickMathGetSqrtRatioAtTick(step.tickNext) - - isLower := step.sqrtPriceNextX96.Lt(sqrtPriceLimitX96) - isHigher := step.sqrtPriceNextX96.Gt(sqrtPriceLimitX96) - - var sqrtRatioTargetX96 *u256.Uint - if (zeroForOne && isLower) || (!zeroForOne && isHigher) { - sqrtRatioTargetX96 = sqrtPriceLimitX96 - } else { - sqrtRatioTargetX96 = step.sqrtPriceNextX96 - } - - _sqrtPriceX96Str, _amountInStr, _amountOutStr, _feeAmountStr := plp.SwapMathComputeSwapStepStr( - state.sqrtPriceX96, - sqrtRatioTargetX96, - state.liquidity, - state.amountSpecifiedRemaining, - uint64(pool.fee), - ) - - // ---- - - // 상태 격리 - step.amountIn = u256.MustFromDecimal(_amountInStr) - step.amountOut = u256.MustFromDecimal(_amountOutStr) - step.feeAmount = u256.MustFromDecimal(_feeAmountStr) - - amountInWithFee := i256.FromUint256(new(u256.Uint).Add(step.amountIn, step.feeAmount)) - if exactInput { - state.amountSpecifiedRemaining = i256.Zero().Sub(state.amountSpecifiedRemaining, amountInWithFee) - state.amountCalculated = i256.Zero().Sub(state.amountCalculated, i256.FromUint256(step.amountOut)) - } else { - state.amountSpecifiedRemaining = i256.Zero().Add(state.amountSpecifiedRemaining, i256.FromUint256(step.amountOut)) - state.amountCalculated = i256.Zero().Add(state.amountCalculated, amountInWithFee) - } - - // if the protocol fee is on, calculate how much is owed, decrement feeAmount, and increment protocolFee - if cache.feeProtocol > 0 { - delta := step.feeAmount - delta.Div(delta, u256.NewUint(uint64(cache.feeProtocol))) - - step.feeAmount.Sub(step.feeAmount, delta) - state.protocolFee.Add(state.protocolFee, delta) - } - - // update global fee tracker - if state.liquidity.Gt(u256.Zero()) { - update := u256.MulDiv(step.feeAmount, u256.MustFromDecimal(consts.Q128), state.liquidity) - state.SetFeeGrowthGlobalX128(new(u256.Uint).Add(state.feeGrowthGlobalX128, update)) - } - swapFee = new(u256.Uint).Add(swapFee, step.feeAmount) - - ////////// - state.SetSqrtPriceX96(_sqrtPriceX96Str) - // shift tick if we reached the next price - if state.sqrtPriceX96.Eq(step.sqrtPriceNextX96) { - tickTransition(step, zeroForOne, state, pool) - continue - } - // recompute unless we're on a lower tick boundary (i.e. already transitioned ticks), and haven't moved - if state.sqrtPriceX96.Neq(step.sqrtPriceStartX96) { - state.SetTick(common.TickMathGetTickAtSqrtRatio(state.sqrtPriceX96)) - } + for shouldContinueSwap(state, sqrtPriceLimitX96) { + swapFee = computeSwapStep(state, pool, zeroForOne, sqrtPriceLimitX96, exactInput, cache, swapFee) } - // END LOOP // update pool sqrtPrice pool.slot0.sqrtPriceX96 = state.sqrtPriceX96 @@ -357,7 +269,6 @@ func Swap( // update fee growth global and, if necessary, protocol fees // overflow is acceptable, protocol has to withdraw before it hits MAX_UINT256 fees - // -- 최종 상태 if zeroForOne { pool.feeGrowthGlobal0X128 = state.feeGrowthGlobalX128 if state.protocolFee.Gt(u256.Zero()) { @@ -402,6 +313,129 @@ func Swap( return amount0.ToString(), amount1.ToString() } +func shouldContinueSwap(state SwapState, sqrtPriceLimitX96 *u256.Uint) bool { + return !state.amountSpecifiedRemaining.IsZero() && + state.sqrtPriceX96.Neq(sqrtPriceLimitX96) +} + +// computeSingleSwapStep executes a single step of swap. +// computing the new state and price limit. +func computeSwapStep( + state SwapState, + pool *Pool, + zeroForOne bool, + sqrtPriceLimitX96 *u256.Uint, + exactInput bool, + cache SwapCache, + swapFee *u256.Uint, +) *u256.Uint { + step := computeSwapStepInit(state, pool, zeroForOne) + + sqrtRatioTargetX96 := computeTargetSqrtRatio(step, sqrtPriceLimitX96, zeroForOne) + state, step = computeAmounts(state, sqrtRatioTargetX96, pool, step) + + state = updateAmounts(step, state, exactInput) + + // if the protocol fee is on, calculate how much is owed, + // decrement fee amount, and increment protocol fee + feeProtocol := cache.feeProtocol + if feeProtocol > 0 { + state = updateFeeProtocol(step, feeProtocol, state) + } + + // update global fee tracker + if state.liquidity.Gt(u256.Zero()) { + update := u256.MulDiv(step.feeAmount, u256.MustFromDecimal(consts.Q128), state.liquidity) + state.SetFeeGrowthGlobalX128(new(u256.Uint).Add(state.feeGrowthGlobalX128, update)) + } + + swapFee = new(u256.Uint).Add(swapFee, step.feeAmount) + + if state.sqrtPriceX96.Eq(step.sqrtPriceNextX96) { + tickTransition(step, zeroForOne, state, pool) + } + + if state.sqrtPriceX96.Neq(step.sqrtPriceStartX96) { + state.SetTick(common.TickMathGetTickAtSqrtRatio(state.sqrtPriceX96)) + } + + return swapFee +} + +// updateFeeProtocol calculates and updates protocol fees for the current step. +func updateFeeProtocol(step StepComputations, feeProtocol uint8, state SwapState) SwapState { + delta := step.feeAmount + delta.Div(delta, u256.NewUint(uint64(feeProtocol))) + + step.feeAmount.Sub(step.feeAmount, delta) + state.protocolFee.Add(state.protocolFee, delta) + + return state +} + +// computeSwapStepInit initializes the computation for a single swap step. +func computeSwapStepInit(state SwapState, pool *Pool, zeroForOne bool) StepComputations { + var step StepComputations + step.sqrtPriceStartX96 = state.sqrtPriceX96 + step.tickNext, step.initialized = pool.tickBitmapNextInitializedTickWithInOneWord( + state.tick, + pool.tickSpacing, + zeroForOne, + ) + + // prevent overshoot the min/max tick + step.clampTickNext() + + // get the price for the next tick + step.sqrtPriceNextX96 = common.TickMathGetSqrtRatioAtTick(step.tickNext) + return step +} + +// computeTargetSqrtRatio determines the target sqrt price for the current swap step. +func computeTargetSqrtRatio(step StepComputations, sqrtPriceLimitX96 *u256.Uint, zeroForOne bool) *u256.Uint { + isLower := step.sqrtPriceNextX96.Lt(sqrtPriceLimitX96) + isHigher := step.sqrtPriceNextX96.Gt(sqrtPriceLimitX96) + + if (zeroForOne && isLower) || (!zeroForOne && isHigher) { + return sqrtPriceLimitX96 + } + return step.sqrtPriceNextX96 +} + +// computeAmounts calculates the input and output amounts for the current swap step. +func computeAmounts(state SwapState, sqrtRatioTargetX96 *u256.Uint, pool *Pool, step StepComputations) (SwapState, StepComputations) { + sqrtPriceX96Str, amountInStr, amountOutStr, feeAmountStr := plp.SwapMathComputeSwapStepStr( + state.sqrtPriceX96, + sqrtRatioTargetX96, + state.liquidity, + state.amountSpecifiedRemaining, + uint64(pool.fee), + ) + + step.amountIn = u256.MustFromDecimal(amountInStr) + step.amountOut = u256.MustFromDecimal(amountOutStr) + step.feeAmount = u256.MustFromDecimal(feeAmountStr) + + state.SetSqrtPriceX96(sqrtPriceX96Str) + + return state, step +} + +// updateAmounts updates the remaining amounts and calculated amounts +// based on the swap direction. +func updateAmounts(step StepComputations, state SwapState, exactInput bool) SwapState { + amountInWithFee := i256.FromUint256(new(u256.Uint).Add(step.amountIn, step.feeAmount)) + if exactInput { + state.amountSpecifiedRemaining = i256.Zero().Sub(state.amountSpecifiedRemaining, amountInWithFee) + state.amountCalculated = i256.Zero().Sub(state.amountCalculated, i256.FromUint256(step.amountOut)) + return state + } + state.amountSpecifiedRemaining = i256.Zero().Add(state.amountSpecifiedRemaining, i256.FromUint256(step.amountOut)) + state.amountCalculated = i256.Zero().Add(state.amountCalculated, amountInWithFee) + + return state +} + func (pool *Pool) swapTransfers(zeroForOne bool, payer, recipient std.Address, amount0, amount1 *i256.Int) { var targetTokenPath string var amount *i256.Int From 1e2bea4a83a55369ec9067c969251b92703bb848 Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Tue, 3 Dec 2024 19:30:19 +0900 Subject: [PATCH 10/24] fix --- pool/type.gno | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/pool/type.gno b/pool/type.gno index c0074bbb4..592b093c4 100644 --- a/pool/type.gno +++ b/pool/type.gno @@ -149,6 +149,22 @@ type StepComputations struct { feeAmount *u256.Uint // how much fee is being paid in this step } +// init initializes the computation for a single swap step +func (step *StepComputations) initSwapStep(state SwapState, pool *Pool, zeroForOne bool) { + step.sqrtPriceStartX96 = state.sqrtPriceX96 + step.tickNext, step.initialized = pool.tickBitmapNextInitializedTickWithInOneWord( + state.tick, + pool.tickSpacing, + zeroForOne, + ) + + // prevent overshoot the min/max tick + step.clampTickNext() + + // get the price for the next tick + step.sqrtPriceNextX96 = common.TickMathGetSqrtRatioAtTick(step.tickNext) +} + // clampTickNext ensures that `tickNext` stays within the min, max tick boundaries // as the tick bitmap is not aware of these bounds func (step *StepComputations) clampTickNext() { From a98bd2bd94b755a21edf7a15666613cd151b9115 Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Wed, 4 Dec 2024 13:10:47 +0900 Subject: [PATCH 11/24] pool test --- pool/pool.gno | 412 ++++++++++++++++++++++++++++++--------------- pool/pool_test.gno | 235 ++++++++++++++++++++++++++ 2 files changed, 509 insertions(+), 138 deletions(-) diff --git a/pool/pool.gno b/pool/pool.gno index 7cc3d9d5b..31c51d5e5 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -166,6 +166,147 @@ func Collect( return amount0.ToString(), amount1.ToString() } + + +//////////////////////////////////////////////////////////// +// Swap +//////////////////////////////////////////////////////////// + +// SwapResult represents the final state after swap computation +type SwapResult struct { + Amount0 *i256.Int + Amount1 *i256.Int + NewSqrtPrice *u256.Uint + NewTick int32 + NewLiquidity *u256.Uint + NewProtocolFees ProtocolFees + FeeGrowthGlobal0 *u256.Uint + FeeGrowthGlobal1 *u256.Uint + SwapFee *u256.Uint +} + +// SwapComputation encapsulates pure computation logic for swap +type SwapComputation struct { + AmountSpecified *i256.Int + SqrtPriceLimitX96 *u256.Uint + ZeroForOne bool + ExactInput bool + InitialState SwapState + Cache SwapCache +} + +func computeSwap(pool *Pool, comp SwapComputation) (*SwapResult, error) { + state := comp.InitialState + swapFee := u256.Zero() + + // Compute swap steps until completion + for shouldContinueSwap(state, comp.SqrtPriceLimitX96) { + var newFee *u256.Uint + state, newFee = computeSwapStep(state, pool, comp.ZeroForOne, comp.SqrtPriceLimitX96, comp.ExactInput, comp.Cache, swapFee) + swapFee = newFee + } + + // Calculate final amounts + amount0 := state.amountCalculated + amount1 := i256.Zero().Sub(comp.AmountSpecified, state.amountSpecifiedRemaining) + if comp.ZeroForOne == comp.ExactInput { + amount0, amount1 = amount1, amount0 + } + + // Prepare result + result := &SwapResult{ + Amount0: amount0, + Amount1: amount1, + NewSqrtPrice: state.sqrtPriceX96, + NewTick: state.tick, + NewLiquidity: state.liquidity, + NewProtocolFees: ProtocolFees{ + token0: pool.protocolFees.token0, + token1: pool.protocolFees.token1, + }, + FeeGrowthGlobal0: pool.feeGrowthGlobal0X128, + FeeGrowthGlobal1: pool.feeGrowthGlobal1X128, + SwapFee: swapFee, + } + + // Update protocol fees if necessary + if comp.ZeroForOne { + if state.protocolFee.Gt(u256.Zero()) { + result.NewProtocolFees.token0 = new(u256.Uint).Add(result.NewProtocolFees.token0, state.protocolFee) + } + result.FeeGrowthGlobal0 = state.feeGrowthGlobalX128 + } else { + if state.protocolFee.Gt(u256.Zero()) { + result.NewProtocolFees.token1 = new(u256.Uint).Add(result.NewProtocolFees.token1, state.protocolFee) + } + result.FeeGrowthGlobal1 = state.feeGrowthGlobalX128 + } + + return result, nil +} + + +// applySwapResult updates pool state with computed results +func applySwapResult(pool *Pool, result *SwapResult) { + pool.slot0.sqrtPriceX96 = result.NewSqrtPrice + pool.slot0.tick = result.NewTick + pool.liquidity = result.NewLiquidity + pool.protocolFees = result.NewProtocolFees + pool.feeGrowthGlobal0X128 = result.FeeGrowthGlobal0 + pool.feeGrowthGlobal1X128 = result.FeeGrowthGlobal1 +} + +// validatePriceLimits validates the price limits for the swap +func validatePriceLimits(pool *Pool, zeroForOne bool, sqrtPriceLimitX96 *u256.Uint) { + if zeroForOne { + minSqrtRatio := u256.MustFromDecimal(consts.MIN_SQRT_RATIO) + + cond1 := sqrtPriceLimitX96.Lt(pool.slot0.sqrtPriceX96) + cond2 := sqrtPriceLimitX96.Gt(minSqrtRatio) + if !(cond1 && cond2) { + panic(addDetailToError( + errPriceOutOfRange, + ufmt.Sprintf("pool.gno__Swap() || sqrtPriceLimitX96(%s) < slot0Start.sqrtPriceX96(%s) && sqrtPriceLimitX96(%s) > consts.MIN_SQRT_RATIO(%s)", + sqrtPriceLimitX96.ToString(), + pool.slot0.sqrtPriceX96.ToString(), + sqrtPriceLimitX96.ToString(), + consts.MIN_SQRT_RATIO), + )) + } + } else { + maxSqrtRatio := u256.MustFromDecimal(consts.MAX_SQRT_RATIO) + + cond1 := sqrtPriceLimitX96.Gt(pool.slot0.sqrtPriceX96) + cond2 := sqrtPriceLimitX96.Lt(maxSqrtRatio) + if !(cond1 && cond2) { + panic(addDetailToError( + errPriceOutOfRange, + ufmt.Sprintf("pool.gno__Swap() || sqrtPriceLimitX96(%s) > slot0Start.sqrtPriceX96(%s) && sqrtPriceLimitX96(%s) < consts.MAX_SQRT_RATIO(%s)", + sqrtPriceLimitX96.ToString(), + pool.slot0.sqrtPriceX96.ToString(), + sqrtPriceLimitX96.ToString(), + consts.MAX_SQRT_RATIO), + )) + } + } +} + +// getFeeProtocol returns the appropriate fee protocol based on zero for one +func getFeeProtocol(slot0 Slot0, zeroForOne bool) uint8 { + if zeroForOne { + return slot0.feeProtocol % 16 + } + return slot0.feeProtocol / 16 +} + +// getFeeGrowthGlobal returns the appropriate fee growth global based on zero for one +func getFeeGrowthGlobal(pool *Pool, zeroForOne bool) *u256.Uint { + if zeroForOne { + return pool.feeGrowthGlobal0X128 + } + return pool.feeGrowthGlobal1X128 +} + // Swap swaps token0 for token1, or token1 for token0 // Returns swapped amount0, amount1 in string // ref: https://docs.gnoswap.io/contracts/pool/pool.gno#swap @@ -197,9 +338,6 @@ func Swap( )) } - amountSpecified := i256.MustFromDecimal(_amountSpecified) - sqrtPriceLimitX96 := u256.MustFromDecimal(_sqrtPriceLimitX96) - pool := GetPool(token0Path, token1Path, fee) slot0Start := pool.slot0 @@ -210,85 +348,34 @@ func Swap( slot0Start.unlocked = false defer func() { slot0Start.unlocked = true }() - var feeProtocol uint8 - var feeGrowthGlobalX128 *u256.Uint - - if zeroForOne { - minSqrtRatio := u256.MustFromDecimal(consts.MIN_SQRT_RATIO) - - cond1 := sqrtPriceLimitX96.Lt(slot0Start.sqrtPriceX96) - cond2 := sqrtPriceLimitX96.Gt(minSqrtRatio) - if !(cond1 && cond2) { - panic(addDetailToError( - errPriceOutOfRange, - ufmt.Sprintf("pool.gno__Swap() || sqrtPriceLimitX96(%s) < slot0Start.sqrtPriceX96(%s) && sqrtPriceLimitX96(%s) > consts.MIN_SQRT_RATIO(%s)", sqrtPriceLimitX96.ToString(), slot0Start.sqrtPriceX96.ToString(), sqrtPriceLimitX96.ToString(), consts.MIN_SQRT_RATIO), - )) - } - feeProtocol = slot0Start.feeProtocol % 16 - feeGrowthGlobalX128 = pool.feeGrowthGlobal0X128 - - } else { - maxSqrtRatio := u256.MustFromDecimal(consts.MAX_SQRT_RATIO) - - cond1 := sqrtPriceLimitX96.Gt(slot0Start.sqrtPriceX96) - cond2 := sqrtPriceLimitX96.Lt(maxSqrtRatio) - if !(cond1 && cond2) { - panic(addDetailToError( - errPriceOutOfRange, - ufmt.Sprintf("pool.gno__Swap() || sqrtPriceLimitX96(%s) > slot0Start.sqrtPriceX96(%s) && sqrtPriceLimitX96(%s) < consts.MAX_SQRT_RATIO(%s)", sqrtPriceLimitX96.ToString(), slot0Start.sqrtPriceX96.ToString(), sqrtPriceLimitX96.ToString(), consts.MAX_SQRT_RATIO), - )) - } + amountSpecified := i256.MustFromDecimal(_amountSpecified) + sqrtPriceLimitX96 := u256.MustFromDecimal(_sqrtPriceLimitX96) - feeProtocol = slot0Start.feeProtocol / 16 - feeGrowthGlobalX128 = pool.feeGrowthGlobal1X128 - } + validatePriceLimits(pool, zeroForOne, sqrtPriceLimitX96) + feeProtocol := getFeeProtocol(slot0Start, zeroForOne) cache := newSwapCache(feeProtocol, pool.liquidity) + feeGrowthGlobalX128 := getFeeGrowthGlobal(pool, zeroForOne) state := newSwapState(amountSpecified, feeGrowthGlobalX128, cache.liquidityStart, pool.slot0) - exactInput := amountSpecified.Gt(i256.Zero()) - - // continue swapping as long as we haven't used the entire input/output and haven't reached the price limit - swapFee := u256.Zero() - for shouldContinueSwap(state, sqrtPriceLimitX96) { - swapFee = computeSwapStep(state, pool, zeroForOne, sqrtPriceLimitX96, exactInput, cache, swapFee) - } - - // update pool sqrtPrice - pool.slot0.sqrtPriceX96 = state.sqrtPriceX96 - - // update tick if it changed - if state.tick != slot0Start.tick { - pool.slot0.tick = state.tick + comp := SwapComputation{ + AmountSpecified: amountSpecified, + SqrtPriceLimitX96: sqrtPriceLimitX96, + ZeroForOne: zeroForOne, + ExactInput: amountSpecified.Gt(i256.Zero()), + InitialState: state, + Cache: cache, } - // update liquidity if it changed - if !(cache.liquidityStart.Eq(state.liquidity)) { - pool.liquidity = state.liquidity - } - - // update fee growth global and, if necessary, protocol fees - // overflow is acceptable, protocol has to withdraw before it hits MAX_UINT256 fees - if zeroForOne { - pool.feeGrowthGlobal0X128 = state.feeGrowthGlobalX128 - if state.protocolFee.Gt(u256.Zero()) { - pool.protocolFees.token0 = new(u256.Uint).Add(pool.protocolFees.token0, state.protocolFee) - } - } else { - pool.feeGrowthGlobal1X128 = state.feeGrowthGlobalX128 - if state.protocolFee.Gt(u256.Zero()) { - pool.protocolFees.token1 = new(u256.Uint).Add(pool.protocolFees.token1, state.protocolFee) - } + result, err := computeSwap(pool, comp) + if err != nil { + panic(err) } - amount0 := state.amountCalculated - amount1 := i256.Zero().Sub(amountSpecified, state.amountSpecifiedRemaining) - if zeroForOne == exactInput { - amount0, amount1 = amount1, amount0 - } + applySwapResult(pool, result) // actual swap - pool.swapTransfers(zeroForOne, payer, recipient, amount0, amount1) + pool.swapTransfers(zeroForOne, payer, recipient, result.Amount0, result.Amount1) prevAddr, prevRealm := getPrev() @@ -302,65 +389,111 @@ func Swap( "sqrtPriceLimitX96", _sqrtPriceLimitX96, "payer", payer.String(), "recipient", recipient.String(), - "internal_amount0", amount0.ToString(), - "internal_amount1", amount1.ToString(), + "internal_amount0", result.Amount0.ToString(), + "internal_amount1", result.Amount1.ToString(), "internal_protocolFee0", pool.protocolFees.token0.ToString(), "internal_protocolFee1", pool.protocolFees.token1.ToString(), - "internal_swapFee", swapFee.ToString(), + "internal_swapFee", result.SwapFee.ToString(), "internal_sqrtPriceX96", pool.slot0.sqrtPriceX96.ToString(), ) - return amount0.ToString(), amount1.ToString() + return result.Amount0.ToString(), result.Amount1.ToString() } func shouldContinueSwap(state SwapState, sqrtPriceLimitX96 *u256.Uint) bool { - return !state.amountSpecifiedRemaining.IsZero() && - state.sqrtPriceX96.Neq(sqrtPriceLimitX96) + println("sqrtPriceLimitX96", sqrtPriceLimitX96.ToString()) + return !(state.amountSpecifiedRemaining.IsZero()) && !(state.sqrtPriceX96.Eq(sqrtPriceLimitX96)) +} + + +// computeSwapStep executes a single step of swap and returns new state +func computeSwapStep( + state SwapState, + pool *Pool, + zeroForOne bool, + sqrtPriceLimitX96 *u256.Uint, + exactInput bool, + cache SwapCache, + swapFee *u256.Uint, +) (SwapState, *u256.Uint) { + step := computeSwapStepInit(state, pool, zeroForOne) + sqrtRatioTargetX96 := computeTargetSqrtRatio(step, sqrtPriceLimitX96, zeroForOne) + + var newState SwapState + newState, step = computeAmounts(state, sqrtRatioTargetX96, pool, step) + newState = updateAmounts(step, newState, exactInput) + + // if the protocol fee is on, calculate how much is owed, + // decrement fee amount, and increment protocol fee + if cache.feeProtocol > 0 { + newState = updateFeeProtocol(step, cache.feeProtocol, newState) + } + + // update global fee tracker + if newState.liquidity.Gt(u256.Zero()) { + update := u256.MulDiv(step.feeAmount, u256.MustFromDecimal(consts.Q128), newState.liquidity) + newState.SetFeeGrowthGlobalX128(new(u256.Uint).Add(newState.feeGrowthGlobalX128, update)) + } + + newSwapFee := new(u256.Uint).Add(swapFee, step.feeAmount) + + if newState.sqrtPriceX96.Eq(step.sqrtPriceNextX96) { + newState = tickTransition(step, zeroForOne, newState, pool) + } + + if newState.sqrtPriceX96.Neq(step.sqrtPriceStartX96) { + newState.SetTick(common.TickMathGetTickAtSqrtRatio(newState.sqrtPriceX96)) + } + + return newState, newSwapFee } // computeSingleSwapStep executes a single step of swap. // computing the new state and price limit. -func computeSwapStep( - state SwapState, - pool *Pool, - zeroForOne bool, - sqrtPriceLimitX96 *u256.Uint, - exactInput bool, - cache SwapCache, - swapFee *u256.Uint, -) *u256.Uint { - step := computeSwapStepInit(state, pool, zeroForOne) +// func computeSwapStep( +// state SwapState, +// pool *Pool, +// zeroForOne bool, +// sqrtPriceLimitX96 *u256.Uint, +// exactInput bool, +// cache SwapCache, +// swapFee *u256.Uint, +// ) *u256.Uint { +// step := computeSwapStepInit(state, pool, zeroForOne) - sqrtRatioTargetX96 := computeTargetSqrtRatio(step, sqrtPriceLimitX96, zeroForOne) - state, step = computeAmounts(state, sqrtRatioTargetX96, pool, step) - state = updateAmounts(step, state, exactInput) +// sqrtRatioTargetX96 := computeTargetSqrtRatio(step, sqrtPriceLimitX96, zeroForOne) +// state, step = computeAmounts(state, sqrtRatioTargetX96, pool, step) - // if the protocol fee is on, calculate how much is owed, - // decrement fee amount, and increment protocol fee - feeProtocol := cache.feeProtocol - if feeProtocol > 0 { - state = updateFeeProtocol(step, feeProtocol, state) - } +// state = updateAmounts(step, state, exactInput) - // update global fee tracker - if state.liquidity.Gt(u256.Zero()) { - update := u256.MulDiv(step.feeAmount, u256.MustFromDecimal(consts.Q128), state.liquidity) - state.SetFeeGrowthGlobalX128(new(u256.Uint).Add(state.feeGrowthGlobalX128, update)) - } +// // if the protocol fee is on, calculate how much is owed, +// // decrement fee amount, and increment protocol fee +// feeProtocol := cache.feeProtocol +// if feeProtocol > 0 { +// state = updateFeeProtocol(step, feeProtocol, state) +// } - swapFee = new(u256.Uint).Add(swapFee, step.feeAmount) +// println(">>>>>> AFTER UPDATE FEE PROTOCOL") - if state.sqrtPriceX96.Eq(step.sqrtPriceNextX96) { - tickTransition(step, zeroForOne, state, pool) - } +// // update global fee tracker +// if state.liquidity.Gt(u256.Zero()) { +// update := u256.MulDiv(step.feeAmount, u256.MustFromDecimal(consts.Q128), state.liquidity) +// state.SetFeeGrowthGlobalX128(new(u256.Uint).Add(state.feeGrowthGlobalX128, update)) +// } - if state.sqrtPriceX96.Neq(step.sqrtPriceStartX96) { - state.SetTick(common.TickMathGetTickAtSqrtRatio(state.sqrtPriceX96)) - } +// swapFee = new(u256.Uint).Add(swapFee, step.feeAmount) - return swapFee -} +// if state.sqrtPriceX96.Eq(step.sqrtPriceNextX96) { +// tickTransition(step, zeroForOne, state, pool) +// } + +// if state.sqrtPriceX96.Neq(step.sqrtPriceStartX96) { +// state.SetTick(common.TickMathGetTickAtSqrtRatio(state.sqrtPriceX96)) +// } + +// return swapFee +// } // updateFeeProtocol calculates and updates protocol fees for the current step. func updateFeeProtocol(step StepComputations, feeProtocol uint8, state SwapState) SwapState { @@ -377,12 +510,15 @@ func updateFeeProtocol(step StepComputations, feeProtocol uint8, state SwapState func computeSwapStepInit(state SwapState, pool *Pool, zeroForOne bool) StepComputations { var step StepComputations step.sqrtPriceStartX96 = state.sqrtPriceX96 - step.tickNext, step.initialized = pool.tickBitmapNextInitializedTickWithInOneWord( + tickNext, initialized := pool.tickBitmapNextInitializedTickWithInOneWord( state.tick, pool.tickSpacing, zeroForOne, ) + step.tickNext = tickNext + step.initialized = initialized + // prevent overshoot the min/max tick step.clampTickNext() @@ -457,37 +593,37 @@ func (pool *Pool) swapTransfers(zeroForOne bool, payer, recipient std.Address, a pool.transferAndVerify(recipient, targetTokenPath, amount, !isToken0) } -// tickTransition handles the transition between ticks during a swap -func tickTransition(step StepComputations, zeroForOne bool, state SwapState, pool *Pool) { - // if the tick is initialized, run the tick transition - if step.initialized { - var fee0, fee1 *u256.Uint - - // check for the placeholder value, which we replace with the actual value - // the first time the swap crosses an initialized tick - if zeroForOne { - fee0 = state.feeGrowthGlobalX128 - fee1 = pool.feeGrowthGlobal1X128 - } else { - fee0 = pool.feeGrowthGlobal0X128 - fee1 = state.feeGrowthGlobalX128 - } +// tickTransition now returns new state instead of modifying existing +func tickTransition(step StepComputations, zeroForOne bool, state SwapState, pool *Pool) SwapState { + newState := state + + if step.initialized { + var fee0, fee1 *u256.Uint + + if zeroForOne { + fee0 = state.feeGrowthGlobalX128 + fee1 = pool.feeGrowthGlobal1X128 + } else { + fee0 = pool.feeGrowthGlobal0X128 + fee1 = state.feeGrowthGlobalX128 + } - // if we're moving leftward, we interpret liquidityNet as the opposite sign - liquidityNet := pool.tickCross(step.tickNext, fee0, fee1) + liquidityNet := pool.tickCross(step.tickNext, fee0, fee1) - if zeroForOne { - liquidityNet = i256.Zero().Neg(liquidityNet) - } + if zeroForOne { + liquidityNet = i256.Zero().Neg(liquidityNet) + } - state.liquidity = liquidityMathAddDelta(state.liquidity, liquidityNet) - } + newState.liquidity = liquidityMathAddDelta(state.liquidity, liquidityNet) + } - if zeroForOne { - state.tick = step.tickNext - 1 - } else { - state.tick = step.tickNext - } + if zeroForOne { + newState.tick = step.tickNext - 1 + } else { + newState.tick = step.tickNext + } + + return newState } // SetFeeProtocolByAdmin sets the fee protocol for all pools diff --git a/pool/pool_test.gno b/pool/pool_test.gno index f49863468..0458302fc 100644 --- a/pool/pool_test.gno +++ b/pool/pool_test.gno @@ -226,3 +226,238 @@ func TestUpdatePoolBalance(t *testing.T) { }) } } + + +func TestShouldContinueSwap(t *testing.T) { + tests := []struct { + name string + state SwapState + sqrtPriceLimitX96 *u256.Uint + expected bool + }{ + { + name: "Should continue - amount remaining and price not at limit", + state: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000"), + sqrtPriceX96: u256.MustFromDecimal("1000000"), + }, + sqrtPriceLimitX96: u256.MustFromDecimal("900000"), + expected: true, + }, + { + name: "Should stop - no amount remaining", + state: SwapState{ + amountSpecifiedRemaining: i256.Zero(), + sqrtPriceX96: u256.MustFromDecimal("1000000"), + }, + sqrtPriceLimitX96: u256.MustFromDecimal("900000"), + expected: false, + }, + { + name: "Should stop - price at limit", + state: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000"), + sqrtPriceX96: u256.MustFromDecimal("900000"), + }, + sqrtPriceLimitX96: u256.MustFromDecimal("900000"), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := shouldContinueSwap(tt.state, tt.sqrtPriceLimitX96) + uassert.Equal(t, tt.expected, result) + }) + } +} + +func TestUpdateAmounts(t *testing.T) { + tests := []struct { + name string + step StepComputations + state SwapState + exactInput bool + expectedState SwapState + }{ + { + name: "Exact input update", + step: StepComputations{ + amountIn: u256.MustFromDecimal("100"), + amountOut: u256.MustFromDecimal("97"), + feeAmount: u256.MustFromDecimal("3"), + }, + state: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000"), + amountCalculated: i256.Zero(), + }, + exactInput: true, + expectedState: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("897"), // 1000 - (100 + 3) + amountCalculated: i256.MustFromDecimal("-97"), + }, + }, + { + name: "Exact output update", + step: StepComputations{ + amountIn: u256.MustFromDecimal("100"), + amountOut: u256.MustFromDecimal("97"), + feeAmount: u256.MustFromDecimal("3"), + }, + state: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("-1000"), + amountCalculated: i256.Zero(), + }, + exactInput: false, + expectedState: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("-903"), // -1000 + 97 + amountCalculated: i256.MustFromDecimal("103"), // 100 + 3 + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := updateAmounts(tt.step, tt.state, tt.exactInput) + + uassert.True(t, tt.expectedState.amountSpecifiedRemaining.Eq(result.amountSpecifiedRemaining)) + uassert.True(t, tt.expectedState.amountCalculated.Eq(result.amountCalculated)) + }) + } +} + +func TestComputeSwap(t *testing.T) { + mockPool := &Pool{ + token0Path: "token0", + token1Path: "token1", + fee: 3000, // 0.3% + tickSpacing: 60, + slot0: Slot0{ + sqrtPriceX96: u256.MustFromDecimal("1000000000000000000"), // 1.0 + tick: 0, + feeProtocol: 0, + unlocked: true, + }, + liquidity: u256.MustFromDecimal("1000000000000000000"), // 1.0 + protocolFees: ProtocolFees{ + token0: u256.Zero(), + token1: u256.Zero(), + }, + feeGrowthGlobal0X128: u256.Zero(), + feeGrowthGlobal1X128: u256.Zero(), + tickBitmaps: make(TickBitmaps), + ticks: make(Ticks), + positions: make(Positions), + } + + wordPos, _ := tickBitmapPosition(0) + mockPool.tickBitmaps[wordPos] = u256.NewUint(1) + + t.Run("basic swap", func(t *testing.T) { + comp := SwapComputation{ + AmountSpecified: i256.MustFromDecimal("1000000"), // 1.0 token + SqrtPriceLimitX96: u256.MustFromDecimal("1100000000000000000"), // 1.1 + ZeroForOne: true, + ExactInput: true, + InitialState: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000000"), + amountCalculated: i256.Zero(), + sqrtPriceX96: mockPool.slot0.sqrtPriceX96, + tick: mockPool.slot0.tick, + feeGrowthGlobalX128: mockPool.feeGrowthGlobal0X128, + protocolFee: u256.Zero(), + liquidity: mockPool.liquidity, + }, + Cache: SwapCache{ + feeProtocol: 0, + liquidityStart: mockPool.liquidity, + }, + } + + result, err := computeSwap(mockPool, comp) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if result.Amount0.IsZero() { + t.Error("expected non-zero amount0") + } + if result.Amount1.IsZero() { + t.Error("expected non-zero amount1") + } + if result.SwapFee.IsZero() { + t.Error("expected non-zero swap fee") + } + }) + + t.Run("swap with protocol fee", func(t *testing.T) { + t.Skip() + mockPoolWithFee := *mockPool + mockPoolWithFee.slot0.feeProtocol = 4 // 4 = 1/4 of fee goes to protocol + + comp := SwapComputation{ + AmountSpecified: i256.MustFromDecimal("1000000"), + SqrtPriceLimitX96: u256.MustFromDecimal("1100000000000000000"), + ZeroForOne: true, + ExactInput: true, + InitialState: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000000"), + amountCalculated: i256.Zero(), + sqrtPriceX96: mockPoolWithFee.slot0.sqrtPriceX96, + tick: mockPoolWithFee.slot0.tick, + feeGrowthGlobalX128: mockPoolWithFee.feeGrowthGlobal0X128, + protocolFee: u256.Zero(), + liquidity: mockPoolWithFee.liquidity, + }, + Cache: SwapCache{ + feeProtocol: 4, + liquidityStart: mockPoolWithFee.liquidity, + }, + } + + result, err := computeSwap(&mockPoolWithFee, comp) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + // 프로토콜 수수료 검증 + if mockPoolWithFee.protocolFees.token0.IsZero() { + t.Error("expected non-zero protocol fee for token0") + } + }) + + t.Run("swap with zero liquidity", func(t *testing.T) { + t.Skip() + mockPoolZeroLiq := *mockPool + mockPoolZeroLiq.liquidity = u256.Zero() + + comp := SwapComputation{ + AmountSpecified: i256.MustFromDecimal("1000000"), + SqrtPriceLimitX96: u256.MustFromDecimal("1100000000000000000"), + ZeroForOne: true, + ExactInput: true, + InitialState: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000000"), + amountCalculated: i256.Zero(), + sqrtPriceX96: mockPoolZeroLiq.slot0.sqrtPriceX96, + tick: mockPoolZeroLiq.slot0.tick, + feeGrowthGlobalX128: mockPoolZeroLiq.feeGrowthGlobal0X128, + protocolFee: u256.Zero(), + liquidity: mockPoolZeroLiq.liquidity, + }, + Cache: SwapCache{ + feeProtocol: 0, + liquidityStart: mockPoolZeroLiq.liquidity, + }, + } + + result, err := computeSwap(&mockPoolZeroLiq, comp) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if !result.Amount0.IsZero() || !result.Amount1.IsZero() { + t.Error("expected zero amounts for zero liquidity") + } + }) +} From 69dad4ecbd6dc650d9083c78e60ef4d4e2843963 Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Wed, 4 Dec 2024 15:00:50 +0900 Subject: [PATCH 12/24] fmt --- pool/pool.gno | 417 +++++++++++++--------------- pool/pool_test.gno | 665 +++++++++++++++++++++------------------------ 2 files changed, 496 insertions(+), 586 deletions(-) diff --git a/pool/pool.gno b/pool/pool.gno index 31c51d5e5..aa1cc4ed7 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -166,145 +166,142 @@ func Collect( return amount0.ToString(), amount1.ToString() } - - //////////////////////////////////////////////////////////// // Swap //////////////////////////////////////////////////////////// // SwapResult represents the final state after swap computation type SwapResult struct { - Amount0 *i256.Int - Amount1 *i256.Int - NewSqrtPrice *u256.Uint - NewTick int32 - NewLiquidity *u256.Uint - NewProtocolFees ProtocolFees - FeeGrowthGlobal0 *u256.Uint - FeeGrowthGlobal1 *u256.Uint - SwapFee *u256.Uint + Amount0 *i256.Int + Amount1 *i256.Int + NewSqrtPrice *u256.Uint + NewTick int32 + NewLiquidity *u256.Uint + NewProtocolFees ProtocolFees + FeeGrowthGlobal0 *u256.Uint + FeeGrowthGlobal1 *u256.Uint + SwapFee *u256.Uint } // SwapComputation encapsulates pure computation logic for swap type SwapComputation struct { - AmountSpecified *i256.Int - SqrtPriceLimitX96 *u256.Uint - ZeroForOne bool - ExactInput bool - InitialState SwapState - Cache SwapCache + AmountSpecified *i256.Int + SqrtPriceLimitX96 *u256.Uint + ZeroForOne bool + ExactInput bool + InitialState SwapState + Cache SwapCache } func computeSwap(pool *Pool, comp SwapComputation) (*SwapResult, error) { - state := comp.InitialState - swapFee := u256.Zero() - - // Compute swap steps until completion - for shouldContinueSwap(state, comp.SqrtPriceLimitX96) { - var newFee *u256.Uint - state, newFee = computeSwapStep(state, pool, comp.ZeroForOne, comp.SqrtPriceLimitX96, comp.ExactInput, comp.Cache, swapFee) - swapFee = newFee - } - - // Calculate final amounts - amount0 := state.amountCalculated - amount1 := i256.Zero().Sub(comp.AmountSpecified, state.amountSpecifiedRemaining) - if comp.ZeroForOne == comp.ExactInput { - amount0, amount1 = amount1, amount0 - } - - // Prepare result - result := &SwapResult{ - Amount0: amount0, - Amount1: amount1, - NewSqrtPrice: state.sqrtPriceX96, - NewTick: state.tick, - NewLiquidity: state.liquidity, - NewProtocolFees: ProtocolFees{ - token0: pool.protocolFees.token0, - token1: pool.protocolFees.token1, - }, - FeeGrowthGlobal0: pool.feeGrowthGlobal0X128, - FeeGrowthGlobal1: pool.feeGrowthGlobal1X128, - SwapFee: swapFee, - } - - // Update protocol fees if necessary - if comp.ZeroForOne { - if state.protocolFee.Gt(u256.Zero()) { - result.NewProtocolFees.token0 = new(u256.Uint).Add(result.NewProtocolFees.token0, state.protocolFee) - } - result.FeeGrowthGlobal0 = state.feeGrowthGlobalX128 - } else { - if state.protocolFee.Gt(u256.Zero()) { - result.NewProtocolFees.token1 = new(u256.Uint).Add(result.NewProtocolFees.token1, state.protocolFee) - } - result.FeeGrowthGlobal1 = state.feeGrowthGlobalX128 - } - - return result, nil -} + state := comp.InitialState + swapFee := u256.Zero() + + // Compute swap steps until completion + for shouldContinueSwap(state, comp.SqrtPriceLimitX96) { + var newFee *u256.Uint + state, newFee = computeSwapStep(state, pool, comp.ZeroForOne, comp.SqrtPriceLimitX96, comp.ExactInput, comp.Cache, swapFee) + swapFee = newFee + } + + // Calculate final amounts + amount0 := state.amountCalculated + amount1 := i256.Zero().Sub(comp.AmountSpecified, state.amountSpecifiedRemaining) + if comp.ZeroForOne == comp.ExactInput { + amount0, amount1 = amount1, amount0 + } + + // Prepare result + result := &SwapResult{ + Amount0: amount0, + Amount1: amount1, + NewSqrtPrice: state.sqrtPriceX96, + NewTick: state.tick, + NewLiquidity: state.liquidity, + NewProtocolFees: ProtocolFees{ + token0: pool.protocolFees.token0, + token1: pool.protocolFees.token1, + }, + FeeGrowthGlobal0: pool.feeGrowthGlobal0X128, + FeeGrowthGlobal1: pool.feeGrowthGlobal1X128, + SwapFee: swapFee, + } + + // Update protocol fees if necessary + if comp.ZeroForOne { + if state.protocolFee.Gt(u256.Zero()) { + result.NewProtocolFees.token0 = new(u256.Uint).Add(result.NewProtocolFees.token0, state.protocolFee) + } + result.FeeGrowthGlobal0 = state.feeGrowthGlobalX128 + } else { + if state.protocolFee.Gt(u256.Zero()) { + result.NewProtocolFees.token1 = new(u256.Uint).Add(result.NewProtocolFees.token1, state.protocolFee) + } + result.FeeGrowthGlobal1 = state.feeGrowthGlobalX128 + } + return result, nil +} // applySwapResult updates pool state with computed results func applySwapResult(pool *Pool, result *SwapResult) { - pool.slot0.sqrtPriceX96 = result.NewSqrtPrice - pool.slot0.tick = result.NewTick - pool.liquidity = result.NewLiquidity - pool.protocolFees = result.NewProtocolFees - pool.feeGrowthGlobal0X128 = result.FeeGrowthGlobal0 - pool.feeGrowthGlobal1X128 = result.FeeGrowthGlobal1 + pool.slot0.sqrtPriceX96 = result.NewSqrtPrice + pool.slot0.tick = result.NewTick + pool.liquidity = result.NewLiquidity + pool.protocolFees = result.NewProtocolFees + pool.feeGrowthGlobal0X128 = result.FeeGrowthGlobal0 + pool.feeGrowthGlobal1X128 = result.FeeGrowthGlobal1 } // validatePriceLimits validates the price limits for the swap func validatePriceLimits(pool *Pool, zeroForOne bool, sqrtPriceLimitX96 *u256.Uint) { - if zeroForOne { - minSqrtRatio := u256.MustFromDecimal(consts.MIN_SQRT_RATIO) - - cond1 := sqrtPriceLimitX96.Lt(pool.slot0.sqrtPriceX96) - cond2 := sqrtPriceLimitX96.Gt(minSqrtRatio) - if !(cond1 && cond2) { - panic(addDetailToError( - errPriceOutOfRange, - ufmt.Sprintf("pool.gno__Swap() || sqrtPriceLimitX96(%s) < slot0Start.sqrtPriceX96(%s) && sqrtPriceLimitX96(%s) > consts.MIN_SQRT_RATIO(%s)", - sqrtPriceLimitX96.ToString(), - pool.slot0.sqrtPriceX96.ToString(), - sqrtPriceLimitX96.ToString(), - consts.MIN_SQRT_RATIO), - )) - } - } else { - maxSqrtRatio := u256.MustFromDecimal(consts.MAX_SQRT_RATIO) - - cond1 := sqrtPriceLimitX96.Gt(pool.slot0.sqrtPriceX96) - cond2 := sqrtPriceLimitX96.Lt(maxSqrtRatio) - if !(cond1 && cond2) { - panic(addDetailToError( - errPriceOutOfRange, - ufmt.Sprintf("pool.gno__Swap() || sqrtPriceLimitX96(%s) > slot0Start.sqrtPriceX96(%s) && sqrtPriceLimitX96(%s) < consts.MAX_SQRT_RATIO(%s)", - sqrtPriceLimitX96.ToString(), - pool.slot0.sqrtPriceX96.ToString(), - sqrtPriceLimitX96.ToString(), - consts.MAX_SQRT_RATIO), - )) - } - } + if zeroForOne { + minSqrtRatio := u256.MustFromDecimal(consts.MIN_SQRT_RATIO) + + cond1 := sqrtPriceLimitX96.Lt(pool.slot0.sqrtPriceX96) + cond2 := sqrtPriceLimitX96.Gt(minSqrtRatio) + if !(cond1 && cond2) { + panic(addDetailToError( + errPriceOutOfRange, + ufmt.Sprintf("pool.gno__Swap() || sqrtPriceLimitX96(%s) < slot0Start.sqrtPriceX96(%s) && sqrtPriceLimitX96(%s) > consts.MIN_SQRT_RATIO(%s)", + sqrtPriceLimitX96.ToString(), + pool.slot0.sqrtPriceX96.ToString(), + sqrtPriceLimitX96.ToString(), + consts.MIN_SQRT_RATIO), + )) + } + } else { + maxSqrtRatio := u256.MustFromDecimal(consts.MAX_SQRT_RATIO) + + cond1 := sqrtPriceLimitX96.Gt(pool.slot0.sqrtPriceX96) + cond2 := sqrtPriceLimitX96.Lt(maxSqrtRatio) + if !(cond1 && cond2) { + panic(addDetailToError( + errPriceOutOfRange, + ufmt.Sprintf("pool.gno__Swap() || sqrtPriceLimitX96(%s) > slot0Start.sqrtPriceX96(%s) && sqrtPriceLimitX96(%s) < consts.MAX_SQRT_RATIO(%s)", + sqrtPriceLimitX96.ToString(), + pool.slot0.sqrtPriceX96.ToString(), + sqrtPriceLimitX96.ToString(), + consts.MAX_SQRT_RATIO), + )) + } + } } // getFeeProtocol returns the appropriate fee protocol based on zero for one func getFeeProtocol(slot0 Slot0, zeroForOne bool) uint8 { - if zeroForOne { - return slot0.feeProtocol % 16 - } - return slot0.feeProtocol / 16 + if zeroForOne { + return slot0.feeProtocol % 16 + } + return slot0.feeProtocol / 16 } // getFeeGrowthGlobal returns the appropriate fee growth global based on zero for one func getFeeGrowthGlobal(pool *Pool, zeroForOne bool) *u256.Uint { - if zeroForOne { - return pool.feeGrowthGlobal0X128 - } - return pool.feeGrowthGlobal1X128 + if zeroForOne { + return pool.feeGrowthGlobal0X128 + } + return pool.feeGrowthGlobal1X128 } // Swap swaps token0 for token1, or token1 for token0 @@ -353,18 +350,19 @@ func Swap( validatePriceLimits(pool, zeroForOne, sqrtPriceLimitX96) + feeGrowthGlobalX128 := getFeeGrowthGlobal(pool, zeroForOne) feeProtocol := getFeeProtocol(slot0Start, zeroForOne) cache := newSwapCache(feeProtocol, pool.liquidity) - feeGrowthGlobalX128 := getFeeGrowthGlobal(pool, zeroForOne) + state := newSwapState(amountSpecified, feeGrowthGlobalX128, cache.liquidityStart, pool.slot0) comp := SwapComputation{ - AmountSpecified: amountSpecified, + AmountSpecified: amountSpecified, SqrtPriceLimitX96: sqrtPriceLimitX96, - ZeroForOne: zeroForOne, - ExactInput: amountSpecified.Gt(i256.Zero()), - InitialState: state, - Cache: cache, + ZeroForOne: zeroForOne, + ExactInput: amountSpecified.Gt(i256.Zero()), + InitialState: state, + Cache: cache, } result, err := computeSwap(pool, comp) @@ -401,99 +399,50 @@ func Swap( } func shouldContinueSwap(state SwapState, sqrtPriceLimitX96 *u256.Uint) bool { - println("sqrtPriceLimitX96", sqrtPriceLimitX96.ToString()) return !(state.amountSpecifiedRemaining.IsZero()) && !(state.sqrtPriceX96.Eq(sqrtPriceLimitX96)) } - // computeSwapStep executes a single step of swap and returns new state func computeSwapStep( - state SwapState, - pool *Pool, - zeroForOne bool, - sqrtPriceLimitX96 *u256.Uint, - exactInput bool, - cache SwapCache, - swapFee *u256.Uint, + state SwapState, + pool *Pool, + zeroForOne bool, + sqrtPriceLimitX96 *u256.Uint, + exactInput bool, + cache SwapCache, + swapFee *u256.Uint, ) (SwapState, *u256.Uint) { - step := computeSwapStepInit(state, pool, zeroForOne) - sqrtRatioTargetX96 := computeTargetSqrtRatio(step, sqrtPriceLimitX96, zeroForOne) - - var newState SwapState - newState, step = computeAmounts(state, sqrtRatioTargetX96, pool, step) - newState = updateAmounts(step, newState, exactInput) - - // if the protocol fee is on, calculate how much is owed, - // decrement fee amount, and increment protocol fee - if cache.feeProtocol > 0 { - newState = updateFeeProtocol(step, cache.feeProtocol, newState) - } - - // update global fee tracker - if newState.liquidity.Gt(u256.Zero()) { - update := u256.MulDiv(step.feeAmount, u256.MustFromDecimal(consts.Q128), newState.liquidity) - newState.SetFeeGrowthGlobalX128(new(u256.Uint).Add(newState.feeGrowthGlobalX128, update)) - } + step := computeSwapStepInit(state, pool, zeroForOne) + sqrtRatioTargetX96 := computeTargetSqrtRatio(step, sqrtPriceLimitX96, zeroForOne) - newSwapFee := new(u256.Uint).Add(swapFee, step.feeAmount) + var newState SwapState + newState, step = computeAmounts(state, sqrtRatioTargetX96, pool, step) + newState = updateAmounts(step, newState, exactInput) - if newState.sqrtPriceX96.Eq(step.sqrtPriceNextX96) { - newState = tickTransition(step, zeroForOne, newState, pool) - } - - if newState.sqrtPriceX96.Neq(step.sqrtPriceStartX96) { - newState.SetTick(common.TickMathGetTickAtSqrtRatio(newState.sqrtPriceX96)) - } - - return newState, newSwapFee -} - -// computeSingleSwapStep executes a single step of swap. -// computing the new state and price limit. -// func computeSwapStep( -// state SwapState, -// pool *Pool, -// zeroForOne bool, -// sqrtPriceLimitX96 *u256.Uint, -// exactInput bool, -// cache SwapCache, -// swapFee *u256.Uint, -// ) *u256.Uint { -// step := computeSwapStepInit(state, pool, zeroForOne) - - -// sqrtRatioTargetX96 := computeTargetSqrtRatio(step, sqrtPriceLimitX96, zeroForOne) -// state, step = computeAmounts(state, sqrtRatioTargetX96, pool, step) - -// state = updateAmounts(step, state, exactInput) - -// // if the protocol fee is on, calculate how much is owed, -// // decrement fee amount, and increment protocol fee -// feeProtocol := cache.feeProtocol -// if feeProtocol > 0 { -// state = updateFeeProtocol(step, feeProtocol, state) -// } - -// println(">>>>>> AFTER UPDATE FEE PROTOCOL") + // if the protocol fee is on, calculate how much is owed, + // decrement fee amount, and increment protocol fee + if cache.feeProtocol > 0 { + newState = updateFeeProtocol(step, cache.feeProtocol, newState) + } -// // update global fee tracker -// if state.liquidity.Gt(u256.Zero()) { -// update := u256.MulDiv(step.feeAmount, u256.MustFromDecimal(consts.Q128), state.liquidity) -// state.SetFeeGrowthGlobalX128(new(u256.Uint).Add(state.feeGrowthGlobalX128, update)) -// } + // update global fee tracker + if newState.liquidity.Gt(u256.Zero()) { + update := u256.MulDiv(step.feeAmount, u256.MustFromDecimal(consts.Q128), newState.liquidity) + newState.SetFeeGrowthGlobalX128(new(u256.Uint).Add(newState.feeGrowthGlobalX128, update)) + } -// swapFee = new(u256.Uint).Add(swapFee, step.feeAmount) + newSwapFee := new(u256.Uint).Add(swapFee, step.feeAmount) -// if state.sqrtPriceX96.Eq(step.sqrtPriceNextX96) { -// tickTransition(step, zeroForOne, state, pool) -// } + if newState.sqrtPriceX96.Eq(step.sqrtPriceNextX96) { + newState = tickTransition(step, zeroForOne, newState, pool) + } -// if state.sqrtPriceX96.Neq(step.sqrtPriceStartX96) { -// state.SetTick(common.TickMathGetTickAtSqrtRatio(state.sqrtPriceX96)) -// } + if newState.sqrtPriceX96.Neq(step.sqrtPriceStartX96) { + newState.SetTick(common.TickMathGetTickAtSqrtRatio(newState.sqrtPriceX96)) + } -// return swapFee -// } + return newState, newSwapFee +} // updateFeeProtocol calculates and updates protocol fees for the current step. func updateFeeProtocol(step StepComputations, feeProtocol uint8, state SwapState) SwapState { @@ -595,35 +544,35 @@ func (pool *Pool) swapTransfers(zeroForOne bool, payer, recipient std.Address, a // tickTransition now returns new state instead of modifying existing func tickTransition(step StepComputations, zeroForOne bool, state SwapState, pool *Pool) SwapState { - newState := state - - if step.initialized { - var fee0, fee1 *u256.Uint + newState := state - if zeroForOne { - fee0 = state.feeGrowthGlobalX128 - fee1 = pool.feeGrowthGlobal1X128 - } else { - fee0 = pool.feeGrowthGlobal0X128 - fee1 = state.feeGrowthGlobalX128 - } + if step.initialized { + var fee0, fee1 *u256.Uint - liquidityNet := pool.tickCross(step.tickNext, fee0, fee1) + if zeroForOne { + fee0 = state.feeGrowthGlobalX128 + fee1 = pool.feeGrowthGlobal1X128 + } else { + fee0 = pool.feeGrowthGlobal0X128 + fee1 = state.feeGrowthGlobalX128 + } - if zeroForOne { - liquidityNet = i256.Zero().Neg(liquidityNet) - } + liquidityNet := pool.tickCross(step.tickNext, fee0, fee1) - newState.liquidity = liquidityMathAddDelta(state.liquidity, liquidityNet) - } + if zeroForOne { + liquidityNet = i256.Zero().Neg(liquidityNet) + } - if zeroForOne { - newState.tick = step.tickNext - 1 - } else { - newState.tick = step.tickNext - } + newState.liquidity = liquidityMathAddDelta(state.liquidity, liquidityNet) + } + + if zeroForOne { + newState.tick = step.tickNext - 1 + } else { + newState.tick = step.tickNext + } - return newState + return newState } // SetFeeProtocolByAdmin sets the fee protocol for all pools @@ -697,14 +646,14 @@ func setFeeProtocol(feeProtocol0, feeProtocol1 uint8) uint8 { } func validateFeeProtocol(feeProtocol0, feeProtocol1 uint8) error { - if !isValidFeeProtocolValue(feeProtocol0) || !isValidFeeProtocolValue(feeProtocol1) { - return errInvalidProtocolFeePct - } - return nil + if !isValidFeeProtocolValue(feeProtocol0) || !isValidFeeProtocolValue(feeProtocol1) { + return errInvalidProtocolFeePct + } + return nil } func isValidFeeProtocolValue(value uint8) bool { - return value == 0 || (value >= 4 && value <= 10) + return value == 0 || (value >= 4 && value <= 10) } // CollectProtocolByAdmin collects protocol fees for the given pool that accumulated while it was being used for swap @@ -903,25 +852,25 @@ func updatePoolBalance( var overflow bool var newBalance *u256.Uint - if isToken0 { - newBalance, overflow = new(u256.Uint).SubOverflow(token0, amount) - if overflow || newBalance.Lt(u256.Zero()) { + if isToken0 { + newBalance, overflow = new(u256.Uint).SubOverflow(token0, amount) + if overflow || newBalance.Lt(u256.Zero()) { return nil, ufmt.Errorf( "%s || cannot decrease, token0(%s) - amount(%s)", errTransferFailed.Error(), token0.ToString(), amount.ToString(), ) - } - return newBalance, nil - } + } + return newBalance, nil + } - newBalance, overflow = new(u256.Uint).SubOverflow(token1, amount) - if overflow || newBalance.Lt(u256.Zero()) { + newBalance, overflow = new(u256.Uint).SubOverflow(token1, amount) + if overflow || newBalance.Lt(u256.Zero()) { return nil, ufmt.Errorf( "%s || cannot decrease, token1(%s) - amount(%s)", errTransferFailed.Error(), token1.ToString(), amount.ToString(), ) - } - return newBalance, nil + } + return newBalance, nil } func (pool *Pool) transferFromAndVerify( diff --git a/pool/pool_test.gno b/pool/pool_test.gno index 0458302fc..067aab183 100644 --- a/pool/pool_test.gno +++ b/pool/pool_test.gno @@ -1,12 +1,11 @@ package pool import ( - "std" "testing" "gno.land/p/demo/uassert" - u256 "gno.land/p/gnoswap/uint256" i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" ) func TestSaveProtocolFees(t *testing.T) { @@ -17,7 +16,7 @@ func TestSaveProtocolFees(t *testing.T) { amount1 *u256.Uint want0 *u256.Uint want1 *u256.Uint - wantFee0 *u256.Uint + wantFee0 *u256.Uint wantFee1 *u256.Uint }{ { @@ -80,384 +79,346 @@ func TestSaveProtocolFees(t *testing.T) { } func TestTransferAndVerify(t *testing.T) { - // Setup common test data - pool := &Pool{ - balances: Balances{ - token0: u256.NewUint(1000), - token1: u256.NewUint(1000), - }, - } + // Setup common test data + pool := &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(1000), + }, + } - t.Run("validatePoolBalance", func(t *testing.T) { - tests := []struct { - name string - amount *u256.Uint - isToken0 bool - expectedError bool - }{ - { - name: "must success for negative amount", - amount: u256.NewUint(500), - isToken0: true, - expectedError: false, - }, - { - name: "must panic for insufficient token0 balance", - amount: u256.NewUint(1500), - isToken0: true, - expectedError: true, - }, - { - name: "must success for negative amount", - amount: u256.NewUint(500), - isToken0: false, - expectedError: false, - }, - { - name: "must panic for insufficient token1 balance", - amount: u256.NewUint(1500), - isToken0: false, - expectedError: true, - }, - } + t.Run("validatePoolBalance", func(t *testing.T) { + tests := []struct { + name string + amount *u256.Uint + isToken0 bool + expectedError bool + }{ + { + name: "must success for negative amount", + amount: u256.NewUint(500), + isToken0: true, + expectedError: false, + }, + { + name: "must panic for insufficient token0 balance", + amount: u256.NewUint(1500), + isToken0: true, + expectedError: true, + }, + { + name: "must success for negative amount", + amount: u256.NewUint(500), + isToken0: false, + expectedError: false, + }, + { + name: "must panic for insufficient token1 balance", + amount: u256.NewUint(1500), + isToken0: false, + expectedError: true, + }, + } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { token0 := pool.balances.token0 token1 := pool.balances.token1 - err := validatePoolBalance(token0, token1, tt.amount, tt.isToken0) + err := validatePoolBalance(token0, token1, tt.amount, tt.isToken0) if err != nil { if !tt.expectedError { t.Errorf("unexpected error: %v", err) } } - }) - } - }) + }) + } + }) } func TestUpdatePoolBalance(t *testing.T) { - tests := []struct { - name string - initialToken0 *u256.Uint - initialToken1 *u256.Uint - amount *u256.Uint - isToken0 bool - expectedBal *u256.Uint - expectErr bool - }{ - { - name: "normal token0 decrease", - initialToken0: u256.NewUint(1000), - initialToken1: u256.NewUint(2000), - amount: u256.NewUint(300), - isToken0: true, - expectedBal: u256.NewUint(700), - expectErr: false, - }, - { - name: "normal token1 decrease", - initialToken0: u256.NewUint(1000), - initialToken1: u256.NewUint(2000), - amount: u256.NewUint(500), - isToken0: false, - expectedBal: u256.NewUint(1500), - expectErr: false, - }, - { - name: "insufficient token0 balance", - initialToken0: u256.NewUint(100), - initialToken1: u256.NewUint(2000), - amount: u256.NewUint(200), - isToken0: true, - expectedBal: nil, - expectErr: true, - }, - { - name: "insufficient token1 balance", - initialToken0: u256.NewUint(1000), - initialToken1: u256.NewUint(100), - amount: u256.NewUint(200), - isToken0: false, - expectedBal: nil, - expectErr: true, - }, - { - name: "zero value handling", - initialToken0: u256.NewUint(1000), - initialToken1: u256.NewUint(2000), - amount: u256.NewUint(0), - isToken0: true, - expectedBal: u256.NewUint(1000), - expectErr: false, - }, - } + tests := []struct { + name string + initialToken0 *u256.Uint + initialToken1 *u256.Uint + amount *u256.Uint + isToken0 bool + expectedBal *u256.Uint + expectErr bool + }{ + { + name: "normal token0 decrease", + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(300), + isToken0: true, + expectedBal: u256.NewUint(700), + expectErr: false, + }, + { + name: "normal token1 decrease", + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(500), + isToken0: false, + expectedBal: u256.NewUint(1500), + expectErr: false, + }, + { + name: "insufficient token0 balance", + initialToken0: u256.NewUint(100), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(200), + isToken0: true, + expectedBal: nil, + expectErr: true, + }, + { + name: "insufficient token1 balance", + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(100), + amount: u256.NewUint(200), + isToken0: false, + expectedBal: nil, + expectErr: true, + }, + { + name: "zero value handling", + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(0), + isToken0: true, + expectedBal: u256.NewUint(1000), + expectErr: false, + }, + } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - pool := &Pool{ - balances: Balances{ - token0: tt.initialToken0, - token1: tt.initialToken1, - }, - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pool := &Pool{ + balances: Balances{ + token0: tt.initialToken0, + token1: tt.initialToken1, + }, + } newBal, err := updatePoolBalance(tt.initialToken0, tt.initialToken1, tt.amount, tt.isToken0) - if tt.expectErr { - if err == nil { - t.Errorf("%s: expected error but no error", tt.name) - } - return - } - if err != nil { - t.Errorf("%s: unexpected error: %v", tt.name, err) - return - } - - if !newBal.Eq(tt.expectedBal) { - t.Errorf("%s: balance mismatch, expected: %s, actual: %s", - tt.name, - tt.expectedBal.ToString(), - newBal.ToString(), - ) - } - }) - } + if tt.expectErr { + if err == nil { + t.Errorf("%s: expected error but no error", tt.name) + } + return + } + if err != nil { + t.Errorf("%s: unexpected error: %v", tt.name, err) + return + } + + if !newBal.Eq(tt.expectedBal) { + t.Errorf("%s: balance mismatch, expected: %s, actual: %s", + tt.name, + tt.expectedBal.ToString(), + newBal.ToString(), + ) + } + }) + } } - func TestShouldContinueSwap(t *testing.T) { - tests := []struct { - name string - state SwapState - sqrtPriceLimitX96 *u256.Uint - expected bool - }{ - { - name: "Should continue - amount remaining and price not at limit", - state: SwapState{ - amountSpecifiedRemaining: i256.MustFromDecimal("1000"), - sqrtPriceX96: u256.MustFromDecimal("1000000"), - }, - sqrtPriceLimitX96: u256.MustFromDecimal("900000"), - expected: true, - }, - { - name: "Should stop - no amount remaining", - state: SwapState{ - amountSpecifiedRemaining: i256.Zero(), - sqrtPriceX96: u256.MustFromDecimal("1000000"), - }, - sqrtPriceLimitX96: u256.MustFromDecimal("900000"), - expected: false, - }, - { - name: "Should stop - price at limit", - state: SwapState{ - amountSpecifiedRemaining: i256.MustFromDecimal("1000"), - sqrtPriceX96: u256.MustFromDecimal("900000"), - }, - sqrtPriceLimitX96: u256.MustFromDecimal("900000"), - expected: false, - }, - } + tests := []struct { + name string + state SwapState + sqrtPriceLimitX96 *u256.Uint + expected bool + }{ + { + name: "Should continue - amount remaining and price not at limit", + state: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000"), + sqrtPriceX96: u256.MustFromDecimal("1000000"), + }, + sqrtPriceLimitX96: u256.MustFromDecimal("900000"), + expected: true, + }, + { + name: "Should stop - no amount remaining", + state: SwapState{ + amountSpecifiedRemaining: i256.Zero(), + sqrtPriceX96: u256.MustFromDecimal("1000000"), + }, + sqrtPriceLimitX96: u256.MustFromDecimal("900000"), + expected: false, + }, + { + name: "Should stop - price at limit", + state: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000"), + sqrtPriceX96: u256.MustFromDecimal("900000"), + }, + sqrtPriceLimitX96: u256.MustFromDecimal("900000"), + expected: false, + }, + } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := shouldContinueSwap(tt.state, tt.sqrtPriceLimitX96) - uassert.Equal(t, tt.expected, result) - }) - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := shouldContinueSwap(tt.state, tt.sqrtPriceLimitX96) + uassert.Equal(t, tt.expected, result) + }) + } } func TestUpdateAmounts(t *testing.T) { - tests := []struct { - name string - step StepComputations - state SwapState - exactInput bool - expectedState SwapState - }{ - { - name: "Exact input update", - step: StepComputations{ - amountIn: u256.MustFromDecimal("100"), - amountOut: u256.MustFromDecimal("97"), - feeAmount: u256.MustFromDecimal("3"), - }, - state: SwapState{ - amountSpecifiedRemaining: i256.MustFromDecimal("1000"), - amountCalculated: i256.Zero(), - }, - exactInput: true, - expectedState: SwapState{ - amountSpecifiedRemaining: i256.MustFromDecimal("897"), // 1000 - (100 + 3) - amountCalculated: i256.MustFromDecimal("-97"), - }, - }, - { - name: "Exact output update", - step: StepComputations{ - amountIn: u256.MustFromDecimal("100"), - amountOut: u256.MustFromDecimal("97"), - feeAmount: u256.MustFromDecimal("3"), - }, - state: SwapState{ - amountSpecifiedRemaining: i256.MustFromDecimal("-1000"), - amountCalculated: i256.Zero(), - }, - exactInput: false, - expectedState: SwapState{ - amountSpecifiedRemaining: i256.MustFromDecimal("-903"), // -1000 + 97 - amountCalculated: i256.MustFromDecimal("103"), // 100 + 3 - }, - }, - } + tests := []struct { + name string + step StepComputations + state SwapState + exactInput bool + expectedState SwapState + }{ + { + name: "Exact input update", + step: StepComputations{ + amountIn: u256.MustFromDecimal("100"), + amountOut: u256.MustFromDecimal("97"), + feeAmount: u256.MustFromDecimal("3"), + }, + state: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000"), + amountCalculated: i256.Zero(), + }, + exactInput: true, + expectedState: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("897"), // 1000 - (100 + 3) + amountCalculated: i256.MustFromDecimal("-97"), + }, + }, + { + name: "Exact output update", + step: StepComputations{ + amountIn: u256.MustFromDecimal("100"), + amountOut: u256.MustFromDecimal("97"), + feeAmount: u256.MustFromDecimal("3"), + }, + state: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("-1000"), + amountCalculated: i256.Zero(), + }, + exactInput: false, + expectedState: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("-903"), // -1000 + 97 + amountCalculated: i256.MustFromDecimal("103"), // 100 + 3 + }, + }, + } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := updateAmounts(tt.step, tt.state, tt.exactInput) - - uassert.True(t, tt.expectedState.amountSpecifiedRemaining.Eq(result.amountSpecifiedRemaining)) - uassert.True(t, tt.expectedState.amountCalculated.Eq(result.amountCalculated)) - }) - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := updateAmounts(tt.step, tt.state, tt.exactInput) + + uassert.True(t, tt.expectedState.amountSpecifiedRemaining.Eq(result.amountSpecifiedRemaining)) + uassert.True(t, tt.expectedState.amountCalculated.Eq(result.amountCalculated)) + }) + } } func TestComputeSwap(t *testing.T) { - mockPool := &Pool{ - token0Path: "token0", - token1Path: "token1", - fee: 3000, // 0.3% - tickSpacing: 60, - slot0: Slot0{ - sqrtPriceX96: u256.MustFromDecimal("1000000000000000000"), // 1.0 - tick: 0, - feeProtocol: 0, - unlocked: true, - }, - liquidity: u256.MustFromDecimal("1000000000000000000"), // 1.0 - protocolFees: ProtocolFees{ - token0: u256.Zero(), - token1: u256.Zero(), - }, - feeGrowthGlobal0X128: u256.Zero(), - feeGrowthGlobal1X128: u256.Zero(), - tickBitmaps: make(TickBitmaps), - ticks: make(Ticks), - positions: make(Positions), - } - - wordPos, _ := tickBitmapPosition(0) - mockPool.tickBitmaps[wordPos] = u256.NewUint(1) - - t.Run("basic swap", func(t *testing.T) { - comp := SwapComputation{ - AmountSpecified: i256.MustFromDecimal("1000000"), // 1.0 token - SqrtPriceLimitX96: u256.MustFromDecimal("1100000000000000000"), // 1.1 - ZeroForOne: true, - ExactInput: true, - InitialState: SwapState{ - amountSpecifiedRemaining: i256.MustFromDecimal("1000000"), - amountCalculated: i256.Zero(), - sqrtPriceX96: mockPool.slot0.sqrtPriceX96, - tick: mockPool.slot0.tick, - feeGrowthGlobalX128: mockPool.feeGrowthGlobal0X128, - protocolFee: u256.Zero(), - liquidity: mockPool.liquidity, - }, - Cache: SwapCache{ - feeProtocol: 0, - liquidityStart: mockPool.liquidity, - }, - } - - result, err := computeSwap(mockPool, comp) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - - if result.Amount0.IsZero() { - t.Error("expected non-zero amount0") - } - if result.Amount1.IsZero() { - t.Error("expected non-zero amount1") - } - if result.SwapFee.IsZero() { - t.Error("expected non-zero swap fee") - } - }) - - t.Run("swap with protocol fee", func(t *testing.T) { - t.Skip() - mockPoolWithFee := *mockPool - mockPoolWithFee.slot0.feeProtocol = 4 // 4 = 1/4 of fee goes to protocol - - comp := SwapComputation{ - AmountSpecified: i256.MustFromDecimal("1000000"), - SqrtPriceLimitX96: u256.MustFromDecimal("1100000000000000000"), - ZeroForOne: true, - ExactInput: true, - InitialState: SwapState{ - amountSpecifiedRemaining: i256.MustFromDecimal("1000000"), - amountCalculated: i256.Zero(), - sqrtPriceX96: mockPoolWithFee.slot0.sqrtPriceX96, - tick: mockPoolWithFee.slot0.tick, - feeGrowthGlobalX128: mockPoolWithFee.feeGrowthGlobal0X128, - protocolFee: u256.Zero(), - liquidity: mockPoolWithFee.liquidity, - }, - Cache: SwapCache{ - feeProtocol: 4, - liquidityStart: mockPoolWithFee.liquidity, - }, - } - - result, err := computeSwap(&mockPoolWithFee, comp) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - - // 프로토콜 수수료 검증 - if mockPoolWithFee.protocolFees.token0.IsZero() { - t.Error("expected non-zero protocol fee for token0") - } - }) - - t.Run("swap with zero liquidity", func(t *testing.T) { - t.Skip() - mockPoolZeroLiq := *mockPool - mockPoolZeroLiq.liquidity = u256.Zero() + mockPool := &Pool{ + token0Path: "token0", + token1Path: "token1", + fee: 3000, // 0.3% + tickSpacing: 60, + slot0: Slot0{ + sqrtPriceX96: u256.MustFromDecimal("1000000000000000000"), // 1.0 + tick: 0, + feeProtocol: 0, + unlocked: true, + }, + liquidity: u256.MustFromDecimal("1000000000000000000"), // 1.0 + protocolFees: ProtocolFees{ + token0: u256.Zero(), + token1: u256.Zero(), + }, + feeGrowthGlobal0X128: u256.Zero(), + feeGrowthGlobal1X128: u256.Zero(), + tickBitmaps: make(TickBitmaps), + ticks: make(Ticks), + positions: make(Positions), + } - comp := SwapComputation{ - AmountSpecified: i256.MustFromDecimal("1000000"), - SqrtPriceLimitX96: u256.MustFromDecimal("1100000000000000000"), - ZeroForOne: true, - ExactInput: true, - InitialState: SwapState{ - amountSpecifiedRemaining: i256.MustFromDecimal("1000000"), - amountCalculated: i256.Zero(), - sqrtPriceX96: mockPoolZeroLiq.slot0.sqrtPriceX96, - tick: mockPoolZeroLiq.slot0.tick, - feeGrowthGlobalX128: mockPoolZeroLiq.feeGrowthGlobal0X128, - protocolFee: u256.Zero(), - liquidity: mockPoolZeroLiq.liquidity, - }, - Cache: SwapCache{ - feeProtocol: 0, - liquidityStart: mockPoolZeroLiq.liquidity, - }, - } + wordPos, _ := tickBitmapPosition(0) + mockPool.tickBitmaps[wordPos] = u256.NewUint(1) + + t.Run("basic swap", func(t *testing.T) { + comp := SwapComputation{ + AmountSpecified: i256.MustFromDecimal("1000000"), // 1.0 token + SqrtPriceLimitX96: u256.MustFromDecimal("1100000000000000000"), // 1.1 + ZeroForOne: true, + ExactInput: true, + InitialState: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000000"), + amountCalculated: i256.Zero(), + sqrtPriceX96: mockPool.slot0.sqrtPriceX96, + tick: mockPool.slot0.tick, + feeGrowthGlobalX128: mockPool.feeGrowthGlobal0X128, + protocolFee: u256.Zero(), + liquidity: mockPool.liquidity, + }, + Cache: SwapCache{ + feeProtocol: 0, + liquidityStart: mockPool.liquidity, + }, + } + + result, err := computeSwap(mockPool, comp) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if result.Amount0.IsZero() { + t.Error("expected non-zero amount0") + } + if result.Amount1.IsZero() { + t.Error("expected non-zero amount1") + } + if result.SwapFee.IsZero() { + t.Error("expected non-zero swap fee") + } + }) + + t.Run("swap with zero liquidity", func(t *testing.T) { + mockPoolZeroLiq := *mockPool + mockPoolZeroLiq.liquidity = u256.Zero() + + comp := SwapComputation{ + AmountSpecified: i256.MustFromDecimal("1000000"), + SqrtPriceLimitX96: u256.MustFromDecimal("1100000000000000000"), + ZeroForOne: true, + ExactInput: true, + InitialState: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000000"), + amountCalculated: i256.Zero(), + sqrtPriceX96: mockPoolZeroLiq.slot0.sqrtPriceX96, + tick: mockPoolZeroLiq.slot0.tick, + feeGrowthGlobalX128: mockPoolZeroLiq.feeGrowthGlobal0X128, + protocolFee: u256.Zero(), + liquidity: mockPoolZeroLiq.liquidity, + }, + Cache: SwapCache{ + feeProtocol: 0, + liquidityStart: mockPoolZeroLiq.liquidity, + }, + } - result, err := computeSwap(&mockPoolZeroLiq, comp) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } + result, err := computeSwap(&mockPoolZeroLiq, comp) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } - if !result.Amount0.IsZero() || !result.Amount1.IsZero() { - t.Error("expected zero amounts for zero liquidity") - } - }) + if !result.Amount0.IsZero() || !result.Amount1.IsZero() { + t.Error("expected zero amounts for zero liquidity") + } + }) } From c8db3ee171c9ef18cab008dc89f2ccd5bc091146 Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Wed, 4 Dec 2024 15:28:38 +0900 Subject: [PATCH 13/24] test: transferFromAndVerify --- pool/pool.gno | 188 ++++++++++++++++++++++----------------------- pool/pool_test.gno | 148 +++++++++++++++++++++++++++++++++++ 2 files changed, 242 insertions(+), 94 deletions(-) diff --git a/pool/pool.gno b/pool/pool.gno index aa1cc4ed7..3b831e17d 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -193,6 +193,100 @@ type SwapComputation struct { Cache SwapCache } +// Swap swaps token0 for token1, or token1 for token0 +// Returns swapped amount0, amount1 in string +// ref: https://docs.gnoswap.io/contracts/pool/pool.gno#swap +func Swap( + token0Path string, + token1Path string, + fee uint32, + recipient std.Address, + zeroForOne bool, + _amountSpecified string, // int256 + _sqrtPriceLimitX96 string, // uint160 + payer std.Address, // router +) (string, string) { // int256 x2 + common.IsHalted() + if common.GetLimitCaller() { + caller := std.PrevRealm().Addr() + if err := common.RouterOnly(caller); err != nil { + panic(addDetailToError( + errNoPermission, + ufmt.Sprintf("pool.gno__Swap() || only router(%s) can call pool swap(), called from %s", consts.ROUTER_ADDR, caller.String()), + )) + } + } + + if _amountSpecified == "0" { + panic(addDetailToError( + errInvalidSwapAmount, + ufmt.Sprintf("pool.gno__Swap() || amountSpecified == 0"), + )) + } + + pool := GetPool(token0Path, token1Path, fee) + + slot0Start := pool.slot0 + if !slot0Start.unlocked { + panic(errLockedPool) + } + + slot0Start.unlocked = false + defer func() { slot0Start.unlocked = true }() + + amountSpecified := i256.MustFromDecimal(_amountSpecified) + sqrtPriceLimitX96 := u256.MustFromDecimal(_sqrtPriceLimitX96) + + validatePriceLimits(pool, zeroForOne, sqrtPriceLimitX96) + + feeGrowthGlobalX128 := getFeeGrowthGlobal(pool, zeroForOne) + feeProtocol := getFeeProtocol(slot0Start, zeroForOne) + cache := newSwapCache(feeProtocol, pool.liquidity) + + state := newSwapState(amountSpecified, feeGrowthGlobalX128, cache.liquidityStart, pool.slot0) + + comp := SwapComputation{ + AmountSpecified: amountSpecified, + SqrtPriceLimitX96: sqrtPriceLimitX96, + ZeroForOne: zeroForOne, + ExactInput: amountSpecified.Gt(i256.Zero()), + InitialState: state, + Cache: cache, + } + + result, err := computeSwap(pool, comp) + if err != nil { + panic(err) + } + + applySwapResult(pool, result) + + // actual swap + pool.swapTransfers(zeroForOne, payer, recipient, result.Amount0, result.Amount1) + + prevAddr, prevRealm := getPrev() + + std.Emit( + "Swap", + "prevAddr", prevAddr, + "prevRealm", prevRealm, + "poolPath", GetPoolPath(token0Path, token1Path, fee), + "zeroForOne", ufmt.Sprintf("%t", zeroForOne), + "amountSpecified", _amountSpecified, + "sqrtPriceLimitX96", _sqrtPriceLimitX96, + "payer", payer.String(), + "recipient", recipient.String(), + "internal_amount0", result.Amount0.ToString(), + "internal_amount1", result.Amount1.ToString(), + "internal_protocolFee0", pool.protocolFees.token0.ToString(), + "internal_protocolFee1", pool.protocolFees.token1.ToString(), + "internal_swapFee", result.SwapFee.ToString(), + "internal_sqrtPriceX96", pool.slot0.sqrtPriceX96.ToString(), + ) + + return result.Amount0.ToString(), result.Amount1.ToString() +} + func computeSwap(pool *Pool, comp SwapComputation) (*SwapResult, error) { state := comp.InitialState swapFee := u256.Zero() @@ -304,100 +398,6 @@ func getFeeGrowthGlobal(pool *Pool, zeroForOne bool) *u256.Uint { return pool.feeGrowthGlobal1X128 } -// Swap swaps token0 for token1, or token1 for token0 -// Returns swapped amount0, amount1 in string -// ref: https://docs.gnoswap.io/contracts/pool/pool.gno#swap -func Swap( - token0Path string, - token1Path string, - fee uint32, - recipient std.Address, - zeroForOne bool, - _amountSpecified string, // int256 - _sqrtPriceLimitX96 string, // uint160 - payer std.Address, // router -) (string, string) { // int256 x2 - common.IsHalted() - if common.GetLimitCaller() { - caller := std.PrevRealm().Addr() - if err := common.RouterOnly(caller); err != nil { - panic(addDetailToError( - errNoPermission, - ufmt.Sprintf("pool.gno__Swap() || only router(%s) can call pool swap(), called from %s", consts.ROUTER_ADDR, caller.String()), - )) - } - } - - if _amountSpecified == "0" { - panic(addDetailToError( - errInvalidSwapAmount, - ufmt.Sprintf("pool.gno__Swap() || amountSpecified == 0"), - )) - } - - pool := GetPool(token0Path, token1Path, fee) - - slot0Start := pool.slot0 - if !slot0Start.unlocked { - panic(errLockedPool) - } - - slot0Start.unlocked = false - defer func() { slot0Start.unlocked = true }() - - amountSpecified := i256.MustFromDecimal(_amountSpecified) - sqrtPriceLimitX96 := u256.MustFromDecimal(_sqrtPriceLimitX96) - - validatePriceLimits(pool, zeroForOne, sqrtPriceLimitX96) - - feeGrowthGlobalX128 := getFeeGrowthGlobal(pool, zeroForOne) - feeProtocol := getFeeProtocol(slot0Start, zeroForOne) - cache := newSwapCache(feeProtocol, pool.liquidity) - - state := newSwapState(amountSpecified, feeGrowthGlobalX128, cache.liquidityStart, pool.slot0) - - comp := SwapComputation{ - AmountSpecified: amountSpecified, - SqrtPriceLimitX96: sqrtPriceLimitX96, - ZeroForOne: zeroForOne, - ExactInput: amountSpecified.Gt(i256.Zero()), - InitialState: state, - Cache: cache, - } - - result, err := computeSwap(pool, comp) - if err != nil { - panic(err) - } - - applySwapResult(pool, result) - - // actual swap - pool.swapTransfers(zeroForOne, payer, recipient, result.Amount0, result.Amount1) - - prevAddr, prevRealm := getPrev() - - std.Emit( - "Swap", - "prevAddr", prevAddr, - "prevRealm", prevRealm, - "poolPath", GetPoolPath(token0Path, token1Path, fee), - "zeroForOne", ufmt.Sprintf("%t", zeroForOne), - "amountSpecified", _amountSpecified, - "sqrtPriceLimitX96", _sqrtPriceLimitX96, - "payer", payer.String(), - "recipient", recipient.String(), - "internal_amount0", result.Amount0.ToString(), - "internal_amount1", result.Amount1.ToString(), - "internal_protocolFee0", pool.protocolFees.token0.ToString(), - "internal_protocolFee1", pool.protocolFees.token1.ToString(), - "internal_swapFee", result.SwapFee.ToString(), - "internal_sqrtPriceX96", pool.slot0.sqrtPriceX96.ToString(), - ) - - return result.Amount0.ToString(), result.Amount1.ToString() -} - func shouldContinueSwap(state SwapState, sqrtPriceLimitX96 *u256.Uint) bool { return !(state.amountSpecifiedRemaining.IsZero()) && !(state.sqrtPriceX96.Eq(sqrtPriceLimitX96)) } diff --git a/pool/pool_test.gno b/pool/pool_test.gno index 067aab183..e30b058fe 100644 --- a/pool/pool_test.gno +++ b/pool/pool_test.gno @@ -1,6 +1,7 @@ package pool import ( + "std" "testing" "gno.land/p/demo/uassert" @@ -422,3 +423,150 @@ func TestComputeSwap(t *testing.T) { } }) } + +func TestTransferFromAndVerify(t *testing.T) { + tests := []struct { + name string + pool *Pool + from std.Address + to std.Address + tokenPath string + amount *i256.Int + isToken0 bool + expectedBal0 *u256.Uint + expectedBal1 *u256.Uint + }{ + { + name: "normal token0 transfer", + pool: &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + from: std.Address("from_addr"), + to: std.Address("to_addr"), + tokenPath: "token0_path", + amount: i256.NewInt(500), + isToken0: true, + expectedBal0: u256.NewUint(1500), // 1000 + 500 + expectedBal1: u256.NewUint(2000), // unchanged + }, + { + name: "normal token1 transfer", + pool: &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + from: std.Address("from_addr"), + to: std.Address("to_addr"), + tokenPath: "token1_path", + amount: i256.NewInt(800), + isToken0: false, + expectedBal0: u256.NewUint(1000), // unchanged + expectedBal1: u256.NewUint(2800), // 2000 + 800 + }, + { + name: "zero value transfer", + pool: &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + from: std.Address("from_addr"), + to: std.Address("to_addr"), + tokenPath: "token0_path", + amount: i256.NewInt(0), + isToken0: true, + expectedBal0: u256.NewUint(1000), // unchanged + expectedBal1: u256.NewUint(2000), // unchanged + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // mock transferFromByRegisterCall + oldTransferFromByRegisterCall := transferFromByRegisterCall + defer func() { transferFromByRegisterCall = oldTransferFromByRegisterCall }() + + transferFromByRegisterCall = func(tokenPath string, from, to std.Address, amount uint64) bool { + // mock the transfer (just return true) + return true + } + + tt.pool.transferFromAndVerify(tt.from, tt.to, tt.tokenPath, tt.amount, tt.isToken0) + + if !tt.pool.balances.token0.Eq(tt.expectedBal0) { + t.Errorf("token0 balance mismatch: expected %s, got %s", + tt.expectedBal0.ToString(), + tt.pool.balances.token0.ToString()) + } + + if !tt.pool.balances.token1.Eq(tt.expectedBal1) { + t.Errorf("token1 balance mismatch: expected %s, got %s", + tt.expectedBal1.ToString(), + tt.pool.balances.token1.ToString()) + } + }) + } + + t.Run("negative value handling", func(t *testing.T) { + pool := &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + } + + oldTransferFromByRegisterCall := transferFromByRegisterCall + defer func() { transferFromByRegisterCall = oldTransferFromByRegisterCall }() + + transferFromByRegisterCall = func(tokenPath string, from, to std.Address, amount uint64) bool { + return true + } + + negativeAmount := i256.NewInt(-500) + pool.transferFromAndVerify( + std.Address("from_addr"), + std.Address("to_addr"), + "token0_path", + negativeAmount, + true, + ) + + expectedBal := u256.NewUint(1500) // 1000 + 500 (absolute value) + if !pool.balances.token0.Eq(expectedBal) { + t.Errorf("negative amount handling failed: expected %s, got %s", + expectedBal.ToString(), + pool.balances.token0.ToString()) + } + }) + + t.Run("uint64 overflow value", func(t *testing.T) { + pool := &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + } + + hugeAmount := i256.FromUint256(u256.MustFromDecimal("18446744073709551616")) // 2^64 + + defer func() { + if r := recover(); r == nil { + t.Error("expected panic for amount exceeding uint64 range") + } + }() + + pool.transferFromAndVerify( + std.Address("from_addr"), + std.Address("to_addr"), + "token0_path", + hugeAmount, + true, + ) + }) +} From d795424eaa398985e0fe0e6eeae8e9cf9608ac3d Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Wed, 4 Dec 2024 15:31:54 +0900 Subject: [PATCH 14/24] fix --- pool/pool_test.gno | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/pool/pool_test.gno b/pool/pool_test.gno index e30b058fe..51e1571d4 100644 --- a/pool/pool_test.gno +++ b/pool/pool_test.gno @@ -5,6 +5,7 @@ import ( "testing" "gno.land/p/demo/uassert" + "gno.land/p/demo/testutils" i256 "gno.land/p/gnoswap/int256" u256 "gno.land/p/gnoswap/uint256" ) @@ -444,8 +445,8 @@ func TestTransferFromAndVerify(t *testing.T) { token1: u256.NewUint(2000), }, }, - from: std.Address("from_addr"), - to: std.Address("to_addr"), + from: testutils.TestAddress("from_addr"), + to: testutils.TestAddress("to_addr"), tokenPath: "token0_path", amount: i256.NewInt(500), isToken0: true, @@ -460,8 +461,8 @@ func TestTransferFromAndVerify(t *testing.T) { token1: u256.NewUint(2000), }, }, - from: std.Address("from_addr"), - to: std.Address("to_addr"), + from: testutils.TestAddress("from_addr"), + to: testutils.TestAddress("to_addr"), tokenPath: "token1_path", amount: i256.NewInt(800), isToken0: false, @@ -476,8 +477,8 @@ func TestTransferFromAndVerify(t *testing.T) { token1: u256.NewUint(2000), }, }, - from: std.Address("from_addr"), - to: std.Address("to_addr"), + from: testutils.TestAddress("from_addr"), + to: testutils.TestAddress("to_addr"), tokenPath: "token0_path", amount: i256.NewInt(0), isToken0: true, @@ -530,8 +531,8 @@ func TestTransferFromAndVerify(t *testing.T) { negativeAmount := i256.NewInt(-500) pool.transferFromAndVerify( - std.Address("from_addr"), - std.Address("to_addr"), + testutils.TestAddress("from_addr"), + testutils.TestAddress("to_addr"), "token0_path", negativeAmount, true, @@ -562,8 +563,8 @@ func TestTransferFromAndVerify(t *testing.T) { }() pool.transferFromAndVerify( - std.Address("from_addr"), - std.Address("to_addr"), + testutils.TestAddress("from_addr"), + testutils.TestAddress("to_addr"), "token0_path", hugeAmount, true, From b0ee16d708603cf5cd389c3605a6c00c97fba46a Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Wed, 4 Dec 2024 16:19:43 +0900 Subject: [PATCH 15/24] test: burn (placeholder) --- pool/pool.gno | 92 ++++++++++++++++++++++---------- pool/pool_test.gno | 129 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 192 insertions(+), 29 deletions(-) diff --git a/pool/pool.gno b/pool/pool.gno index 3b831e17d..fef36ea2e 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -170,7 +170,8 @@ func Collect( // Swap //////////////////////////////////////////////////////////// -// SwapResult represents the final state after swap computation +// SwapResult encapsulates all state changes that occur as a result of a swap +// This type ensure all state transitions are atomic and can be applied at once. type SwapResult struct { Amount0 *i256.Int Amount1 *i256.Int @@ -287,6 +288,21 @@ func Swap( return result.Amount0.ToString(), result.Amount1.ToString() } +// computeSwap performs the core swap computation without modifying pool state +// The function follows these state transitions: +// 1. Initial State: Provided by `SwapComputation.InitialState` +// 2. Stepping State: For each step: +// - Compute next tick and price target +// - Calculate amounts and fees +// - Update state (remaining amount, fees, liquidity) +// - Handle tick transitions if necessary +// 3. Final State: Aggregated in SwapResult +// +// The computation continues until either: +// - The entire amount is consumed (`amountSpecifiedRemaining` = 0) +// - The price limit is reached (`sqrtPriceX96` = `sqrtPriceLimitX96`) +// +// Returns an error if the computation fails at any step func computeSwap(pool *Pool, comp SwapComputation) (*SwapResult, error) { state := comp.InitialState swapFee := u256.Zero() @@ -337,7 +353,8 @@ func computeSwap(pool *Pool, comp SwapComputation) (*SwapResult, error) { return result, nil } -// applySwapResult updates pool state with computed results +// applySwapResult updates pool state with computed results. +// All state changes are applied at once to maintain consistency func applySwapResult(pool *Pool, result *SwapResult) { pool.slot0.sqrtPriceX96 = result.NewSqrtPrice pool.slot0.tick = result.NewTick @@ -347,7 +364,14 @@ func applySwapResult(pool *Pool, result *SwapResult) { pool.feeGrowthGlobal1X128 = result.FeeGrowthGlobal1 } -// validatePriceLimits validates the price limits for the swap +// validatePriceLimits ensures the provided price limit is valid for the swap direction +// The function enforces that: +// For zeroForOne (selling token0): +// - Price limit must be below current price +// - Price limit must be above MIN_SQRT_RATIO +// For !zeroForOne (selling token1): +// - Price limit must be above current price +// - Price limit must be below MAX_SQRT_RATIO func validatePriceLimits(pool *Pool, zeroForOne bool, sqrtPriceLimitX96 *u256.Uint) { if zeroForOne { minSqrtRatio := u256.MustFromDecimal(consts.MIN_SQRT_RATIO) @@ -413,8 +437,11 @@ func computeSwapStep( swapFee *u256.Uint, ) (SwapState, *u256.Uint) { step := computeSwapStepInit(state, pool, zeroForOne) + + // determining the price target for this step sqrtRatioTargetX96 := computeTargetSqrtRatio(step, sqrtPriceLimitX96, zeroForOne) + // computing the amounts to be swapped at this step var newState SwapState newState, step = computeAmounts(state, sqrtRatioTargetX96, pool, step) newState = updateAmounts(step, newState, exactInput) @@ -431,8 +458,7 @@ func computeSwapStep( newState.SetFeeGrowthGlobalX128(new(u256.Uint).Add(newState.feeGrowthGlobalX128, update)) } - newSwapFee := new(u256.Uint).Add(swapFee, step.feeAmount) - + // handling tick transitions if newState.sqrtPriceX96.Eq(step.sqrtPriceNextX96) { newState = tickTransition(step, zeroForOne, newState, pool) } @@ -441,6 +467,8 @@ func computeSwapStep( newState.SetTick(common.TickMathGetTickAtSqrtRatio(newState.sqrtPriceX96)) } + newSwapFee := new(u256.Uint).Add(swapFee, step.feeAmount) + return newState, newSwapFee } @@ -506,8 +534,13 @@ func computeAmounts(state SwapState, sqrtRatioTargetX96 *u256.Uint, pool *Pool, return state, step } -// updateAmounts updates the remaining amounts and calculated amounts -// based on the swap direction. +// updateAmounts calculates new remaining and calculated amounts based on the swap step +// For exact input swaps: +// - Decrements remaining input amount by (amountIn + feeAmount) +// - Decrements calculated amount by amountOut +// For exact output swaps: +// - Increments remaining output amount by amountOut +// - Increments calculated amount by (amountIn + feeAmount) func updateAmounts(step StepComputations, state SwapState, exactInput bool) SwapState { amountInWithFee := i256.FromUint256(new(u256.Uint).Add(step.amountIn, step.feeAmount)) if exactInput { @@ -521,29 +554,9 @@ func updateAmounts(step StepComputations, state SwapState, exactInput bool) Swap return state } -func (pool *Pool) swapTransfers(zeroForOne bool, payer, recipient std.Address, amount0, amount1 *i256.Int) { - var targetTokenPath string - var amount *i256.Int - var isToken0 bool - - switch zeroForOne { - case true: - targetTokenPath = pool.token0Path - amount = amount0 - isToken0 = true - case false: - targetTokenPath = pool.token1Path - amount = amount1 - isToken0 = false - } - - // payer -> POOL -> recipient - pool.transferFromAndVerify(payer, consts.POOL_ADDR, targetTokenPath, amount, isToken0) - pool.transferAndVerify(recipient, targetTokenPath, amount, !isToken0) -} - -// tickTransition now returns new state instead of modifying existing +// tickTransition handles the transition between price ticks during a swap func tickTransition(step StepComputations, zeroForOne bool, state SwapState, pool *Pool) SwapState { + // ensure existing state to keep immutability newState := state if step.initialized { @@ -575,6 +588,27 @@ func tickTransition(step StepComputations, zeroForOne bool, state SwapState, poo return newState } +func (pool *Pool) swapTransfers(zeroForOne bool, payer, recipient std.Address, amount0, amount1 *i256.Int) { + var targetTokenPath string + var amount *i256.Int + var isToken0 bool + + switch zeroForOne { + case true: + targetTokenPath = pool.token0Path + amount = amount0 + isToken0 = true + case false: + targetTokenPath = pool.token1Path + amount = amount1 + isToken0 = false + } + + // payer -> POOL -> recipient + pool.transferFromAndVerify(payer, consts.POOL_ADDR, targetTokenPath, amount, isToken0) + pool.transferAndVerify(recipient, targetTokenPath, amount, !isToken0) +} + // SetFeeProtocolByAdmin sets the fee protocol for all pools // Also it will be applied to new created pools func SetFeeProtocolByAdmin( diff --git a/pool/pool_test.gno b/pool/pool_test.gno index 51e1571d4..7e5ecf3a5 100644 --- a/pool/pool_test.gno +++ b/pool/pool_test.gno @@ -4,12 +4,140 @@ import ( "std" "testing" + "gno.land/r/gnoswap/v1/consts" "gno.land/p/demo/uassert" "gno.land/p/demo/testutils" i256 "gno.land/p/gnoswap/int256" u256 "gno.land/p/gnoswap/uint256" ) +func TestMint(t *testing.T) { + token0Path := "test_token0" + token1Path := "test_token1" + fee := uint32(3000) + recipient := testutils.TestAddress("recipient") + tickLower := int32(-100) + tickUpper := int32(100) + liquidityAmount := "100000" + + t.Run("unauthorized caller mint should fail", func(t *testing.T) { + unauthorized := testutils.TestAddress("unauthorized") + defer func() { + if r := recover(); r == nil { + t.Error("unauthorized caller mint should fail") + } + }() + + Mint(token0Path, token1Path, fee, recipient, tickLower, tickUpper, liquidityAmount, unauthorized) + }) + + t.Run("mint with 0 liquidity should fail", func(t *testing.T) { + authorized := consts.POSITION_ADDR + defer func() { + if r := recover(); r == nil { + t.Error("mint with 0 liquidity should fail") + } + }() + + Mint(token0Path, token1Path, fee, recipient, tickLower, tickUpper, "0", authorized) + }) +} + +func TestBurn(t *testing.T) { + // Setup + originalGetPool := GetPool + defer func() { + GetPool = originalGetPool + }() + + // Mock data + mockCaller := consts.POSITION_ADDR + mockPosition := PositionInfo{ + liquidity: u256.NewUint(1000), + tokensOwed0: u256.NewUint(0), + tokensOwed1: u256.NewUint(0), + } + mockPool := &Pool{ + positions: make(map[string]PositionInfo), + } + + GetPool = func(token0Path, token1Path string, fee uint32) *Pool { + return mockPool + } + + tests := []struct { + name string + liquidityAmount string + tickLower int32 + tickUpper int32 + expectedAmount0 string + expectedAmount1 string + expectPanic bool + }{ + { + name: "successful burn", + liquidityAmount: "500", + tickLower: -100, + tickUpper: 100, + expectedAmount0: "100", + expectedAmount1: "200", + }, + { + name: "zero liquidity", + liquidityAmount: "0", + tickLower: -100, + tickUpper: 100, + expectPanic: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.name == "successful burn" { + t.Skip("skipping until find better way to test this") + } + + // setup position for this test + posKey := positionGetKey(mockCaller, tt.tickLower, tt.tickUpper) + mockPool.positions[posKey] = mockPosition + + if tt.expectPanic { + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic but got none") + } + }() + } + + amount0, amount1 := Burn( + "token0", + "token1", + 3000, + tt.tickLower, + tt.tickUpper, + tt.liquidityAmount, + ) + + if !tt.expectPanic { + if amount0 != tt.expectedAmount0 { + t.Errorf("expected amount0 %s, got %s", tt.expectedAmount0, amount0) + } + if amount1 != tt.expectedAmount1 { + t.Errorf("expected amount1 %s, got %s", tt.expectedAmount1, amount1) + } + + newPosition := mockPool.positions[posKey] + if newPosition.tokensOwed0.IsZero() { + t.Error("expected tokensOwed0 to be updated") + } + if newPosition.tokensOwed1.IsZero() { + t.Error("expected tokensOwed1 to be updated") + } + } + }) + } +} + func TestSaveProtocolFees(t *testing.T) { tests := []struct { name string @@ -351,6 +479,7 @@ func TestComputeSwap(t *testing.T) { } wordPos, _ := tickBitmapPosition(0) + // TODO: use avl mockPool.tickBitmaps[wordPos] = u256.NewUint(1) t.Run("basic swap", func(t *testing.T) { From 92d7e60ed3167ede13cf8539d6d23631bb03dc5c Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Fri, 6 Dec 2024 10:02:14 +0900 Subject: [PATCH 16/24] fix --- pool/_RPC_dry.gno | 2 +- pool/pool.gno | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pool/_RPC_dry.gno b/pool/_RPC_dry.gno index 52b802553..5f76a0935 100644 --- a/pool/_RPC_dry.gno +++ b/pool/_RPC_dry.gno @@ -63,7 +63,7 @@ func DrySwap( slot0.unlocked = false cache := newSwapCache(feeProtocol, pool.liquidity) - state := newSwapState(amountSpecified, feeGrowthGlobalX128, cache.liquidityStart, slot0) // TODO: feeGrowthGlobalX128.Clone() or NOT + state := newSwapState(amountSpecified.Clone(), feeGrowthGlobalX128.Clone(), cache.liquidityStart.Clone(), slot0) // TODO: feeGrowthGlobalX128.Clone() or NOT exactInput := amountSpecified.Gt(i256.Zero()) diff --git a/pool/pool.gno b/pool/pool.gno index c854d866c..2b4f73c9f 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -24,9 +24,9 @@ func Mint( recipient std.Address, tickLower int32, tickUpper int32, - _liquidityAmount string, // uint128 + _liquidityAmount string, positionCaller std.Address, -) (string, string) { // uint256 x2== "0" +) (string, string) { common.IsHalted() if common.GetLimitCaller() { caller := std.PrevRealm().Addr() From 399b4846836c1e344c8c6beb9788dc958e8a4f67 Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Fri, 6 Dec 2024 14:43:32 +0900 Subject: [PATCH 17/24] remove prefix --- pool/pool.gno | 80 +++++++++++++++++++++++++-------------------------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/pool/pool.gno b/pool/pool.gno index 2b4f73c9f..04f8ebcf1 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -71,8 +71,8 @@ func Burn( fee uint32, tickLower int32, tickUpper int32, - liquidityAmount string, // uint128 -) (string, string) { // uint256 x2 + liquidityAmount string, +) (string, string) { common.IsHalted() caller := std.PrevRealm().Addr() if common.GetLimitCaller() { @@ -115,9 +115,9 @@ func Collect( recipient std.Address, tickLower int32, tickUpper int32, - _amount0Requested string, // uint128 - _amount1Requested string, // uint128 -) (string, string) { // uint128 x2 + amount0Requested string, + amount1Requested string, +) (string, string) { common.IsHalted() if common.GetLimitCaller() { caller := std.PrevRealm().Addr() @@ -129,8 +129,8 @@ func Collect( } } - amount0Requested := u256.MustFromDecimal(_amount0Requested) - amount1Requested := u256.MustFromDecimal(_amount1Requested) + amount0Req := u256.MustFromDecimal(amount0Requested) + amount1Req := u256.MustFromDecimal(amount1Requested) pool := GetPool(token0Path, token1Path, fee) @@ -144,7 +144,7 @@ func Collect( } // Smallest of three: amount0Requested, position.tokensOwed0, pool.balances.token0 - amount0 := u256Min(amount0Requested, position.tokensOwed0) + amount0 := u256Min(amount0Req, position.tokensOwed0) amount0 = u256Min(amount0, pool.balances.token0) // Update state first then transfer @@ -153,7 +153,7 @@ func Collect( transferByRegisterCall(pool.token0Path, recipient, amount0.Uint64()) // Smallest of three: amount0Requested, position.tokensOwed0, pool.balances.token0 - amount1 := u256Min(amount1Requested, position.tokensOwed1) + amount1 := u256Min(amount1Req, position.tokensOwed1) amount1 = u256Min(amount1, pool.balances.token1) // Update state first then transfer @@ -203,10 +203,10 @@ func Swap( fee uint32, recipient std.Address, zeroForOne bool, - _amountSpecified string, // int256 - _sqrtPriceLimitX96 string, // uint160 + amountSpecified string, + sqrtPriceLimitX96 string, payer std.Address, // router -) (string, string) { // int256 x2 +) (string, string) { common.IsHalted() if common.GetLimitCaller() { caller := std.PrevRealm().Addr() @@ -218,7 +218,7 @@ func Swap( } } - if _amountSpecified == "0" { + if amountSpecified == "0" { panic(addDetailToError( errInvalidSwapAmount, ufmt.Sprintf("pool.gno__Swap() || amountSpecified == 0"), @@ -235,22 +235,22 @@ func Swap( slot0Start.unlocked = false defer func() { slot0Start.unlocked = true }() - amountSpecified := i256.MustFromDecimal(_amountSpecified) - sqrtPriceLimitX96 := u256.MustFromDecimal(_sqrtPriceLimitX96) + amounts := i256.MustFromDecimal(amountSpecified) + sqrtPriceLimit := u256.MustFromDecimal(sqrtPriceLimitX96) - validatePriceLimits(pool, zeroForOne, sqrtPriceLimitX96) + validatePriceLimits(pool, zeroForOne, sqrtPriceLimit) feeGrowthGlobalX128 := getFeeGrowthGlobal(pool, zeroForOne) feeProtocol := getFeeProtocol(slot0Start, zeroForOne) cache := newSwapCache(feeProtocol, pool.liquidity) - state := newSwapState(amountSpecified, feeGrowthGlobalX128, cache.liquidityStart, pool.slot0) + state := newSwapState(amounts, feeGrowthGlobalX128, cache.liquidityStart, pool.slot0) comp := SwapComputation{ - AmountSpecified: amountSpecified, - SqrtPriceLimitX96: sqrtPriceLimitX96, + AmountSpecified: amounts, + SqrtPriceLimitX96: sqrtPriceLimit, ZeroForOne: zeroForOne, - ExactInput: amountSpecified.Gt(i256.Zero()), + ExactInput: amounts.Gt(i256.Zero()), InitialState: state, Cache: cache, } @@ -273,8 +273,8 @@ func Swap( "prevRealm", prevRealm, "poolPath", GetPoolPath(token0Path, token1Path, fee), "zeroForOne", ufmt.Sprintf("%t", zeroForOne), - "amountSpecified", _amountSpecified, - "sqrtPriceLimitX96", _sqrtPriceLimitX96, + "amountSpecified", amountSpecified, + "sqrtPriceLimitX96", sqrtPriceLimitX96, "payer", payer.String(), "recipient", recipient.String(), "internal_amount0", result.Amount0.ToString(), @@ -697,9 +697,9 @@ func CollectProtocolByAdmin( token1Path string, fee uint32, recipient std.Address, - _amount0Requested string, // uint128 - _amount1Requested string, // uint128 -) (string, string) { // uint128 x2 + amount0Requested string, + amount1Requested string, +) (string, string) { caller := std.PrevRealm().Addr() if err := common.AdminOnly(caller); err != nil { panic(err) @@ -710,8 +710,8 @@ func CollectProtocolByAdmin( token1Path, fee, recipient, - _amount0Requested, - _amount1Requested, + amount0Requested, + amount1Requested, ) prevAddr, prevRealm := getPrev() @@ -739,9 +739,9 @@ func CollectProtocol( token1Path string, fee uint32, recipient std.Address, - _amount0Requested string, // uint128 - _amount1Requested string, // uint128 -) (string, string) { // uint128 x2 + amount0Requested string, + amount1Requested string, +) (string, string) { caller := std.PrevRealm().Addr() if err := common.GovernanceOnly(caller); err != nil { panic(err) @@ -752,8 +752,8 @@ func CollectProtocol( token1Path, fee, recipient, - _amount0Requested, - _amount1Requested, + amount0Requested, + amount1Requested, ) prevAddr, prevRealm := getPrev() @@ -777,18 +777,18 @@ func collectProtocol( token1Path string, fee uint32, recipient std.Address, - _amount0Requested string, // uint128 - _amount1Requested string, // uint128 -) (string, string) { // uint128 x2 + amount0Requested string, + amount1Requested string, +) (string, string) { common.IsHalted() - amount0Requested := u256.MustFromDecimal(_amount0Requested) - amount1Requested := u256.MustFromDecimal(_amount1Requested) - pool := GetPool(token0Path, token1Path, fee) - amount0 := u256Min(amount0Requested, pool.protocolFees.token0) - amount1 := u256Min(amount1Requested, pool.protocolFees.token1) + amount0Req := u256.MustFromDecimal(amount0Requested) + amount1Req := u256.MustFromDecimal(amount1Requested) + + amount0 := u256Min(amount0Req, pool.protocolFees.token0) + amount1 := u256Min(amount1Req, pool.protocolFees.token1) amount0, amount1 = pool.saveProtocolFees(amount0, amount1) uAmount0 := amount0.Uint64() From 292fc5cd672520c1817c1836babae834f9186101 Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Fri, 6 Dec 2024 15:35:27 +0900 Subject: [PATCH 18/24] condition as a function --- pool/pool.gno | 77 ++++++++++++++++++++++++++++----------------------- 1 file changed, 43 insertions(+), 34 deletions(-) diff --git a/pool/pool.gno b/pool/pool.gno index 04f8ebcf1..550bd9b32 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -71,8 +71,8 @@ func Burn( fee uint32, tickLower int32, tickUpper int32, - liquidityAmount string, -) (string, string) { + liquidityAmount string, // uint128 +) (string, string) { // uint256 x2 common.IsHalted() caller := std.PrevRealm().Addr() if common.GetLimitCaller() { @@ -129,9 +129,6 @@ func Collect( } } - amount0Req := u256.MustFromDecimal(amount0Requested) - amount1Req := u256.MustFromDecimal(amount1Requested) - pool := GetPool(token0Path, token1Path, fee) positionKey := positionGetKey(std.PrevRealm().Addr(), tickLower, tickUpper) @@ -143,22 +140,16 @@ func Collect( )) } - // Smallest of three: amount0Requested, position.tokensOwed0, pool.balances.token0 - amount0 := u256Min(amount0Req, position.tokensOwed0) - amount0 = u256Min(amount0, pool.balances.token0) + var amount0, amount1 *u256.Uint - // Update state first then transfer - position.tokensOwed0 = new(u256.Uint).Sub(position.tokensOwed0, amount0) - pool.balances.token0 = new(u256.Uint).Sub(pool.balances.token0, amount0) + // Smallest of three: amount0Requested, position.tokensOwed0, pool.balances.token0 + amount0Req := u256.MustFromDecimal(amount0Requested) + amount0, position.tokensOwed0, pool.balances.token0 = collectToken(amount0Req, position.tokensOwed0, pool.balances.token0) transferByRegisterCall(pool.token0Path, recipient, amount0.Uint64()) // Smallest of three: amount0Requested, position.tokensOwed0, pool.balances.token0 - amount1 := u256Min(amount1Req, position.tokensOwed1) - amount1 = u256Min(amount1, pool.balances.token1) - - // Update state first then transfer - position.tokensOwed1 = new(u256.Uint).Sub(position.tokensOwed1, amount1) - pool.balances.token1 = new(u256.Uint).Sub(pool.balances.token1, amount1) + amount1Req := u256.MustFromDecimal(amount1Requested) + amount1, position.tokensOwed1, pool.balances.token1 = collectToken(amount1Req, position.tokensOwed1, pool.balances.token1) transferByRegisterCall(pool.token1Path, recipient, amount1.Uint64()) pool.positions[positionKey] = position @@ -166,9 +157,20 @@ func Collect( return amount0.ToString(), amount1.ToString() } -//////////////////////////////////////////////////////////// -// Swap -//////////////////////////////////////////////////////////// +// collectToken handles the collection of a single token type (token0 or token1) +func collectToken( + amountReq, tokensOwed, poolBalance *u256.Uint, +) (amount, newTokensOwed, newPoolBalance *u256.Uint) { + // find smallest of three amounts + amount = u256Min(amountReq, tokensOwed) + amount = u256Min(amount, poolBalance) + + // value for update state + newTokensOwed = new(u256.Uint).Sub(tokensOwed, amount) + newPoolBalance = new(u256.Uint).Sub(poolBalance, amount) + + return amount, newTokensOwed, newPoolBalance +} // SwapResult encapsulates all state changes that occur as a result of a swap // This type ensure all state transitions are atomic and can be applied at once. @@ -506,15 +508,22 @@ func computeSwapStepInit(state SwapState, pool *Pool, zeroForOne bool) StepCompu // computeTargetSqrtRatio determines the target sqrt price for the current swap step. func computeTargetSqrtRatio(step StepComputations, sqrtPriceLimitX96 *u256.Uint, zeroForOne bool) *u256.Uint { - isLower := step.sqrtPriceNextX96.Lt(sqrtPriceLimitX96) - isHigher := step.sqrtPriceNextX96.Gt(sqrtPriceLimitX96) - - if (zeroForOne && isLower) || (!zeroForOne && isHigher) { + if shouldUsePriceLimit(step.sqrtPriceNextX96, sqrtPriceLimitX96, zeroForOne) { return sqrtPriceLimitX96 } return step.sqrtPriceNextX96 } +// shouldUsePriceLimit returns true if the price limit should be used instead of the next tick price +func shouldUsePriceLimit(sqrtPriceNext, sqrtPriceLimit *u256.Uint, zeroForOne bool) bool { + isLower := sqrtPriceNext.Lt(sqrtPriceLimit) + isHigher := sqrtPriceNext.Gt(sqrtPriceLimit) + if zeroForOne { + return isLower + } + return isHigher +} + // computeAmounts calculates the input and output amounts for the current swap step. func computeAmounts(state SwapState, sqrtRatioTargetX96 *u256.Uint, pool *Pool, step StepComputations) (SwapState, StepComputations) { sqrtPriceX96Str, amountInStr, amountOutStr, feeAmountStr := plp.SwapMathComputeSwapStepStr( @@ -591,22 +600,18 @@ func tickTransition(step StepComputations, zeroForOne bool, state SwapState, poo func (pool *Pool) swapTransfers(zeroForOne bool, payer, recipient std.Address, amount0, amount1 *i256.Int) { var targetTokenPath string var amount *i256.Int - var isToken0 bool - switch zeroForOne { - case true: + if zeroForOne { targetTokenPath = pool.token0Path amount = amount0 - isToken0 = true - case false: + } else { targetTokenPath = pool.token1Path amount = amount1 - isToken0 = false } // payer -> POOL -> recipient - pool.transferFromAndVerify(payer, consts.POOL_ADDR, targetTokenPath, amount, isToken0) - pool.transferAndVerify(recipient, targetTokenPath, amount, !isToken0) + pool.transferFromAndVerify(payer, consts.POOL_ADDR, targetTokenPath, amount, zeroForOne) + pool.transferAndVerify(recipient, targetTokenPath, amount, !zeroForOne) } // SetFeeProtocolByAdmin sets the fee protocol for all pools @@ -888,7 +893,7 @@ func updatePoolBalance( if isToken0 { newBalance, overflow = new(u256.Uint).SubOverflow(token0, amount) - if overflow || newBalance.Lt(u256.Zero()) { + if isBalanceOverflowOrNegative(overflow, newBalance) { return nil, ufmt.Errorf( "%s || cannot decrease, token0(%s) - amount(%s)", errTransferFailed.Error(), token0.ToString(), amount.ToString(), @@ -898,7 +903,7 @@ func updatePoolBalance( } newBalance, overflow = new(u256.Uint).SubOverflow(token1, amount) - if overflow || newBalance.Lt(u256.Zero()) { + if isBalanceOverflowOrNegative(overflow, newBalance) { return nil, ufmt.Errorf( "%s || cannot decrease, token1(%s) - amount(%s)", errTransferFailed.Error(), token1.ToString(), amount.ToString(), @@ -907,6 +912,10 @@ func updatePoolBalance( return newBalance, nil } +func isBalanceOverflowOrNegative(overflow bool, newBalance *u256.Uint) bool { + return overflow || newBalance.Lt(u256.Zero()) +} + func (pool *Pool) transferFromAndVerify( from, to std.Address, tokenPath string, From a8bc3e292cccaafb0606f4785a39f45bd5c3770b Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Sat, 7 Dec 2024 10:22:26 +0900 Subject: [PATCH 19/24] abs --- pool/pool.gno | 4 ++-- pool/position_modify.gno | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pool/pool.gno b/pool/pool.gno index 550bd9b32..3c64b5c5b 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -93,8 +93,8 @@ func Burn( position, amount0, amount1 := pool.modifyPosition(posParams) if amount0.Gt(i256.Zero()) || amount1.Gt(i256.Zero()) { - position.tokensOwed0 = new(u256.Uint).Add(position.tokensOwed0, amount0.Abs()) - position.tokensOwed1 = new(u256.Uint).Add(position.tokensOwed1, amount1.Abs()) + position.tokensOwed0 = new(u256.Uint).Add(position.tokensOwed0, amount0) + position.tokensOwed1 = new(u256.Uint).Add(position.tokensOwed1, amount1) } positionKey := positionGetKey(caller, tickLower, tickUpper) diff --git a/pool/position_modify.gno b/pool/position_modify.gno index 00fe51be1..12927f626 100644 --- a/pool/position_modify.gno +++ b/pool/position_modify.gno @@ -42,7 +42,7 @@ func (pool *Pool) modifyPosition(params ModifyPositionParams) (PositionInfo, *i2 amount1 = calculateToken1Amount(sqrtRatioLower, sqrtRatioUpper, liqDelta) } - return position, amount0, amount1 + return position, amount0.Abs(), amount1.Abs() } func calculateToken0Amount(sqrtPriceLower, sqrtPriceUpper *u256.Uint, liquidityDelta *i256.Int) *i256.Int { From 930664495cc9a2b1b283ba2041ac49e910e61cfd Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Sat, 7 Dec 2024 10:42:11 +0900 Subject: [PATCH 20/24] fix --- pool/errors.gno | 1 + pool/pool.gno | 43 +++++++++++++++++++++++++++------------- pool/position_modify.gno | 8 ++++---- 3 files changed, 34 insertions(+), 18 deletions(-) diff --git a/pool/errors.gno b/pool/errors.gno index 681539cd5..9fcb48219 100644 --- a/pool/errors.gno +++ b/pool/errors.gno @@ -31,6 +31,7 @@ var ( errTransferFailed = errors.New("[GNOSWAP-POOL-021] token transfer failed") errInvalidTickAndTickSpacing = errors.New("[GNOSWAP-POOL-022] invalid tick and tick spacing requested") errInvalidAddress = errors.New("[GNOSWAP-POOL-023] invalid address") + errUnderflow = errors.New("[GNOSWAP-POOL-024] underflow") // TODO: make as common error code ) // addDetailToError adds detail to an error message diff --git a/pool/pool.gno b/pool/pool.gno index 3c64b5c5b..291ef54e5 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -50,11 +50,11 @@ func Mint( position := newModifyPositionParams(recipient, tickLower, tickUpper, i256.FromUint256(liquidityAmount)) _, amount0, amount1 := pool.modifyPosition(position) - if amount0.Gt(i256.Zero()) { + if amount0.Gt(u256.Zero()) { pool.transferFromAndVerify(positionCaller, consts.POOL_ADDR, pool.token0Path, amount0, true) } - if amount1.Gt(i256.Zero()) { + if amount1.Gt(u256.Zero()) { pool.transferFromAndVerify(positionCaller, consts.POOL_ADDR, pool.token1Path, amount1, false) } @@ -92,7 +92,7 @@ func Burn( posParams := newModifyPositionParams(caller, tickLower, tickUpper, liqDelta) position, amount0, amount1 := pool.modifyPosition(posParams) - if amount0.Gt(i256.Zero()) || amount1.Gt(i256.Zero()) { + if amount0.Gt(u256.Zero()) || amount1.Gt(u256.Zero()) { position.tokensOwed0 = new(u256.Uint).Add(position.tokensOwed0, amount0) position.tokensOwed1 = new(u256.Uint).Add(position.tokensOwed1, amount1) } @@ -309,10 +309,16 @@ func computeSwap(pool *Pool, comp SwapComputation) (*SwapResult, error) { state := comp.InitialState swapFee := u256.Zero() + + var newFee *u256.Uint + var err error + // Compute swap steps until completion for shouldContinueSwap(state, comp.SqrtPriceLimitX96) { - var newFee *u256.Uint - state, newFee = computeSwapStep(state, pool, comp.ZeroForOne, comp.SqrtPriceLimitX96, comp.ExactInput, comp.Cache, swapFee) + state, newFee, err = computeSwapStep(state, pool, comp.ZeroForOne, comp.SqrtPriceLimitX96, comp.ExactInput, comp.Cache, swapFee) + if err != nil { + return nil, err + } swapFee = newFee } @@ -437,7 +443,7 @@ func computeSwapStep( exactInput bool, cache SwapCache, swapFee *u256.Uint, -) (SwapState, *u256.Uint) { +) (SwapState, *u256.Uint, error) { step := computeSwapStepInit(state, pool, zeroForOne) // determining the price target for this step @@ -445,13 +451,18 @@ func computeSwapStep( // computing the amounts to be swapped at this step var newState SwapState + var err error + newState, step = computeAmounts(state, sqrtRatioTargetX96, pool, step) newState = updateAmounts(step, newState, exactInput) // if the protocol fee is on, calculate how much is owed, // decrement fee amount, and increment protocol fee if cache.feeProtocol > 0 { - newState = updateFeeProtocol(step, cache.feeProtocol, newState) + newState, err = updateFeeProtocol(step, cache.feeProtocol, newState) + if err != nil { + return state, nil, err + } } // update global fee tracker @@ -471,18 +482,22 @@ func computeSwapStep( newSwapFee := new(u256.Uint).Add(swapFee, step.feeAmount) - return newState, newSwapFee + return newState, newSwapFee, nil } // updateFeeProtocol calculates and updates protocol fees for the current step. -func updateFeeProtocol(step StepComputations, feeProtocol uint8, state SwapState) SwapState { +func updateFeeProtocol(step StepComputations, feeProtocol uint8, state SwapState) (SwapState, error) { delta := step.feeAmount delta.Div(delta, u256.NewUint(uint64(feeProtocol))) - step.feeAmount.Sub(step.feeAmount, delta) + newFeeAmount, overflow := new(u256.Uint).SubOverflow(step.feeAmount, delta) + if overflow { + return state, errUnderflow + } + step.feeAmount = newFeeAmount state.protocolFee.Add(state.protocolFee, delta) - return state + return state, nil } // computeSwapStepInit initializes the computation for a single swap step. @@ -610,7 +625,7 @@ func (pool *Pool) swapTransfers(zeroForOne bool, payer, recipient std.Address, a } // payer -> POOL -> recipient - pool.transferFromAndVerify(payer, consts.POOL_ADDR, targetTokenPath, amount, zeroForOne) + pool.transferFromAndVerify(payer, consts.POOL_ADDR, targetTokenPath, amount.Abs(), zeroForOne) pool.transferAndVerify(recipient, targetTokenPath, amount, !zeroForOne) } @@ -919,10 +934,10 @@ func isBalanceOverflowOrNegative(overflow bool, newBalance *u256.Uint) bool { func (pool *Pool) transferFromAndVerify( from, to std.Address, tokenPath string, - amount *i256.Int, + amount *u256.Uint, isToken0 bool, ) { - absAmount := amount.Abs() + absAmount := amount amountUint64, err := checkAmountRange(absAmount) if err != nil { panic(err) diff --git a/pool/position_modify.gno b/pool/position_modify.gno index 12927f626..0400b126a 100644 --- a/pool/position_modify.gno +++ b/pool/position_modify.gno @@ -10,16 +10,16 @@ import ( // modifyPosition updates a position in the pool and calculates the amount of tokens to be added or removed. // Returns positionInfo, amount0, amount1 -func (pool *Pool) modifyPosition(params ModifyPositionParams) (PositionInfo, *i256.Int, *i256.Int) { +func (pool *Pool) modifyPosition(params ModifyPositionParams) (PositionInfo, *u256.Uint, *u256.Uint) { position := pool.updatePosition(params) liqDelta := params.liquidityDelta - amount0, amount1 := i256.Zero(), i256.Zero() - if liqDelta.IsZero() { - return position, amount0, amount1 + return position, u256.Zero(), u256.Zero() } + amount0, amount1 := i256.Zero(), i256.Zero() + tick := pool.slot0.tick sqrtRatioLower := common.TickMathGetSqrtRatioAtTick(params.tickLower) sqrtRatioUpper := common.TickMathGetSqrtRatioAtTick(params.tickUpper) From 19f1a8e6027e9bcadfe6b4ff41ff6bd52cba56c1 Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Mon, 9 Dec 2024 15:00:02 +0900 Subject: [PATCH 21/24] test: position_update --- pool/position_update_test.gno | 85 +++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 pool/position_update_test.gno diff --git a/pool/position_update_test.gno b/pool/position_update_test.gno new file mode 100644 index 000000000..24ca4cb55 --- /dev/null +++ b/pool/position_update_test.gno @@ -0,0 +1,85 @@ +package pool + +import ( + "testing" + + "std" + + "gno.land/p/demo/uassert" + + "gno.land/r/gnoswap/v1/consts" + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +func TestUpdatePosition(t *testing.T) { + poolParams := &createPoolParams{ + token0Path: "token0", + token1Path: "token1", + fee: 500, + tickSpacing: 10, + sqrtPriceX96: u256.MustFromDecimal("1000000000000000000"), // 1.0 + } + p := newPool(poolParams) + + tests := []struct { + name string + positionParams ModifyPositionParams + expectLiquidity *u256.Uint + }{ + { + name: "add new position", + positionParams: ModifyPositionParams{ + owner: consts.POSITION_ADDR, + tickLower: -100, + tickUpper: 100, + liquidityDelta: i256.MustFromDecimal("1000000"), + }, + expectLiquidity: u256.MustFromDecimal("1000000"), + }, + { + name: "add liquidity to existing position", + positionParams: ModifyPositionParams{ + owner: consts.POSITION_ADDR, + tickLower: -100, + tickUpper: 100, + liquidityDelta: i256.MustFromDecimal("500000"), + }, + expectLiquidity: u256.MustFromDecimal("1500000"), + }, + { + name: "remove liquidity from position", + positionParams: ModifyPositionParams{ + owner: consts.POSITION_ADDR, + tickLower: -100, + tickUpper: 100, + liquidityDelta: i256.MustFromDecimal("-500000"), + }, + expectLiquidity: u256.MustFromDecimal("1000000"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + position := p.updatePosition(tt.positionParams) + + if !position.liquidity.Eq(tt.expectLiquidity) { + t.Errorf("liquidity mismatch: expected %s, got %s", + tt.expectLiquidity.ToString(), + position.liquidity.ToString()) + } + + if !tt.positionParams.liquidityDelta.IsZero() { + lowerTick := p.ticks[tt.positionParams.tickLower] + upperTick := p.ticks[tt.positionParams.tickUpper] + + if !lowerTick.initialized { + t.Error("lower tick not initialized") + } + if !upperTick.initialized { + t.Error("upper tick not initialized") + } + } + }) + } +} From 303746a08bfebb51565d49f437257bfd292f21fd Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Mon, 9 Dec 2024 15:03:04 +0900 Subject: [PATCH 22/24] X128 prefix --- pool/pool.gno | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pool/pool.gno b/pool/pool.gno index 291ef54e5..1ab2b8696 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -181,8 +181,8 @@ type SwapResult struct { NewTick int32 NewLiquidity *u256.Uint NewProtocolFees ProtocolFees - FeeGrowthGlobal0 *u256.Uint - FeeGrowthGlobal1 *u256.Uint + FeeGrowthGlobal0X128 *u256.Uint + FeeGrowthGlobal1X128 *u256.Uint SwapFee *u256.Uint } @@ -340,8 +340,8 @@ func computeSwap(pool *Pool, comp SwapComputation) (*SwapResult, error) { token0: pool.protocolFees.token0, token1: pool.protocolFees.token1, }, - FeeGrowthGlobal0: pool.feeGrowthGlobal0X128, - FeeGrowthGlobal1: pool.feeGrowthGlobal1X128, + FeeGrowthGlobal0X128: pool.feeGrowthGlobal0X128, + FeeGrowthGlobal1X128: pool.feeGrowthGlobal1X128, SwapFee: swapFee, } @@ -350,12 +350,12 @@ func computeSwap(pool *Pool, comp SwapComputation) (*SwapResult, error) { if state.protocolFee.Gt(u256.Zero()) { result.NewProtocolFees.token0 = new(u256.Uint).Add(result.NewProtocolFees.token0, state.protocolFee) } - result.FeeGrowthGlobal0 = state.feeGrowthGlobalX128 + result.FeeGrowthGlobal0X128 = state.feeGrowthGlobalX128 } else { if state.protocolFee.Gt(u256.Zero()) { result.NewProtocolFees.token1 = new(u256.Uint).Add(result.NewProtocolFees.token1, state.protocolFee) } - result.FeeGrowthGlobal1 = state.feeGrowthGlobalX128 + result.FeeGrowthGlobal1X128 = state.feeGrowthGlobalX128 } return result, nil @@ -368,8 +368,8 @@ func applySwapResult(pool *Pool, result *SwapResult) { pool.slot0.tick = result.NewTick pool.liquidity = result.NewLiquidity pool.protocolFees = result.NewProtocolFees - pool.feeGrowthGlobal0X128 = result.FeeGrowthGlobal0 - pool.feeGrowthGlobal1X128 = result.FeeGrowthGlobal1 + pool.feeGrowthGlobal0X128 = result.FeeGrowthGlobal0X128 + pool.feeGrowthGlobal1X128 = result.FeeGrowthGlobal1X128 } // validatePriceLimits ensures the provided price limit is valid for the swap direction From 5b0b453529c9b3d1957752637f5e0aacb4ed4cbb Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Mon, 9 Dec 2024 19:30:22 +0900 Subject: [PATCH 23/24] remove TODO --- pool/_RPC_dry.gno | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pool/_RPC_dry.gno b/pool/_RPC_dry.gno index 5f76a0935..5ffe30c86 100644 --- a/pool/_RPC_dry.gno +++ b/pool/_RPC_dry.gno @@ -58,12 +58,11 @@ func DrySwap( feeGrowthGlobalX128 = pool.feeGrowthGlobal1X128 } - // pool.slot0.unlocked = false slot0 := pool.slot0 slot0.unlocked = false cache := newSwapCache(feeProtocol, pool.liquidity) - state := newSwapState(amountSpecified.Clone(), feeGrowthGlobalX128.Clone(), cache.liquidityStart.Clone(), slot0) // TODO: feeGrowthGlobalX128.Clone() or NOT + state := newSwapState(amountSpecified.Clone(), feeGrowthGlobalX128.Clone(), cache.liquidityStart.Clone(), slot0) exactInput := amountSpecified.Gt(i256.Zero()) From d68748e6d8c4d5ba83b93f61bb585c8e3f2c83f9 Mon Sep 17 00:00:00 2001 From: 0xTopaz Date: Tue, 10 Dec 2024 15:01:31 +0900 Subject: [PATCH 24/24] GSW-1838 fix: Modify test code to change transferFromAndVerify function param --- pool/pool_test.gno | 258 +++++++++++++++++----------------- pool/position_modify_test.gno | 43 +++--- 2 files changed, 150 insertions(+), 151 deletions(-) diff --git a/pool/pool_test.gno b/pool/pool_test.gno index 7e5ecf3a5..af02987d9 100644 --- a/pool/pool_test.gno +++ b/pool/pool_test.gno @@ -4,138 +4,138 @@ import ( "std" "testing" - "gno.land/r/gnoswap/v1/consts" - "gno.land/p/demo/uassert" "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" i256 "gno.land/p/gnoswap/int256" u256 "gno.land/p/gnoswap/uint256" + "gno.land/r/gnoswap/v1/consts" ) func TestMint(t *testing.T) { - token0Path := "test_token0" - token1Path := "test_token1" - fee := uint32(3000) - recipient := testutils.TestAddress("recipient") - tickLower := int32(-100) - tickUpper := int32(100) - liquidityAmount := "100000" - - t.Run("unauthorized caller mint should fail", func(t *testing.T) { - unauthorized := testutils.TestAddress("unauthorized") - defer func() { - if r := recover(); r == nil { - t.Error("unauthorized caller mint should fail") - } - }() - - Mint(token0Path, token1Path, fee, recipient, tickLower, tickUpper, liquidityAmount, unauthorized) - }) - - t.Run("mint with 0 liquidity should fail", func(t *testing.T) { - authorized := consts.POSITION_ADDR - defer func() { - if r := recover(); r == nil { - t.Error("mint with 0 liquidity should fail") - } - }() - - Mint(token0Path, token1Path, fee, recipient, tickLower, tickUpper, "0", authorized) - }) + token0Path := "test_token0" + token1Path := "test_token1" + fee := uint32(3000) + recipient := testutils.TestAddress("recipient") + tickLower := int32(-100) + tickUpper := int32(100) + liquidityAmount := "100000" + + t.Run("unauthorized caller mint should fail", func(t *testing.T) { + unauthorized := testutils.TestAddress("unauthorized") + defer func() { + if r := recover(); r == nil { + t.Error("unauthorized caller mint should fail") + } + }() + + Mint(token0Path, token1Path, fee, recipient, tickLower, tickUpper, liquidityAmount, unauthorized) + }) + + t.Run("mint with 0 liquidity should fail", func(t *testing.T) { + authorized := consts.POSITION_ADDR + defer func() { + if r := recover(); r == nil { + t.Error("mint with 0 liquidity should fail") + } + }() + + Mint(token0Path, token1Path, fee, recipient, tickLower, tickUpper, "0", authorized) + }) } func TestBurn(t *testing.T) { - // Setup - originalGetPool := GetPool - defer func() { - GetPool = originalGetPool - }() - - // Mock data - mockCaller := consts.POSITION_ADDR - mockPosition := PositionInfo{ - liquidity: u256.NewUint(1000), - tokensOwed0: u256.NewUint(0), - tokensOwed1: u256.NewUint(0), - } - mockPool := &Pool{ - positions: make(map[string]PositionInfo), - } - - GetPool = func(token0Path, token1Path string, fee uint32) *Pool { - return mockPool - } - - tests := []struct { - name string - liquidityAmount string - tickLower int32 - tickUpper int32 - expectedAmount0 string - expectedAmount1 string - expectPanic bool - }{ - { - name: "successful burn", - liquidityAmount: "500", - tickLower: -100, - tickUpper: 100, - expectedAmount0: "100", - expectedAmount1: "200", - }, - { - name: "zero liquidity", - liquidityAmount: "0", - tickLower: -100, - tickUpper: 100, - expectPanic: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.name == "successful burn" { - t.Skip("skipping until find better way to test this") - } - - // setup position for this test - posKey := positionGetKey(mockCaller, tt.tickLower, tt.tickUpper) - mockPool.positions[posKey] = mockPosition - - if tt.expectPanic { - defer func() { - if r := recover(); r == nil { - t.Errorf("expected panic but got none") - } - }() - } - - amount0, amount1 := Burn( - "token0", - "token1", - 3000, - tt.tickLower, - tt.tickUpper, - tt.liquidityAmount, - ) - - if !tt.expectPanic { - if amount0 != tt.expectedAmount0 { - t.Errorf("expected amount0 %s, got %s", tt.expectedAmount0, amount0) - } - if amount1 != tt.expectedAmount1 { - t.Errorf("expected amount1 %s, got %s", tt.expectedAmount1, amount1) - } - - newPosition := mockPool.positions[posKey] - if newPosition.tokensOwed0.IsZero() { - t.Error("expected tokensOwed0 to be updated") - } - if newPosition.tokensOwed1.IsZero() { - t.Error("expected tokensOwed1 to be updated") - } - } - }) - } + // Setup + originalGetPool := GetPool + defer func() { + GetPool = originalGetPool + }() + + // Mock data + mockCaller := consts.POSITION_ADDR + mockPosition := PositionInfo{ + liquidity: u256.NewUint(1000), + tokensOwed0: u256.NewUint(0), + tokensOwed1: u256.NewUint(0), + } + mockPool := &Pool{ + positions: make(map[string]PositionInfo), + } + + GetPool = func(token0Path, token1Path string, fee uint32) *Pool { + return mockPool + } + + tests := []struct { + name string + liquidityAmount string + tickLower int32 + tickUpper int32 + expectedAmount0 string + expectedAmount1 string + expectPanic bool + }{ + { + name: "successful burn", + liquidityAmount: "500", + tickLower: -100, + tickUpper: 100, + expectedAmount0: "100", + expectedAmount1: "200", + }, + { + name: "zero liquidity", + liquidityAmount: "0", + tickLower: -100, + tickUpper: 100, + expectPanic: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.name == "successful burn" { + t.Skip("skipping until find better way to test this") + } + + // setup position for this test + posKey := positionGetKey(mockCaller, tt.tickLower, tt.tickUpper) + mockPool.positions[posKey] = mockPosition + + if tt.expectPanic { + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic but got none") + } + }() + } + + amount0, amount1 := Burn( + "token0", + "token1", + 3000, + tt.tickLower, + tt.tickUpper, + tt.liquidityAmount, + ) + + if !tt.expectPanic { + if amount0 != tt.expectedAmount0 { + t.Errorf("expected amount0 %s, got %s", tt.expectedAmount0, amount0) + } + if amount1 != tt.expectedAmount1 { + t.Errorf("expected amount1 %s, got %s", tt.expectedAmount1, amount1) + } + + newPosition := mockPool.positions[posKey] + if newPosition.tokensOwed0.IsZero() { + t.Error("expected tokensOwed0 to be updated") + } + if newPosition.tokensOwed1.IsZero() { + t.Error("expected tokensOwed1 to be updated") + } + } + }) + } } func TestSaveProtocolFees(t *testing.T) { @@ -479,7 +479,7 @@ func TestComputeSwap(t *testing.T) { } wordPos, _ := tickBitmapPosition(0) - // TODO: use avl + // TODO: use avl mockPool.tickBitmaps[wordPos] = u256.NewUint(1) t.Run("basic swap", func(t *testing.T) { @@ -623,11 +623,11 @@ func TestTransferFromAndVerify(t *testing.T) { defer func() { transferFromByRegisterCall = oldTransferFromByRegisterCall }() transferFromByRegisterCall = func(tokenPath string, from, to std.Address, amount uint64) bool { - // mock the transfer (just return true) - return true + // mock the transfer (just return true) + return true } - tt.pool.transferFromAndVerify(tt.from, tt.to, tt.tokenPath, tt.amount, tt.isToken0) + tt.pool.transferFromAndVerify(tt.from, tt.to, tt.tokenPath, u256.MustFromDecimal(tt.amount.ToString()), tt.isToken0) if !tt.pool.balances.token0.Eq(tt.expectedBal0) { t.Errorf("token0 balance mismatch: expected %s, got %s", @@ -663,7 +663,7 @@ func TestTransferFromAndVerify(t *testing.T) { testutils.TestAddress("from_addr"), testutils.TestAddress("to_addr"), "token0_path", - negativeAmount, + u256.MustFromDecimal(negativeAmount.Abs().ToString()), true, ) @@ -695,7 +695,7 @@ func TestTransferFromAndVerify(t *testing.T) { testutils.TestAddress("from_addr"), testutils.TestAddress("to_addr"), "token0_path", - hugeAmount, + u256.MustFromDecimal(hugeAmount.ToString()), true, ) }) diff --git a/pool/position_modify_test.gno b/pool/position_modify_test.gno index 7aed386ef..dec483e61 100644 --- a/pool/position_modify_test.gno +++ b/pool/position_modify_test.gno @@ -4,13 +4,12 @@ import ( "testing" "gno.land/p/demo/uassert" - u256 "gno.land/p/gnoswap/uint256" i256 "gno.land/p/gnoswap/int256" - "gno.land/r/gnoswap/v1/consts" + u256 "gno.land/p/gnoswap/uint256" "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/consts" ) - func TestModifyPosition(t *testing.T) { const ( fee500 = uint32(500) @@ -18,34 +17,34 @@ func TestModifyPosition(t *testing.T) { ) tests := []struct { - name string - sqrtPrice string - tickLower int32 - tickUpper int32 + name string + sqrtPrice string + tickLower int32 + tickUpper int32 expectedAmt0 string expectedAmt1 string }{ { - name: "current price is lower than range", - sqrtPrice: common.TickMathGetSqrtRatioAtTick(-12000).ToString(), - tickLower: -11000, - tickUpper: -9000, + name: "current price is lower than range", + sqrtPrice: common.TickMathGetSqrtRatioAtTick(-12000).ToString(), + tickLower: -11000, + tickUpper: -9000, expectedAmt0: "16492846", expectedAmt1: "0", }, { - name: "current price is in range", - sqrtPrice: common.TickMathGetSqrtRatioAtTick(-10000).ToString(), - tickLower: -11000, - tickUpper: -9000, + name: "current price is in range", + sqrtPrice: common.TickMathGetSqrtRatioAtTick(-10000).ToString(), + tickLower: -11000, + tickUpper: -9000, expectedAmt0: "8040316", expectedAmt1: "2958015", }, { - name: "current price is higher than range", - sqrtPrice: common.TickMathGetSqrtRatioAtTick(-8000).ToString(), - tickLower: -11000, - tickUpper: -9000, + name: "current price is higher than range", + sqrtPrice: common.TickMathGetSqrtRatioAtTick(-8000).ToString(), + tickLower: -11000, + tickUpper: -9000, expectedAmt0: "0", expectedAmt1: "6067683", }, @@ -86,7 +85,7 @@ func TestModifyPositionEdgeCases(t *testing.T) { barPath, fooPath, fee500, - sp.ToString(), + sqrtPrice, ) pool := newPool(poolParams) params := ModifyPositionParams{ @@ -106,7 +105,7 @@ func TestModifyPositionEdgeCases(t *testing.T) { } } }() - + pool.modifyPosition(params) }) @@ -115,7 +114,7 @@ func TestModifyPositionEdgeCases(t *testing.T) { barPath, fooPath, fee500, - sp.ToString(), + sqrtPrice, ) pool := newPool(poolParams) params := ModifyPositionParams{