diff --git a/pool/_RPC_dry.gno b/pool/_RPC_dry.gno index e589c139a..5ffe30c86 100644 --- a/pool/_RPC_dry.gno +++ b/pool/_RPC_dry.gno @@ -58,9 +58,11 @@ func DrySwap( feeGrowthGlobalX128 = pool.feeGrowthGlobal1X128 } - pool.slot0.unlocked = false + slot0 := pool.slot0 + slot0.unlocked = false + cache := newSwapCache(feeProtocol, pool.liquidity) - state := pool.newSwapState(amountSpecified, feeGrowthGlobalX128, cache.liquidityStart) // TODO: feeGrowthGlobalX128.Clone() or NOT + state := newSwapState(amountSpecified.Clone(), feeGrowthGlobalX128.Clone(), cache.liquidityStart.Clone(), slot0) exactInput := amountSpecified.Gt(i256.Zero()) diff --git a/pool/errors.gno b/pool/errors.gno index 5999c6461..2031f4ad9 100644 --- a/pool/errors.gno +++ b/pool/errors.gno @@ -32,6 +32,7 @@ var ( errInvalidTickAndTickSpacing = errors.New("[GNOSWAP-POOL-022] invalid tick and tick spacing requested") errInvalidAddress = errors.New("[GNOSWAP-POOL-023] invalid address") errInvalidTickRange = errors.New("[GNOSWAP-POOL-024] tickLower is greater than tickUpper") + errUnderflow = errors.New("[GNOSWAP-POOL-025] underflow") // TODO: make as common error code ) // addDetailToError adds detail to an error message diff --git a/pool/pool.gno b/pool/pool.gno index e54bc49fc..1ab2b8696 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -24,13 +24,13 @@ func Mint( recipient std.Address, tickLower int32, tickUpper int32, - _liquidityAmount string, // uint128 + _liquidityAmount string, positionCaller std.Address, -) (string, string) { // uint256 x2 +) (string, string) { common.IsHalted() if common.GetLimitCaller() { caller := std.PrevRealm().Addr() - if caller != consts.POSITION_ADDR { + if err := common.PositionOnly(caller); err != nil { panic(addDetailToError( errNoPermission, ufmt.Sprintf("pool.gno__Mint() || only position(%s) can call pool mint(), called from %s", consts.POSITION_ADDR, caller.String()), @@ -47,20 +47,14 @@ func Mint( } pool := GetPool(token0Path, token1Path, fee) - _, amount0, amount1 := pool.modifyPosition( - ModifyPositionParams{ - recipient, // owner - tickLower, // tickLower - tickUpper, // tickUpper - i256.FromUint256(liquidityAmount), // liquidityDelta - }, - ) + position := newModifyPositionParams(recipient, tickLower, tickUpper, i256.FromUint256(liquidityAmount)) + _, amount0, amount1 := pool.modifyPosition(position) - if amount0.Gt(i256.Zero()) { + if amount0.Gt(u256.Zero()) { pool.transferFromAndVerify(positionCaller, consts.POOL_ADDR, pool.token0Path, amount0, true) } - if amount1.Gt(i256.Zero()) { + if amount1.Gt(u256.Zero()) { pool.transferFromAndVerify(positionCaller, consts.POOL_ADDR, pool.token1Path, amount1, false) } @@ -77,12 +71,12 @@ func Burn( fee uint32, tickLower int32, tickUpper int32, - _liquidityAmount string, // uint128 + liquidityAmount string, // uint128 ) (string, string) { // uint256 x2 common.IsHalted() caller := std.PrevRealm().Addr() if common.GetLimitCaller() { - if caller != consts.POSITION_ADDR { + if err := common.PositionOnly(caller); err != nil { panic(addDetailToError( errNoPermission, ufmt.Sprintf("pool.gno__Burn() || only position(%s) can call pool burn(), called from %s", consts.POSITION_ADDR, caller.String()), @@ -90,21 +84,13 @@ func Burn( } } - liquidityAmount := u256.MustFromDecimal(_liquidityAmount) + liqAmount := u256.MustFromDecimal(liquidityAmount) pool := GetPool(token0Path, token1Path, fee) - position, amount0Int, amount1Int := pool.modifyPosition( // in256 x2 - ModifyPositionParams{ - caller, // msg.sender - tickLower, - tickUpper, - i256.Zero().Neg(i256.FromUint256(liquidityAmount)), - }, - ) - - amount0 := amount0Int.Abs() - amount1 := amount1Int.Abs() + liqDelta := i256.Zero().Neg(i256.FromUint256(liqAmount)) + posParams := newModifyPositionParams(caller, tickLower, tickUpper, liqDelta) + position, amount0, amount1 := pool.modifyPosition(posParams) if amount0.Gt(u256.Zero()) || amount1.Gt(u256.Zero()) { position.tokensOwed0 = new(u256.Uint).Add(position.tokensOwed0, amount0) @@ -129,13 +115,13 @@ func Collect( recipient std.Address, tickLower int32, tickUpper int32, - _amount0Requested string, // uint128 - _amount1Requested string, // uint128 -) (string, string) { // uint128 x2 + amount0Requested string, + amount1Requested string, +) (string, string) { common.IsHalted() if common.GetLimitCaller() { caller := std.PrevRealm().Addr() - if caller != consts.POSITION_ADDR { + if err := common.PositionOnly(caller); err != nil { panic(addDetailToError( errNoPermission, ufmt.Sprintf("pool.gno__Collect() || only position(%s) can call pool collect(), called from %s", consts.POSITION_ADDR, caller.String()), @@ -143,9 +129,6 @@ func Collect( } } - amount0Requested := u256.MustFromDecimal(_amount0Requested) - amount1Requested := u256.MustFromDecimal(_amount1Requested) - pool := GetPool(token0Path, token1Path, fee) positionKey := positionGetKey(std.PrevRealm().Addr(), tickLower, tickUpper) @@ -157,22 +140,16 @@ func Collect( )) } - // Smallest of three: amount0Requested, position.tokensOwed0, pool.balances.token0 - amount0 := u256Min(amount0Requested, position.tokensOwed0) - amount0 = u256Min(amount0, pool.balances.token0) + var amount0, amount1 *u256.Uint - // Update state first then transfer - position.tokensOwed0 = new(u256.Uint).Sub(position.tokensOwed0, amount0) - pool.balances.token0 = new(u256.Uint).Sub(pool.balances.token0, amount0) + // Smallest of three: amount0Requested, position.tokensOwed0, pool.balances.token0 + amount0Req := u256.MustFromDecimal(amount0Requested) + amount0, position.tokensOwed0, pool.balances.token0 = collectToken(amount0Req, position.tokensOwed0, pool.balances.token0) transferByRegisterCall(pool.token0Path, recipient, amount0.Uint64()) // Smallest of three: amount0Requested, position.tokensOwed0, pool.balances.token0 - amount1 := u256Min(amount1Requested, position.tokensOwed1) - amount1 = u256Min(amount1, pool.balances.token1) - - // Update state first then transfer - position.tokensOwed1 = new(u256.Uint).Sub(position.tokensOwed1, amount1) - pool.balances.token1 = new(u256.Uint).Sub(pool.balances.token1, amount1) + amount1Req := u256.MustFromDecimal(amount1Requested) + amount1, position.tokensOwed1, pool.balances.token1 = collectToken(amount1Req, position.tokensOwed1, pool.balances.token1) transferByRegisterCall(pool.token1Path, recipient, amount1.Uint64()) pool.positions[positionKey] = position @@ -180,6 +157,45 @@ func Collect( return amount0.ToString(), amount1.ToString() } +// collectToken handles the collection of a single token type (token0 or token1) +func collectToken( + amountReq, tokensOwed, poolBalance *u256.Uint, +) (amount, newTokensOwed, newPoolBalance *u256.Uint) { + // find smallest of three amounts + amount = u256Min(amountReq, tokensOwed) + amount = u256Min(amount, poolBalance) + + // value for update state + newTokensOwed = new(u256.Uint).Sub(tokensOwed, amount) + newPoolBalance = new(u256.Uint).Sub(poolBalance, amount) + + return amount, newTokensOwed, newPoolBalance +} + +// SwapResult encapsulates all state changes that occur as a result of a swap +// This type ensure all state transitions are atomic and can be applied at once. +type SwapResult struct { + Amount0 *i256.Int + Amount1 *i256.Int + NewSqrtPrice *u256.Uint + NewTick int32 + NewLiquidity *u256.Uint + NewProtocolFees ProtocolFees + FeeGrowthGlobal0X128 *u256.Uint + FeeGrowthGlobal1X128 *u256.Uint + SwapFee *u256.Uint +} + +// SwapComputation encapsulates pure computation logic for swap +type SwapComputation struct { + AmountSpecified *i256.Int + SqrtPriceLimitX96 *u256.Uint + ZeroForOne bool + ExactInput bool + InitialState SwapState + Cache SwapCache +} + // Swap swaps token0 for token1, or token1 for token0 // Returns swapped amount0, amount1 in string // ref: https://docs.gnoswap.io/contracts/pool/pool.gno#swap @@ -189,14 +205,14 @@ func Swap( fee uint32, recipient std.Address, zeroForOne bool, - _amountSpecified string, // int256 - _sqrtPriceLimitX96 string, // uint160 + amountSpecified string, + sqrtPriceLimitX96 string, payer std.Address, // router -) (string, string) { // int256 x2 +) (string, string) { common.IsHalted() if common.GetLimitCaller() { caller := std.PrevRealm().Addr() - if caller != consts.ROUTER_ADDR { + if err := common.RouterOnly(caller); err != nil { panic(addDetailToError( errNoPermission, ufmt.Sprintf("pool.gno__Swap() || only router(%s) can call pool swap(), called from %s", consts.ROUTER_ADDR, caller.String()), @@ -204,248 +220,413 @@ func Swap( } } - if _amountSpecified == "0" { + if amountSpecified == "0" { panic(addDetailToError( errInvalidSwapAmount, ufmt.Sprintf("pool.gno__Swap() || amountSpecified == 0"), )) } - amountSpecified := i256.MustFromDecimal(_amountSpecified) - sqrtPriceLimitX96 := u256.MustFromDecimal(_sqrtPriceLimitX96) - pool := GetPool(token0Path, token1Path, fee) + slot0Start := pool.slot0 + if !slot0Start.unlocked { + panic(errLockedPool) + } - if !(slot0Start.unlocked) { - panic(addDetailToError( - errLockedPool, - ufmt.Sprintf("pool.gno__Swap() || slot0Start.unlocked(false) must be unlocked)"), - )) + slot0Start.unlocked = false + defer func() { slot0Start.unlocked = true }() + + amounts := i256.MustFromDecimal(amountSpecified) + sqrtPriceLimit := u256.MustFromDecimal(sqrtPriceLimitX96) + + validatePriceLimits(pool, zeroForOne, sqrtPriceLimit) + + feeGrowthGlobalX128 := getFeeGrowthGlobal(pool, zeroForOne) + feeProtocol := getFeeProtocol(slot0Start, zeroForOne) + cache := newSwapCache(feeProtocol, pool.liquidity) + + state := newSwapState(amounts, feeGrowthGlobalX128, cache.liquidityStart, pool.slot0) + + comp := SwapComputation{ + AmountSpecified: amounts, + SqrtPriceLimitX96: sqrtPriceLimit, + ZeroForOne: zeroForOne, + ExactInput: amounts.Gt(i256.Zero()), + InitialState: state, + Cache: cache, + } + + result, err := computeSwap(pool, comp) + if err != nil { + panic(err) + } + + applySwapResult(pool, result) + + // actual swap + pool.swapTransfers(zeroForOne, payer, recipient, result.Amount0, result.Amount1) + + prevAddr, prevRealm := getPrev() + + std.Emit( + "Swap", + "prevAddr", prevAddr, + "prevRealm", prevRealm, + "poolPath", GetPoolPath(token0Path, token1Path, fee), + "zeroForOne", ufmt.Sprintf("%t", zeroForOne), + "amountSpecified", amountSpecified, + "sqrtPriceLimitX96", sqrtPriceLimitX96, + "payer", payer.String(), + "recipient", recipient.String(), + "internal_amount0", result.Amount0.ToString(), + "internal_amount1", result.Amount1.ToString(), + "internal_protocolFee0", pool.protocolFees.token0.ToString(), + "internal_protocolFee1", pool.protocolFees.token1.ToString(), + "internal_swapFee", result.SwapFee.ToString(), + "internal_sqrtPriceX96", pool.slot0.sqrtPriceX96.ToString(), + ) + + return result.Amount0.ToString(), result.Amount1.ToString() +} + +// computeSwap performs the core swap computation without modifying pool state +// The function follows these state transitions: +// 1. Initial State: Provided by `SwapComputation.InitialState` +// 2. Stepping State: For each step: +// - Compute next tick and price target +// - Calculate amounts and fees +// - Update state (remaining amount, fees, liquidity) +// - Handle tick transitions if necessary +// 3. Final State: Aggregated in SwapResult +// +// The computation continues until either: +// - The entire amount is consumed (`amountSpecifiedRemaining` = 0) +// - The price limit is reached (`sqrtPriceX96` = `sqrtPriceLimitX96`) +// +// Returns an error if the computation fails at any step +func computeSwap(pool *Pool, comp SwapComputation) (*SwapResult, error) { + state := comp.InitialState + swapFee := u256.Zero() + + + var newFee *u256.Uint + var err error + + // Compute swap steps until completion + for shouldContinueSwap(state, comp.SqrtPriceLimitX96) { + state, newFee, err = computeSwapStep(state, pool, comp.ZeroForOne, comp.SqrtPriceLimitX96, comp.ExactInput, comp.Cache, swapFee) + if err != nil { + return nil, err + } + swapFee = newFee + } + + // Calculate final amounts + amount0 := state.amountCalculated + amount1 := i256.Zero().Sub(comp.AmountSpecified, state.amountSpecifiedRemaining) + if comp.ZeroForOne == comp.ExactInput { + amount0, amount1 = amount1, amount0 + } + + // Prepare result + result := &SwapResult{ + Amount0: amount0, + Amount1: amount1, + NewSqrtPrice: state.sqrtPriceX96, + NewTick: state.tick, + NewLiquidity: state.liquidity, + NewProtocolFees: ProtocolFees{ + token0: pool.protocolFees.token0, + token1: pool.protocolFees.token1, + }, + FeeGrowthGlobal0X128: pool.feeGrowthGlobal0X128, + FeeGrowthGlobal1X128: pool.feeGrowthGlobal1X128, + SwapFee: swapFee, + } + + // Update protocol fees if necessary + if comp.ZeroForOne { + if state.protocolFee.Gt(u256.Zero()) { + result.NewProtocolFees.token0 = new(u256.Uint).Add(result.NewProtocolFees.token0, state.protocolFee) + } + result.FeeGrowthGlobal0X128 = state.feeGrowthGlobalX128 + } else { + if state.protocolFee.Gt(u256.Zero()) { + result.NewProtocolFees.token1 = new(u256.Uint).Add(result.NewProtocolFees.token1, state.protocolFee) + } + result.FeeGrowthGlobal1X128 = state.feeGrowthGlobalX128 } - var feeProtocol uint8 - var feeGrowthGlobalX128 *u256.Uint + return result, nil +} +// applySwapResult updates pool state with computed results. +// All state changes are applied at once to maintain consistency +func applySwapResult(pool *Pool, result *SwapResult) { + pool.slot0.sqrtPriceX96 = result.NewSqrtPrice + pool.slot0.tick = result.NewTick + pool.liquidity = result.NewLiquidity + pool.protocolFees = result.NewProtocolFees + pool.feeGrowthGlobal0X128 = result.FeeGrowthGlobal0X128 + pool.feeGrowthGlobal1X128 = result.FeeGrowthGlobal1X128 +} + +// validatePriceLimits ensures the provided price limit is valid for the swap direction +// The function enforces that: +// For zeroForOne (selling token0): +// - Price limit must be below current price +// - Price limit must be above MIN_SQRT_RATIO +// For !zeroForOne (selling token1): +// - Price limit must be above current price +// - Price limit must be below MAX_SQRT_RATIO +func validatePriceLimits(pool *Pool, zeroForOne bool, sqrtPriceLimitX96 *u256.Uint) { if zeroForOne { minSqrtRatio := u256.MustFromDecimal(consts.MIN_SQRT_RATIO) - cond1 := sqrtPriceLimitX96.Lt(slot0Start.sqrtPriceX96) + cond1 := sqrtPriceLimitX96.Lt(pool.slot0.sqrtPriceX96) cond2 := sqrtPriceLimitX96.Gt(minSqrtRatio) if !(cond1 && cond2) { panic(addDetailToError( errPriceOutOfRange, - ufmt.Sprintf("pool.gno__Swap() || sqrtPriceLimitX96(%s) < slot0Start.sqrtPriceX96(%s) && sqrtPriceLimitX96(%s) > consts.MIN_SQRT_RATIO(%s)", sqrtPriceLimitX96.ToString(), slot0Start.sqrtPriceX96.ToString(), sqrtPriceLimitX96.ToString(), consts.MIN_SQRT_RATIO), + ufmt.Sprintf("pool.gno__Swap() || sqrtPriceLimitX96(%s) < slot0Start.sqrtPriceX96(%s) && sqrtPriceLimitX96(%s) > consts.MIN_SQRT_RATIO(%s)", + sqrtPriceLimitX96.ToString(), + pool.slot0.sqrtPriceX96.ToString(), + sqrtPriceLimitX96.ToString(), + consts.MIN_SQRT_RATIO), )) } - feeProtocol = slot0Start.feeProtocol % 16 - feeGrowthGlobalX128 = pool.feeGrowthGlobal0X128 - } else { maxSqrtRatio := u256.MustFromDecimal(consts.MAX_SQRT_RATIO) - cond1 := sqrtPriceLimitX96.Gt(slot0Start.sqrtPriceX96) + cond1 := sqrtPriceLimitX96.Gt(pool.slot0.sqrtPriceX96) cond2 := sqrtPriceLimitX96.Lt(maxSqrtRatio) if !(cond1 && cond2) { panic(addDetailToError( errPriceOutOfRange, - ufmt.Sprintf("pool.gno__Swap() || sqrtPriceLimitX96(%s) > slot0Start.sqrtPriceX96(%s) && sqrtPriceLimitX96(%s) < consts.MAX_SQRT_RATIO(%s)", sqrtPriceLimitX96.ToString(), slot0Start.sqrtPriceX96.ToString(), sqrtPriceLimitX96.ToString(), consts.MAX_SQRT_RATIO), + ufmt.Sprintf("pool.gno__Swap() || sqrtPriceLimitX96(%s) > slot0Start.sqrtPriceX96(%s) && sqrtPriceLimitX96(%s) < consts.MAX_SQRT_RATIO(%s)", + sqrtPriceLimitX96.ToString(), + pool.slot0.sqrtPriceX96.ToString(), + sqrtPriceLimitX96.ToString(), + consts.MAX_SQRT_RATIO), )) } - - feeProtocol = slot0Start.feeProtocol / 16 - feeGrowthGlobalX128 = pool.feeGrowthGlobal1X128 } +} - pool.slot0.unlocked = false - cache := newSwapCache(feeProtocol, pool.liquidity) - state := pool.newSwapState(amountSpecified, feeGrowthGlobalX128, cache.liquidityStart) +// getFeeProtocol returns the appropriate fee protocol based on zero for one +func getFeeProtocol(slot0 Slot0, zeroForOne bool) uint8 { + if zeroForOne { + return slot0.feeProtocol % 16 + } + return slot0.feeProtocol / 16 +} - exactInput := amountSpecified.Gt(i256.Zero()) +// getFeeGrowthGlobal returns the appropriate fee growth global based on zero for one +func getFeeGrowthGlobal(pool *Pool, zeroForOne bool) *u256.Uint { + if zeroForOne { + return pool.feeGrowthGlobal0X128 + } + return pool.feeGrowthGlobal1X128 +} - // continue swapping as long as we haven't used the entire input/output and haven't reached the price limit - swapFee := u256.Zero() - for !(state.amountSpecifiedRemaining.IsZero()) && !(state.sqrtPriceX96.Eq(sqrtPriceLimitX96)) { - var step StepComputations - step.sqrtPriceStartX96 = state.sqrtPriceX96 - - step.tickNext, step.initialized = pool.tickBitmapNextInitializedTickWithInOneWord( - state.tick, - pool.tickSpacing, - zeroForOne, - ) +func shouldContinueSwap(state SwapState, sqrtPriceLimitX96 *u256.Uint) bool { + return !(state.amountSpecifiedRemaining.IsZero()) && !(state.sqrtPriceX96.Eq(sqrtPriceLimitX96)) +} - // ensure that we do not overshoot the min/max tick, as the tick bitmap is not aware of these bounds - if step.tickNext < consts.MIN_TICK { - step.tickNext = consts.MIN_TICK - } else if step.tickNext > consts.MAX_TICK { - step.tickNext = consts.MAX_TICK +// computeSwapStep executes a single step of swap and returns new state +func computeSwapStep( + state SwapState, + pool *Pool, + zeroForOne bool, + sqrtPriceLimitX96 *u256.Uint, + exactInput bool, + cache SwapCache, + swapFee *u256.Uint, +) (SwapState, *u256.Uint, error) { + step := computeSwapStepInit(state, pool, zeroForOne) + + // determining the price target for this step + sqrtRatioTargetX96 := computeTargetSqrtRatio(step, sqrtPriceLimitX96, zeroForOne) + + // computing the amounts to be swapped at this step + var newState SwapState + var err error + + newState, step = computeAmounts(state, sqrtRatioTargetX96, pool, step) + newState = updateAmounts(step, newState, exactInput) + + // if the protocol fee is on, calculate how much is owed, + // decrement fee amount, and increment protocol fee + if cache.feeProtocol > 0 { + newState, err = updateFeeProtocol(step, cache.feeProtocol, newState) + if err != nil { + return state, nil, err } + } - // get the price for the next tick - step.sqrtPriceNextX96 = common.TickMathGetSqrtRatioAtTick(step.tickNext) + // update global fee tracker + if newState.liquidity.Gt(u256.Zero()) { + update := u256.MulDiv(step.feeAmount, u256.MustFromDecimal(consts.Q128), newState.liquidity) + newState.SetFeeGrowthGlobalX128(new(u256.Uint).Add(newState.feeGrowthGlobalX128, update)) + } - isLower := step.sqrtPriceNextX96.Lt(sqrtPriceLimitX96) - isHigher := step.sqrtPriceNextX96.Gt(sqrtPriceLimitX96) + // handling tick transitions + if newState.sqrtPriceX96.Eq(step.sqrtPriceNextX96) { + newState = tickTransition(step, zeroForOne, newState, pool) + } - var sqrtRatioTargetX96 *u256.Uint - if (zeroForOne && isLower) || (!zeroForOne && isHigher) { - sqrtRatioTargetX96 = sqrtPriceLimitX96 - } else { - sqrtRatioTargetX96 = step.sqrtPriceNextX96 - } + if newState.sqrtPriceX96.Neq(step.sqrtPriceStartX96) { + newState.SetTick(common.TickMathGetTickAtSqrtRatio(newState.sqrtPriceX96)) + } - _sqrtPriceX96Str, _amountInStr, _amountOutStr, _feeAmountStr := plp.SwapMathComputeSwapStepStr( - state.sqrtPriceX96, - sqrtRatioTargetX96, - state.liquidity, - state.amountSpecifiedRemaining, - uint64(pool.fee), - ) - state.sqrtPriceX96 = u256.MustFromDecimal(_sqrtPriceX96Str) - step.amountIn = u256.MustFromDecimal(_amountInStr) - step.amountOut = u256.MustFromDecimal(_amountOutStr) - step.feeAmount = u256.MustFromDecimal(_feeAmountStr) - - amountInWithFee := i256.FromUint256(new(u256.Uint).Add(step.amountIn, step.feeAmount)) - if exactInput { - state.amountSpecifiedRemaining = i256.Zero().Sub(state.amountSpecifiedRemaining, amountInWithFee) - state.amountCalculated = i256.Zero().Sub(state.amountCalculated, i256.FromUint256(step.amountOut)) - } else { - state.amountSpecifiedRemaining = i256.Zero().Add(state.amountSpecifiedRemaining, i256.FromUint256(step.amountOut)) - state.amountCalculated = i256.Zero().Add(state.amountCalculated, amountInWithFee) - } + newSwapFee := new(u256.Uint).Add(swapFee, step.feeAmount) - // if the protocol fee is on, calculate how much is owed, decrement feeAmount, and increment protocolFee - if cache.feeProtocol > 0 { - delta := new(u256.Uint).Div(step.feeAmount, u256.NewUint(uint64(cache.feeProtocol))) - step.feeAmount = new(u256.Uint).Sub(step.feeAmount, delta) - state.protocolFee = new(u256.Uint).Add(state.protocolFee, delta) - } + return newState, newSwapFee, nil +} - // update global fee tracker - if state.liquidity.Gt(u256.Zero()) { - update := u256.MulDiv(step.feeAmount, u256.MustFromDecimal(consts.Q128), state.liquidity) - state.feeGrowthGlobalX128 = new(u256.Uint).Add(state.feeGrowthGlobalX128, update) - } - swapFee = new(u256.Uint).Add(swapFee, step.feeAmount) - - // shift tick if we reached the next price - if state.sqrtPriceX96.Eq(step.sqrtPriceNextX96) { - // if the tick is initialized, run the tick transition - if step.initialized { - var fee0, fee1 *u256.Uint - - // check for the placeholder value, which we replace with the actual value the first time the swap crosses an initialized tick - if zeroForOne { - fee0 = state.feeGrowthGlobalX128 - fee1 = pool.feeGrowthGlobal1X128 - } else { - fee0 = pool.feeGrowthGlobal0X128 - fee1 = state.feeGrowthGlobalX128 - } - - liquidityNet := pool.tickCross( - step.tickNext, - fee0, - fee1, - ) - - // if we're moving leftward, we interpret liquidityNet as the opposite sign - if zeroForOne { - liquidityNet = i256.Zero().Neg(liquidityNet) - } - - state.liquidity = liquidityMathAddDelta(state.liquidity, liquidityNet) - } - - if zeroForOne { - state.tick = step.tickNext - 1 - } else { - state.tick = step.tickNext - } - } else if !(state.sqrtPriceX96.Eq(step.sqrtPriceStartX96)) { - // recompute unless we're on a lower tick boundary (i.e. already transitioned ticks), and haven't moved - state.tick = common.TickMathGetTickAtSqrtRatio(state.sqrtPriceX96) - } +// updateFeeProtocol calculates and updates protocol fees for the current step. +func updateFeeProtocol(step StepComputations, feeProtocol uint8, state SwapState) (SwapState, error) { + delta := step.feeAmount + delta.Div(delta, u256.NewUint(uint64(feeProtocol))) + + newFeeAmount, overflow := new(u256.Uint).SubOverflow(step.feeAmount, delta) + if overflow { + return state, errUnderflow } - // END LOOP + step.feeAmount = newFeeAmount + state.protocolFee.Add(state.protocolFee, delta) - // update pool sqrtPrice - pool.slot0.sqrtPriceX96 = state.sqrtPriceX96 + return state, nil +} - // update tick if it changed - if state.tick != slot0Start.tick { - pool.slot0.tick = state.tick - } +// computeSwapStepInit initializes the computation for a single swap step. +func computeSwapStepInit(state SwapState, pool *Pool, zeroForOne bool) StepComputations { + var step StepComputations + step.sqrtPriceStartX96 = state.sqrtPriceX96 + tickNext, initialized := pool.tickBitmapNextInitializedTickWithInOneWord( + state.tick, + pool.tickSpacing, + zeroForOne, + ) + + step.tickNext = tickNext + step.initialized = initialized + + // prevent overshoot the min/max tick + step.clampTickNext() + + // get the price for the next tick + step.sqrtPriceNextX96 = common.TickMathGetSqrtRatioAtTick(step.tickNext) + return step +} - // update liquidity if it changed - if !(cache.liquidityStart.Eq(state.liquidity)) { - pool.liquidity = state.liquidity +// computeTargetSqrtRatio determines the target sqrt price for the current swap step. +func computeTargetSqrtRatio(step StepComputations, sqrtPriceLimitX96 *u256.Uint, zeroForOne bool) *u256.Uint { + if shouldUsePriceLimit(step.sqrtPriceNextX96, sqrtPriceLimitX96, zeroForOne) { + return sqrtPriceLimitX96 } + return step.sqrtPriceNextX96 +} - // update fee growth global and, if necessary, protocol fees - // overflow is acceptable, protocol has to withdraw before it hits MAX_UINT256 fees +// shouldUsePriceLimit returns true if the price limit should be used instead of the next tick price +func shouldUsePriceLimit(sqrtPriceNext, sqrtPriceLimit *u256.Uint, zeroForOne bool) bool { + isLower := sqrtPriceNext.Lt(sqrtPriceLimit) + isHigher := sqrtPriceNext.Gt(sqrtPriceLimit) if zeroForOne { - pool.feeGrowthGlobal0X128 = state.feeGrowthGlobalX128 - if state.protocolFee.Gt(u256.Zero()) { - pool.protocolFees.token0 = new(u256.Uint).Add(pool.protocolFees.token0, state.protocolFee) - } - } else { - pool.feeGrowthGlobal1X128 = state.feeGrowthGlobalX128 - if state.protocolFee.Gt(u256.Zero()) { - pool.protocolFees.token1 = new(u256.Uint).Add(pool.protocolFees.token1, state.protocolFee) - } + return isLower } + return isHigher +} - var amount0, amount1 *i256.Int - if zeroForOne == exactInput { - amount0 = i256.Zero().Sub(amountSpecified, state.amountSpecifiedRemaining) - amount1 = state.amountCalculated - } else { - amount0 = state.amountCalculated - amount1 = i256.Zero().Sub(amountSpecified, state.amountSpecifiedRemaining) +// computeAmounts calculates the input and output amounts for the current swap step. +func computeAmounts(state SwapState, sqrtRatioTargetX96 *u256.Uint, pool *Pool, step StepComputations) (SwapState, StepComputations) { + sqrtPriceX96Str, amountInStr, amountOutStr, feeAmountStr := plp.SwapMathComputeSwapStepStr( + state.sqrtPriceX96, + sqrtRatioTargetX96, + state.liquidity, + state.amountSpecifiedRemaining, + uint64(pool.fee), + ) + + step.amountIn = u256.MustFromDecimal(amountInStr) + step.amountOut = u256.MustFromDecimal(amountOutStr) + step.feeAmount = u256.MustFromDecimal(feeAmountStr) + + state.SetSqrtPriceX96(sqrtPriceX96Str) + + return state, step +} + +// updateAmounts calculates new remaining and calculated amounts based on the swap step +// For exact input swaps: +// - Decrements remaining input amount by (amountIn + feeAmount) +// - Decrements calculated amount by amountOut +// For exact output swaps: +// - Increments remaining output amount by amountOut +// - Increments calculated amount by (amountIn + feeAmount) +func updateAmounts(step StepComputations, state SwapState, exactInput bool) SwapState { + amountInWithFee := i256.FromUint256(new(u256.Uint).Add(step.amountIn, step.feeAmount)) + if exactInput { + state.amountSpecifiedRemaining = i256.Zero().Sub(state.amountSpecifiedRemaining, amountInWithFee) + state.amountCalculated = i256.Zero().Sub(state.amountCalculated, i256.FromUint256(step.amountOut)) + return state } + state.amountSpecifiedRemaining = i256.Zero().Add(state.amountSpecifiedRemaining, i256.FromUint256(step.amountOut)) + state.amountCalculated = i256.Zero().Add(state.amountCalculated, amountInWithFee) - // actual swap - if zeroForOne { - // payer > POOL - pool.transferFromAndVerify(payer, consts.POOL_ADDR, pool.token0Path, amount0, true) + return state +} - // POOL > recipient - pool.transferAndVerify(recipient, pool.token1Path, amount1, false) +// tickTransition handles the transition between price ticks during a swap +func tickTransition(step StepComputations, zeroForOne bool, state SwapState, pool *Pool) SwapState { + // ensure existing state to keep immutability + newState := state - } else { - // payer > POOL - pool.transferFromAndVerify(payer, consts.POOL_ADDR, pool.token1Path, amount1, false) + if step.initialized { + var fee0, fee1 *u256.Uint + + if zeroForOne { + fee0 = state.feeGrowthGlobalX128 + fee1 = pool.feeGrowthGlobal1X128 + } else { + fee0 = pool.feeGrowthGlobal0X128 + fee1 = state.feeGrowthGlobalX128 + } + + liquidityNet := pool.tickCross(step.tickNext, fee0, fee1) - // POOL > recipient - pool.transferAndVerify(recipient, pool.token0Path, amount0, true) + if zeroForOne { + liquidityNet = i256.Zero().Neg(liquidityNet) + } + newState.liquidity = liquidityMathAddDelta(state.liquidity, liquidityNet) } - prevAddr, prevRealm := getPrev() + if zeroForOne { + newState.tick = step.tickNext - 1 + } else { + newState.tick = step.tickNext + } - std.Emit( - "Swap", - "prevAddr", prevAddr, - "prevRealm", prevRealm, - "poolPath", GetPoolPath(token0Path, token1Path, fee), - "zeroForOne", ufmt.Sprintf("%t", zeroForOne), - "amountSpecified", _amountSpecified, - "sqrtPriceLimitX96", _sqrtPriceLimitX96, - "payer", payer.String(), - "recipient", recipient.String(), - "internal_amount0", amount0.ToString(), - "internal_amount1", amount1.ToString(), - "internal_protocolFee0", pool.protocolFees.token0.ToString(), - "internal_protocolFee1", pool.protocolFees.token1.ToString(), - "internal_swapFee", swapFee.ToString(), - "internal_sqrtPriceX96", pool.slot0.sqrtPriceX96.ToString(), - ) + return newState +} - pool.slot0.unlocked = true - return amount0.ToString(), amount1.ToString() +func (pool *Pool) swapTransfers(zeroForOne bool, payer, recipient std.Address, amount0, amount1 *i256.Int) { + var targetTokenPath string + var amount *i256.Int + + if zeroForOne { + targetTokenPath = pool.token0Path + amount = amount0 + } else { + targetTokenPath = pool.token1Path + amount = amount1 + } + + // payer -> POOL -> recipient + pool.transferFromAndVerify(payer, consts.POOL_ADDR, targetTokenPath, amount.Abs(), zeroForOne) + pool.transferAndVerify(recipient, targetTokenPath, amount, !zeroForOne) } // SetFeeProtocolByAdmin sets the fee protocol for all pools @@ -476,10 +657,7 @@ func SetFeeProtocolByAdmin( // Only governance contract can execute this function via proposal // Also it will be applied to new created pools // ref: https://docs.gnoswap.io/contracts/pool/pool.gno#setfeeprotocol -func SetFeeProtocol( - feeProtocol0 uint8, - feeProtocol1 uint8, -) { +func SetFeeProtocol(feeProtocol0, feeProtocol1 uint8) { caller := std.PrevRealm().Addr() if err := common.GovernanceOnly(caller); err != nil { panic(err) @@ -498,17 +676,12 @@ func SetFeeProtocol( ) } -func setFeeProtocol( - feeProtocol0 uint8, - feeProtocol1 uint8, -) uint8 { +func setFeeProtocol(feeProtocol0, feeProtocol1 uint8) uint8 { common.IsHalted() - fee0Cond := feeProtocol0 == 0 || (feeProtocol0 >= 4 && feeProtocol0 <= 10) - fee1Cond := feeProtocol1 == 0 || (feeProtocol1 >= 4 && feeProtocol1 <= 10) - if !(fee0Cond && fee1Cond) { + if err := validateFeeProtocol(feeProtocol0, feeProtocol1); err != nil { panic(addDetailToError( - errInvalidProtocolFeePct, + err, ufmt.Sprintf("pool.gno__setFeeProtocol() || expected (feeProtocol0(%d) == 0 || (feeProtocol0(%d) >= 4 && feeProtocol0(%d) <= 10)) && (feeProtocol1(%d) == 0 || (feeProtocol1(%d) >= 4 && feeProtocol1(%d) <= 10))", feeProtocol0, feeProtocol0, feeProtocol0, feeProtocol1, feeProtocol1, feeProtocol1), )) } @@ -526,6 +699,17 @@ func setFeeProtocol( return newFee } +func validateFeeProtocol(feeProtocol0, feeProtocol1 uint8) error { + if !isValidFeeProtocolValue(feeProtocol0) || !isValidFeeProtocolValue(feeProtocol1) { + return errInvalidProtocolFeePct + } + return nil +} + +func isValidFeeProtocolValue(value uint8) bool { + return value == 0 || (value >= 4 && value <= 10) +} + // CollectProtocolByAdmin collects protocol fees for the given pool that accumulated while it was being used for swap // Returns collected amount0, amount1 in string func CollectProtocolByAdmin( @@ -533,9 +717,9 @@ func CollectProtocolByAdmin( token1Path string, fee uint32, recipient std.Address, - _amount0Requested string, // uint128 - _amount1Requested string, // uint128 -) (string, string) { // uint128 x2 + amount0Requested string, + amount1Requested string, +) (string, string) { caller := std.PrevRealm().Addr() if err := common.AdminOnly(caller); err != nil { panic(err) @@ -546,8 +730,8 @@ func CollectProtocolByAdmin( token1Path, fee, recipient, - _amount0Requested, - _amount1Requested, + amount0Requested, + amount1Requested, ) prevAddr, prevRealm := getPrev() @@ -575,9 +759,9 @@ func CollectProtocol( token1Path string, fee uint32, recipient std.Address, - _amount0Requested string, // uint128 - _amount1Requested string, // uint128 -) (string, string) { // uint128 x2 + amount0Requested string, + amount1Requested string, +) (string, string) { caller := std.PrevRealm().Addr() if err := common.GovernanceOnly(caller); err != nil { panic(err) @@ -588,8 +772,8 @@ func CollectProtocol( token1Path, fee, recipient, - _amount0Requested, - _amount1Requested, + amount0Requested, + amount1Requested, ) prevAddr, prevRealm := getPrev() @@ -613,18 +797,18 @@ func collectProtocol( token1Path string, fee uint32, recipient std.Address, - _amount0Requested string, // uint128 - _amount1Requested string, // uint128 -) (string, string) { // uint128 x2 + amount0Requested string, + amount1Requested string, +) (string, string) { common.IsHalted() - amount0Requested := u256.MustFromDecimal(_amount0Requested) - amount1Requested := u256.MustFromDecimal(_amount1Requested) - pool := GetPool(token0Path, token1Path, fee) - amount0 := u256Min(amount0Requested, pool.protocolFees.token0) - amount1 := u256Min(amount1Requested, pool.protocolFees.token1) + amount0Req := u256.MustFromDecimal(amount0Requested) + amount1Req := u256.MustFromDecimal(amount1Requested) + + amount0 := u256Min(amount0Req, pool.protocolFees.token0) + amount1 := u256Min(amount1Req, pool.protocolFees.token1) amount0, amount1 = pool.saveProtocolFees(amount0, amount1) uAmount0 := amount0.Uint64() @@ -662,72 +846,102 @@ func (pool *Pool) transferAndVerify( amount *i256.Int, isToken0 bool, ) { - if amount.IsZero() { - return - } - - // must be negative to send token from pool to user - // as point of view from pool, it is negative - if !amount.IsNeg() { + if amount.Sign() != -1 { panic(addDetailToError( errMustBeNegative, ufmt.Sprintf("pool.gno__transferAndVerify() || amount(%s) must be negative", amount.ToString()), )) } - // check pool.balances + absAmount := amount.Abs() + + token0 := pool.balances.token0 + token1 := pool.balances.token1 + + if err := validatePoolBalance(token0, token1, absAmount, isToken0); err != nil { + panic(err) + } + amountUint64, err := checkAmountRange(absAmount) + if err != nil { + panic(err) + } + + transferByRegisterCall(tokenPath, to, amountUint64) + + newBalance, err := updatePoolBalance(token0, token1, absAmount, isToken0) + if err != nil { + panic(err) + } + if isToken0 { - if pool.balances.token0.Lt(amount.Abs()) { - panic(addDetailToError( - errTransferFailed, - ufmt.Sprintf("pool.gno__transferAndVerify() || pool.balances.token0(%s) >= amount.Abs(%s)", pool.balances.token0.ToString(), amount.Abs().ToString()), - )) - } + pool.balances.token0 = newBalance } else { - if pool.balances.token1.Lt(amount.Abs()) { - panic(addDetailToError( - errTransferFailed, - ufmt.Sprintf("pool.gno__transferAndVerify() || pool.balances.token1(%s) >= amount.Abs(%s)", pool.balances.token1.ToString(), amount.Abs().ToString()), - )) - } + pool.balances.token1 = newBalance } +} - amountUint64 := checkAmountRange(amount) - - // try sending - // will panic if following conditions are met: - // - POOL does not have enough balance - // - token is not registered - transferByRegisterCall(tokenPath, to, amountUint64) +func validatePoolBalance(token0, token1, amount *u256.Uint, isToken0 bool) error { + if isToken0 { + if token0.Lt(amount) { + return ufmt.Errorf( + "%s || token0(%s) >= amount(%s)", + errTransferFailed.Error(), token0.ToString(), amount.ToString(), + ) + } + return nil + } + if token1.Lt(amount) { + return ufmt.Errorf( + "%s || token1(%s) >= amount(%s)", + errTransferFailed.Error(), token1.ToString(), amount.ToString(), + ) + } + return nil +} - // update pool.balances +func updatePoolBalance( + token0, token1, amount *u256.Uint, + isToken0 bool, +) (*u256.Uint, error) { var overflow bool + var newBalance *u256.Uint + if isToken0 { - pool.balances.token0, overflow = new(u256.Uint).SubOverflow(pool.balances.token0, amount.Abs()) - if overflow { - panic(addDetailToError( - errTransferFailed, - ufmt.Sprintf("pool.gno__transferAndVerify() || cannot decrease, pool.balances.token0(%s) - amount(%s)", pool.balances.token0.ToString(), amount.Abs().ToString()), - )) - } - } else { - pool.balances.token1, overflow = new(u256.Uint).SubOverflow(pool.balances.token1, amount.Abs()) - if pool.balances.token1.Lt(u256.Zero()) { - panic(addDetailToError( - errTransferFailed, - ufmt.Sprintf("pool.gno__transferAndVerify() || cannot decrease, pool.balances.token1(%s) - amount(%s)", pool.balances.token1.ToString(), amount.Abs().ToString()), - )) + newBalance, overflow = new(u256.Uint).SubOverflow(token0, amount) + if isBalanceOverflowOrNegative(overflow, newBalance) { + return nil, ufmt.Errorf( + "%s || cannot decrease, token0(%s) - amount(%s)", + errTransferFailed.Error(), token0.ToString(), amount.ToString(), + ) } + return newBalance, nil } + + newBalance, overflow = new(u256.Uint).SubOverflow(token1, amount) + if isBalanceOverflowOrNegative(overflow, newBalance) { + return nil, ufmt.Errorf( + "%s || cannot decrease, token1(%s) - amount(%s)", + errTransferFailed.Error(), token1.ToString(), amount.ToString(), + ) + } + return newBalance, nil +} + +func isBalanceOverflowOrNegative(overflow bool, newBalance *u256.Uint) bool { + return overflow || newBalance.Lt(u256.Zero()) } func (pool *Pool) transferFromAndVerify( from, to std.Address, tokenPath string, - amount *i256.Int, + amount *u256.Uint, isToken0 bool, ) { - amountUint64 := checkAmountRange(amount) + absAmount := amount + amountUint64, err := checkAmountRange(absAmount) + if err != nil { + panic(err) + } // try sending // will panic if following conditions are met: @@ -736,26 +950,24 @@ func (pool *Pool) transferFromAndVerify( // - token is not registered transferFromByRegisterCall(tokenPath, from, to, amountUint64) - // update pool.balances + // update pool balances if isToken0 { - pool.balances.token0 = new(u256.Uint).Add(pool.balances.token0, amount.Abs()) + pool.balances.token0 = new(u256.Uint).Add(pool.balances.token0, absAmount) } else { - pool.balances.token1 = new(u256.Uint).Add(pool.balances.token1, amount.Abs()) + pool.balances.token1 = new(u256.Uint).Add(pool.balances.token1, absAmount) } } -func checkAmountRange(amount *i256.Int) uint64 { - // check amount is in uint64 range - amountAbs := amount.Abs() - amountUint64, overflow := amountAbs.Uint64WithOverflow() +func checkAmountRange(amount *u256.Uint) (uint64, error) { + res, overflow := amount.Uint64WithOverflow() if overflow { - panic(addDetailToError( - errOutOfRange, - ufmt.Sprintf("pool.gno__checkAmountRange() || amountAbs(%s) overflows uint64 range", amountAbs.ToString()), - )) + return 0, ufmt.Errorf( + "%s || amount(%s) overflows uint64 range", + errOutOfRange.Error(), amount.ToString(), + ) } - return amountUint64 + return res, nil } // receiver getters diff --git a/pool/pool_test.gno b/pool/pool_test.gno new file mode 100644 index 000000000..af02987d9 --- /dev/null +++ b/pool/pool_test.gno @@ -0,0 +1,702 @@ +package pool + +import ( + "std" + "testing" + + "gno.land/p/demo/testutils" + "gno.land/p/demo/uassert" + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" + "gno.land/r/gnoswap/v1/consts" +) + +func TestMint(t *testing.T) { + token0Path := "test_token0" + token1Path := "test_token1" + fee := uint32(3000) + recipient := testutils.TestAddress("recipient") + tickLower := int32(-100) + tickUpper := int32(100) + liquidityAmount := "100000" + + t.Run("unauthorized caller mint should fail", func(t *testing.T) { + unauthorized := testutils.TestAddress("unauthorized") + defer func() { + if r := recover(); r == nil { + t.Error("unauthorized caller mint should fail") + } + }() + + Mint(token0Path, token1Path, fee, recipient, tickLower, tickUpper, liquidityAmount, unauthorized) + }) + + t.Run("mint with 0 liquidity should fail", func(t *testing.T) { + authorized := consts.POSITION_ADDR + defer func() { + if r := recover(); r == nil { + t.Error("mint with 0 liquidity should fail") + } + }() + + Mint(token0Path, token1Path, fee, recipient, tickLower, tickUpper, "0", authorized) + }) +} + +func TestBurn(t *testing.T) { + // Setup + originalGetPool := GetPool + defer func() { + GetPool = originalGetPool + }() + + // Mock data + mockCaller := consts.POSITION_ADDR + mockPosition := PositionInfo{ + liquidity: u256.NewUint(1000), + tokensOwed0: u256.NewUint(0), + tokensOwed1: u256.NewUint(0), + } + mockPool := &Pool{ + positions: make(map[string]PositionInfo), + } + + GetPool = func(token0Path, token1Path string, fee uint32) *Pool { + return mockPool + } + + tests := []struct { + name string + liquidityAmount string + tickLower int32 + tickUpper int32 + expectedAmount0 string + expectedAmount1 string + expectPanic bool + }{ + { + name: "successful burn", + liquidityAmount: "500", + tickLower: -100, + tickUpper: 100, + expectedAmount0: "100", + expectedAmount1: "200", + }, + { + name: "zero liquidity", + liquidityAmount: "0", + tickLower: -100, + tickUpper: 100, + expectPanic: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.name == "successful burn" { + t.Skip("skipping until find better way to test this") + } + + // setup position for this test + posKey := positionGetKey(mockCaller, tt.tickLower, tt.tickUpper) + mockPool.positions[posKey] = mockPosition + + if tt.expectPanic { + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic but got none") + } + }() + } + + amount0, amount1 := Burn( + "token0", + "token1", + 3000, + tt.tickLower, + tt.tickUpper, + tt.liquidityAmount, + ) + + if !tt.expectPanic { + if amount0 != tt.expectedAmount0 { + t.Errorf("expected amount0 %s, got %s", tt.expectedAmount0, amount0) + } + if amount1 != tt.expectedAmount1 { + t.Errorf("expected amount1 %s, got %s", tt.expectedAmount1, amount1) + } + + newPosition := mockPool.positions[posKey] + if newPosition.tokensOwed0.IsZero() { + t.Error("expected tokensOwed0 to be updated") + } + if newPosition.tokensOwed1.IsZero() { + t.Error("expected tokensOwed1 to be updated") + } + } + }) + } +} + +func TestSaveProtocolFees(t *testing.T) { + tests := []struct { + name string + pool *Pool + amount0 *u256.Uint + amount1 *u256.Uint + want0 *u256.Uint + want1 *u256.Uint + wantFee0 *u256.Uint + wantFee1 *u256.Uint + }{ + { + name: "normal fee deduction", + pool: &Pool{ + protocolFees: ProtocolFees{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + amount0: u256.NewUint(500), + amount1: u256.NewUint(1000), + want0: u256.NewUint(500), + want1: u256.NewUint(1000), + wantFee0: u256.NewUint(500), + wantFee1: u256.NewUint(1000), + }, + { + name: "exact fee deduction (1 deduction)", + pool: &Pool{ + protocolFees: ProtocolFees{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + amount0: u256.NewUint(1000), + amount1: u256.NewUint(2000), + want0: u256.NewUint(999), + want1: u256.NewUint(1999), + wantFee0: u256.NewUint(1), + wantFee1: u256.NewUint(1), + }, + { + name: "0 fee deduction", + pool: &Pool{ + protocolFees: ProtocolFees{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + amount0: u256.NewUint(0), + amount1: u256.NewUint(0), + want0: u256.NewUint(0), + want1: u256.NewUint(0), + wantFee0: u256.NewUint(1000), + wantFee1: u256.NewUint(2000), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got0, got1 := tt.pool.saveProtocolFees(tt.amount0, tt.amount1) + + uassert.Equal(t, got0.ToString(), tt.want0.ToString()) + uassert.Equal(t, got1.ToString(), tt.want1.ToString()) + uassert.Equal(t, tt.pool.protocolFees.token0.ToString(), tt.wantFee0.ToString()) + uassert.Equal(t, tt.pool.protocolFees.token1.ToString(), tt.wantFee1.ToString()) + }) + } +} + +func TestTransferAndVerify(t *testing.T) { + // Setup common test data + pool := &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(1000), + }, + } + + t.Run("validatePoolBalance", func(t *testing.T) { + tests := []struct { + name string + amount *u256.Uint + isToken0 bool + expectedError bool + }{ + { + name: "must success for negative amount", + amount: u256.NewUint(500), + isToken0: true, + expectedError: false, + }, + { + name: "must panic for insufficient token0 balance", + amount: u256.NewUint(1500), + isToken0: true, + expectedError: true, + }, + { + name: "must success for negative amount", + amount: u256.NewUint(500), + isToken0: false, + expectedError: false, + }, + { + name: "must panic for insufficient token1 balance", + amount: u256.NewUint(1500), + isToken0: false, + expectedError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token0 := pool.balances.token0 + token1 := pool.balances.token1 + + err := validatePoolBalance(token0, token1, tt.amount, tt.isToken0) + if err != nil { + if !tt.expectedError { + t.Errorf("unexpected error: %v", err) + } + } + }) + } + }) +} + +func TestUpdatePoolBalance(t *testing.T) { + tests := []struct { + name string + initialToken0 *u256.Uint + initialToken1 *u256.Uint + amount *u256.Uint + isToken0 bool + expectedBal *u256.Uint + expectErr bool + }{ + { + name: "normal token0 decrease", + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(300), + isToken0: true, + expectedBal: u256.NewUint(700), + expectErr: false, + }, + { + name: "normal token1 decrease", + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(500), + isToken0: false, + expectedBal: u256.NewUint(1500), + expectErr: false, + }, + { + name: "insufficient token0 balance", + initialToken0: u256.NewUint(100), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(200), + isToken0: true, + expectedBal: nil, + expectErr: true, + }, + { + name: "insufficient token1 balance", + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(100), + amount: u256.NewUint(200), + isToken0: false, + expectedBal: nil, + expectErr: true, + }, + { + name: "zero value handling", + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(0), + isToken0: true, + expectedBal: u256.NewUint(1000), + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pool := &Pool{ + balances: Balances{ + token0: tt.initialToken0, + token1: tt.initialToken1, + }, + } + + newBal, err := updatePoolBalance(tt.initialToken0, tt.initialToken1, tt.amount, tt.isToken0) + + if tt.expectErr { + if err == nil { + t.Errorf("%s: expected error but no error", tt.name) + } + return + } + if err != nil { + t.Errorf("%s: unexpected error: %v", tt.name, err) + return + } + + if !newBal.Eq(tt.expectedBal) { + t.Errorf("%s: balance mismatch, expected: %s, actual: %s", + tt.name, + tt.expectedBal.ToString(), + newBal.ToString(), + ) + } + }) + } +} + +func TestShouldContinueSwap(t *testing.T) { + tests := []struct { + name string + state SwapState + sqrtPriceLimitX96 *u256.Uint + expected bool + }{ + { + name: "Should continue - amount remaining and price not at limit", + state: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000"), + sqrtPriceX96: u256.MustFromDecimal("1000000"), + }, + sqrtPriceLimitX96: u256.MustFromDecimal("900000"), + expected: true, + }, + { + name: "Should stop - no amount remaining", + state: SwapState{ + amountSpecifiedRemaining: i256.Zero(), + sqrtPriceX96: u256.MustFromDecimal("1000000"), + }, + sqrtPriceLimitX96: u256.MustFromDecimal("900000"), + expected: false, + }, + { + name: "Should stop - price at limit", + state: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000"), + sqrtPriceX96: u256.MustFromDecimal("900000"), + }, + sqrtPriceLimitX96: u256.MustFromDecimal("900000"), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := shouldContinueSwap(tt.state, tt.sqrtPriceLimitX96) + uassert.Equal(t, tt.expected, result) + }) + } +} + +func TestUpdateAmounts(t *testing.T) { + tests := []struct { + name string + step StepComputations + state SwapState + exactInput bool + expectedState SwapState + }{ + { + name: "Exact input update", + step: StepComputations{ + amountIn: u256.MustFromDecimal("100"), + amountOut: u256.MustFromDecimal("97"), + feeAmount: u256.MustFromDecimal("3"), + }, + state: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000"), + amountCalculated: i256.Zero(), + }, + exactInput: true, + expectedState: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("897"), // 1000 - (100 + 3) + amountCalculated: i256.MustFromDecimal("-97"), + }, + }, + { + name: "Exact output update", + step: StepComputations{ + amountIn: u256.MustFromDecimal("100"), + amountOut: u256.MustFromDecimal("97"), + feeAmount: u256.MustFromDecimal("3"), + }, + state: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("-1000"), + amountCalculated: i256.Zero(), + }, + exactInput: false, + expectedState: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("-903"), // -1000 + 97 + amountCalculated: i256.MustFromDecimal("103"), // 100 + 3 + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := updateAmounts(tt.step, tt.state, tt.exactInput) + + uassert.True(t, tt.expectedState.amountSpecifiedRemaining.Eq(result.amountSpecifiedRemaining)) + uassert.True(t, tt.expectedState.amountCalculated.Eq(result.amountCalculated)) + }) + } +} + +func TestComputeSwap(t *testing.T) { + mockPool := &Pool{ + token0Path: "token0", + token1Path: "token1", + fee: 3000, // 0.3% + tickSpacing: 60, + slot0: Slot0{ + sqrtPriceX96: u256.MustFromDecimal("1000000000000000000"), // 1.0 + tick: 0, + feeProtocol: 0, + unlocked: true, + }, + liquidity: u256.MustFromDecimal("1000000000000000000"), // 1.0 + protocolFees: ProtocolFees{ + token0: u256.Zero(), + token1: u256.Zero(), + }, + feeGrowthGlobal0X128: u256.Zero(), + feeGrowthGlobal1X128: u256.Zero(), + tickBitmaps: make(TickBitmaps), + ticks: make(Ticks), + positions: make(Positions), + } + + wordPos, _ := tickBitmapPosition(0) + // TODO: use avl + mockPool.tickBitmaps[wordPos] = u256.NewUint(1) + + t.Run("basic swap", func(t *testing.T) { + comp := SwapComputation{ + AmountSpecified: i256.MustFromDecimal("1000000"), // 1.0 token + SqrtPriceLimitX96: u256.MustFromDecimal("1100000000000000000"), // 1.1 + ZeroForOne: true, + ExactInput: true, + InitialState: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000000"), + amountCalculated: i256.Zero(), + sqrtPriceX96: mockPool.slot0.sqrtPriceX96, + tick: mockPool.slot0.tick, + feeGrowthGlobalX128: mockPool.feeGrowthGlobal0X128, + protocolFee: u256.Zero(), + liquidity: mockPool.liquidity, + }, + Cache: SwapCache{ + feeProtocol: 0, + liquidityStart: mockPool.liquidity, + }, + } + + result, err := computeSwap(mockPool, comp) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if result.Amount0.IsZero() { + t.Error("expected non-zero amount0") + } + if result.Amount1.IsZero() { + t.Error("expected non-zero amount1") + } + if result.SwapFee.IsZero() { + t.Error("expected non-zero swap fee") + } + }) + + t.Run("swap with zero liquidity", func(t *testing.T) { + mockPoolZeroLiq := *mockPool + mockPoolZeroLiq.liquidity = u256.Zero() + + comp := SwapComputation{ + AmountSpecified: i256.MustFromDecimal("1000000"), + SqrtPriceLimitX96: u256.MustFromDecimal("1100000000000000000"), + ZeroForOne: true, + ExactInput: true, + InitialState: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000000"), + amountCalculated: i256.Zero(), + sqrtPriceX96: mockPoolZeroLiq.slot0.sqrtPriceX96, + tick: mockPoolZeroLiq.slot0.tick, + feeGrowthGlobalX128: mockPoolZeroLiq.feeGrowthGlobal0X128, + protocolFee: u256.Zero(), + liquidity: mockPoolZeroLiq.liquidity, + }, + Cache: SwapCache{ + feeProtocol: 0, + liquidityStart: mockPoolZeroLiq.liquidity, + }, + } + + result, err := computeSwap(&mockPoolZeroLiq, comp) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if !result.Amount0.IsZero() || !result.Amount1.IsZero() { + t.Error("expected zero amounts for zero liquidity") + } + }) +} + +func TestTransferFromAndVerify(t *testing.T) { + tests := []struct { + name string + pool *Pool + from std.Address + to std.Address + tokenPath string + amount *i256.Int + isToken0 bool + expectedBal0 *u256.Uint + expectedBal1 *u256.Uint + }{ + { + name: "normal token0 transfer", + pool: &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + from: testutils.TestAddress("from_addr"), + to: testutils.TestAddress("to_addr"), + tokenPath: "token0_path", + amount: i256.NewInt(500), + isToken0: true, + expectedBal0: u256.NewUint(1500), // 1000 + 500 + expectedBal1: u256.NewUint(2000), // unchanged + }, + { + name: "normal token1 transfer", + pool: &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + from: testutils.TestAddress("from_addr"), + to: testutils.TestAddress("to_addr"), + tokenPath: "token1_path", + amount: i256.NewInt(800), + isToken0: false, + expectedBal0: u256.NewUint(1000), // unchanged + expectedBal1: u256.NewUint(2800), // 2000 + 800 + }, + { + name: "zero value transfer", + pool: &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + from: testutils.TestAddress("from_addr"), + to: testutils.TestAddress("to_addr"), + tokenPath: "token0_path", + amount: i256.NewInt(0), + isToken0: true, + expectedBal0: u256.NewUint(1000), // unchanged + expectedBal1: u256.NewUint(2000), // unchanged + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // mock transferFromByRegisterCall + oldTransferFromByRegisterCall := transferFromByRegisterCall + defer func() { transferFromByRegisterCall = oldTransferFromByRegisterCall }() + + transferFromByRegisterCall = func(tokenPath string, from, to std.Address, amount uint64) bool { + // mock the transfer (just return true) + return true + } + + tt.pool.transferFromAndVerify(tt.from, tt.to, tt.tokenPath, u256.MustFromDecimal(tt.amount.ToString()), tt.isToken0) + + if !tt.pool.balances.token0.Eq(tt.expectedBal0) { + t.Errorf("token0 balance mismatch: expected %s, got %s", + tt.expectedBal0.ToString(), + tt.pool.balances.token0.ToString()) + } + + if !tt.pool.balances.token1.Eq(tt.expectedBal1) { + t.Errorf("token1 balance mismatch: expected %s, got %s", + tt.expectedBal1.ToString(), + tt.pool.balances.token1.ToString()) + } + }) + } + + t.Run("negative value handling", func(t *testing.T) { + pool := &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + } + + oldTransferFromByRegisterCall := transferFromByRegisterCall + defer func() { transferFromByRegisterCall = oldTransferFromByRegisterCall }() + + transferFromByRegisterCall = func(tokenPath string, from, to std.Address, amount uint64) bool { + return true + } + + negativeAmount := i256.NewInt(-500) + pool.transferFromAndVerify( + testutils.TestAddress("from_addr"), + testutils.TestAddress("to_addr"), + "token0_path", + u256.MustFromDecimal(negativeAmount.Abs().ToString()), + true, + ) + + expectedBal := u256.NewUint(1500) // 1000 + 500 (absolute value) + if !pool.balances.token0.Eq(expectedBal) { + t.Errorf("negative amount handling failed: expected %s, got %s", + expectedBal.ToString(), + pool.balances.token0.ToString()) + } + }) + + t.Run("uint64 overflow value", func(t *testing.T) { + pool := &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + } + + hugeAmount := i256.FromUint256(u256.MustFromDecimal("18446744073709551616")) // 2^64 + + defer func() { + if r := recover(); r == nil { + t.Error("expected panic for amount exceeding uint64 range") + } + }() + + pool.transferFromAndVerify( + testutils.TestAddress("from_addr"), + testutils.TestAddress("to_addr"), + "token0_path", + u256.MustFromDecimal(hugeAmount.ToString()), + true, + ) + }) +} diff --git a/pool/position_modify.gno b/pool/position_modify.gno index 00fe51be1..0400b126a 100644 --- a/pool/position_modify.gno +++ b/pool/position_modify.gno @@ -10,16 +10,16 @@ import ( // modifyPosition updates a position in the pool and calculates the amount of tokens to be added or removed. // Returns positionInfo, amount0, amount1 -func (pool *Pool) modifyPosition(params ModifyPositionParams) (PositionInfo, *i256.Int, *i256.Int) { +func (pool *Pool) modifyPosition(params ModifyPositionParams) (PositionInfo, *u256.Uint, *u256.Uint) { position := pool.updatePosition(params) liqDelta := params.liquidityDelta - amount0, amount1 := i256.Zero(), i256.Zero() - if liqDelta.IsZero() { - return position, amount0, amount1 + return position, u256.Zero(), u256.Zero() } + amount0, amount1 := i256.Zero(), i256.Zero() + tick := pool.slot0.tick sqrtRatioLower := common.TickMathGetSqrtRatioAtTick(params.tickLower) sqrtRatioUpper := common.TickMathGetSqrtRatioAtTick(params.tickUpper) @@ -42,7 +42,7 @@ func (pool *Pool) modifyPosition(params ModifyPositionParams) (PositionInfo, *i2 amount1 = calculateToken1Amount(sqrtRatioLower, sqrtRatioUpper, liqDelta) } - return position, amount0, amount1 + return position, amount0.Abs(), amount1.Abs() } func calculateToken0Amount(sqrtPriceLower, sqrtPriceUpper *u256.Uint, liquidityDelta *i256.Int) *i256.Int { diff --git a/pool/position_modify_test.gno b/pool/position_modify_test.gno index 8f6609261..dec483e61 100644 --- a/pool/position_modify_test.gno +++ b/pool/position_modify_test.gno @@ -4,13 +4,12 @@ import ( "testing" "gno.land/p/demo/uassert" - u256 "gno.land/p/gnoswap/uint256" i256 "gno.land/p/gnoswap/int256" - "gno.land/r/gnoswap/v1/consts" + u256 "gno.land/p/gnoswap/uint256" "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/consts" ) - func TestModifyPosition(t *testing.T) { const ( fee500 = uint32(500) @@ -18,34 +17,34 @@ func TestModifyPosition(t *testing.T) { ) tests := []struct { - name string - sqrtPrice string - tickLower int32 - tickUpper int32 + name string + sqrtPrice string + tickLower int32 + tickUpper int32 expectedAmt0 string expectedAmt1 string }{ { - name: "current price is lower than range", - sqrtPrice: common.TickMathGetSqrtRatioAtTick(-12000).ToString(), - tickLower: -11000, - tickUpper: -9000, + name: "current price is lower than range", + sqrtPrice: common.TickMathGetSqrtRatioAtTick(-12000).ToString(), + tickLower: -11000, + tickUpper: -9000, expectedAmt0: "16492846", expectedAmt1: "0", }, { - name: "current price is in range", - sqrtPrice: common.TickMathGetSqrtRatioAtTick(-10000).ToString(), - tickLower: -11000, - tickUpper: -9000, + name: "current price is in range", + sqrtPrice: common.TickMathGetSqrtRatioAtTick(-10000).ToString(), + tickLower: -11000, + tickUpper: -9000, expectedAmt0: "8040316", expectedAmt1: "2958015", }, { - name: "current price is higher than range", - sqrtPrice: common.TickMathGetSqrtRatioAtTick(-8000).ToString(), - tickLower: -11000, - tickUpper: -9000, + name: "current price is higher than range", + sqrtPrice: common.TickMathGetSqrtRatioAtTick(-8000).ToString(), + tickLower: -11000, + tickUpper: -9000, expectedAmt0: "0", expectedAmt1: "6067683", }, @@ -58,7 +57,7 @@ func TestModifyPosition(t *testing.T) { barPath, fooPath, fee500, - tt.sqrtPrice, + sqrtPrice.ToString(), ) pool := newPool(poolParams) @@ -106,7 +105,7 @@ func TestModifyPositionEdgeCases(t *testing.T) { } } }() - + pool.modifyPosition(params) }) diff --git a/pool/position_update.gno b/pool/position_update.gno index 48d870d1b..5f21e49f7 100644 --- a/pool/position_update.gno +++ b/pool/position_update.gno @@ -5,8 +5,8 @@ import ( ) func (pool *Pool) updatePosition(positionParams ModifyPositionParams) PositionInfo { - _feeGrowthGlobal0X128 := u256.MustFromDecimal(pool.feeGrowthGlobal0X128.ToString()) - _feeGrowthGlobal1X128 := u256.MustFromDecimal(pool.feeGrowthGlobal1X128.ToString()) + feeGrowthGlobal0X128 := pool.feeGrowthGlobal0X128.Clone() + feeGrowthGlobal1X128 := pool.feeGrowthGlobal1X128.Clone() var flippedLower, flippedUpper bool if !(positionParams.liquidityDelta.IsZero()) { @@ -14,8 +14,8 @@ func (pool *Pool) updatePosition(positionParams ModifyPositionParams) PositionIn positionParams.tickLower, pool.slot0.tick, positionParams.liquidityDelta, - _feeGrowthGlobal0X128, - _feeGrowthGlobal1X128, + feeGrowthGlobal0X128, + feeGrowthGlobal1X128, false, pool.maxLiquidityPerTick, ) @@ -24,8 +24,8 @@ func (pool *Pool) updatePosition(positionParams ModifyPositionParams) PositionIn positionParams.tickUpper, pool.slot0.tick, positionParams.liquidityDelta, - _feeGrowthGlobal0X128, - _feeGrowthGlobal1X128, + feeGrowthGlobal0X128, + feeGrowthGlobal1X128, true, pool.maxLiquidityPerTick, ) @@ -43,8 +43,8 @@ func (pool *Pool) updatePosition(positionParams ModifyPositionParams) PositionIn positionParams.tickLower, positionParams.tickUpper, pool.slot0.tick, - _feeGrowthGlobal0X128, - _feeGrowthGlobal1X128, + feeGrowthGlobal0X128, + feeGrowthGlobal1X128, ) positionKey := positionGetKey(positionParams.owner, positionParams.tickLower, positionParams.tickUpper) diff --git a/pool/position_update_test.gno b/pool/position_update_test.gno new file mode 100644 index 000000000..24ca4cb55 --- /dev/null +++ b/pool/position_update_test.gno @@ -0,0 +1,85 @@ +package pool + +import ( + "testing" + + "std" + + "gno.land/p/demo/uassert" + + "gno.land/r/gnoswap/v1/consts" + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +func TestUpdatePosition(t *testing.T) { + poolParams := &createPoolParams{ + token0Path: "token0", + token1Path: "token1", + fee: 500, + tickSpacing: 10, + sqrtPriceX96: u256.MustFromDecimal("1000000000000000000"), // 1.0 + } + p := newPool(poolParams) + + tests := []struct { + name string + positionParams ModifyPositionParams + expectLiquidity *u256.Uint + }{ + { + name: "add new position", + positionParams: ModifyPositionParams{ + owner: consts.POSITION_ADDR, + tickLower: -100, + tickUpper: 100, + liquidityDelta: i256.MustFromDecimal("1000000"), + }, + expectLiquidity: u256.MustFromDecimal("1000000"), + }, + { + name: "add liquidity to existing position", + positionParams: ModifyPositionParams{ + owner: consts.POSITION_ADDR, + tickLower: -100, + tickUpper: 100, + liquidityDelta: i256.MustFromDecimal("500000"), + }, + expectLiquidity: u256.MustFromDecimal("1500000"), + }, + { + name: "remove liquidity from position", + positionParams: ModifyPositionParams{ + owner: consts.POSITION_ADDR, + tickLower: -100, + tickUpper: 100, + liquidityDelta: i256.MustFromDecimal("-500000"), + }, + expectLiquidity: u256.MustFromDecimal("1000000"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + position := p.updatePosition(tt.positionParams) + + if !position.liquidity.Eq(tt.expectLiquidity) { + t.Errorf("liquidity mismatch: expected %s, got %s", + tt.expectLiquidity.ToString(), + position.liquidity.ToString()) + } + + if !tt.positionParams.liquidityDelta.IsZero() { + lowerTick := p.ticks[tt.positionParams.tickLower] + upperTick := p.ticks[tt.positionParams.tickUpper] + + if !lowerTick.initialized { + t.Error("lower tick not initialized") + } + if !upperTick.initialized { + t.Error("upper tick not initialized") + } + } + }) + } +} diff --git a/pool/token_register.gno b/pool/token_register.gno index 8c730bec7..419b2ee49 100644 --- a/pool/token_register.gno +++ b/pool/token_register.gno @@ -60,12 +60,18 @@ func RegisterGRC20Interface(pkgPath string, igrc20 GRC20Interface) { // UnregisterGRC20Interface unregisters a GRC20 token interface func UnregisterGRC20Interface(pkgPath string) { if err := common.SatisfyCond(isUserCall()); err != nil { - panic(err) + panic(addDetailToError( + errNoPermission, + ufmt.Sprintf("token_register.gno__UnregisterGRC20Interface() || unauthorized address(%s) to unregister", std.PrevRealm().Addr()), + )) } caller := std.PrevRealm().Addr() if err := common.TokenRegisterOnly(caller); err != nil { - panic(err) + panic(addDetailToError( + errNoPermission, + ufmt.Sprintf("token_register.gno__UnregisterGRC20Interface() || unauthorized address(%s) to unregister", caller), + )) } pkgPath = handleNative(pkgPath) diff --git a/pool/type.gno b/pool/type.gno index de92f0a1f..d2a6c5197 100644 --- a/pool/type.gno +++ b/pool/type.gno @@ -4,6 +4,7 @@ import ( "std" "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/consts" i256 "gno.land/p/gnoswap/int256" u256 "gno.land/p/gnoswap/uint256" @@ -66,6 +67,20 @@ type ModifyPositionParams struct { liquidityDelta *i256.Int // any change in liquidity } +func newModifyPositionParams( + owner std.Address, + tickLower int32, + tickUpper int32, + liquidityDelta *i256.Int, +) ModifyPositionParams { + return ModifyPositionParams{ + owner: owner, + tickLower: tickLower, + tickUpper: tickUpper, + liquidityDelta: liquidityDelta, + } +} + type SwapCache struct { feeProtocol uint8 // protocol fee for the input token liquidityStart *u256.Uint // liquidity at the beginning of the swap @@ -91,13 +106,12 @@ type SwapState struct { liquidity *u256.Uint // current liquidity in range } -func (pool *Pool) newSwapState( +func newSwapState( amountSpecifiedRemaining *i256.Int, feeGrowthGlobalX128 *u256.Uint, liquidity *u256.Uint, + slot0 Slot0, ) SwapState { - slot0 := pool.slot0 - return SwapState{ amountSpecifiedRemaining: amountSpecifiedRemaining, amountCalculated: i256.Zero(), @@ -109,6 +123,22 @@ func (pool *Pool) newSwapState( } } +func (s *SwapState) SetSqrtPriceX96(sqrtPriceX96 string) { + s.sqrtPriceX96 = u256.MustFromDecimal(sqrtPriceX96) +} + +func (s *SwapState) SetTick(tick int32) { + s.tick = tick +} + +func (s *SwapState) SetFeeGrowthGlobalX128(feeGrowthGlobalX128 *u256.Uint) { + s.feeGrowthGlobalX128 = feeGrowthGlobalX128 +} + +func (s *SwapState) SetProtocolFee(fee *u256.Uint) { + s.protocolFee = fee +} + type StepComputations struct { sqrtPriceStartX96 *u256.Uint // price at the beginning of the step tickNext int32 // next tick to swap to from the current tick in the swap direction @@ -119,6 +149,32 @@ type StepComputations struct { feeAmount *u256.Uint // how much fee is being paid in this step } +// init initializes the computation for a single swap step +func (step *StepComputations) initSwapStep(state SwapState, pool *Pool, zeroForOne bool) { + step.sqrtPriceStartX96 = state.sqrtPriceX96 + step.tickNext, step.initialized = pool.tickBitmapNextInitializedTickWithInOneWord( + state.tick, + pool.tickSpacing, + zeroForOne, + ) + + // prevent overshoot the min/max tick + step.clampTickNext() + + // get the price for the next tick + step.sqrtPriceNextX96 = common.TickMathGetSqrtRatioAtTick(step.tickNext) +} + +// clampTickNext ensures that `tickNext` stays within the min, max tick boundaries +// as the tick bitmap is not aware of these bounds +func (step *StepComputations) clampTickNext() { + if step.tickNext < consts.MIN_TICK { + step.tickNext = consts.MIN_TICK + } else if step.tickNext > consts.MAX_TICK { + step.tickNext = consts.MAX_TICK + } +} + type PositionInfo struct { liquidity *u256.Uint // amount of liquidity owned by this position