From 168c7df7ea6582e764ae02192f200665c94042a9 Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Mon, 16 Dec 2024 12:51:40 +0900 Subject: [PATCH] refactor: dry swap (#421) * remove duplicate functions * GSW-1838 refactor: Using the new constructor to protect raw data in sqrtRatio calculations * GSW-1838 refactor: need to lock pool.slot0 to prevent re-entry * refactor: Using clone data to protect original data * refactor: remove unused import * fix: sqrtRatio calculation default value issue * refactor: swap and swap math * refactor: computeSwapStep - edge case test (amount : zero and over liquidity) - EXACT IN / EXACT OUT Case * refactor: Separate pool transfer-related test code * refactor: Rename a receiver function param and liquidity math bug fix - Fix absolute value before checking if delta value is negative to fix handling error issue if negative. - Rename the param in the receiver function from pool to p. * refactor: Changed large numbers to const for readability & removed some comments * refactor: Remove unnecessary function usage * refactor: Fix error messages to make their meaning clear --------- Co-authored-by: 0xTopaz Co-authored-by: 0xTopaz <60733299+onlyhyde@users.noreply.github.com> Co-authored-by: Blake <104744707+r3v4s@users.noreply.github.com> --- _deploy/p/gnoswap/pool/sqrt_price_math.gno | 2 +- _deploy/r/gnoswap/common/tick_math.gno | 22 +- _deploy/r/gnoswap/consts/consts.gno | 3 +- pool/_RPC_dry.gno | 208 ----- pool/_helper_test.gno | 157 ++++ pool/liquidity_math.gno | 41 +- pool/liquidity_math_test.gno | 8 +- pool/pool.gno | 744 ++-------------- pool/pool_test.gno | 563 ------------ pool/pool_transfer.gno | 171 ++++ pool/pool_transfer_test.gno | 298 +++++++ pool/position_modify.gno | 46 +- pool/swap.gno | 524 +++++++++++ pool/swap_test.gno | 954 +++++++++++++++++++++ pool/tick.gno | 85 +- pool/tick_bitmap.gno | 20 +- pool/type.gno | 145 +++- pool/utils.gno | 14 + 18 files changed, 2448 insertions(+), 1557 deletions(-) delete mode 100644 pool/_RPC_dry.gno create mode 100644 pool/pool_transfer.gno create mode 100644 pool/pool_transfer_test.gno create mode 100644 pool/swap.gno create mode 100644 pool/swap_test.gno diff --git a/_deploy/p/gnoswap/pool/sqrt_price_math.gno b/_deploy/p/gnoswap/pool/sqrt_price_math.gno index 29b16feaf..34a37d655 100644 --- a/_deploy/p/gnoswap/pool/sqrt_price_math.gno +++ b/_deploy/p/gnoswap/pool/sqrt_price_math.gno @@ -118,7 +118,7 @@ func sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown( res := new(u256.Uint).Add(sqrtPX96, quotient) if res.Gt(max160) { - panic("sqrtPx96 + quotient overflow uint160") + panic("GetNextSqrtPriceFromAmount1RoundingDown sqrtPx96 + quotient overflow uint160") } return res } else { diff --git a/_deploy/r/gnoswap/common/tick_math.gno b/_deploy/r/gnoswap/common/tick_math.gno index 40bca145a..2a0bc689a 100644 --- a/_deploy/r/gnoswap/common/tick_math.gno +++ b/_deploy/r/gnoswap/common/tick_math.gno @@ -43,21 +43,28 @@ var binaryLogConsts = [8]*u256.Uint{ var ( shift1By32Left = u256.MustFromDecimal("4294967296") // (1 << 32) + maxTick = int32(887272) ) func TickMathGetSqrtRatioAtTick(tick int32) *u256.Uint { // uint160 sqrtPriceX96 absTick := abs(tick) - if absTick > 887272 { // MAX_TICK + if absTick > maxTick { panic(addDetailToError( errOutOfRange, - ufmt.Sprintf("tick_math.gno__TickMathGetSqrtRatioAtTick() || tick is out of range (larger than 887272), tick: %d", tick), + ufmt.Sprintf("tick is out of range (larger than 887272), tick: %d", tick), )) } - ratio := u256.MustFromDecimal("340282366920938463463374607431768211456") // consts.Q128 + var initialBit int32 = 0x1 + var ratio *u256.Uint + if (absTick & initialBit) != 0 { + ratio = tickRatioMap[initialBit] + } else { + ratio = u256.MustFromDecimal("340282366920938463463374607431768211456") // consts.Q128 + } for mask, value := range tickRatioMap { - if absTick&mask != 0 { + if (mask != initialBit) && absTick&mask != 0 { // ratio = (ratio * value) >> 128 ratio = ratio.Mul(ratio, value) ratio = ratio.Rsh(ratio, 128) @@ -65,12 +72,11 @@ func TickMathGetSqrtRatioAtTick(tick int32) *u256.Uint { // uint160 sqrtPriceX96 } if tick > 0 { - _maxUint256 := u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935") // consts.MAX_UINT256 - _tmp := new(u256.Uint).Div(_maxUint256, ratio) - ratio = _tmp.Clone() + maxUint256 := u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935") // consts.MAX_UINT256 + ratio = new(u256.Uint).Div(maxUint256, ratio) } - shifted := ratio.Rsh(ratio, 32).Clone() // ratio >> 32 + shifted := new(u256.Uint).Rsh(ratio, 32) // ratio >> 32 remainder := ratio.Mod(ratio, shift1By32Left) // ratio % (1 << 32) var adj *u256.Uint diff --git a/_deploy/r/gnoswap/consts/consts.gno b/_deploy/r/gnoswap/consts/consts.gno index 10e47e85d..9f9228e9f 100644 --- a/_deploy/r/gnoswap/consts/consts.gno +++ b/_deploy/r/gnoswap/consts/consts.gno @@ -90,9 +90,8 @@ const ( UINT64_MAX uint64 = 18446744073709551615 MAX_UINT128 string = "340282366920938463463374607431768211455" - MAX_UINT160 string = "1461501637330902918203684832716283019655932542975" - + MAX_INT256 string = "57896044618658097711785492504343953926634992332820282019728792003956564819968" MAX_UINT256 string = "115792089237316195423570985008687907853269984665640564039457584007913129639935" // Tick Related diff --git a/pool/_RPC_dry.gno b/pool/_RPC_dry.gno deleted file mode 100644 index 5ffe30c86..000000000 --- a/pool/_RPC_dry.gno +++ /dev/null @@ -1,208 +0,0 @@ -package pool - -import ( - "gno.land/r/gnoswap/v1/common" - - "gno.land/r/gnoswap/v1/consts" - - plp "gno.land/p/gnoswap/pool" // pool package - - i256 "gno.land/p/gnoswap/int256" - u256 "gno.land/p/gnoswap/uint256" -) - -// DrySwap simulates a swap and returns the amount0, amount1 that would be received and a boolean indicating if the swap is possible -func DrySwap( - token0Path string, - token1Path string, - fee uint32, - zeroForOne bool, - _amountSpecified string, - _sqrtPriceLimitX96 string, -) (string, string, bool) { - - if _amountSpecified == "0" { - return "0", "0", false - } - - amountSpecified := i256.MustFromDecimal(_amountSpecified) - sqrtPriceLimitX96 := u256.MustFromDecimal(_sqrtPriceLimitX96) - - pool := GetPool(token0Path, token1Path, fee) - slot0Start := pool.slot0 - - var feeProtocol uint8 - var feeGrowthGlobalX128 *u256.Uint - - if zeroForOne { - minSqrtRatio := u256.MustFromDecimal(consts.MIN_SQRT_RATIO) - - cond1 := sqrtPriceLimitX96.Lt(slot0Start.sqrtPriceX96) - cond2 := sqrtPriceLimitX96.Gt(minSqrtRatio) - if !(cond1 && cond2) { - return "0", "0", false - } - - feeProtocol = slot0Start.feeProtocol % 16 - feeGrowthGlobalX128 = pool.feeGrowthGlobal0X128 - } else { - maxSqrtRatio := u256.MustFromDecimal(consts.MAX_SQRT_RATIO) - - cond1 := sqrtPriceLimitX96.Gt(slot0Start.sqrtPriceX96) - cond2 := sqrtPriceLimitX96.Lt(maxSqrtRatio) - if !(cond1 && cond2) { - return "0", "0", false - } - - feeProtocol = slot0Start.feeProtocol / 16 - feeGrowthGlobalX128 = pool.feeGrowthGlobal1X128 - } - - slot0 := pool.slot0 - slot0.unlocked = false - - cache := newSwapCache(feeProtocol, pool.liquidity) - state := newSwapState(amountSpecified.Clone(), feeGrowthGlobalX128.Clone(), cache.liquidityStart.Clone(), slot0) - - exactInput := amountSpecified.Gt(i256.Zero()) - - // continue swapping as long as we haven't used the entire input/output and haven't reached the price limit - for !(state.amountSpecifiedRemaining.IsZero()) && !(state.sqrtPriceX96.Eq(sqrtPriceLimitX96)) { - var step StepComputations - step.sqrtPriceStartX96 = state.sqrtPriceX96 - - step.tickNext, step.initialized = pool.tickBitmapNextInitializedTickWithInOneWord( - state.tick, - pool.tickSpacing, - zeroForOne, - ) - - // ensure that we do not overshoot the min/max tick, as the tick bitmap is not aware of these bounds - if step.tickNext < consts.MIN_TICK { - step.tickNext = consts.MIN_TICK - } else if step.tickNext > consts.MAX_TICK { - step.tickNext = consts.MAX_TICK - } - - // get the price for the next tick - step.sqrtPriceNextX96 = common.TickMathGetSqrtRatioAtTick(step.tickNext) - - isLower := step.sqrtPriceNextX96.Lt(sqrtPriceLimitX96) - isHigher := step.sqrtPriceNextX96.Gt(sqrtPriceLimitX96) - - var sqrtRatioTargetX96 *u256.Uint - if (zeroForOne && isLower) || (!zeroForOne && isHigher) { - sqrtRatioTargetX96 = sqrtPriceLimitX96 - } else { - sqrtRatioTargetX96 = step.sqrtPriceNextX96 - } - - _sqrtPriceX96Str, _amountInStr, _amountOutStr, _feeAmountStr := plp.SwapMathComputeSwapStepStr( - state.sqrtPriceX96, - sqrtRatioTargetX96, - state.liquidity, - state.amountSpecifiedRemaining, - uint64(pool.fee), - ) - state.sqrtPriceX96 = u256.MustFromDecimal(_sqrtPriceX96Str) - step.amountIn = u256.MustFromDecimal(_amountInStr) - step.amountOut = u256.MustFromDecimal(_amountOutStr) - step.feeAmount = u256.MustFromDecimal(_feeAmountStr) - - amountInWithFee := i256.FromUint256(new(u256.Uint).Add(step.amountIn, step.feeAmount)) - if exactInput { - state.amountSpecifiedRemaining = i256.Zero().Sub(state.amountSpecifiedRemaining, amountInWithFee) - state.amountCalculated = i256.Zero().Sub(state.amountCalculated, i256.FromUint256(step.amountOut)) - } else { - state.amountSpecifiedRemaining = i256.Zero().Add(state.amountSpecifiedRemaining, i256.FromUint256(step.amountOut)) - state.amountCalculated = i256.Zero().Add(state.amountCalculated, amountInWithFee) - } - - // if the protocol fee is on, calculate how much is owed, decrement feeAmount, and increment protocolFee - if cache.feeProtocol > 0 { - delta := new(u256.Uint).Div(step.feeAmount, u256.NewUint(uint64(cache.feeProtocol))) - step.feeAmount = new(u256.Uint).Sub(step.feeAmount, delta) - state.protocolFee = new(u256.Uint).Add(state.protocolFee, delta) - } - - // update global fee tracker - if state.liquidity.Gt(u256.Zero()) { - // OBS if `DrySwap()` update its state, next ACTUAL `Swap()` gets affect - - // value1 := new(u256.Uint).Mul(step.feeAmount, u256.MustFromDecimal(consts.Q128)) - // value2 := new(u256.Uint).Div(value1, state.liquidity) - - // state.feeGrowthGlobalX128 = new(u256.Uint).Add(state.feeGrowthGlobalX128, value2) - } - - // shift tick if we reached the next price - if state.sqrtPriceX96.Eq(step.sqrtPriceNextX96) { - // if the tick is initialized, run the tick transition - if step.initialized { - var fee0, fee1 *u256.Uint - - // check for the placeholder value, which we replace with the actual value the first time the swap crosses an initialized tick - if zeroForOne { - fee0 = state.feeGrowthGlobalX128 - fee1 = pool.feeGrowthGlobal1X128 - } else { - fee0 = pool.feeGrowthGlobal0X128 - fee1 = state.feeGrowthGlobalX128 - } - - liquidityNet := pool.tickCross( - step.tickNext, - fee0, - fee1, - ) - - // if we're moving leftward, we interpret liquidityNet as the opposite sign - if zeroForOne { - liquidityNet = i256.Zero().Neg(liquidityNet) - } - - state.liquidity = liquidityMathAddDelta(state.liquidity, liquidityNet) - } - - if zeroForOne { - state.tick = step.tickNext - 1 - } else { - state.tick = step.tickNext - } - } else if !(state.sqrtPriceX96.Eq(step.sqrtPriceStartX96)) { - // recompute unless we're on a lower tick boundary (i.e. already transitioned ticks), and haven't moved - state.tick = common.TickMathGetTickAtSqrtRatio(state.sqrtPriceX96) - } - } - // END LOOP - - var amount0, amount1 *i256.Int - if zeroForOne == exactInput { - amount0 = i256.Zero().Sub(amountSpecified, state.amountSpecifiedRemaining) - amount1 = state.amountCalculated - } else { - amount0 = state.amountCalculated - amount1 = i256.Zero().Sub(amountSpecified, state.amountSpecifiedRemaining) - } - - pool.slot0.unlocked = true - - if zeroForOne { - if pool.balances.token1.Lt(amount1.Abs()) { - // NOT ENOUGH BALANCE for output token1 - return "0", "0", false - } - } else { - if pool.balances.token0.Lt(amount0.Abs()) { - // NOT ENOUGH BALANCE for output token0 - return "0", "0", false - } - } - - // JUST NOT ENOUGH BALANCE - if amount0.IsZero() || amount1.IsZero() { - return "0", "0", false - } - - return amount0.ToString(), amount1.ToString(), true -} diff --git a/pool/_helper_test.gno b/pool/_helper_test.gno index 992d99cc3..ea719ff2c 100644 --- a/pool/_helper_test.gno +++ b/pool/_helper_test.gno @@ -37,6 +37,8 @@ const ( fee3000 uint32 = 3000 maxApprove uint64 = 18446744073709551615 max_timeout int64 = 9999999999 + + maxSqrtPriceLimitX96 string = "1461446703485210103287273052203988822378723970341" ) const ( @@ -50,6 +52,7 @@ var ( alice = pusers.AddressOrName(testutils.TestAddress("alice")) pool = pusers.AddressOrName(consts.POOL_ADDR) protocolFee = pusers.AddressOrName(consts.PROTOCOL_FEE_ADDR) + router = pusers.AddressOrName(consts.ROUTER_ADDR) adminRealm = std.NewUserRealm(users.Resolve(admin)) posRealm = std.NewCodeRealm(consts.POSITION_PATH) @@ -215,6 +218,160 @@ func MintPosition(t *testing.T, caller) } +func MintPositionAll(t *testing.T, caller std.Address) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(caller)) + TokenApprove(t, gnsPath, pusers.AddressOrName(caller), pool, maxApprove) + TokenApprove(t, gnsPath, pusers.AddressOrName(caller), router, maxApprove) + TokenApprove(t, wugnotPath, pusers.AddressOrName(caller), pool, maxApprove) + TokenApprove(t, wugnotPath, pusers.AddressOrName(caller), router, maxApprove) + + params := []struct { + tickLower int32 + tickUpper int32 + liquidity uint64 + zeroToOne bool + }{ + { + tickLower: -300, + tickUpper: -240, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -240, + tickUpper: -180, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -180, + tickUpper: -120, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -120, + tickUpper: -60, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -60, + tickUpper: 0, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: 0, + tickUpper: 60, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 60, + tickUpper: 120, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 120, + tickUpper: 180, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 180, + tickUpper: 240, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 240, + tickUpper: 300, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: -360, + tickUpper: -300, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -420, + tickUpper: -360, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -480, + tickUpper: -420, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -540, + tickUpper: -480, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -600, + tickUpper: -540, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: 300, + tickUpper: 360, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 360, + tickUpper: 420, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 420, + tickUpper: 480, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 480, + tickUpper: 540, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 540, + tickUpper: 600, + liquidity: 10, + zeroToOne: false, + }, + } + + for _, p := range params { + MintPosition(t, + wugnotPath, + gnsPath, + fee3000, + p.tickLower, + p.tickUpper, + "100", + "100", + "0", + "0", + max_timeout, + caller, + caller) + } + +} + func wugnotApprove(t *testing.T, owner, spender pusers.AddressOrName, amount uint64) { t.Helper() std.TestSetRealm(std.NewUserRealm(users.Resolve(owner))) diff --git a/pool/liquidity_math.gno b/pool/liquidity_math.gno index 165acf834..47aaea2cc 100644 --- a/pool/liquidity_math.gno +++ b/pool/liquidity_math.gno @@ -7,39 +7,52 @@ import ( "gno.land/p/demo/ufmt" ) -// liquidityMathAddDelta Calculate the new liquidity with delta liquidity. -// If delta liquidity is negative, it will subtract the absolute value of delta liquidity from the current liquidity. -// If delta liquidity is positive, it will add the absolute value of delta liquidity to the current liquidity. -// inputs: -// - x: current liquidity -// - y: delta liquidity -// Returns the new liquidity. +// liquidityMathAddDelta calculates the new liquidity by applying the delta liquidity to the current liquidity. +// If delta liquidity is negative, it subtracts the absolute value of delta liquidity from the current liquidity. +// If delta liquidity is positive, it adds the absolute value of delta liquidity to the current liquidity. +// +// Parameters: +// - x: The current liquidity as a uint256 value. +// - y: The delta liquidity as a signed int256 value. +// +// Returns: +// - The new liquidity as a uint256 value. +// +// Notes: +// - If `x` or `y` is nil, the function panics with an appropriate error message. +// - If `y` is negative, its absolute value is subtracted from `x`. +// - The result must be less than `x`. Otherwise, the function panics to prevent underflow. +// +// - If `y` is positive, it is added to `x`. +// - The result must be greater than or equal to `x`. Otherwise, the function panics to prevent overflow. +// +// - The function ensures correctness by validating the results of the arithmetic operations. func liquidityMathAddDelta(x *u256.Uint, y *i256.Int) *u256.Uint { if x == nil || y == nil { panic(addDetailToError( errInvalidInput, - ufmt.Sprintf("liquidity_math.gno__liquidityMathAddDelta() || x or y is nil"), + "x or y is nil", )) } - absDelta := y.Abs() var z *u256.Uint // Subtract or add based on the sign of y if y.Lt(i256.Zero()) { + absDelta := y.Abs() z = new(u256.Uint).Sub(x, absDelta) - if z.Gte(x) { // z must be < x + if z.Gte(x) { panic(addDetailToError( errLiquidityCalculation, - ufmt.Sprintf("liquidity_math.gno__liquidityMathAddDelta() || LS(z must be < x) (x: %s, y: %s, z:%s)", x.ToString(), y.ToString(), z.ToString()), + ufmt.Sprintf("Less than Condition(z must be < x) (x: %s, y: %s, z:%s)", x.ToString(), y.ToString(), z.ToString()), )) } } else { - z = new(u256.Uint).Add(x, absDelta) - if z.Lt(x) { // z must be >= x + z = new(u256.Uint).Add(x, y.Abs()) + if z.Lt(x) { panic(addDetailToError( errLiquidityCalculation, - ufmt.Sprintf("liquidity_math.gno__liquidityMathAddDelta() || LA(z must be >= x) (x: %s, y: %s, z:%s)", x.ToString(), y.ToString(), z.ToString()), + ufmt.Sprintf("Less than or Equal Condition(z must be >= x) (x: %s, y: %s, z:%s)", x.ToString(), y.ToString(), z.ToString()), )) } } diff --git a/pool/liquidity_math_test.gno b/pool/liquidity_math_test.gno index 794dc2b97..fc2f379b2 100644 --- a/pool/liquidity_math_test.gno +++ b/pool/liquidity_math_test.gno @@ -22,7 +22,7 @@ func TestLiquidityMathAddDelta(t *testing.T) { y = i256.MustFromDecimal("100") liquidityMathAddDelta(nil, y) }, - wantPanic: addDetailToError(errInvalidInput, ufmt.Sprintf("liquidity_math.gno__liquidityMathAddDelta() || x or y is nil")), + wantPanic: addDetailToError(errInvalidInput, "x or y is nil"), }, { name: "y is nil", @@ -31,7 +31,7 @@ func TestLiquidityMathAddDelta(t *testing.T) { x = u256.MustFromDecimal("100") liquidityMathAddDelta(x, nil) }, - wantPanic: addDetailToError(errInvalidInput, ufmt.Sprintf("liquidity_math.gno__liquidityMathAddDelta() || x or y is nil")), + wantPanic: addDetailToError(errInvalidInput, "x or y is nil"), }, { name: "underflow panic with sub delta", @@ -42,7 +42,7 @@ func TestLiquidityMathAddDelta(t *testing.T) { }, wantPanic: addDetailToError( errLiquidityCalculation, - ufmt.Sprintf("liquidity_math.gno__liquidityMathAddDelta() || LS(z must be < x) (x: 0, y: -100, z:115792089237316195423570985008687907853269984665640564039457584007913129639836)")), + ufmt.Sprintf("Less than Condition(z must be < x) (x: 0, y: -100, z:115792089237316195423570985008687907853269984665640564039457584007913129639836)")), }, { name: "overflow panic with add delta", @@ -53,7 +53,7 @@ func TestLiquidityMathAddDelta(t *testing.T) { }, wantPanic: addDetailToError( errLiquidityCalculation, - ufmt.Sprintf("liquidity_math.gno__liquidityMathAddDelta() || LA(z must be >= x) (x: 115792089237316195423570985008687907853269984665640564039457584007913129639935, y: 100, z:99)")), + ufmt.Sprintf("Less than or Equal Condition(z must be >= x) (x: 115792089237316195423570985008687907853269984665640564039457584007913129639935, y: 100, z:99)")), }, } diff --git a/pool/pool.gno b/pool/pool.gno index 268d634c2..db4df9b8e 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -5,8 +5,6 @@ import ( "gno.land/p/demo/ufmt" - plp "gno.land/p/gnoswap/pool" - "gno.land/r/gnoswap/v1/common" "gno.land/r/gnoswap/v1/consts" @@ -33,17 +31,14 @@ func Mint( if err := common.PositionOnly(caller); err != nil { panic(addDetailToError( errNoPermission, - ufmt.Sprintf("pool.gno__Mint() || only position(%s) can call pool mint(), called from %s", consts.POSITION_ADDR, caller.String()), + ufmt.Sprintf("only position(%s) can call pool mint(), called from %s", consts.POSITION_ADDR, caller.String()), )) } } liquidityAmount := u256.MustFromDecimal(_liquidityAmount) if liquidityAmount.IsZero() { - panic(addDetailToError( - errZeroLiquidity, - ufmt.Sprintf("pool.gno__Mint() || liquidityAmount == 0"), - )) + panic(errZeroLiquidity) } pool := GetPool(token0Path, token1Path, fee) @@ -79,7 +74,7 @@ func Burn( if err := common.PositionOnly(caller); err != nil { panic(addDetailToError( errNoPermission, - ufmt.Sprintf("pool.gno__Burn() || only position(%s) can call pool burn(), called from %s", consts.POSITION_ADDR, caller.String()), + ufmt.Sprintf("only position(%s) can call pool burn(), called from %s", consts.POSITION_ADDR, caller.String()), )) } } @@ -124,7 +119,7 @@ func Collect( if err := common.PositionOnly(caller); err != nil { panic(addDetailToError( errNoPermission, - ufmt.Sprintf("pool.gno__Collect() || only position(%s) can call pool collect(), called from %s", consts.POSITION_ADDR, caller.String()), + ufmt.Sprintf("only position(%s) can call pool collect(), called from %s", consts.POSITION_ADDR, caller.String()), )) } } @@ -136,7 +131,7 @@ func Collect( if !exist { panic(addDetailToError( errDataNotFound, - ufmt.Sprintf("pool.gno__Collect() || positionKey(%s) does not exist", positionKey), + ufmt.Sprintf("positionKey(%s) does not exist", positionKey), )) } @@ -163,7 +158,19 @@ func Collect( return amount0.ToString(), amount1.ToString() } -// collectToken handles the collection of a single token type (token0 or token1) +// collectToken handles the collection of tokens (either token0 or token1) from a position. +// It calculates the actual amount that can be collected based on three constraints: +// the requested amount, tokens owed, and available pool balance. +// +// Parameters: +// - amountReq: amount requested to collect +// - tokensOwed: amount of tokens owed to the position +// - poolBalance: current balance of tokens in the pool +// +// Returns: +// - amount: actual amount that will be collected (minimum of the three inputs) +// - newTokensOwed: remaining tokens owed after collection +// - newPoolBalance: remaining pool balance after collection func collectToken( amountReq, tokensOwed, poolBalance *u256.Uint, ) (amount, newTokensOwed, newPoolBalance *u256.Uint) { @@ -178,467 +185,6 @@ func collectToken( return amount, newTokensOwed, newPoolBalance } -// SwapResult encapsulates all state changes that occur as a result of a swap -// This type ensure all state transitions are atomic and can be applied at once. -type SwapResult struct { - Amount0 *i256.Int - Amount1 *i256.Int - NewSqrtPrice *u256.Uint - NewTick int32 - NewLiquidity *u256.Uint - NewProtocolFees ProtocolFees - FeeGrowthGlobal0X128 *u256.Uint - FeeGrowthGlobal1X128 *u256.Uint - SwapFee *u256.Uint -} - -// SwapComputation encapsulates pure computation logic for swap -type SwapComputation struct { - AmountSpecified *i256.Int - SqrtPriceLimitX96 *u256.Uint - ZeroForOne bool - ExactInput bool - InitialState SwapState - Cache SwapCache -} - -// Swap swaps token0 for token1, or token1 for token0 -// Returns swapped amount0, amount1 in string -// ref: https://docs.gnoswap.io/contracts/pool/pool.gno#swap -func Swap( - token0Path string, - token1Path string, - fee uint32, - recipient std.Address, - zeroForOne bool, - amountSpecified string, - sqrtPriceLimitX96 string, - payer std.Address, // router -) (string, string) { - common.IsHalted() - if common.GetLimitCaller() { - caller := std.PrevRealm().Addr() - if err := common.RouterOnly(caller); err != nil { - panic(addDetailToError( - errNoPermission, - ufmt.Sprintf("pool.gno__Swap() || only router(%s) can call pool swap(), called from %s", consts.ROUTER_ADDR, caller.String()), - )) - } - } - - if amountSpecified == "0" { - panic(addDetailToError( - errInvalidSwapAmount, - ufmt.Sprintf("pool.gno__Swap() || amountSpecified == 0"), - )) - } - - pool := GetPool(token0Path, token1Path, fee) - - slot0Start := pool.slot0 - if !slot0Start.unlocked { - panic(errLockedPool) - } - - slot0Start.unlocked = false - defer func() { slot0Start.unlocked = true }() - - amounts := i256.MustFromDecimal(amountSpecified) - sqrtPriceLimit := u256.MustFromDecimal(sqrtPriceLimitX96) - - validatePriceLimits(pool, zeroForOne, sqrtPriceLimit) - - feeGrowthGlobalX128 := getFeeGrowthGlobal(pool, zeroForOne) - feeProtocol := getFeeProtocol(slot0Start, zeroForOne) - cache := newSwapCache(feeProtocol, pool.liquidity) - - state := newSwapState(amounts, feeGrowthGlobalX128, cache.liquidityStart, pool.slot0) - - comp := SwapComputation{ - AmountSpecified: amounts, - SqrtPriceLimitX96: sqrtPriceLimit, - ZeroForOne: zeroForOne, - ExactInput: amounts.Gt(i256.Zero()), - InitialState: state, - Cache: cache, - } - - result, err := computeSwap(pool, comp) - if err != nil { - panic(err) - } - - applySwapResult(pool, result) - - // actual swap - pool.swapTransfers(zeroForOne, payer, recipient, result.Amount0, result.Amount1) - - prevAddr, prevRealm := getPrev() - - std.Emit( - "Swap", - "prevAddr", prevAddr, - "prevRealm", prevRealm, - "poolPath", GetPoolPath(token0Path, token1Path, fee), - "zeroForOne", ufmt.Sprintf("%t", zeroForOne), - "amountSpecified", amountSpecified, - "sqrtPriceLimitX96", sqrtPriceLimitX96, - "payer", payer.String(), - "recipient", recipient.String(), - "internal_amount0", result.Amount0.ToString(), - "internal_amount1", result.Amount1.ToString(), - "internal_protocolFee0", pool.protocolFees.token0.ToString(), - "internal_protocolFee1", pool.protocolFees.token1.ToString(), - "internal_swapFee", result.SwapFee.ToString(), - "internal_sqrtPriceX96", pool.slot0.sqrtPriceX96.ToString(), - ) - - return result.Amount0.ToString(), result.Amount1.ToString() -} - -// computeSwap performs the core swap computation without modifying pool state -// The function follows these state transitions: -// 1. Initial State: Provided by `SwapComputation.InitialState` -// 2. Stepping State: For each step: -// - Compute next tick and price target -// - Calculate amounts and fees -// - Update state (remaining amount, fees, liquidity) -// - Handle tick transitions if necessary -// -// 3. Final State: Aggregated in SwapResult -// -// The computation continues until either: -// - The entire amount is consumed (`amountSpecifiedRemaining` = 0) -// - The price limit is reached (`sqrtPriceX96` = `sqrtPriceLimitX96`) -// -// Returns an error if the computation fails at any step -func computeSwap(pool *Pool, comp SwapComputation) (*SwapResult, error) { - state := comp.InitialState - swapFee := u256.Zero() - - var newFee *u256.Uint - var err error - - // Compute swap steps until completion - for shouldContinueSwap(state, comp.SqrtPriceLimitX96) { - state, newFee, err = computeSwapStep(state, pool, comp.ZeroForOne, comp.SqrtPriceLimitX96, comp.ExactInput, comp.Cache, swapFee) - if err != nil { - return nil, err - } - swapFee = newFee - } - - // Calculate final amounts - amount0 := state.amountCalculated - amount1 := i256.Zero().Sub(comp.AmountSpecified, state.amountSpecifiedRemaining) - if comp.ZeroForOne == comp.ExactInput { - amount0, amount1 = amount1, amount0 - } - - // Prepare result - result := &SwapResult{ - Amount0: amount0, - Amount1: amount1, - NewSqrtPrice: state.sqrtPriceX96, - NewTick: state.tick, - NewLiquidity: state.liquidity, - NewProtocolFees: ProtocolFees{ - token0: pool.protocolFees.token0, - token1: pool.protocolFees.token1, - }, - FeeGrowthGlobal0X128: pool.feeGrowthGlobal0X128, - FeeGrowthGlobal1X128: pool.feeGrowthGlobal1X128, - SwapFee: swapFee, - } - - // Update protocol fees if necessary - if comp.ZeroForOne { - if state.protocolFee.Gt(u256.Zero()) { - result.NewProtocolFees.token0 = new(u256.Uint).Add(result.NewProtocolFees.token0, state.protocolFee) - } - result.FeeGrowthGlobal0X128 = state.feeGrowthGlobalX128 - } else { - if state.protocolFee.Gt(u256.Zero()) { - result.NewProtocolFees.token1 = new(u256.Uint).Add(result.NewProtocolFees.token1, state.protocolFee) - } - result.FeeGrowthGlobal1X128 = state.feeGrowthGlobalX128 - } - - return result, nil -} - -// applySwapResult updates pool state with computed results. -// All state changes are applied at once to maintain consistency -func applySwapResult(pool *Pool, result *SwapResult) { - pool.slot0.sqrtPriceX96 = result.NewSqrtPrice - pool.slot0.tick = result.NewTick - pool.liquidity = result.NewLiquidity - pool.protocolFees = result.NewProtocolFees - pool.feeGrowthGlobal0X128 = result.FeeGrowthGlobal0X128 - pool.feeGrowthGlobal1X128 = result.FeeGrowthGlobal1X128 -} - -// validatePriceLimits ensures the provided price limit is valid for the swap direction -// The function enforces that: -// For zeroForOne (selling token0): -// - Price limit must be below current price -// - Price limit must be above MIN_SQRT_RATIO -// -// For !zeroForOne (selling token1): -// - Price limit must be above current price -// - Price limit must be below MAX_SQRT_RATIO -func validatePriceLimits(pool *Pool, zeroForOne bool, sqrtPriceLimitX96 *u256.Uint) { - if zeroForOne { - minSqrtRatio := u256.MustFromDecimal(consts.MIN_SQRT_RATIO) - - cond1 := sqrtPriceLimitX96.Lt(pool.slot0.sqrtPriceX96) - cond2 := sqrtPriceLimitX96.Gt(minSqrtRatio) - if !(cond1 && cond2) { - panic(addDetailToError( - errPriceOutOfRange, - ufmt.Sprintf("pool.gno__Swap() || sqrtPriceLimitX96(%s) < slot0Start.sqrtPriceX96(%s) && sqrtPriceLimitX96(%s) > consts.MIN_SQRT_RATIO(%s)", - sqrtPriceLimitX96.ToString(), - pool.slot0.sqrtPriceX96.ToString(), - sqrtPriceLimitX96.ToString(), - consts.MIN_SQRT_RATIO), - )) - } - } else { - maxSqrtRatio := u256.MustFromDecimal(consts.MAX_SQRT_RATIO) - - cond1 := sqrtPriceLimitX96.Gt(pool.slot0.sqrtPriceX96) - cond2 := sqrtPriceLimitX96.Lt(maxSqrtRatio) - if !(cond1 && cond2) { - panic(addDetailToError( - errPriceOutOfRange, - ufmt.Sprintf("pool.gno__Swap() || sqrtPriceLimitX96(%s) > slot0Start.sqrtPriceX96(%s) && sqrtPriceLimitX96(%s) < consts.MAX_SQRT_RATIO(%s)", - sqrtPriceLimitX96.ToString(), - pool.slot0.sqrtPriceX96.ToString(), - sqrtPriceLimitX96.ToString(), - consts.MAX_SQRT_RATIO), - )) - } - } -} - -// getFeeProtocol returns the appropriate fee protocol based on zero for one -func getFeeProtocol(slot0 Slot0, zeroForOne bool) uint8 { - if zeroForOne { - return slot0.feeProtocol % 16 - } - return slot0.feeProtocol / 16 -} - -// getFeeGrowthGlobal returns the appropriate fee growth global based on zero for one -func getFeeGrowthGlobal(pool *Pool, zeroForOne bool) *u256.Uint { - if zeroForOne { - return pool.feeGrowthGlobal0X128 - } - return pool.feeGrowthGlobal1X128 -} - -func shouldContinueSwap(state SwapState, sqrtPriceLimitX96 *u256.Uint) bool { - return !(state.amountSpecifiedRemaining.IsZero()) && !(state.sqrtPriceX96.Eq(sqrtPriceLimitX96)) -} - -// computeSwapStep executes a single step of swap and returns new state -func computeSwapStep( - state SwapState, - pool *Pool, - zeroForOne bool, - sqrtPriceLimitX96 *u256.Uint, - exactInput bool, - cache SwapCache, - swapFee *u256.Uint, -) (SwapState, *u256.Uint, error) { - step := computeSwapStepInit(state, pool, zeroForOne) - - // determining the price target for this step - sqrtRatioTargetX96 := computeTargetSqrtRatio(step, sqrtPriceLimitX96, zeroForOne) - - // computing the amounts to be swapped at this step - var newState SwapState - var err error - - newState, step = computeAmounts(state, sqrtRatioTargetX96, pool, step) - newState = updateAmounts(step, newState, exactInput) - - // if the protocol fee is on, calculate how much is owed, - // decrement fee amount, and increment protocol fee - if cache.feeProtocol > 0 { - newState, err = updateFeeProtocol(step, cache.feeProtocol, newState) - if err != nil { - return state, nil, err - } - } - - // update global fee tracker - if newState.liquidity.Gt(u256.Zero()) { - update := u256.MulDiv(step.feeAmount, u256.MustFromDecimal(consts.Q128), newState.liquidity) - newState.SetFeeGrowthGlobalX128(new(u256.Uint).Add(newState.feeGrowthGlobalX128, update)) - } - - // handling tick transitions - if newState.sqrtPriceX96.Eq(step.sqrtPriceNextX96) { - newState = tickTransition(step, zeroForOne, newState, pool) - } - - if newState.sqrtPriceX96.Neq(step.sqrtPriceStartX96) { - newState.SetTick(common.TickMathGetTickAtSqrtRatio(newState.sqrtPriceX96)) - } - - newSwapFee := new(u256.Uint).Add(swapFee, step.feeAmount) - - return newState, newSwapFee, nil -} - -// updateFeeProtocol calculates and updates protocol fees for the current step. -func updateFeeProtocol(step StepComputations, feeProtocol uint8, state SwapState) (SwapState, error) { - delta := step.feeAmount - delta.Div(delta, u256.NewUint(uint64(feeProtocol))) - - newFeeAmount, overflow := new(u256.Uint).SubOverflow(step.feeAmount, delta) - if overflow { - return state, errUnderflow - } - step.feeAmount = newFeeAmount - state.protocolFee.Add(state.protocolFee, delta) - - return state, nil -} - -// computeSwapStepInit initializes the computation for a single swap step. -func computeSwapStepInit(state SwapState, pool *Pool, zeroForOne bool) StepComputations { - var step StepComputations - step.sqrtPriceStartX96 = state.sqrtPriceX96 - tickNext, initialized := pool.tickBitmapNextInitializedTickWithInOneWord( - state.tick, - pool.tickSpacing, - zeroForOne, - ) - - step.tickNext = tickNext - step.initialized = initialized - - // prevent overshoot the min/max tick - step.clampTickNext() - - // get the price for the next tick - step.sqrtPriceNextX96 = common.TickMathGetSqrtRatioAtTick(step.tickNext) - return step -} - -// computeTargetSqrtRatio determines the target sqrt price for the current swap step. -func computeTargetSqrtRatio(step StepComputations, sqrtPriceLimitX96 *u256.Uint, zeroForOne bool) *u256.Uint { - if shouldUsePriceLimit(step.sqrtPriceNextX96, sqrtPriceLimitX96, zeroForOne) { - return sqrtPriceLimitX96 - } - return step.sqrtPriceNextX96 -} - -// shouldUsePriceLimit returns true if the price limit should be used instead of the next tick price -func shouldUsePriceLimit(sqrtPriceNext, sqrtPriceLimit *u256.Uint, zeroForOne bool) bool { - isLower := sqrtPriceNext.Lt(sqrtPriceLimit) - isHigher := sqrtPriceNext.Gt(sqrtPriceLimit) - if zeroForOne { - return isLower - } - return isHigher -} - -// computeAmounts calculates the input and output amounts for the current swap step. -func computeAmounts(state SwapState, sqrtRatioTargetX96 *u256.Uint, pool *Pool, step StepComputations) (SwapState, StepComputations) { - sqrtPriceX96Str, amountInStr, amountOutStr, feeAmountStr := plp.SwapMathComputeSwapStepStr( - state.sqrtPriceX96, - sqrtRatioTargetX96, - state.liquidity, - state.amountSpecifiedRemaining, - uint64(pool.fee), - ) - - step.amountIn = u256.MustFromDecimal(amountInStr) - step.amountOut = u256.MustFromDecimal(amountOutStr) - step.feeAmount = u256.MustFromDecimal(feeAmountStr) - - state.SetSqrtPriceX96(sqrtPriceX96Str) - - return state, step -} - -// updateAmounts calculates new remaining and calculated amounts based on the swap step -// For exact input swaps: -// - Decrements remaining input amount by (amountIn + feeAmount) -// - Decrements calculated amount by amountOut -// -// For exact output swaps: -// - Increments remaining output amount by amountOut -// - Increments calculated amount by (amountIn + feeAmount) -func updateAmounts(step StepComputations, state SwapState, exactInput bool) SwapState { - amountInWithFee := i256.FromUint256(new(u256.Uint).Add(step.amountIn, step.feeAmount)) - if exactInput { - state.amountSpecifiedRemaining = i256.Zero().Sub(state.amountSpecifiedRemaining, amountInWithFee) - state.amountCalculated = i256.Zero().Sub(state.amountCalculated, i256.FromUint256(step.amountOut)) - return state - } - state.amountSpecifiedRemaining = i256.Zero().Add(state.amountSpecifiedRemaining, i256.FromUint256(step.amountOut)) - state.amountCalculated = i256.Zero().Add(state.amountCalculated, amountInWithFee) - - return state -} - -// tickTransition handles the transition between price ticks during a swap -func tickTransition(step StepComputations, zeroForOne bool, state SwapState, pool *Pool) SwapState { - // ensure existing state to keep immutability - newState := state - - if step.initialized { - var fee0, fee1 *u256.Uint - - if zeroForOne { - fee0 = state.feeGrowthGlobalX128 - fee1 = pool.feeGrowthGlobal1X128 - } else { - fee0 = pool.feeGrowthGlobal0X128 - fee1 = state.feeGrowthGlobalX128 - } - - liquidityNet := pool.tickCross(step.tickNext, fee0, fee1) - - if zeroForOne { - liquidityNet = i256.Zero().Neg(liquidityNet) - } - - newState.liquidity = liquidityMathAddDelta(state.liquidity, liquidityNet) - } - - if zeroForOne { - newState.tick = step.tickNext - 1 - } else { - newState.tick = step.tickNext - } - - pool.tickCrossHook(GetPoolPath(pool.token0Path, pool.token1Path, pool.fee), newState.tick, zeroForOne) - - return newState -} - -func (pool *Pool) swapTransfers(zeroForOne bool, payer, recipient std.Address, amount0, amount1 *i256.Int) { - var targetTokenPath string - var amount *i256.Int - - if zeroForOne { - targetTokenPath = pool.token0Path - amount = amount0 - } else { - targetTokenPath = pool.token1Path - amount = amount1 - } - - // payer -> POOL -> recipient - pool.transferFromAndVerify(payer, consts.POOL_ADDR, targetTokenPath, amount.Abs(), zeroForOne) - pool.transferAndVerify(recipient, targetTokenPath, amount, !zeroForOne) -} - // SetFeeProtocolByAdmin sets the fee protocol for all pools // Also it will be applied to new created pools func SetFeeProtocolByAdmin( @@ -686,16 +232,44 @@ func SetFeeProtocol(feeProtocol0, feeProtocol1 uint8) { ) } +// setFeeProtocol updates the protocol fee configuration for all existing pools and sets +// the default for new pools. This is an internal function called by both `admin` and `governance` +// protocol fee management functions. +// +// The protocol fee is stored as a single `uint8` value where: +// - Lower 4 bits store feeProtocol0 (for token0) +// - Upper 4 bits store feeProtocol1 (for token1) +// +// This compact representation allows storing both fee values in a single byte. +// +// Parameters (must be 0 or between 4 and 10 inclusive): +// - feeProtocol0: protocol fee for token0 +// - feeProtocol1: protocol fee for token1 +// +// Returns: +// - newFee (uint8): the combined fee protocol value +// +// Example: +// If feeProtocol0 = 4 and feeProtocol1 = 5: +// +// newFee = 4 + (5 << 4) +// // Results in: 0x54 (84 in decimal) +// // Binary: 0101 0100 +// // ^^^^ ^^^^ +// // fee1=5 fee0=4 func setFeeProtocol(feeProtocol0, feeProtocol1 uint8) uint8 { common.IsHalted() if err := validateFeeProtocol(feeProtocol0, feeProtocol1); err != nil { panic(addDetailToError( err, - ufmt.Sprintf("pool.gno__setFeeProtocol() || expected (feeProtocol0(%d) == 0 || (feeProtocol0(%d) >= 4 && feeProtocol0(%d) <= 10)) && (feeProtocol1(%d) == 0 || (feeProtocol1(%d) >= 4 && feeProtocol1(%d) <= 10))", feeProtocol0, feeProtocol0, feeProtocol0, feeProtocol1, feeProtocol1, feeProtocol1), + ufmt.Sprintf("expected (feeProtocol0(%d) == 0 || (feeProtocol0(%d) >= 4 && feeProtocol0(%d) <= 10)) && (feeProtocol1(%d) == 0 || (feeProtocol1(%d) >= 4 && feeProtocol1(%d) <= 10))", feeProtocol0, feeProtocol0, feeProtocol0, feeProtocol1, feeProtocol1, feeProtocol1), )) } + // combine both protocol fee into a single byte: + // - feePrtocol0 occupies the lower 4 bits + // - feeProtocol1 is shifted the lower 4 positions to occupy the upper 4 bits newFee := feeProtocol0 + (feeProtocol1 << 4) // ( << 4 ) = ( * 16 ) // iterate all pool @@ -716,6 +290,8 @@ func validateFeeProtocol(feeProtocol0, feeProtocol1 uint8) error { return nil } +// isValidFeeProtocolValue checks if a fee protocol value is within acceptable range. +// valid values are either 0 or between 4 and 10 inclusive. func isValidFeeProtocolValue(value uint8) bool { return value == 0 || (value >= 4 && value <= 10) } @@ -839,227 +415,29 @@ func collectProtocol( return amount0.ToString(), amount1.ToString() } -func (pool *Pool) saveProtocolFees(amount0, amount1 *u256.Uint) (*u256.Uint, *u256.Uint) { +// saveProtocolFees updates the protocol fee balances after collection. +// +// Parameters: +// - amount0: amount of token0 fees to collect +// - amount1: amount of token1 fees to collect +// +// Returns the adjusted amounts that will actually be collected for both tokens. +func (p *Pool) saveProtocolFees(amount0, amount1 *u256.Uint) (*u256.Uint, *u256.Uint) { cond01 := amount0.Gt(u256.Zero()) - cond02 := amount0.Eq(pool.protocolFees.token0) + cond02 := amount0.Eq(p.protocolFees.token0) if cond01 && cond02 { amount0 = new(u256.Uint).Sub(amount0, u256.One()) } cond11 := amount1.Gt(u256.Zero()) - cond12 := amount1.Eq(pool.protocolFees.token1) + cond12 := amount1.Eq(p.protocolFees.token1) if cond11 && cond12 { amount1 = new(u256.Uint).Sub(amount1, u256.One()) } - pool.protocolFees.token0 = new(u256.Uint).Sub(pool.protocolFees.token0, amount0) - pool.protocolFees.token1 = new(u256.Uint).Sub(pool.protocolFees.token1, amount1) + p.protocolFees.token0 = new(u256.Uint).Sub(p.protocolFees.token0, amount0) + p.protocolFees.token1 = new(u256.Uint).Sub(p.protocolFees.token1, amount1) // return rest fee return amount0, amount1 } - -func (pool *Pool) transferAndVerify( - to std.Address, - tokenPath string, - amount *i256.Int, - isToken0 bool, -) { - if amount.Sign() != -1 { - panic(addDetailToError( - errMustBeNegative, - ufmt.Sprintf("pool.gno__transferAndVerify() || amount(%s) must be negative", amount.ToString()), - )) - } - - absAmount := amount.Abs() - - token0 := pool.balances.token0 - token1 := pool.balances.token1 - - if err := validatePoolBalance(token0, token1, absAmount, isToken0); err != nil { - panic(err) - } - amountUint64, err := checkAmountRange(absAmount) - if err != nil { - panic(err) - } - - token := common.GetTokenTeller(tokenPath) - checkTransferError(token.Transfer(to, amountUint64)) - - newBalance, err := updatePoolBalance(token0, token1, absAmount, isToken0) - if err != nil { - panic(err) - } - - if isToken0 { - pool.balances.token0 = newBalance - } else { - pool.balances.token1 = newBalance - } -} - -func validatePoolBalance(token0, token1, amount *u256.Uint, isToken0 bool) error { - if isToken0 { - if token0.Lt(amount) { - return ufmt.Errorf( - "%s || token0(%s) >= amount(%s)", - errTransferFailed.Error(), token0.ToString(), amount.ToString(), - ) - } - return nil - } - if token1.Lt(amount) { - return ufmt.Errorf( - "%s || token1(%s) >= amount(%s)", - errTransferFailed.Error(), token1.ToString(), amount.ToString(), - ) - } - return nil -} - -func updatePoolBalance( - token0, token1, amount *u256.Uint, - isToken0 bool, -) (*u256.Uint, error) { - var overflow bool - var newBalance *u256.Uint - - if isToken0 { - newBalance, overflow = new(u256.Uint).SubOverflow(token0, amount) - if isBalanceOverflowOrNegative(overflow, newBalance) { - return nil, ufmt.Errorf( - "%s || cannot decrease, token0(%s) - amount(%s)", - errTransferFailed.Error(), token0.ToString(), amount.ToString(), - ) - } - return newBalance, nil - } - - newBalance, overflow = new(u256.Uint).SubOverflow(token1, amount) - if isBalanceOverflowOrNegative(overflow, newBalance) { - return nil, ufmt.Errorf( - "%s || cannot decrease, token1(%s) - amount(%s)", - errTransferFailed.Error(), token1.ToString(), amount.ToString(), - ) - } - return newBalance, nil -} - -func isBalanceOverflowOrNegative(overflow bool, newBalance *u256.Uint) bool { - return overflow || newBalance.Lt(u256.Zero()) -} - -func (pool *Pool) transferFromAndVerify( - from, to std.Address, - tokenPath string, - amount *u256.Uint, - isToken0 bool, -) { - absAmount := amount - amountUint64, err := checkAmountRange(absAmount) - if err != nil { - panic(err) - } - - token := common.GetTokenTeller(tokenPath) - checkTransferError(token.TransferFrom(from, to, amountUint64)) - - // update pool balances - if isToken0 { - pool.balances.token0 = new(u256.Uint).Add(pool.balances.token0, absAmount) - } else { - pool.balances.token1 = new(u256.Uint).Add(pool.balances.token1, absAmount) - } -} - -func checkAmountRange(amount *u256.Uint) (uint64, error) { - res, overflow := amount.Uint64WithOverflow() - if overflow { - return 0, ufmt.Errorf( - "%s || amount(%s) overflows uint64 range", - errOutOfRange.Error(), amount.ToString(), - ) - } - - return res, nil -} - -// receiver getters -func (p *Pool) GetToken0Path() string { - return p.token0Path -} - -func (p *Pool) GetToken1Path() string { - return p.token1Path -} - -func (p *Pool) GetFee() uint32 { - return p.fee -} - -func (p *Pool) GetBalanceToken0() *u256.Uint { - return p.balances.token0 -} - -func (p *Pool) GetBalanceToken1() *u256.Uint { - return p.balances.token1 -} - -func (p *Pool) GetTickSpacing() int32 { - return p.tickSpacing -} - -func (p *Pool) GetMaxLiquidityPerTick() *u256.Uint { - return p.maxLiquidityPerTick -} - -func (p *Pool) GetSlot0() Slot0 { - return p.slot0 -} - -func (p *Pool) GetSlot0SqrtPriceX96() *u256.Uint { - return p.slot0.sqrtPriceX96 -} - -func (p *Pool) GetSlot0Tick() int32 { - return p.slot0.tick -} - -func (p *Pool) GetSlot0FeeProtocol() uint8 { - return p.slot0.feeProtocol -} - -func (p *Pool) GetSlot0Unlocked() bool { - return p.slot0.unlocked -} - -func (p *Pool) GetFeeGrowthGlobal0X128() *u256.Uint { - return p.feeGrowthGlobal0X128 -} - -func (p *Pool) GetFeeGrowthGlobal1X128() *u256.Uint { - return p.feeGrowthGlobal1X128 -} - -func (p *Pool) GetProtocolFeesToken0() *u256.Uint { - return p.protocolFees.token0 -} - -func (p *Pool) GetProtocolFeesToken1() *u256.Uint { - return p.protocolFees.token1 -} - -func (p *Pool) GetLiquidity() *u256.Uint { - return p.liquidity -} - -func mustGetPool(poolPath string) *Pool { - pool, exist := pools[poolPath] - if !exist { - panic(addDetailToError(errDataNotFound, - ufmt.Sprintf("poolPath(%s) does not exist", poolPath))) - } - return pool -} diff --git a/pool/pool_test.gno b/pool/pool_test.gno index c6ef82d46..b58d95ac3 100644 --- a/pool/pool_test.gno +++ b/pool/pool_test.gno @@ -140,566 +140,3 @@ func TestBurn(t *testing.T) { } } -func TestSaveProtocolFees(t *testing.T) { - tests := []struct { - name string - pool *Pool - amount0 *u256.Uint - amount1 *u256.Uint - want0 *u256.Uint - want1 *u256.Uint - wantFee0 *u256.Uint - wantFee1 *u256.Uint - }{ - { - name: "normal fee deduction", - pool: &Pool{ - protocolFees: ProtocolFees{ - token0: u256.NewUint(1000), - token1: u256.NewUint(2000), - }, - }, - amount0: u256.NewUint(500), - amount1: u256.NewUint(1000), - want0: u256.NewUint(500), - want1: u256.NewUint(1000), - wantFee0: u256.NewUint(500), - wantFee1: u256.NewUint(1000), - }, - { - name: "exact fee deduction (1 deduction)", - pool: &Pool{ - protocolFees: ProtocolFees{ - token0: u256.NewUint(1000), - token1: u256.NewUint(2000), - }, - }, - amount0: u256.NewUint(1000), - amount1: u256.NewUint(2000), - want0: u256.NewUint(999), - want1: u256.NewUint(1999), - wantFee0: u256.NewUint(1), - wantFee1: u256.NewUint(1), - }, - { - name: "0 fee deduction", - pool: &Pool{ - protocolFees: ProtocolFees{ - token0: u256.NewUint(1000), - token1: u256.NewUint(2000), - }, - }, - amount0: u256.NewUint(0), - amount1: u256.NewUint(0), - want0: u256.NewUint(0), - want1: u256.NewUint(0), - wantFee0: u256.NewUint(1000), - wantFee1: u256.NewUint(2000), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got0, got1 := tt.pool.saveProtocolFees(tt.amount0, tt.amount1) - - uassert.Equal(t, got0.ToString(), tt.want0.ToString()) - uassert.Equal(t, got1.ToString(), tt.want1.ToString()) - uassert.Equal(t, tt.pool.protocolFees.token0.ToString(), tt.wantFee0.ToString()) - uassert.Equal(t, tt.pool.protocolFees.token1.ToString(), tt.wantFee1.ToString()) - }) - } -} - -func TestTransferAndVerify(t *testing.T) { - // Setup common test data - pool := &Pool{ - balances: Balances{ - token0: u256.NewUint(1000), - token1: u256.NewUint(1000), - }, - } - - t.Run("validatePoolBalance", func(t *testing.T) { - tests := []struct { - name string - amount *u256.Uint - isToken0 bool - expectedError bool - }{ - { - name: "must success for negative amount", - amount: u256.NewUint(500), - isToken0: true, - expectedError: false, - }, - { - name: "must panic for insufficient token0 balance", - amount: u256.NewUint(1500), - isToken0: true, - expectedError: true, - }, - { - name: "must success for negative amount", - amount: u256.NewUint(500), - isToken0: false, - expectedError: false, - }, - { - name: "must panic for insufficient token1 balance", - amount: u256.NewUint(1500), - isToken0: false, - expectedError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - token0 := pool.balances.token0 - token1 := pool.balances.token1 - - err := validatePoolBalance(token0, token1, tt.amount, tt.isToken0) - if err != nil { - if !tt.expectedError { - t.Errorf("unexpected error: %v", err) - } - } - }) - } - }) -} - -func TestUpdatePoolBalance(t *testing.T) { - tests := []struct { - name string - initialToken0 *u256.Uint - initialToken1 *u256.Uint - amount *u256.Uint - isToken0 bool - expectedBal *u256.Uint - expectErr bool - }{ - { - name: "normal token0 decrease", - initialToken0: u256.NewUint(1000), - initialToken1: u256.NewUint(2000), - amount: u256.NewUint(300), - isToken0: true, - expectedBal: u256.NewUint(700), - expectErr: false, - }, - { - name: "normal token1 decrease", - initialToken0: u256.NewUint(1000), - initialToken1: u256.NewUint(2000), - amount: u256.NewUint(500), - isToken0: false, - expectedBal: u256.NewUint(1500), - expectErr: false, - }, - { - name: "insufficient token0 balance", - initialToken0: u256.NewUint(100), - initialToken1: u256.NewUint(2000), - amount: u256.NewUint(200), - isToken0: true, - expectedBal: nil, - expectErr: true, - }, - { - name: "insufficient token1 balance", - initialToken0: u256.NewUint(1000), - initialToken1: u256.NewUint(100), - amount: u256.NewUint(200), - isToken0: false, - expectedBal: nil, - expectErr: true, - }, - { - name: "zero value handling", - initialToken0: u256.NewUint(1000), - initialToken1: u256.NewUint(2000), - amount: u256.NewUint(0), - isToken0: true, - expectedBal: u256.NewUint(1000), - expectErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - pool := &Pool{ - balances: Balances{ - token0: tt.initialToken0, - token1: tt.initialToken1, - }, - } - - newBal, err := updatePoolBalance(tt.initialToken0, tt.initialToken1, tt.amount, tt.isToken0) - - if tt.expectErr { - if err == nil { - t.Errorf("%s: expected error but no error", tt.name) - } - return - } - if err != nil { - t.Errorf("%s: unexpected error: %v", tt.name, err) - return - } - - if !newBal.Eq(tt.expectedBal) { - t.Errorf("%s: balance mismatch, expected: %s, actual: %s", - tt.name, - tt.expectedBal.ToString(), - newBal.ToString(), - ) - } - }) - } -} - -func TestShouldContinueSwap(t *testing.T) { - tests := []struct { - name string - state SwapState - sqrtPriceLimitX96 *u256.Uint - expected bool - }{ - { - name: "Should continue - amount remaining and price not at limit", - state: SwapState{ - amountSpecifiedRemaining: i256.MustFromDecimal("1000"), - sqrtPriceX96: u256.MustFromDecimal("1000000"), - }, - sqrtPriceLimitX96: u256.MustFromDecimal("900000"), - expected: true, - }, - { - name: "Should stop - no amount remaining", - state: SwapState{ - amountSpecifiedRemaining: i256.Zero(), - sqrtPriceX96: u256.MustFromDecimal("1000000"), - }, - sqrtPriceLimitX96: u256.MustFromDecimal("900000"), - expected: false, - }, - { - name: "Should stop - price at limit", - state: SwapState{ - amountSpecifiedRemaining: i256.MustFromDecimal("1000"), - sqrtPriceX96: u256.MustFromDecimal("900000"), - }, - sqrtPriceLimitX96: u256.MustFromDecimal("900000"), - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := shouldContinueSwap(tt.state, tt.sqrtPriceLimitX96) - uassert.Equal(t, tt.expected, result) - }) - } -} - -func TestUpdateAmounts(t *testing.T) { - tests := []struct { - name string - step StepComputations - state SwapState - exactInput bool - expectedState SwapState - }{ - { - name: "Exact input update", - step: StepComputations{ - amountIn: u256.MustFromDecimal("100"), - amountOut: u256.MustFromDecimal("97"), - feeAmount: u256.MustFromDecimal("3"), - }, - state: SwapState{ - amountSpecifiedRemaining: i256.MustFromDecimal("1000"), - amountCalculated: i256.Zero(), - }, - exactInput: true, - expectedState: SwapState{ - amountSpecifiedRemaining: i256.MustFromDecimal("897"), // 1000 - (100 + 3) - amountCalculated: i256.MustFromDecimal("-97"), - }, - }, - { - name: "Exact output update", - step: StepComputations{ - amountIn: u256.MustFromDecimal("100"), - amountOut: u256.MustFromDecimal("97"), - feeAmount: u256.MustFromDecimal("3"), - }, - state: SwapState{ - amountSpecifiedRemaining: i256.MustFromDecimal("-1000"), - amountCalculated: i256.Zero(), - }, - exactInput: false, - expectedState: SwapState{ - amountSpecifiedRemaining: i256.MustFromDecimal("-903"), // -1000 + 97 - amountCalculated: i256.MustFromDecimal("103"), // 100 + 3 - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := updateAmounts(tt.step, tt.state, tt.exactInput) - - uassert.True(t, tt.expectedState.amountSpecifiedRemaining.Eq(result.amountSpecifiedRemaining)) - uassert.True(t, tt.expectedState.amountCalculated.Eq(result.amountCalculated)) - }) - } -} - -func TestComputeSwap(t *testing.T) { - mockPool := &Pool{ - token0Path: "token0", - token1Path: "token1", - fee: 3000, // 0.3% - tickSpacing: 60, - slot0: Slot0{ - sqrtPriceX96: u256.MustFromDecimal("1000000000000000000"), // 1.0 - tick: 0, - feeProtocol: 0, - unlocked: true, - }, - liquidity: u256.MustFromDecimal("1000000000000000000"), // 1.0 - protocolFees: ProtocolFees{ - token0: u256.Zero(), - token1: u256.Zero(), - }, - feeGrowthGlobal0X128: u256.Zero(), - feeGrowthGlobal1X128: u256.Zero(), - tickBitmaps: make(TickBitmaps), - ticks: make(Ticks), - positions: make(Positions), - } - - wordPos, _ := tickBitmapPosition(0) - // TODO: use avl - mockPool.tickBitmaps[wordPos] = u256.NewUint(1) - - t.Run("basic swap", func(t *testing.T) { - comp := SwapComputation{ - AmountSpecified: i256.MustFromDecimal("1000000"), // 1.0 token - SqrtPriceLimitX96: u256.MustFromDecimal("1100000000000000000"), // 1.1 - ZeroForOne: true, - ExactInput: true, - InitialState: SwapState{ - amountSpecifiedRemaining: i256.MustFromDecimal("1000000"), - amountCalculated: i256.Zero(), - sqrtPriceX96: mockPool.slot0.sqrtPriceX96, - tick: mockPool.slot0.tick, - feeGrowthGlobalX128: mockPool.feeGrowthGlobal0X128, - protocolFee: u256.Zero(), - liquidity: mockPool.liquidity, - }, - Cache: SwapCache{ - feeProtocol: 0, - liquidityStart: mockPool.liquidity, - }, - } - - result, err := computeSwap(mockPool, comp) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - - if result.Amount0.IsZero() { - t.Error("expected non-zero amount0") - } - if result.Amount1.IsZero() { - t.Error("expected non-zero amount1") - } - if result.SwapFee.IsZero() { - t.Error("expected non-zero swap fee") - } - }) - - t.Run("swap with zero liquidity", func(t *testing.T) { - mockPoolZeroLiq := *mockPool - mockPoolZeroLiq.liquidity = u256.Zero() - - comp := SwapComputation{ - AmountSpecified: i256.MustFromDecimal("1000000"), - SqrtPriceLimitX96: u256.MustFromDecimal("1100000000000000000"), - ZeroForOne: true, - ExactInput: true, - InitialState: SwapState{ - amountSpecifiedRemaining: i256.MustFromDecimal("1000000"), - amountCalculated: i256.Zero(), - sqrtPriceX96: mockPoolZeroLiq.slot0.sqrtPriceX96, - tick: mockPoolZeroLiq.slot0.tick, - feeGrowthGlobalX128: mockPoolZeroLiq.feeGrowthGlobal0X128, - protocolFee: u256.Zero(), - liquidity: mockPoolZeroLiq.liquidity, - }, - Cache: SwapCache{ - feeProtocol: 0, - liquidityStart: mockPoolZeroLiq.liquidity, - }, - } - - result, err := computeSwap(&mockPoolZeroLiq, comp) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - - if !result.Amount0.IsZero() || !result.Amount1.IsZero() { - t.Error("expected zero amounts for zero liquidity") - } - }) -} - -func TestTransferFromAndVerify(t *testing.T) { - tests := []struct { - name string - pool *Pool - from std.Address - to std.Address - tokenPath string - amount *i256.Int - isToken0 bool - expectedBal0 *u256.Uint - expectedBal1 *u256.Uint - }{ - { - name: "normal token0 transfer", - pool: &Pool{ - balances: Balances{ - token0: u256.NewUint(1000), - token1: u256.NewUint(2000), - }, - }, - from: testutils.TestAddress("from_addr"), - to: testutils.TestAddress("to_addr"), - tokenPath: "gno.land/r/onbloc/bar", - amount: i256.NewInt(500), - isToken0: true, - expectedBal0: u256.NewUint(1500), // 1000 + 500 - expectedBal1: u256.NewUint(2000), // unchanged - }, - { - name: "normal token1 transfer", - pool: &Pool{ - balances: Balances{ - token0: u256.NewUint(1000), - token1: u256.NewUint(2000), - }, - }, - from: testutils.TestAddress("from_addr"), - to: testutils.TestAddress("to_addr"), - tokenPath: "gno.land/r/onbloc/foo", - amount: i256.NewInt(800), - isToken0: false, - expectedBal0: u256.NewUint(1000), // unchanged - expectedBal1: u256.NewUint(2800), // 2000 + 800 - }, - { - name: "zero value transfer", - pool: &Pool{ - balances: Balances{ - token0: u256.NewUint(1000), - token1: u256.NewUint(2000), - }, - }, - from: testutils.TestAddress("from_addr"), - to: testutils.TestAddress("to_addr"), - tokenPath: "gno.land/r/onbloc/bar", - amount: i256.NewInt(0), - isToken0: true, - expectedBal0: u256.NewUint(1000), // unchanged - expectedBal1: u256.NewUint(2000), // unchanged - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - - oldCheckTransferError := checkTransferError - defer func() { - checkTransferError = oldCheckTransferError - }() - - checkTransferError = func(err error) { - return - } - - tt.pool.transferFromAndVerify(tt.from, tt.to, tt.tokenPath, u256.MustFromDecimal(tt.amount.ToString()), tt.isToken0) - - if !tt.pool.balances.token0.Eq(tt.expectedBal0) { - t.Errorf("token0 balance mismatch: expected %s, got %s", - tt.expectedBal0.ToString(), - tt.pool.balances.token0.ToString()) - } - - if !tt.pool.balances.token1.Eq(tt.expectedBal1) { - t.Errorf("token1 balance mismatch: expected %s, got %s", - tt.expectedBal1.ToString(), - tt.pool.balances.token1.ToString()) - } - }) - } - - t.Run("negative value handling", func(t *testing.T) { - pool := &Pool{ - balances: Balances{ - token0: u256.NewUint(1000), - token1: u256.NewUint(2000), - }, - } - - oldCheckTransferError := checkTransferError - defer func() { checkTransferError = oldCheckTransferError }() - - checkTransferError = func(err error) { - return - } - - negativeAmount := i256.NewInt(-500) - pool.transferFromAndVerify( - testutils.TestAddress("from_addr"), - testutils.TestAddress("to_addr"), - "gno.land/r/onbloc/qux", - u256.MustFromDecimal(negativeAmount.Abs().ToString()), - true, - ) - - expectedBal := u256.NewUint(1500) // 1000 + 500 (absolute value) - if !pool.balances.token0.Eq(expectedBal) { - t.Errorf("negative amount handling failed: expected %s, got %s", - expectedBal.ToString(), - pool.balances.token0.ToString()) - } - }) - - t.Run("uint64 overflow value", func(t *testing.T) { - pool := &Pool{ - balances: Balances{ - token0: u256.NewUint(1000), - token1: u256.NewUint(2000), - }, - } - - hugeAmount := i256.FromUint256(u256.MustFromDecimal("18446744073709551616")) // 2^64 - - defer func() { - if r := recover(); r == nil { - t.Error("expected panic for amount exceeding uint64 range") - } - }() - - pool.transferFromAndVerify( - testutils.TestAddress("from_addr"), - testutils.TestAddress("to_addr"), - "gno.land/r/onbloc/qux", - u256.MustFromDecimal(hugeAmount.ToString()), - true, - ) - }) -} diff --git a/pool/pool_transfer.gno b/pool/pool_transfer.gno new file mode 100644 index 000000000..201ae2cfb --- /dev/null +++ b/pool/pool_transfer.gno @@ -0,0 +1,171 @@ +package pool + +import ( + "std" + + "gno.land/p/demo/ufmt" + + "gno.land/r/gnoswap/v1/common" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +// transferAndVerify performs a token transfer out of the pool while ensuring +// the pool has sufficient balance and updating internal accounting. +// This function is typically used during swaps and liquidity removals. +// +// Important requirements: +// - The amount must be negative (representing an outflow from the pool) +// - The pool must have sufficient balance for the transfer +// - The transfer amount must fit within uint64 range +// +// Parameters: +// - to: destination address for the transfer +// - tokenPath: path identifier of the token to transfer +// - amount: amount to transfer (must be negative) +// - isToken0: true if transferring token0, false for token1 +// +// The function will: +// 1. Validate the amount is negative +// 2. Check pool has sufficient balance +// 3. Execute the transfer +// 4. Update pool's internal balance +// +// Panics if any validation fails or if the transfer fails +func (p *Pool) transferAndVerify( + to std.Address, + tokenPath string, + amount *i256.Int, + isToken0 bool, +) { + if amount.Sign() != -1 { + panic(ufmt.Sprintf( + "%v. got: %s", errMustBeNegative, amount.ToString(), + )) + } + + absAmount := amount.Abs() + + token0 := p.balances.token0 + token1 := p.balances.token1 + + if err := validatePoolBalance(token0, token1, absAmount, isToken0); err != nil { + panic(err) + } + amountUint64, err := safeConvertToUint64(absAmount) + if err != nil { + panic(err) + } + + token := common.GetTokenTeller(tokenPath) + checkTransferError(token.Transfer(to, amountUint64)) + + newBalance, err := updatePoolBalance(token0, token1, absAmount, isToken0) + if err != nil { + panic(err) + } + + if isToken0 { + p.balances.token0 = newBalance + } else { + p.balances.token1 = newBalance + } +} + +// transferFromAndVerify performs a token transfer into the pool using transferFrom +// while updating the pool's internal accounting. This function is typically used +// during swaps and liquidity additions. +// +// The function assumes the sender has approved the pool to spend their tokens. +// +// Parameters: +// - from: source address for the transfer +// - to: destination address (typically the pool) +// - tokenPath: path identifier of the token to transfer +// - amount: amount to transfer (must be positive) +// - isToken0: true if transferring token0, false for token1 +// +// The function will: +// 1. Convert amount to uint64 (must fit) +// 2. Execute the transferFrom +// 3. Update pool's internal balance +// +// Panics if the amount conversion fails or if the transfer fails +func (p *Pool) transferFromAndVerify( + from, to std.Address, + tokenPath string, + amount *u256.Uint, + isToken0 bool, +) { + absAmount := amount + amountUint64, err := safeConvertToUint64(absAmount) + if err != nil { + panic(err) + } + + token := common.GetTokenTeller(tokenPath) + checkTransferError(token.TransferFrom(from, to, amountUint64)) + + // update pool balances + if isToken0 { + p.balances.token0 = new(u256.Uint).Add(p.balances.token0, absAmount) + } else { + p.balances.token1 = new(u256.Uint).Add(p.balances.token1, absAmount) + } +} + +// validatePoolBalance checks if the pool has sufficient balance of either token0 and token1 +// before proceeding with a transfer. This prevents the pool won't go into a negative balance. +func validatePoolBalance(token0, token1, amount *u256.Uint, isToken0 bool) error { + if isToken0 { + if token0.Lt(amount) { + return ufmt.Errorf( + "%v. token0(%s) >= amount(%s)", + errTransferFailed, token0.ToString(), amount.ToString(), + ) + } + return nil + } + if token1.Lt(amount) { + return ufmt.Errorf( + "%v. token1(%s) >= amount(%s)", + errTransferFailed, token1.ToString(), amount.ToString(), + ) + } + return nil +} + +// updatePoolBalance calculates the new balance after a transfer and validate. +// It ensures the resulting balance won't be negative or overflow. +func updatePoolBalance( + token0, token1, amount *u256.Uint, + isToken0 bool, +) (*u256.Uint, error) { + var overflow bool + var newBalance *u256.Uint + + if isToken0 { + newBalance, overflow = new(u256.Uint).SubOverflow(token0, amount) + if isBalanceOverflowOrNegative(overflow, newBalance) { + return nil, ufmt.Errorf( + "%v. cannot decrease, token0(%s) - amount(%s)", + errTransferFailed, token0.ToString(), amount.ToString(), + ) + } + return newBalance, nil + } + + newBalance, overflow = new(u256.Uint).SubOverflow(token1, amount) + if isBalanceOverflowOrNegative(overflow, newBalance) { + return nil, ufmt.Errorf( + "%v. cannot decrease, token1(%s) - amount(%s)", + errTransferFailed, token1.ToString(), amount.ToString(), + ) + } + return newBalance, nil +} + +func isBalanceOverflowOrNegative(overflow bool, newBalance *u256.Uint) bool { + return overflow || newBalance.Lt(u256.Zero()) +} diff --git a/pool/pool_transfer_test.gno b/pool/pool_transfer_test.gno new file mode 100644 index 000000000..562825b9a --- /dev/null +++ b/pool/pool_transfer_test.gno @@ -0,0 +1,298 @@ +package pool + +import ( + "std" + "testing" + + "gno.land/p/demo/testutils" + pusers "gno.land/p/demo/users" + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/r/gnoswap/v1/consts" +) + +func TestTransferAndVerify(t *testing.T) { + // Setup common test data + pool := &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(1000), + }, + } + + t.Run("validatePoolBalance", func(t *testing.T) { + tests := []struct { + name string + amount *u256.Uint + isToken0 bool + expectedError bool + }{ + { + name: "must success for negative amount", + amount: u256.NewUint(500), + isToken0: true, + expectedError: false, + }, + { + name: "must panic for insufficient token0 balance", + amount: u256.NewUint(1500), + isToken0: true, + expectedError: true, + }, + { + name: "must success for negative amount", + amount: u256.NewUint(500), + isToken0: false, + expectedError: false, + }, + { + name: "must panic for insufficient token1 balance", + amount: u256.NewUint(1500), + isToken0: false, + expectedError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token0 := pool.balances.token0 + token1 := pool.balances.token1 + + err := validatePoolBalance(token0, token1, tt.amount, tt.isToken0) + if err != nil { + if !tt.expectedError { + t.Errorf("unexpected error: %v", err) + } + } + }) + } + }) +} + +func TestTransferFromAndVerify(t *testing.T) { + tests := []struct { + name string + pool *Pool + from std.Address + to std.Address + tokenPath string + amount *i256.Int + isToken0 bool + expectedBal0 *u256.Uint + expectedBal1 *u256.Uint + }{ + { + name: "normal token0 transfer", + pool: &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + from: testutils.TestAddress("from_addr"), + to: testutils.TestAddress("to_addr"), + tokenPath: fooPath, + amount: i256.NewInt(500), + isToken0: true, + expectedBal0: u256.NewUint(1500), // 1000 + 500 + expectedBal1: u256.NewUint(2000), // unchanged + }, + { + name: "normal token1 transfer", + pool: &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + from: testutils.TestAddress("from_addr"), + to: testutils.TestAddress("to_addr"), + tokenPath: fooPath, + amount: i256.NewInt(800), + isToken0: false, + expectedBal0: u256.NewUint(1000), // unchanged + expectedBal1: u256.NewUint(2800), // 2000 + 800 + }, + { + name: "zero value transfer", + pool: &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + from: testutils.TestAddress("from_addr"), + to: testutils.TestAddress("to_addr"), + tokenPath: fooPath, + amount: i256.NewInt(0), + isToken0: true, + expectedBal0: u256.NewUint(1000), // unchanged + expectedBal1: u256.NewUint(2000), // unchanged + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + TokenFaucet(t, fooPath, pusers.AddressOrName(tt.from)) + TokenApprove(t, fooPath, pusers.AddressOrName(tt.from), pool, u256.MustFromDecimal(tt.amount.ToString()).Uint64()) + + tt.pool.transferFromAndVerify(tt.from, tt.to, tt.tokenPath, u256.MustFromDecimal(tt.amount.ToString()), tt.isToken0) + + if !tt.pool.balances.token0.Eq(tt.expectedBal0) { + t.Errorf("token0 balance mismatch: expected %s, got %s", + tt.expectedBal0.ToString(), + tt.pool.balances.token0.ToString()) + } + + if !tt.pool.balances.token1.Eq(tt.expectedBal1) { + t.Errorf("token1 balance mismatch: expected %s, got %s", + tt.expectedBal1.ToString(), + tt.pool.balances.token1.ToString()) + } + }) + } + + t.Run("negative value handling", func(t *testing.T) { + pool := &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + } + + negativeAmount := i256.NewInt(-500) + + TokenFaucet(t, fooPath, pusers.AddressOrName(testutils.TestAddress("from_addr"))) + TokenApprove(t, fooPath, pusers.AddressOrName(testutils.TestAddress("from_addr")), pusers.AddressOrName(consts.POOL_ADDR), u256.MustFromDecimal(negativeAmount.Abs().ToString()).Uint64()) + pool.transferFromAndVerify( + testutils.TestAddress("from_addr"), + testutils.TestAddress("to_addr"), + fooPath, + u256.MustFromDecimal(negativeAmount.Abs().ToString()), + true, + ) + + expectedBal := u256.NewUint(1500) // 1000 + 500 (absolute value) + if !pool.balances.token0.Eq(expectedBal) { + t.Errorf("negative amount handling failed: expected %s, got %s", + expectedBal.ToString(), + pool.balances.token0.ToString()) + } + }) + + t.Run("uint64 overflow value", func(t *testing.T) { + pool := &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + } + + hugeAmount := i256.FromUint256(u256.MustFromDecimal("18446744073709551616")) // 2^64 + + defer func() { + if r := recover(); r == nil { + t.Error("expected panic for amount exceeding uint64 range") + } + }() + + pool.transferFromAndVerify( + testutils.TestAddress("from_addr"), + testutils.TestAddress("to_addr"), + fooPath, + u256.MustFromDecimal(hugeAmount.ToString()), + true, + ) + }) +} + +func TestUpdatePoolBalance(t *testing.T) { + tests := []struct { + name string + initialToken0 *u256.Uint + initialToken1 *u256.Uint + amount *u256.Uint + isToken0 bool + expectedBal *u256.Uint + expectErr bool + }{ + { + name: "normal token0 decrease", + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(300), + isToken0: true, + expectedBal: u256.NewUint(700), + expectErr: false, + }, + { + name: "normal token1 decrease", + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(500), + isToken0: false, + expectedBal: u256.NewUint(1500), + expectErr: false, + }, + { + name: "insufficient token0 balance", + initialToken0: u256.NewUint(100), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(200), + isToken0: true, + expectedBal: nil, + expectErr: true, + }, + { + name: "insufficient token1 balance", + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(100), + amount: u256.NewUint(200), + isToken0: false, + expectedBal: nil, + expectErr: true, + }, + { + name: "zero value handling", + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(0), + isToken0: true, + expectedBal: u256.NewUint(1000), + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pool := &Pool{ + balances: Balances{ + token0: tt.initialToken0, + token1: tt.initialToken1, + }, + } + + newBal, err := updatePoolBalance(tt.initialToken0, tt.initialToken1, tt.amount, tt.isToken0) + + if tt.expectErr { + if err == nil { + t.Errorf("%s: expected error but no error", tt.name) + } + return + } + if err != nil { + t.Errorf("%s: unexpected error: %v", tt.name, err) + return + } + + if !newBal.Eq(tt.expectedBal) { + t.Errorf("%s: balance mismatch, expected: %s, actual: %s", + tt.name, + tt.expectedBal.ToString(), + newBal.ToString(), + ) + } + }) + } +} diff --git a/pool/position_modify.gno b/pool/position_modify.gno index 0400b126a..c6c574005 100644 --- a/pool/position_modify.gno +++ b/pool/position_modify.gno @@ -4,14 +4,29 @@ import ( "gno.land/r/gnoswap/v1/common" i256 "gno.land/p/gnoswap/int256" - u256 "gno.land/p/gnoswap/uint256" plp "gno.land/p/gnoswap/pool" + u256 "gno.land/p/gnoswap/uint256" ) -// modifyPosition updates a position in the pool and calculates the amount of tokens to be added or removed. -// Returns positionInfo, amount0, amount1 -func (pool *Pool) modifyPosition(params ModifyPositionParams) (PositionInfo, *u256.Uint, *u256.Uint) { - position := pool.updatePosition(params) +// modifyPosition updates a position in the pool and calculates the amount of tokens +// needed (for minting) or returned (for burning). The calculation depends on the current +// price (tick) relative to the position's price range. +// +// The function handles three cases: +// 1. Current price below range (tick < tickLower): only token0 is used/returned +// 2. Current price in range (tickLower <= tick < tickUpper): both tokens are used/returned +// 3. Current price above range (tick >= tickUpper): only token1 is used/returned +// +// Parameters: +// - params: ModifyPositionParams containing owner, tickLower, tickUpper, and liquidityDelta +// +// Returns: +// - PositionInfo: updated position information +// - *u256.Uint: amount of token0 needed/returned +// - *u256.Uint: amount of token1 needed/returned +func (p *Pool) modifyPosition(params ModifyPositionParams) (PositionInfo, *u256.Uint, *u256.Uint) { + // update position state + position := p.updatePosition(params) liqDelta := params.liquidityDelta if liqDelta.IsZero() { @@ -20,25 +35,36 @@ func (pool *Pool) modifyPosition(params ModifyPositionParams) (PositionInfo, *u2 amount0, amount1 := i256.Zero(), i256.Zero() - tick := pool.slot0.tick + // get current state and price bounds + tick := p.slot0.tick + // covert ticks to sqrt price to use in amount calculations + // price = 1.0001^tick, but we use sqrtPriceX96 sqrtRatioLower := common.TickMathGetSqrtRatioAtTick(params.tickLower) sqrtRatioUpper := common.TickMathGetSqrtRatioAtTick(params.tickUpper) - sqrtPriceX96 := pool.slot0.sqrtPriceX96 + sqrtPriceX96 := p.slot0.sqrtPriceX96 - // calculate amount0, amount1 based on the current tick position + // calculate token amounts based on current price position relative to range switch { case tick < params.tickLower: + // case 1 + // full range between lower and upper tick is used for token0 amount0 = calculateToken0Amount(sqrtRatioLower, sqrtRatioUpper, liqDelta) case tick < params.tickUpper: - liquidityBefore := pool.liquidity + // case 2 + liquidityBefore := p.liquidity + // token0 used from current price to upper tick amount0 = calculateToken0Amount(sqrtPriceX96, sqrtRatioUpper, liqDelta) + // token1 used from lower tick to current price amount1 = calculateToken1Amount(sqrtRatioLower, sqrtPriceX96, liqDelta) - pool.liquidity = liquidityMathAddDelta(liquidityBefore, liqDelta) + // update pool's active liquidity since price is in range + p.liquidity = liquidityMathAddDelta(liquidityBefore, liqDelta) default: + // case 3 + // full range between lower and upper tick is used for token1 amount1 = calculateToken1Amount(sqrtRatioLower, sqrtRatioUpper, liqDelta) } diff --git a/pool/swap.gno b/pool/swap.gno new file mode 100644 index 000000000..47c5e2d67 --- /dev/null +++ b/pool/swap.gno @@ -0,0 +1,524 @@ +package pool + +import ( + "std" + + "gno.land/p/demo/ufmt" + "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/consts" + + i256 "gno.land/p/gnoswap/int256" + plp "gno.land/p/gnoswap/pool" // pool package + u256 "gno.land/p/gnoswap/uint256" +) + +// SwapResult encapsulates all state changes that occur as a result of a swap +// This type ensure all state transitions are atomic and can be applied at once. +type SwapResult struct { + Amount0 *i256.Int + Amount1 *i256.Int + NewSqrtPrice *u256.Uint + NewTick int32 + NewLiquidity *u256.Uint + NewProtocolFees ProtocolFees + FeeGrowthGlobal0X128 *u256.Uint + FeeGrowthGlobal1X128 *u256.Uint +} + +// SwapComputation encapsulates pure computation logic for swap +type SwapComputation struct { + AmountSpecified *i256.Int + SqrtPriceLimitX96 *u256.Uint + ZeroForOne bool + ExactInput bool + InitialState SwapState + Cache SwapCache +} + +var ( + fixedPointQ128 = u256.MustFromDecimal(consts.Q128) +) + +// Swap swaps token0 for token1, or token1 for token0 +// Returns swapped amount0, amount1 in string +// ref: https://docs.gnoswap.io/contracts/pool/pool.gno#swap +func Swap( + token0Path string, + token1Path string, + fee uint32, + recipient std.Address, + zeroForOne bool, + amountSpecified string, + sqrtPriceLimitX96 string, + payer std.Address, // router +) (string, string) { + common.IsHalted() + if common.GetLimitCaller() { + caller := std.PrevRealm().Addr() + if err := common.RouterOnly(caller); err != nil { + panic(addDetailToError( + errNoPermission, + ufmt.Sprintf("only router(%s) can call pool swap(), called from %s", consts.ROUTER_ADDR, caller.String()), + )) + } + } + + if amountSpecified == "0" { + panic(addDetailToError( + errInvalidSwapAmount, + ufmt.Sprintf("amountSpecified == 0"), + )) + } + + pool := GetPool(token0Path, token1Path, fee) + + slot0Start := pool.slot0 + if !slot0Start.unlocked { + panic(errLockedPool) + } + pool.slot0.unlocked = false + defer func() { pool.slot0.unlocked = true }() + + sqrtPriceLimit := u256.MustFromDecimal(sqrtPriceLimitX96) + validatePriceLimits(slot0Start, zeroForOne, sqrtPriceLimit) + + amounts := i256.MustFromDecimal(amountSpecified) + feeGrowthGlobalX128 := getFeeGrowthGlobal(pool, zeroForOne) + feeProtocol := getFeeProtocol(slot0Start, zeroForOne) + cache := newSwapCache(feeProtocol, pool.liquidity.Clone()) + state := newSwapState(amounts, feeGrowthGlobalX128, cache.liquidityStart, slot0Start) + + comp := SwapComputation{ + AmountSpecified: amounts, + SqrtPriceLimitX96: sqrtPriceLimit, + ZeroForOne: zeroForOne, + ExactInput: amounts.Gt(i256.Zero()), + InitialState: state, + Cache: cache, + } + + result, err := computeSwap(pool, comp) + if err != nil { + panic(err) + } + + applySwapResult(pool, result) + + // actual swap + pool.swapTransfers(zeroForOne, payer, recipient, result.Amount0, result.Amount1) + + prevAddr, prevPkgPath := getPrev() + + std.Emit( + "Swap", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "poolPath", GetPoolPath(token0Path, token1Path, fee), + "zeroForOne", ufmt.Sprintf("%t", zeroForOne), + "amountSpecified", amountSpecified, + "sqrtPriceLimitX96", sqrtPriceLimitX96, + "payer", payer.String(), + "recipient", recipient.String(), + "internal_amount0", result.Amount0.ToString(), + "internal_amount1", result.Amount1.ToString(), + "internal_protocolFee0", pool.protocolFees.token0.ToString(), + "internal_protocolFee1", pool.protocolFees.token1.ToString(), + "internal_sqrtPriceX96", pool.slot0.sqrtPriceX96.ToString(), + ) + + return result.Amount0.ToString(), result.Amount1.ToString() +} + +// DrySwap simulates a swap and returns the amount0, amount1 that would be received and a boolean indicating if the swap is possible +func DrySwap( + token0Path string, + token1Path string, + fee uint32, + zeroForOne bool, + amountSpecified string, + sqrtPriceLimitX96 string, +) (string, string, bool) { + if amountSpecified == "0" { + return "0", "0", false + } + + pool := GetPool(token0Path, token1Path, fee) + + slot0Start := pool.slot0 + sqrtPriceLimit := u256.MustFromDecimal(sqrtPriceLimitX96) + validatePriceLimits(slot0Start, zeroForOne, sqrtPriceLimit) + + amounts := i256.MustFromDecimal(amountSpecified) + feeGrowthGlobalX128 := getFeeGrowthGlobal(pool, zeroForOne) + feeProtocol := getFeeProtocol(slot0Start, zeroForOne) + cache := newSwapCache(feeProtocol, pool.liquidity.Clone()) + state := newSwapState(amounts, feeGrowthGlobalX128, cache.liquidityStart, slot0Start) + + comp := SwapComputation{ + AmountSpecified: amounts, + SqrtPriceLimitX96: sqrtPriceLimit, + ZeroForOne: zeroForOne, + ExactInput: amounts.Gt(i256.Zero()), + InitialState: state, + Cache: cache, + } + + result, err := computeSwap(pool, comp) + if err != nil { + return "0", "0", false + } + + if zeroForOne { + if pool.balances.token1.Lt(result.Amount1.Abs()) { + return "0", "0", false + } + } else { + if pool.balances.token0.Lt(result.Amount0.Abs()) { + return "0", "0", false + } + } + + // Validate non-zero amounts + if result.Amount0.IsZero() || result.Amount1.IsZero() { + return "0", "0", false + } + + return result.Amount0.ToString(), result.Amount1.ToString(), true +} + +// computeSwap performs the core swap computation without modifying pool state +// The function follows these state transitions: +// 1. Initial State: Provided by `SwapComputation.InitialState` +// 2. Stepping State: For each step: +// - Compute next tick and price target +// - Calculate amounts and fees +// - Update state (remaining amount, fees, liquidity) +// - Handle tick transitions if necessary +// +// 3. Final State: Aggregated in SwapResult +// +// The computation continues until either: +// - The entire amount is consumed (`amountSpecifiedRemaining` = 0) +// - The price limit is reached (`sqrtPriceX96` = `sqrtPriceLimitX96`) +// +// Returns an error if the computation fails at any step +func computeSwap(pool *Pool, comp SwapComputation) (*SwapResult, error) { + state := comp.InitialState + var err error + + // Compute swap steps until completion + for shouldContinueSwap(state, comp.SqrtPriceLimitX96) { + state, err = computeSwapStep(state, pool, comp.ZeroForOne, comp.SqrtPriceLimitX96, comp.ExactInput, comp.Cache) + if err != nil { + return nil, err + } + } + + // Calculate final amounts + amount0 := state.amountCalculated + amount1 := i256.Zero().Sub(comp.AmountSpecified, state.amountSpecifiedRemaining) + if comp.ZeroForOne == comp.ExactInput { + amount0, amount1 = amount1, amount0 + } + + // Prepare result + result := &SwapResult{ + Amount0: amount0, + Amount1: amount1, + NewSqrtPrice: state.sqrtPriceX96, + NewTick: state.tick, + NewLiquidity: state.liquidity, + NewProtocolFees: ProtocolFees{ + token0: pool.protocolFees.token0, + token1: pool.protocolFees.token1, + }, + FeeGrowthGlobal0X128: pool.feeGrowthGlobal0X128, + FeeGrowthGlobal1X128: pool.feeGrowthGlobal1X128, + } + + // Update protocol fees if necessary + if comp.ZeroForOne { + if state.protocolFee.Gt(u256.Zero()) { + result.NewProtocolFees.token0 = new(u256.Uint).Add(result.NewProtocolFees.token0, state.protocolFee) + } + result.FeeGrowthGlobal0X128 = state.feeGrowthGlobalX128 + } else { + if state.protocolFee.Gt(u256.Zero()) { + result.NewProtocolFees.token1 = new(u256.Uint).Add(result.NewProtocolFees.token1, state.protocolFee) + } + result.FeeGrowthGlobal1X128 = state.feeGrowthGlobalX128 + } + + return result, nil +} + +// applySwapResult updates pool state with computed results. +// All state changes are applied at once to maintain consistency +func applySwapResult(pool *Pool, result *SwapResult) { + pool.slot0.sqrtPriceX96 = result.NewSqrtPrice + pool.slot0.tick = result.NewTick + pool.liquidity = result.NewLiquidity + pool.protocolFees = result.NewProtocolFees + pool.feeGrowthGlobal0X128 = result.FeeGrowthGlobal0X128 + pool.feeGrowthGlobal1X128 = result.FeeGrowthGlobal1X128 +} + +// validatePriceLimits ensures the provided price limit is valid for the swap direction +// The function enforces that: +// For zeroForOne (selling token0): +// - Price limit must be below current price +// - Price limit must be above MIN_SQRT_RATIO +// +// For !zeroForOne (selling token1): +// - Price limit must be above current price +// - Price limit must be below MAX_SQRT_RATIO +func validatePriceLimits(slot0 Slot0, zeroForOne bool, sqrtPriceLimitX96 *u256.Uint) { + if zeroForOne { + minSqrtRatio := u256.MustFromDecimal(consts.MIN_SQRT_RATIO) + + cond1 := sqrtPriceLimitX96.Lt(slot0.sqrtPriceX96) + cond2 := sqrtPriceLimitX96.Gt(minSqrtRatio) + if !(cond1 && cond2) { + panic(addDetailToError( + errPriceOutOfRange, + ufmt.Sprintf("sqrtPriceLimitX96(%s) < slot0Start.sqrtPriceX96(%s) && sqrtPriceLimitX96(%s) > consts.MIN_SQRT_RATIO(%s)", + sqrtPriceLimitX96.ToString(), + slot0.sqrtPriceX96.ToString(), + sqrtPriceLimitX96.ToString(), + consts.MIN_SQRT_RATIO), + )) + } + } else { + maxSqrtRatio := u256.MustFromDecimal(consts.MAX_SQRT_RATIO) + + cond1 := sqrtPriceLimitX96.Gt(slot0.sqrtPriceX96) + cond2 := sqrtPriceLimitX96.Lt(maxSqrtRatio) + if !(cond1 && cond2) { + panic(addDetailToError( + errPriceOutOfRange, + ufmt.Sprintf("sqrtPriceLimitX96(%s) > slot0Start.sqrtPriceX96(%s) && sqrtPriceLimitX96(%s) < consts.MAX_SQRT_RATIO(%s)", + sqrtPriceLimitX96.ToString(), + slot0.sqrtPriceX96.ToString(), + sqrtPriceLimitX96.ToString(), + consts.MAX_SQRT_RATIO), + )) + } + } +} + +// getFeeProtocol returns the appropriate fee protocol based on zero for one +func getFeeProtocol(slot0 Slot0, zeroForOne bool) uint8 { + if zeroForOne { + return slot0.feeProtocol % 16 + } + return slot0.feeProtocol / 16 +} + +// getFeeGrowthGlobal returns the appropriate fee growth global based on zero for one +func getFeeGrowthGlobal(pool *Pool, zeroForOne bool) *u256.Uint { + if zeroForOne { + return pool.feeGrowthGlobal0X128.Clone() + } + return pool.feeGrowthGlobal1X128.Clone() +} + +func shouldContinueSwap(state SwapState, sqrtPriceLimitX96 *u256.Uint) bool { + return !(state.amountSpecifiedRemaining.IsZero()) && !(state.sqrtPriceX96.Eq(sqrtPriceLimitX96)) +} + +// computeSwapStep executes a single step of swap and returns new state +func computeSwapStep( + state SwapState, + pool *Pool, + zeroForOne bool, + sqrtPriceLimitX96 *u256.Uint, + exactInput bool, + cache SwapCache, +) (SwapState, error) { + step := computeSwapStepInit(state, pool, zeroForOne) + + // determining the price target for this step + sqrtRatioTargetX96 := computeTargetSqrtRatio(step, sqrtPriceLimitX96, zeroForOne) + + // computing the amounts to be swapped at this step + var newState SwapState + var err error + + newState, step = computeAmounts(state, sqrtRatioTargetX96, pool, step) + newState = updateAmounts(step, newState, exactInput) + + // if the protocol fee is on, calculate how much is owed, + // decrement fee amount, and increment protocol fee + if cache.feeProtocol > 0 { + newState, step, err = updateFeeProtocol(step, cache.feeProtocol, newState) + if err != nil { + return state, err + } + } + + // update global fee tracker + if newState.liquidity.Gt(u256.Zero()) { + update := u256.MulDiv(step.feeAmount, fixedPointQ128, newState.liquidity) + newState.SetFeeGrowthGlobalX128(new(u256.Uint).Add(newState.feeGrowthGlobalX128, update)) + } + + // handling tick transitions + if newState.sqrtPriceX96.Eq(step.sqrtPriceNextX96) { + newState = tickTransition(step, zeroForOne, newState, pool) + } else if newState.sqrtPriceX96.Neq(step.sqrtPriceStartX96) { + newState.SetTick(common.TickMathGetTickAtSqrtRatio(newState.sqrtPriceX96)) + } + + return newState, nil +} + +// updateFeeProtocol calculates and updates protocol fees for the current step. +func updateFeeProtocol(step StepComputations, feeProtocol uint8, state SwapState) (SwapState, StepComputations, error) { + delta := step.feeAmount + delta.Div(delta, u256.NewUint(uint64(feeProtocol))) + + newFeeAmount, overflow := new(u256.Uint).SubOverflow(step.feeAmount, delta) + if overflow { + return state, step, errUnderflow + } + step.feeAmount = newFeeAmount + state.protocolFee.Add(state.protocolFee, delta) + + return state, step, nil +} + +// computeSwapStepInit initializes the computation for a single swap step. +func computeSwapStepInit(state SwapState, pool *Pool, zeroForOne bool) StepComputations { + var step StepComputations + step.sqrtPriceStartX96 = state.sqrtPriceX96 + tickNext, initialized := pool.tickBitmapNextInitializedTickWithInOneWord( + state.tick, + pool.tickSpacing, + zeroForOne, + ) + + step.tickNext = tickNext + step.initialized = initialized + + // prevent overshoot the min/max tick + step.clampTickNext() + + // get the price for the next tick + step.sqrtPriceNextX96 = common.TickMathGetSqrtRatioAtTick(step.tickNext) + return step +} + +// computeTargetSqrtRatio determines the target sqrt price for the current swap step. +func computeTargetSqrtRatio(step StepComputations, sqrtPriceLimitX96 *u256.Uint, zeroForOne bool) *u256.Uint { + if shouldUsePriceLimit(step.sqrtPriceNextX96, sqrtPriceLimitX96, zeroForOne) { + return sqrtPriceLimitX96 + } + return step.sqrtPriceNextX96 +} + +// shouldUsePriceLimit returns true if the price limit should be used instead of the next tick price +func shouldUsePriceLimit(sqrtPriceNext, sqrtPriceLimit *u256.Uint, zeroForOne bool) bool { + isLower := sqrtPriceNext.Lt(sqrtPriceLimit) + isHigher := sqrtPriceNext.Gt(sqrtPriceLimit) + if zeroForOne { + return isLower + } + return isHigher +} + +// computeAmounts calculates the input and output amounts for the current swap step. +func computeAmounts(state SwapState, sqrtRatioTargetX96 *u256.Uint, pool *Pool, step StepComputations) (SwapState, StepComputations) { + sqrtPriceX96Str, amountInStr, amountOutStr, feeAmountStr := plp.SwapMathComputeSwapStepStr( + state.sqrtPriceX96, + sqrtRatioTargetX96, + state.liquidity, + state.amountSpecifiedRemaining, + uint64(pool.fee), + ) + + step.amountIn = u256.MustFromDecimal(amountInStr) + step.amountOut = u256.MustFromDecimal(amountOutStr) + step.feeAmount = u256.MustFromDecimal(feeAmountStr) + + state.SetSqrtPriceX96(sqrtPriceX96Str) + + return state, step +} + +// updateAmounts calculates new remaining and calculated amounts based on the swap step +// For exact input swaps: +// - Decrements remaining input amount by (amountIn + feeAmount) +// - Decrements calculated amount by amountOut +// +// For exact output swaps: +// - Increments remaining output amount by amountOut +// - Increments calculated amount by (amountIn + feeAmount) +func updateAmounts(step StepComputations, state SwapState, exactInput bool) SwapState { + amountInWithFeeU256 := new(u256.Uint).Add(step.amountIn, step.feeAmount) + if amountInWithFeeU256.Gt(u256.MustFromDecimal(consts.MAX_INT256)) { + panic("amountIn + feeAmount overflows int256") + } + amountInWithFee := i256.FromUint256(amountInWithFeeU256) + if step.amountOut.Gt(u256.MustFromDecimal(consts.MAX_INT256)) { + panic("amountOut overflows int256") + } + + if exactInput { + state.amountSpecifiedRemaining = i256.Zero().Sub(state.amountSpecifiedRemaining, amountInWithFee) + state.amountCalculated = i256.Zero().Sub(state.amountCalculated, i256.FromUint256(step.amountOut)) + return state + } else { + state.amountSpecifiedRemaining = i256.Zero().Add(state.amountSpecifiedRemaining, i256.FromUint256(step.amountOut)) + state.amountCalculated = i256.Zero().Add(state.amountCalculated, amountInWithFee) + } + + return state +} + +// tickTransition handles the transition between price ticks during a swap +func tickTransition(step StepComputations, zeroForOne bool, state SwapState, pool *Pool) SwapState { + // ensure existing state to keep immutability + newState := state + + if step.initialized { + fee0, fee1 := u256.Zero(), u256.Zero() + + if zeroForOne { + fee0 = state.feeGrowthGlobalX128 + fee1 = pool.feeGrowthGlobal1X128 + } else { + fee0 = pool.feeGrowthGlobal0X128 + fee1 = state.feeGrowthGlobalX128 + } + + liquidityNet := pool.tickCross(step.tickNext, fee0, fee1) + + if zeroForOne { + liquidityNet = i256.Zero().Neg(liquidityNet) + } + + newState.liquidity = liquidityMathAddDelta(state.liquidity, liquidityNet) + } + + if zeroForOne { + newState.tick = step.tickNext - 1 + } else { + newState.tick = step.tickNext + } + + return newState +} + +func (p *Pool) swapTransfers(zeroForOne bool, payer, recipient std.Address, amount0, amount1 *i256.Int) { + if zeroForOne { + // payer > POOL + p.transferFromAndVerify(payer, consts.POOL_ADDR, p.token0Path, amount0.Abs(), true) + // POOL > recipient + p.transferAndVerify(recipient, p.token1Path, amount1, false) + } else { + // payer > POOL + p.transferFromAndVerify(payer, consts.POOL_ADDR, p.token1Path, amount1.Abs(), false) + // POOL > recipient + p.transferAndVerify(recipient, p.token0Path, amount0, true) + } +} diff --git a/pool/swap_test.gno b/pool/swap_test.gno new file mode 100644 index 000000000..3b73a5139 --- /dev/null +++ b/pool/swap_test.gno @@ -0,0 +1,954 @@ +package pool + +import ( + "std" + "testing" + + "gno.land/p/demo/uassert" + "gno.land/r/demo/users" + + pusers "gno.land/p/demo/users" + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/r/gnoswap/v1/consts" +) + +func TestSaveProtocolFees(t *testing.T) { + tests := []struct { + name string + pool *Pool + amount0 *u256.Uint + amount1 *u256.Uint + want0 *u256.Uint + want1 *u256.Uint + wantFee0 *u256.Uint + wantFee1 *u256.Uint + }{ + { + name: "normal fee deduction", + pool: &Pool{ + protocolFees: ProtocolFees{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + amount0: u256.NewUint(500), + amount1: u256.NewUint(1000), + want0: u256.NewUint(500), + want1: u256.NewUint(1000), + wantFee0: u256.NewUint(500), + wantFee1: u256.NewUint(1000), + }, + { + name: "exact fee deduction (1 deduction)", + pool: &Pool{ + protocolFees: ProtocolFees{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + amount0: u256.NewUint(1000), + amount1: u256.NewUint(2000), + want0: u256.NewUint(999), + want1: u256.NewUint(1999), + wantFee0: u256.NewUint(1), + wantFee1: u256.NewUint(1), + }, + { + name: "0 fee deduction", + pool: &Pool{ + protocolFees: ProtocolFees{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + amount0: u256.NewUint(0), + amount1: u256.NewUint(0), + want0: u256.NewUint(0), + want1: u256.NewUint(0), + wantFee0: u256.NewUint(1000), + wantFee1: u256.NewUint(2000), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got0, got1 := tt.pool.saveProtocolFees(tt.amount0, tt.amount1) + + uassert.Equal(t, got0.ToString(), tt.want0.ToString()) + uassert.Equal(t, got1.ToString(), tt.want1.ToString()) + uassert.Equal(t, tt.pool.protocolFees.token0.ToString(), tt.wantFee0.ToString()) + uassert.Equal(t, tt.pool.protocolFees.token1.ToString(), tt.wantFee1.ToString()) + }) + } +} + +func TestShouldContinueSwap(t *testing.T) { + tests := []struct { + name string + state SwapState + sqrtPriceLimitX96 *u256.Uint + expected bool + }{ + { + name: "Should continue - amount remaining and price not at limit", + state: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000"), + sqrtPriceX96: u256.MustFromDecimal("1000000"), + }, + sqrtPriceLimitX96: u256.MustFromDecimal("900000"), + expected: true, + }, + { + name: "Should stop - no amount remaining", + state: SwapState{ + amountSpecifiedRemaining: i256.Zero(), + sqrtPriceX96: u256.MustFromDecimal("1000000"), + }, + sqrtPriceLimitX96: u256.MustFromDecimal("900000"), + expected: false, + }, + { + name: "Should stop - price at limit", + state: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000"), + sqrtPriceX96: u256.MustFromDecimal("900000"), + }, + sqrtPriceLimitX96: u256.MustFromDecimal("900000"), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := shouldContinueSwap(tt.state, tt.sqrtPriceLimitX96) + uassert.Equal(t, tt.expected, result) + }) + } +} + +func TestUpdateAmounts(t *testing.T) { + tests := []struct { + name string + step StepComputations + state SwapState + exactInput bool + expectedState SwapState + }{ + { + name: "Exact input update", + step: StepComputations{ + amountIn: u256.MustFromDecimal("100"), + amountOut: u256.MustFromDecimal("97"), + feeAmount: u256.MustFromDecimal("3"), + }, + state: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000"), + amountCalculated: i256.Zero(), + }, + exactInput: true, + expectedState: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("897"), // 1000 - (100 + 3) + amountCalculated: i256.MustFromDecimal("-97"), + }, + }, + { + name: "Exact output update", + step: StepComputations{ + amountIn: u256.MustFromDecimal("100"), + amountOut: u256.MustFromDecimal("97"), + feeAmount: u256.MustFromDecimal("3"), + }, + state: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("-1000"), + amountCalculated: i256.Zero(), + }, + exactInput: false, + expectedState: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("-903"), // -1000 + 97 + amountCalculated: i256.MustFromDecimal("103"), // 100 + 3 + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := updateAmounts(tt.step, tt.state, tt.exactInput) + + uassert.True(t, tt.expectedState.amountSpecifiedRemaining.Eq(result.amountSpecifiedRemaining)) + uassert.True(t, tt.expectedState.amountCalculated.Eq(result.amountCalculated)) + }) + } +} + +func TestComputeSwap(t *testing.T) { + mockPool := &Pool{ + token0Path: "token0", + token1Path: "token1", + fee: 3000, // 0.3% + tickSpacing: 60, + slot0: Slot0{ + sqrtPriceX96: u256.MustFromDecimal("1000000000000000000"), // 1.0 + tick: 0, + feeProtocol: 0, + unlocked: true, + }, + liquidity: u256.MustFromDecimal("1000000000000000000"), // 1.0 + protocolFees: ProtocolFees{ + token0: u256.Zero(), + token1: u256.Zero(), + }, + feeGrowthGlobal0X128: u256.Zero(), + feeGrowthGlobal1X128: u256.Zero(), + tickBitmaps: make(TickBitmaps), + ticks: make(Ticks), + positions: make(Positions), + } + + wordPos, _ := tickBitmapPosition(0) + // TODO: use avl + mockPool.tickBitmaps[wordPos] = u256.NewUint(1) + + t.Run("basic swap", func(t *testing.T) { + comp := SwapComputation{ + AmountSpecified: i256.MustFromDecimal("1000000"), // 1.0 token + SqrtPriceLimitX96: u256.MustFromDecimal("1100000000000000000"), // 1.1 + ZeroForOne: true, + ExactInput: true, + InitialState: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000000"), + amountCalculated: i256.Zero(), + sqrtPriceX96: mockPool.slot0.sqrtPriceX96, + tick: mockPool.slot0.tick, + feeGrowthGlobalX128: mockPool.feeGrowthGlobal0X128, + protocolFee: u256.Zero(), + liquidity: mockPool.liquidity, + }, + Cache: SwapCache{ + feeProtocol: 0, + liquidityStart: mockPool.liquidity, + }, + } + + result, err := computeSwap(mockPool, comp) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if result.Amount0.IsZero() { + t.Error("expected non-zero amount0") + } + if result.Amount1.IsZero() { + t.Error("expected non-zero amount1") + } + }) + + t.Run("swap with zero liquidity", func(t *testing.T) { + mockPoolZeroLiq := *mockPool + mockPoolZeroLiq.liquidity = u256.Zero() + + comp := SwapComputation{ + AmountSpecified: i256.MustFromDecimal("1000000"), + SqrtPriceLimitX96: u256.MustFromDecimal("1100000000000000000"), + ZeroForOne: true, + ExactInput: true, + InitialState: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000000"), + amountCalculated: i256.Zero(), + sqrtPriceX96: mockPoolZeroLiq.slot0.sqrtPriceX96, + tick: mockPoolZeroLiq.slot0.tick, + feeGrowthGlobalX128: mockPoolZeroLiq.feeGrowthGlobal0X128, + protocolFee: u256.Zero(), + liquidity: mockPoolZeroLiq.liquidity, + }, + Cache: SwapCache{ + feeProtocol: 0, + liquidityStart: mockPoolZeroLiq.liquidity, + }, + } + + result, err := computeSwap(&mockPoolZeroLiq, comp) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if !result.Amount0.IsZero() || !result.Amount1.IsZero() { + t.Error("expected zero amounts for zero liquidity") + } + }) +} + +func TestSwap_Failures(t *testing.T) { + t.Skip() + const addr = pusers.AddressOrName(consts.ROUTER_ADDR) + + tests := []struct { + name string + setupFn func(t *testing.T) + token0Path string + token1Path string + fee uint32 + recipient std.Address + zeroForOne bool + amountSpecified string + sqrtPriceLimitX96 string + payer std.Address + expectedAmount0 string + expectedAmount1 string + expectError bool + }{ + { + name: "locked pool", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + pool := GetPool(wugnotPath, gnsPath, fee3000) + pool.slot0.unlocked = false + }, + token0Path: wugnotPath, + token1Path: gnsPath, + fee: fee3000, + recipient: users.Resolve(addr), + zeroForOne: true, + amountSpecified: "100", + sqrtPriceLimitX96: "79228162514264337593543950336", + payer: users.Resolve(addr), + expectError: true, + }, + { + name: "zero amount", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + }, + token0Path: wugnotPath, + token1Path: gnsPath, + fee: fee3000, + recipient: users.Resolve(alice), + zeroForOne: true, + amountSpecified: "0", + sqrtPriceLimitX96: "79228162514264337593543950336", + payer: users.Resolve(alice), + expectError: true, + }, + { + name: "zero liquidity", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + pool := GetPool(wugnotPath, gnsPath, fee3000) + pool.liquidity = u256.Zero() + }, + token0Path: wugnotPath, + token1Path: gnsPath, + fee: fee3000, + recipient: users.Resolve(alice), + zeroForOne: true, + amountSpecified: "100", + sqrtPriceLimitX96: "79228162514264337593543950336", + payer: users.Resolve(alice), + expectedAmount0: "0", + expectedAmount1: "0", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetObject(t) + burnTokens(t) + + if tt.setupFn != nil { + tt.setupFn(t) + } + + std.TestSetOrigCaller(tt.payer) + + if tt.expectError { + defer func() { + if r := recover(); r == nil { + t.Errorf("error should be occurred but not occurred") + } + }() + } + + amount0, amount1 := Swap( + tt.token0Path, + tt.token1Path, + tt.fee, + tt.recipient, + tt.zeroForOne, + tt.amountSpecified, + tt.sqrtPriceLimitX96, + tt.payer, + ) + + if !tt.expectError { + uassert.Equal(t, amount0, tt.expectedAmount0) + uassert.Equal(t, amount1, tt.expectedAmount1) + } + }) + } +} + +func TestDrySwap_Failures(t *testing.T) { + mockPool := &Pool{ + token0Path: "token0", + token1Path: "token1", + fee: 3000, + tickSpacing: 60, + slot0: Slot0{ + sqrtPriceX96: u256.MustFromDecimal("1000000000000000000"), // 1.0 + tick: 0, + feeProtocol: 0, + unlocked: true, + }, + liquidity: u256.MustFromDecimal("1000000000000000000"), // 1.0 + balances: Balances{ + token0: u256.MustFromDecimal("1000000000"), + token1: u256.MustFromDecimal("1000000000"), + }, + protocolFees: ProtocolFees{ + token0: u256.Zero(), + token1: u256.Zero(), + }, + feeGrowthGlobal0X128: u256.Zero(), + feeGrowthGlobal1X128: u256.Zero(), + tickBitmaps: make(TickBitmaps), + ticks: make(Ticks), + positions: make(Positions), + } + + originalGetPool := GetPool + defer func() { + GetPool = originalGetPool + }() + GetPool = func(token0Path, token1Path string, fee uint32) *Pool { + return mockPool + } + + tests := []struct { + name string + token0Path string + token1Path string + fee uint32 + zeroForOne bool + amountSpecified string + sqrtPriceLimitX96 string + expectAmount0 string + expectAmount1 string + expectSuccess bool + }{ + { + name: "zero amount token0 to token1", + token0Path: "token0", + token1Path: "token1", + fee: 3000, + zeroForOne: true, + amountSpecified: "0", + sqrtPriceLimitX96: "900000000000000000", + expectAmount0: "0", + expectAmount1: "0", + expectSuccess: false, + }, + { + name: "insufficient balance", + token0Path: "token0", + token1Path: "token1", + fee: 3000, + zeroForOne: false, + amountSpecified: "2000000000", + sqrtPriceLimitX96: "1100000000000000000", + expectAmount0: "0", + expectAmount1: "0", + expectSuccess: false, + }, + { + name: "insufficient balance token1 to token0", + token0Path: "token0", + token1Path: "token1", + fee: 3000, + zeroForOne: false, + amountSpecified: "3000000000", + sqrtPriceLimitX96: "1100000000000000000", + expectAmount0: "0", + expectAmount1: "0", + expectSuccess: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + amount0, amount1, success := DrySwap( + tt.token0Path, + tt.token1Path, + tt.fee, + tt.zeroForOne, + tt.amountSpecified, + tt.sqrtPriceLimitX96, + ) + + uassert.Equal(t, success, tt.expectSuccess) + uassert.Equal(t, amount0, tt.expectAmount0) + uassert.Equal(t, amount1, tt.expectAmount1) + }) + } +} + +func TestSwapAndDrySwapComparison(t *testing.T) { + const addr = pusers.AddressOrName(consts.ROUTER_ADDR) + + tests := []struct { + name string + setupFn func(t *testing.T) + token0Path string + token1Path string + fee uint32 + zeroForOne bool + amountSpecified string + sqrtPriceLimitX96 string + }{ + { + name: "normal swap", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, uint64(1000)) + }, + token0Path: wugnotPath, + token1Path: gnsPath, + fee: fee3000, + zeroForOne: false, + amountSpecified: "100", + sqrtPriceLimitX96: maxSqrtPriceLimitX96, + }, + { + name: "swap - request to swap amount over total liquidty", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, users.Resolve(admin)) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + }, + token0Path: wugnotPath, + token1Path: gnsPath, + fee: fee3000, + zeroForOne: false, + amountSpecified: "2000000000", + sqrtPriceLimitX96: maxSqrtPriceLimitX96, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetObject(t) + burnTokens(t) + + if tt.setupFn != nil { + tt.setupFn(t) + } + + dryAmount0, dryAmount1, drySuccess := DrySwap( + tt.token0Path, + tt.token1Path, + tt.fee, + tt.zeroForOne, + tt.amountSpecified, + tt.sqrtPriceLimitX96, + ) + + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + tt.token0Path, + tt.token1Path, + tt.fee, + users.Resolve(addr), + tt.zeroForOne, + tt.amountSpecified, + tt.sqrtPriceLimitX96, + users.Resolve(addr), + ) + + if !drySuccess { + t.Error("DrySwap failed but actual Swap succeeded") + } + + uassert.NotEqual(t, dryAmount0, "0", "amount0 should not be zero") + uassert.NotEqual(t, dryAmount1, "0", "amount1 should not be zero") + uassert.NotEqual(t, actualAmount0, "0", "amount0 should not be zero") + uassert.NotEqual(t, actualAmount1, "0", "amount1 should not be zero") + + uassert.Equal(t, dryAmount0, actualAmount0, + "Amount0 mismatch between DrySwap and actual Swap") + uassert.Equal(t, dryAmount1, actualAmount1, + "Amount1 mismatch between DrySwap and actual Swap") + }) + } +} + +func TestSwapAndDrySwapComparison_amount_zero(t *testing.T) { + const addr = pusers.AddressOrName(consts.ROUTER_ADDR) + + tests := []struct { + name string + setupFn func(t *testing.T) + action func(t *testing.T) + shouldPanic bool + expected string + }{ + { + name: "zero amount swap - zeroForOne = false", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, uint64(1000)) + }, + action: func(t *testing.T) { + dryAmount0, dryAmount1, drySuccess := DrySwap( + wugnotPath, + gnsPath, + fee3000, + false, + "0", + maxSqrtPriceLimitX96, + ) + uassert.Equal(t, "0", dryAmount0) + uassert.Equal(t, "0", dryAmount1) + uassert.Equal(t, false, drySuccess) + }, + shouldPanic: false, + expected: "[GNOSWAP-POOL-014] invalid swap amount || amountSpecified == 0", + }, + { + name: "zero amount swap - zeroForOne = true", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, uint64(1000)) + }, + action: func(t *testing.T) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + users.Resolve(addr), + true, + "0", + maxSqrtPriceLimitX96, + users.Resolve(addr), + ) + }, + shouldPanic: true, + expected: "[GNOSWAP-POOL-014] invalid swap amount || amountSpecified == 0", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + r := recover() + if r == nil { + if tt.shouldPanic { + t.Errorf(">>> %s: expected panic but got none", tt.name) + return + } + } else { + switch r.(type) { + case string: + if r.(string) != tt.expected { + t.Errorf(">>> %s: got panic %v, want %v", tt.name, r, tt.expected) + } + case error: + if r.(error).Error() != tt.expected { + t.Errorf(">>> %s: got panic %v, want %v", tt.name, r.(error).Error(), tt.expected) + } + default: + t.Errorf(">>> %s: got panic %v, want %v", tt.name, r, tt.expected) + } + } + }() + + resetObject(t) + burnTokens(t) + + if tt.setupFn != nil { + tt.setupFn(t) + } + + if tt.shouldPanic { + tt.action(t) + } else { + tt.action(t) + } + }) + } +} + +func TestSwap_amount_over_liquidity(t *testing.T) { + const addr = pusers.AddressOrName(consts.ROUTER_ADDR) + + tests := []struct { + name string + setupFn func(t *testing.T) + action func(t *testing.T) (string, string) + shouldPanic bool + expected []string + }{ + { + name: "amount over liquidity - zeroForOne = false, token0:20000", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, users.Resolve(admin)) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + }, + action: func(t *testing.T) (string, string) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + users.Resolve(alice), + false, + "20000", + consts.MAX_PRICE, + users.Resolve(addr), + ) + return actualAmount0, actualAmount1 + }, + shouldPanic: false, + expected: []string{"-1989", "2404"}, + }, + { + name: "amount over liquidity - zeroForOne = false, token1:-20000", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, users.Resolve(admin)) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + }, + action: func(t *testing.T) (string, string) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + users.Resolve(alice), + false, + "-20000", + consts.MAX_PRICE, + users.Resolve(addr), + ) + return actualAmount0, actualAmount1 + }, + shouldPanic: false, + expected: []string{"-1989", "2404"}, + }, + { + name: "amount over liquidity - zeroForOne = true, token0:20000", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, users.Resolve(admin)) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + TokenApprove(t, wugnotPath, addr, pool, maxApprove) + }, + action: func(t *testing.T) (string, string) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + users.Resolve(alice), + true, + "20000", + consts.MIN_PRICE, + users.Resolve(addr), + ) + return actualAmount0, actualAmount1 + }, + shouldPanic: false, + expected: []string{"1045", "-990"}, + }, + { + name: "amount over liquidity - zeroForOne = true, token1:-20000", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, users.Resolve(admin)) + TokenFaucet(t, gnsPath, addr) + TokenFaucet(t, wugnotPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + TokenApprove(t, wugnotPath, addr, pool, maxApprove) + }, + action: func(t *testing.T) (string, string) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + users.Resolve(alice), + true, + "-20000", + consts.MIN_PRICE, + users.Resolve(addr), + ) + return actualAmount0, actualAmount1 + }, + shouldPanic: false, + expected: []string{"1045", "-990"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetObject(t) + burnTokens(t) + + if tt.setupFn != nil { + tt.setupFn(t) + } + + if tt.shouldPanic { + tt.action(t) + } else { + amount0, amount1 := tt.action(t) + uassert.Equal(t, tt.expected[0], amount0) + uassert.Equal(t, tt.expected[1], amount1) + } + }) + } +} + +func TestSwap_EXACTIN_OUT(t *testing.T) { + const addr = pusers.AddressOrName(consts.ROUTER_ADDR) + + tests := []struct { + name string + setupFn func(t *testing.T) + action func(t *testing.T) (string, string) + shouldPanic bool + expected []string + }{ + { + name: "EXACT IN - zeroForOne = false, token1:200", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, users.Resolve(admin)) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + }, + action: func(t *testing.T) (string, string) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + users.Resolve(alice), + false, + "200", + consts.MAX_PRICE, + users.Resolve(addr), + ) + return actualAmount0, actualAmount1 + }, + shouldPanic: false, + expected: []string{"-195", "200"}, + }, + { + name: "EXACT OUT - zeroForOne = false, token0:-200", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, users.Resolve(admin)) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + }, + action: func(t *testing.T) (string, string) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + users.Resolve(alice), + false, + "-200", + consts.MAX_PRICE, + users.Resolve(addr), + ) + return actualAmount0, actualAmount1 + }, + shouldPanic: false, + expected: []string{"-200", "208"}, + }, + { + name: "EXACT IN - zeroForOne = true, token0:200", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, users.Resolve(admin)) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + TokenApprove(t, wugnotPath, addr, pool, maxApprove) + }, + action: func(t *testing.T) (string, string) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + users.Resolve(alice), + true, + "200", + consts.MIN_PRICE, + users.Resolve(addr), + ) + return actualAmount0, actualAmount1 + }, + shouldPanic: false, + expected: []string{"200", "-195"}, + }, + { + name: "EXACT OUT - zeroForOne = true, token1:-200", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, users.Resolve(admin)) + TokenFaucet(t, gnsPath, addr) + TokenFaucet(t, wugnotPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + TokenApprove(t, wugnotPath, addr, pool, maxApprove) + }, + action: func(t *testing.T) (string, string) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + users.Resolve(alice), + true, + "-200", + consts.MIN_PRICE, + users.Resolve(addr), + ) + return actualAmount0, actualAmount1 + }, + shouldPanic: false, + expected: []string{"208", "-200"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetObject(t) + burnTokens(t) + + if tt.setupFn != nil { + tt.setupFn(t) + } + + if tt.shouldPanic { + tt.action(t) + } else { + amount0, amount1 := tt.action(t) + uassert.Equal(t, tt.expected[0], amount0) + uassert.Equal(t, tt.expected[1], amount1) + } + }) + } +} diff --git a/pool/tick.gno b/pool/tick.gno index d325826fb..d0db635c2 100644 --- a/pool/tick.gno +++ b/pool/tick.gno @@ -19,17 +19,47 @@ func calculateMaxLiquidityPerTick(tickSpacing int32) *u256.Uint { return new(u256.Uint).Div(u256.MustFromDecimal(consts.MAX_UINT128), u256.NewUint(numTicks)) } +func getFeeGrowthBelowX128( + tickLower, tickCurrent int32, + feeGrowthGlobal0X128, feeGrowthGlobal1X128 *u256.Uint, + lowerTick TickInfo, +) (*u256.Uint, *u256.Uint) { + if tickCurrent >= tickLower { + return lowerTick.feeGrowthOutside0X128, lowerTick.feeGrowthOutside1X128 + } + + below0X128 := new(u256.Uint).Sub(feeGrowthGlobal0X128, lowerTick.feeGrowthOutside0X128) + below1X128 := new(u256.Uint).Sub(feeGrowthGlobal1X128, lowerTick.feeGrowthOutside1X128) + + return below0X128, below1X128 +} + +func getFeeGrowthAboveX128( + tickUpper, tickCurrent int32, + feeGrowthGlobal0X128, feeGrowthGlobal1X128 *u256.Uint, + upperTick TickInfo, +) (*u256.Uint, *u256.Uint) { + if tickCurrent < tickUpper { + return upperTick.feeGrowthOutside0X128, upperTick.feeGrowthOutside1X128 + } + + above0X128 := new(u256.Uint).Sub(feeGrowthGlobal0X128, upperTick.feeGrowthOutside0X128) + above1X128 := new(u256.Uint).Sub(feeGrowthGlobal1X128, upperTick.feeGrowthOutside1X128) + + return above0X128, above1X128 +} + // calculateFeeGrowthInside calculates the fee growth inside a tick range, // and returns the fee growth inside for both tokens. -func (pool *Pool) calculateFeeGrowthInside( +func (p *Pool) calculateFeeGrowthInside( tickLower int32, tickUpper int32, tickCurrent int32, feeGrowthGlobal0X128 *u256.Uint, feeGrowthGlobal1X128 *u256.Uint, ) (*u256.Uint, *u256.Uint) { - lower := pool.getTick(tickLower) - upper := pool.getTick(tickUpper) + lower := p.getTick(tickLower) + upper := p.getTick(tickUpper) feeGrowthBelow0X128, feeGrowthBelow1X128 := getFeeGrowthBelowX128(tickLower, tickCurrent, feeGrowthGlobal0X128, feeGrowthGlobal1X128, lower) feeGrowthAbove0X128, feeGrowthAbove1X128 := getFeeGrowthAboveX128(tickUpper, tickCurrent, feeGrowthGlobal0X128, feeGrowthGlobal1X128, upper) @@ -41,7 +71,7 @@ func (pool *Pool) calculateFeeGrowthInside( } // tickUpdate updates a tick's state and returns whether the tick was flipped. -func (pool *Pool) tickUpdate( +func (p *Pool) tickUpdate( tick int32, tickCurrent int32, liquidityDelta *i256.Int, // int128 @@ -54,7 +84,7 @@ func (pool *Pool) tickUpdate( feeGrowthGlobal0X128 = feeGrowthGlobal0X128.NilToZero() feeGrowthGlobal1X128 = feeGrowthGlobal1X128.NilToZero() - thisTick := pool.getTick(tick) + thisTick := p.getTick(tick) liquidityGrossBefore := thisTick.liquidityGross liquidityGrossAfter := liquidityMathAddDelta(liquidityGrossBefore, liquidityDelta) @@ -85,64 +115,35 @@ func (pool *Pool) tickUpdate( thisTick.liquidityNet = i256.Zero().Add(thisTick.liquidityNet, liquidityDelta) } - pool.ticks[tick] = thisTick + p.ticks[tick] = thisTick return flipped } // tickCross updates a tick's state when it is crossed and returns the liquidity net. -func (pool *Pool) tickCross( +func (p *Pool) tickCross( tick int32, feeGrowthGlobal0X128 *u256.Uint, feeGrowthGlobal1X128 *u256.Uint, ) *i256.Int { - thisTick := pool.getTick(tick) + thisTick := p.getTick(tick) thisTick.feeGrowthOutside0X128 = new(u256.Uint).Sub(feeGrowthGlobal0X128, thisTick.feeGrowthOutside0X128) thisTick.feeGrowthOutside1X128 = new(u256.Uint).Sub(feeGrowthGlobal1X128, thisTick.feeGrowthOutside1X128) - pool.ticks[tick] = thisTick + p.ticks[tick] = thisTick - return thisTick.liquidityNet + return thisTick.liquidityNet.Clone() } -func (pool *Pool) getTick(tick int32) TickInfo { - tickInfo := pool.ticks[tick] +// getTick returns a tick's state. +func (p *Pool) getTick(tick int32) TickInfo { + tickInfo := p.ticks[tick] tickInfo.init() return tickInfo } -func getFeeGrowthBelowX128( - tickLower, tickCurrent int32, - feeGrowthGlobal0X128, feeGrowthGlobal1X128 *u256.Uint, - lowerTick TickInfo, -) (*u256.Uint, *u256.Uint) { - if tickCurrent >= tickLower { - return lowerTick.feeGrowthOutside0X128, lowerTick.feeGrowthOutside1X128 - } - - below0X128 := new(u256.Uint).Sub(feeGrowthGlobal0X128, lowerTick.feeGrowthOutside0X128) - below1X128 := new(u256.Uint).Sub(feeGrowthGlobal1X128, lowerTick.feeGrowthOutside1X128) - - return below0X128, below1X128 -} - -func getFeeGrowthAboveX128( - tickUpper, tickCurrent int32, - feeGrowthGlobal0X128, feeGrowthGlobal1X128 *u256.Uint, - upperTick TickInfo, -) (*u256.Uint, *u256.Uint) { - if tickCurrent < tickUpper { - return upperTick.feeGrowthOutside0X128, upperTick.feeGrowthOutside1X128 - } - - above0X128 := new(u256.Uint).Sub(feeGrowthGlobal0X128, upperTick.feeGrowthOutside0X128) - above1X128 := new(u256.Uint).Sub(feeGrowthGlobal1X128, upperTick.feeGrowthOutside1X128) - - return above0X128, above1X128 -} - // receiver getters func (p *Pool) GetTickLiquidityGross(tick int32) *u256.Uint { return p.mustGetTick(tick).liquidityGross diff --git a/pool/tick_bitmap.gno b/pool/tick_bitmap.gno index 45b4ba9e9..5c4dbf0c8 100644 --- a/pool/tick_bitmap.gno +++ b/pool/tick_bitmap.gno @@ -18,7 +18,7 @@ func tickBitmapPosition(tick int32) (int16, uint8) { // tickBitmapFlipTick flips tthe bit corresponding to the given tick // in the pool's tick bitmap. -func (pool *Pool) tickBitmapFlipTick( +func (p *Pool) tickBitmapFlipTick( tick int32, tickSpacing int32, ) { @@ -32,12 +32,12 @@ func (pool *Pool) tickBitmapFlipTick( wordPos, bitPos := tickBitmapPosition(tick / tickSpacing) mask := new(u256.Uint).Lsh(u256.One(), uint(bitPos)) - pool.setTickBitmap(wordPos, new(u256.Uint).Xor(pool.getTickBitmap(wordPos), mask)) + p.setTickBitmap(wordPos, new(u256.Uint).Xor(p.getTickBitmap(wordPos), mask)) } // tickBitmapNextInitializedTickWithInOneWord finds the next initialized tick within // one word of the bitmap. -func (pool *Pool) tickBitmapNextInitializedTickWithInOneWord( +func (p *Pool) tickBitmapNextInitializedTickWithInOneWord( tick int32, tickSpacing int32, lte bool, @@ -49,7 +49,7 @@ func (pool *Pool) tickBitmapNextInitializedTickWithInOneWord( wordPos, bitPos := getWordAndBitPos(compress, lte) mask := getMaskBit(uint(bitPos), lte) - masked := new(u256.Uint).And(pool.getTickBitmap(wordPos), mask) + masked := new(u256.Uint).And(p.getTickBitmap(wordPos), mask) initialized := !(masked.IsZero()) nextTick := getNextTick(lte, initialized, compress, bitPos, tickSpacing, masked) @@ -58,17 +58,17 @@ func (pool *Pool) tickBitmapNextInitializedTickWithInOneWord( // getTickBitmap gets the tick bitmap for the given word position // if the tick bitmap is not initialized, initialize it to zero -func (pool *Pool) getTickBitmap(wordPos int16) *u256.Uint { - if pool.tickBitmaps[wordPos] == nil { - pool.tickBitmaps[wordPos] = u256.Zero() +func (p *Pool) getTickBitmap(wordPos int16) *u256.Uint { + if p.tickBitmaps[wordPos] == nil { + p.tickBitmaps[wordPos] = u256.Zero() } - return pool.tickBitmaps[wordPos] + return p.tickBitmaps[wordPos] } // setTickBitmap sets the tick bitmap for the given word position -func (pool *Pool) setTickBitmap(wordPos int16, bitmap *u256.Uint) { - pool.tickBitmaps[wordPos] = bitmap +func (p *Pool) setTickBitmap(wordPos int16, bitmap *u256.Uint) { + p.tickBitmaps[wordPos] = bitmap } // getWordAndBitPos gets tick's wordPos and bitPos depending on the swap direction diff --git a/pool/type.gno b/pool/type.gno index 5c54b89ec..5c01d20a5 100644 --- a/pool/type.gno +++ b/pool/type.gno @@ -3,6 +3,8 @@ package pool import ( "std" + "gno.land/p/demo/ufmt" + "gno.land/r/gnoswap/v1/common" "gno.land/r/gnoswap/v1/consts" @@ -57,16 +59,41 @@ func newProtocolFees() ProtocolFees { } } +// ModifyPositionParams repersents the parameters for modifying a liquidity position. +// This structure is used internally both `Mint` and `Burn` operation to manage +// the liquidity positions. type ModifyPositionParams struct { - owner std.Address // address that owns the position + // owner is the address that owns the position + owner std.Address + + // tickLower and atickUpper define the price range + // The actual price range is calculated as 1.0001^tick + // This allows for precision in price range while using integer math. - // the tick range of the position, bounds are included - tickLower int32 - tickUpper int32 + tickLower int32 // lower tick of the position + tickUpper int32 // upper tick of the position - liquidityDelta *i256.Int // any change in liquidity + // liquidityDelta represents the change in liquidity + // Positive for minting, negative for burning + liquidityDelta *i256.Int } +// newModifyPositionParams creates a new `ModifyPositionParams` instance. +// This is used to preare parameters for the `modifyPosition` function, +// which handles both minting and burning of liquidity positions. +// +// Parameters: +// - owner: address that will own (or owns) the position +// - tickLower: lower tick bound of the position +// - tickUpper: upper tick bound of the position +// - liquidityDelta: amount of liquidity to add (positive) or remove (negative) +// +// The tick parameters represent prices as powers of 1.0001: +// - actual_price = 1.0001^tick +// - For example, tick = 100 means price = 1.0001^100 +// +// Returns: +// - ModifyPositionParams: a new instance of ModifyPositionParams func newModifyPositionParams( owner std.Address, tickLower int32, @@ -81,6 +108,7 @@ func newModifyPositionParams( } } +// SwapCache holds data that remains constant throughout a swap. type SwapCache struct { feeProtocol uint8 // protocol fee for the input token liquidityStart *u256.Uint // liquidity at the beginning of the swap @@ -96,6 +124,9 @@ func newSwapCache( } } +// SwapState tracks the changing values during a swap. +// This type helps manage the state transiktions that occur as the swap progresses +// accross different price ranges. type SwapState struct { amountSpecifiedRemaining *i256.Int // amount remaining to be swapped in/out of the input/output token amountCalculated *i256.Int // amount already swapped out/in of the output/input token @@ -139,6 +170,9 @@ func (s *SwapState) SetProtocolFee(fee *u256.Uint) { s.protocolFee = fee } +// StepComputations holds intermediate values used during a single step of a swap. +// Each step represents movement from the current tick to the next initialized tick +// or the target price, whichever comes first. type StepComputations struct { sqrtPriceStartX96 *u256.Uint // price at the beginning of the step tickNext int32 // next tick to swap to from the current tick in the swap direction @@ -154,14 +188,14 @@ func (step *StepComputations) initSwapStep(state SwapState, pool *Pool, zeroForO step.sqrtPriceStartX96 = state.sqrtPriceX96 step.tickNext, step.initialized = pool.tickBitmapNextInitializedTickWithInOneWord( state.tick, - pool.tickSpacing, + pool.tickSpacing, zeroForOne, ) // prevent overshoot the min/max tick step.clampTickNext() - // get the price for the next tick + // get the price for the next tick step.sqrtPriceNextX96 = common.TickMathGetSqrtRatioAtTick(step.tickNext) } @@ -178,12 +212,18 @@ func (step *StepComputations) clampTickNext() { type PositionInfo struct { liquidity *u256.Uint // amount of liquidity owned by this position - // fee growth per unit of liquidity as of the last update to liquidity or fees owed + // Fee growth per unit of liquidity as of the last update + // Used to calculate uncollected fees for token0 feeGrowthInside0LastX128 *u256.Uint + + // Fee growth per unit of liquidity as of the last update + // Used to calculate uncollected fees for token1 feeGrowthInside1LastX128 *u256.Uint - // fees owed to the position owner in token0/token1 + // accumulated fees in token0 waiting to be collected tokensOwed0 *u256.Uint + + // accumulated fees in token1 waiting to be collected tokensOwed1 *u256.Uint } @@ -195,6 +235,8 @@ func (p *PositionInfo) init() { p.tokensOwed1 = p.tokensOwed1.NilToZero() } +// TickInfo stores information about a specific tick in the pool. +// TIcks represent discrete price points that can be used as boundaries for positions. type TickInfo struct { liquidityGross *u256.Uint // total position liquidity that references this tick liquidityNet *i256.Int // amount of net liquidity added (subtracted) when tick is crossed from left to right (right to left) @@ -225,9 +267,11 @@ func (t *TickInfo) init() { t.secondsPerLiquidityOutsideX128 = t.secondsPerLiquidityOutsideX128.NilToZero() } -type Ticks map[int32]TickInfo // tick => TickInfo -type TickBitmaps map[int16]*u256.Uint // tick(wordPos) => bitmap(tickWord ^ mask) -type Positions map[string]PositionInfo // positionKey => PositionInfo +type ( + Ticks map[int32]TickInfo // tick => TickInfo + TickBitmaps map[int16]*u256.Uint // tick(wordPos) => bitmap(tickWord ^ mask) + Positions map[string]PositionInfo // positionKey => PositionInfo +) // type Pool describes a single Pool's state // A pool is identificed with a unique key (token0, token1, fee), where token0 < token1 @@ -284,3 +328,80 @@ func newPool(poolInfo *createPoolParams) *Pool { positions: Positions{}, } } + +func (p *Pool) PoolGetToken0Path() string { + return p.token0Path +} + +func (p *Pool) PoolGetToken1Path() string { + return p.token1Path +} + +func (p *Pool) PoolGetFee() uint32 { + return p.fee +} + +func (p *Pool) PoolGetBalanceToken0() *u256.Uint { + return p.balances.token0 +} + +func (p *Pool) PoolGetBalanceToken1() *u256.Uint { + return p.balances.token1 +} + +func (p *Pool) PoolGetTickSpacing() int32 { + return p.tickSpacing +} + +func (p *Pool) PoolGetMaxLiquidityPerTick() *u256.Uint { + return p.maxLiquidityPerTick +} + +func (p *Pool) PoolGetSlot0() Slot0 { + return p.slot0 +} + +func (p *Pool) PoolGetSlot0SqrtPriceX96() *u256.Uint { + return p.slot0.sqrtPriceX96 +} + +func (p *Pool) PoolGetSlot0Tick() int32 { + return p.slot0.tick +} + +func (p *Pool) PoolGetSlot0FeeProtocol() uint8 { + return p.slot0.feeProtocol +} + +func (p *Pool) PoolGetSlot0Unlocked() bool { + return p.slot0.unlocked +} + +func (p *Pool) PoolGetFeeGrowthGlobal0X128() *u256.Uint { + return p.feeGrowthGlobal0X128 +} + +func (p *Pool) PoolGetFeeGrowthGlobal1X128() *u256.Uint { + return p.feeGrowthGlobal1X128 +} + +func (p *Pool) PoolGetProtocolFeesToken0() *u256.Uint { + return p.protocolFees.token0 +} + +func (p *Pool) PoolGetProtocolFeesToken1() *u256.Uint { + return p.protocolFees.token1 +} + +func (p *Pool) PoolGetLiquidity() *u256.Uint { + return p.liquidity +} + +func mustGetPool(poolPath string) *Pool { + pool, exist := pools[poolPath] + if !exist { + panic(addDetailToError(errDataNotFound, + ufmt.Sprintf("poolPath(%s) does not exist", poolPath))) + } + return pool +} diff --git a/pool/utils.gno b/pool/utils.gno index b8d627a9d..e45116217 100644 --- a/pool/utils.gno +++ b/pool/utils.gno @@ -3,11 +3,25 @@ package pool import ( "std" + "gno.land/p/demo/ufmt" + pusers "gno.land/p/demo/users" u256 "gno.land/p/gnoswap/uint256" ) +func safeConvertToUint64(value *u256.Uint) (uint64, error) { + res, overflow := value.Uint64WithOverflow() + if overflow { + return 0, ufmt.Errorf( + "%v: amount(%s) overflows uint64 range", + errOutOfRange, value.ToString(), + ) + } + + return res, nil +} + func a2u(addr std.Address) pusers.AddressOrName { if !addr.IsValid() { panic(addDetailToError(