From 1f5a147d1acf1d3cf41c850dcdfb245f13782114 Mon Sep 17 00:00:00 2001 From: n3wbie Date: Tue, 10 Dec 2024 15:47:26 +0900 Subject: [PATCH 1/9] GSW-1838 test: swap_math unitest --- .../p/gnoswap/pool/__TEST_swap_math_test.gnoA | 202 ----------------- _deploy/p/gnoswap/pool/swap_math_test.gno | 206 ++++++++++++++++++ 2 files changed, 206 insertions(+), 202 deletions(-) delete mode 100644 _deploy/p/gnoswap/pool/__TEST_swap_math_test.gnoA create mode 100644 _deploy/p/gnoswap/pool/swap_math_test.gno diff --git a/_deploy/p/gnoswap/pool/__TEST_swap_math_test.gnoA b/_deploy/p/gnoswap/pool/__TEST_swap_math_test.gnoA deleted file mode 100644 index b34032bb6..000000000 --- a/_deploy/p/gnoswap/pool/__TEST_swap_math_test.gnoA +++ /dev/null @@ -1,202 +0,0 @@ -package pool - -import ( - "testing" - - i256 "gno.land/p/gnoswap/int256" - u256 "gno.land/p/gnoswap/uint256" -) - -var ( - amountIn_i256 *i256.Int - amountOut *u256.Uint - zeroForOne bool - expected *u256.Uint - rst bool - price *u256.Uint - priceTarget *u256.Uint - liquidity *u256.Uint - fee uint64 - amountIn_String string - amountOut_String string - sqrtQ_String string - feeAmount_String string -) - -func TestSwapMathComputeSwapStepStr_1(t *testing.T) { - var amountIn_feeAmount *u256.Uint - var priceAfterWholeInputAmount *u256.Uint - // exact amount in that gets capped at price target in one for zero - - price = u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - priceTarget = u256.MustFromDecimal("79623317895830914510639640423") // encodePriceSqrt(101,100) = 79623317895830914510639640423 - liquidity = u256.MustFromDecimal("2000000000000000000") // 2e18 - amountIn_i256 = i256.MustFromDecimal("1000000000000000000") // 1e18 - fee = 600 - zeroForOne = false - - sqrtQ_String, amountIn_String, amountOut_String, feeAmount_String := SwapMathComputeSwapStepStr(price, priceTarget, liquidity, amountIn_i256, fee) - - shouldEQ(t, amountIn_String, "9975124224178055") - shouldEQ(t, feeAmount_String, "5988667735148") - shouldEQ(t, amountOut_String, "9925619580021728") - amountIn_feeAmount = u256.MustFromDecimal(amountIn_String) - amountIn_feeAmount.Add(amountIn_feeAmount, u256.MustFromDecimal(feeAmount_String)) - - if amountIn_feeAmount.Cmp(u256.MustFromDecimal("1000000000000000000")) > 0 { - t.Errorf("entire amount is not used") - } - - priceAfterWholeInputAmount = sqrtPriceMathGetNextSqrtPriceFromInput(price, liquidity, u256.MustFromDecimal("1000000000000000000"), zeroForOne) - - shouldEQ(t, sqrtQ_String, priceTarget.ToString()) - if u256.MustFromDecimal(sqrtQ_String).Cmp(priceAfterWholeInputAmount) > 0 { - t.Errorf("price is less than price after whole input amount") - } -} - -func TestSwapMathComputeSwapStepStr_2(t *testing.T) { - var priceAfterWholeInputAmount *u256.Uint - // exact amount out that gets capped at price target in one for zero - - price = u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - priceTarget = u256.MustFromDecimal("79623317895830914510639640423") // encodePriceSqrt(101,100) = 79623317895830914510639640423 - liquidity = u256.MustFromDecimal("2000000000000000000") // 2e18 - amountIn_i256 = i256.MustFromDecimal("-1000000000000000000") // -1e18 - fee = 600 - zeroForOne = false - - sqrtQ_String, amountIn_String, amountOut_String, feeAmount_String := SwapMathComputeSwapStepStr(price, priceTarget, liquidity, amountIn_i256, fee) - - shouldEQ(t, amountIn_String, "9975124224178055") - shouldEQ(t, feeAmount_String, "5988667735148") - shouldEQ(t, amountOut_String, "9925619580021728") - - if u256.MustFromDecimal(amountOut_String).Cmp(u256.MustFromDecimal("1000000000000000000")) >= 0 { - t.Errorf("entire amount out is not returned") - } - - priceAfterWholeInputAmount = sqrtPriceMathGetNextSqrtPriceFromInput(price, liquidity, u256.MustFromDecimal("1000000000000000000"), zeroForOne) - shouldEQ(t, sqrtQ_String, priceTarget.ToString()) - if u256.MustFromDecimal(sqrtQ_String).Cmp(priceAfterWholeInputAmount) > 0 { - t.Errorf("price is less than price after whole output amount") - } -} - -func TestSwapMathComputeSwapStepStr_3(t *testing.T) { - var amountIn_feeAmount *u256.Uint - var priceAfterWholeInputAmount *u256.Uint - // exact amount in that is fully spent in one for zero - - price = u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - priceTarget = u256.MustFromDecimal("792281625142643375935439503360") // encodePriceSqrt(1000,100) = 792281625142643375935439503360 - liquidity = u256.MustFromDecimal("2000000000000000000") // 2e18 - amountIn_i256 = i256.MustFromDecimal("1000000000000000000") // 1e18 - fee = 600 - zeroForOne = false - - sqrtQ_String, amountIn_String, amountOut_String, feeAmount_String := SwapMathComputeSwapStepStr(price, priceTarget, liquidity, amountIn_i256, fee) - - shouldEQ(t, amountIn_String, "999400000000000000") - shouldEQ(t, feeAmount_String, "600000000000000") - shouldEQ(t, amountOut_String, "666399946655997866") - amountIn_feeAmount = u256.MustFromDecimal(amountIn_String) - amountIn_feeAmount.Add(amountIn_feeAmount, u256.MustFromDecimal(feeAmount_String)) - shouldEQ(t, amountIn_feeAmount.ToString(), "1000000000000000000") - - priceAfterWholeInputAmount = sqrtPriceMathGetNextSqrtPriceFromInput(price, liquidity, u256.MustFromDecimal("999400000000000000"), zeroForOne) - shouldEQ(t, sqrtQ_String, priceAfterWholeInputAmount.ToString()) - if u256.MustFromDecimal(sqrtQ_String).Cmp(priceTarget) > 0 { - t.Errorf("price does not reach price target") - } -} - -func TestSwapMathComputeSwapStepStr_4(t *testing.T) { - // amount out is capped at the desired amount out - - price = u256.MustFromDecimal("417332158212080721273783715441582") - priceTarget = u256.MustFromDecimal("1452870262520218020823638996") - liquidity = u256.MustFromDecimal("159344665391607089467575320103") - amountIn_i256 = i256.MustFromDecimal("-1") - fee = 1 - - sqrtQ_String, amountIn_String, amountOut_String, feeAmount_String := SwapMathComputeSwapStepStr(price, priceTarget, liquidity, amountIn_i256, fee) - - shouldEQ(t, sqrtQ_String, "417332158212080721273783715441581") - shouldEQ(t, amountIn_String, "1") - shouldEQ(t, feeAmount_String, "1") - shouldEQ(t, amountOut_String, "1") -} - -func TestSwapMathComputeSwapStepStr_5(t *testing.T) { - var amountIn_feeAmount *u256.Uint - // target price of 1 uses partial input amount - - price = u256.MustFromDecimal("2") - priceTarget = u256.MustFromDecimal("1") - liquidity = u256.MustFromDecimal("1") - amountIn_i256 = i256.MustFromDecimal("3915081100057732413702495386755767") - fee = 1 - - sqrtQ_String, amountIn_String, amountOut_String, feeAmount_String := SwapMathComputeSwapStepStr(price, priceTarget, liquidity, amountIn_i256, fee) - - shouldEQ(t, sqrtQ_String, "1") - shouldEQ(t, feeAmount_String, "39614120871253040049813") - shouldEQ(t, amountOut_String, "0") - shouldEQ(t, amountIn_String, "39614081257132168796771975168") - amountIn_feeAmount = u256.MustFromDecimal(amountIn_String) - amountIn_feeAmount.Add(amountIn_feeAmount, u256.MustFromDecimal(feeAmount_String)) - - if amountIn_feeAmount.Cmp(u256.MustFromDecimal("3915081100057732413702495386755767")) >= 0 { - t.Errorf("amountIn+feeAmount should be less than or eq to 3915081100057732413702495386755767") - } -} - -func TestSwapMathComputeSwapStepStr_6(t *testing.T) { - // entire input amount taken as fee - price = u256.MustFromDecimal("2413") - priceTarget = u256.MustFromDecimal("79887613182836312") - liquidity = u256.MustFromDecimal("1985041575832132834610021537970") - amountIn_i256 = i256.MustFromDecimal("10") - fee = 1872 - - sqrtQ_String, amountIn_String, amountOut_String, feeAmount_String := SwapMathComputeSwapStepStr(price, priceTarget, liquidity, amountIn_i256, fee) - - shouldEQ(t, amountIn_String, "0") - shouldEQ(t, feeAmount_String, "10") - shouldEQ(t, amountOut_String, "0") - shouldEQ(t, sqrtQ_String, "2413") -} - -func TestSwapMathComputeSwapStepStr_7(t *testing.T) { - // handles intermediate insufficient liquidity in zero for one exact output case - - price = u256.MustFromDecimal("20282409603651670423947251286016") - priceTarget = u256.MulDiv(price, u256.NewUint(11), u256.NewUint(10)) - liquidity = u256.MustFromDecimal("1024") - amountIn_i256 = i256.MustFromDecimal("-4") - fee = 3000 - - sqrtQ_String, amountIn_String, amountOut_String, feeAmount_String := SwapMathComputeSwapStepStr(price, priceTarget, liquidity, amountIn_i256, fee) - - shouldEQ(t, amountOut_String, "0") - shouldEQ(t, sqrtQ_String, priceTarget.ToString()) - shouldEQ(t, amountIn_String, "26215") - shouldEQ(t, feeAmount_String, "79") -} - -func TestSwapMathComputeSwapStepStr_8(t *testing.T) { - // handles intermediate insufficient liquidity in one for zero exact output case - price = u256.MustFromDecimal("20282409603651670423947251286016") - priceTarget = u256.MulDiv(price, u256.NewUint(9), u256.NewUint(10)) - liquidity = u256.MustFromDecimal("1024") - amountIn_i256 = i256.MustFromDecimal("-263000") - fee = 3000 - - sqrtQ_String, amountIn_String, amountOut_String, feeAmount_String := SwapMathComputeSwapStepStr(price, priceTarget, liquidity, amountIn_i256, fee) - - shouldEQ(t, amountOut_String, "26214") - shouldEQ(t, sqrtQ_String, priceTarget.ToString()) - shouldEQ(t, amountIn_String, "1") - shouldEQ(t, feeAmount_String, "1") -} diff --git a/_deploy/p/gnoswap/pool/swap_math_test.gno b/_deploy/p/gnoswap/pool/swap_math_test.gno new file mode 100644 index 000000000..356664824 --- /dev/null +++ b/_deploy/p/gnoswap/pool/swap_math_test.gno @@ -0,0 +1,206 @@ +package pool + +import ( + "testing" + + "gno.land/p/demo/uassert" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +func TestSwapMathComputeSwapStepStr(t *testing.T) { + tests := []struct { + name string + currentX96, targetX96 *u256.Uint + liquidity *u256.Uint + amountRemaining *i256.Int + feePips uint64 + sqrtNextX96 *u256.Uint + chkSqrtNextX96 func(sqrtRatioNextX96, priceTarget *u256.Uint) + amountIn, amountOut, feeAmount string + }{ + { + name: "exact amount in that gets capped at price target in one for zero", + currentX96: encodePriceSqrt("1", "1"), + targetX96: encodePriceSqrt("101", "100"), + liquidity: u256.MustFromDecimal("2000000000000000000"), + amountRemaining: i256.MustFromDecimal("1000000000000000000"), + feePips: 600, + sqrtNextX96: encodePriceSqrt("101", "100"), + chkSqrtNextX96: func(sqrtRatioNextX96, priceTarget *u256.Uint) { + uassert.True(t, sqrtRatioNextX96.Eq(priceTarget)) + }, + amountIn: "9975124224178055", + amountOut: "9925619580021728", + feeAmount: "5988667735148", + }, + { + name: "exact amount out that gets capped at price target in one for zero", + currentX96: encodePriceSqrt("1", "1"), + targetX96: encodePriceSqrt("101", "100"), + liquidity: u256.MustFromDecimal("2000000000000000000"), + amountRemaining: i256.MustFromDecimal("-1000000000000000000"), + feePips: 600, + sqrtNextX96: encodePriceSqrt("101", "100"), + chkSqrtNextX96: func(sqrtRatioNextX96, priceTarget *u256.Uint) { + uassert.True(t, sqrtRatioNextX96.Eq(priceTarget)) + }, + amountIn: "9975124224178055", + amountOut: "9925619580021728", + feeAmount: "5988667735148", + }, + { + name: "exact amount in that is fully spent in one for zero", + currentX96: encodePriceSqrt("1", "1"), + targetX96: encodePriceSqrt("1000", "100"), + liquidity: u256.MustFromDecimal("2000000000000000000"), + amountRemaining: i256.MustFromDecimal("1000000000000000000"), + sqrtNextX96: encodePriceSqrt("1000", "100"), + feePips: 600, + chkSqrtNextX96: func(sqrtRatioNextX96, priceTarget *u256.Uint) { + uassert.True(t, sqrtRatioNextX96.Lte(priceTarget)) + }, + amountIn: "999400000000000000", + amountOut: "666399946655997866", + feeAmount: "600000000000000", + }, + { + name: "exact amount out that is fully received in one for zero", + currentX96: encodePriceSqrt("1", "1"), + targetX96: encodePriceSqrt("1000", "100"), + liquidity: u256.MustFromDecimal("2000000000000000000"), + amountRemaining: i256.MustFromDecimal("-1000000000000000000"), + feePips: 600, + sqrtNextX96: encodePriceSqrt("1000", "100"), + chkSqrtNextX96: func(sqrtRatioNextX96, priceTarget *u256.Uint) { + uassert.True(t, sqrtRatioNextX96.Lt(priceTarget)) + }, + amountIn: "2000000000000000000", + amountOut: "1000000000000000000", + feeAmount: "1200720432259356", + }, + { + name: "amount out is capped at the desired amount out", + currentX96: u256.MustFromDecimal("417332158212080721273783715441582"), + targetX96: u256.MustFromDecimal("1452870262520218020823638996"), + liquidity: u256.MustFromDecimal("159344665391607089467575320103"), + amountRemaining: i256.MustFromDecimal("-1"), + feePips: 1, + sqrtNextX96: u256.MustFromDecimal("417332158212080721273783715441581"), + chkSqrtNextX96: func(sqrtRatioNextX96, priceTarget *u256.Uint) { + uassert.True(t, sqrtRatioNextX96.Eq(priceTarget)) + }, + amountIn: "1", + amountOut: "1", + feeAmount: "1", + }, + { + name: "target price of 1 uses partial input amount", + currentX96: u256.MustFromDecimal("2"), + targetX96: u256.MustFromDecimal("1"), + liquidity: u256.MustFromDecimal("1"), + amountRemaining: i256.MustFromDecimal("3915081100057732413702495386755767"), + feePips: 1, + sqrtNextX96: u256.MustFromDecimal("1"), + chkSqrtNextX96: func(sqrtRatioNextX96, priceTarget *u256.Uint) { + uassert.True(t, sqrtRatioNextX96.Eq(priceTarget)) + }, + amountIn: "39614081257132168796771975168", + amountOut: "0", + feeAmount: "39614120871253040049813", + }, + { + name: "entire input amount taken as fee", + currentX96: u256.MustFromDecimal("2413"), + targetX96: u256.MustFromDecimal("79887613182836312"), + liquidity: u256.MustFromDecimal("1985041575832132834610021537970"), + amountRemaining: i256.MustFromDecimal("10"), + feePips: 1872, + sqrtNextX96: u256.MustFromDecimal("2413"), + chkSqrtNextX96: func(sqrtRatioNextX96, priceTarget *u256.Uint) { + uassert.True(t, sqrtRatioNextX96.Eq(priceTarget)) + }, + amountIn: "0", + amountOut: "0", + feeAmount: "10", + }, + { + name: "handles intermediate insufficient liquidity in zero for one exact output case", + currentX96: u256.MustFromDecimal("20282409603651670423947251286016"), + targetX96: u256.MustFromDecimal("22310650564016837466341976414617"), + liquidity: u256.MustFromDecimal("1024"), + amountRemaining: i256.MustFromDecimal("-4"), + feePips: 3000, + sqrtNextX96: u256.MustFromDecimal("22310650564016837466341976414617"), + chkSqrtNextX96: func(sqrtRatioNextX96, priceTarget *u256.Uint) { + uassert.True(t, sqrtRatioNextX96.Eq(priceTarget)) + }, + amountIn: "26215", + amountOut: "0", + feeAmount: "79", + }, + { + name: "handles intermediate insufficient liquidity in one for zero exact output case", + currentX96: u256.MustFromDecimal("20282409603651670423947251286016"), + targetX96: u256.MustFromDecimal("18254168643286503381552526157414"), + liquidity: u256.MustFromDecimal("1024"), + amountRemaining: i256.MustFromDecimal("-263000"), + feePips: 3000, + sqrtNextX96: u256.MustFromDecimal("18254168643286503381552526157414"), + chkSqrtNextX96: func(sqrtRatioNextX96, priceTarget *u256.Uint) { + uassert.True(t, sqrtRatioNextX96.Eq(priceTarget)) + }, + amountIn: "1", + amountOut: "26214", + feeAmount: "1", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + sqrtRatioNextX96, amountIn, amountOut, feeAmount := SwapMathComputeSwapStepStr(test.currentX96, test.targetX96, test.liquidity, test.amountRemaining, test.feePips) + test.chkSqrtNextX96(u256.MustFromDecimal(sqrtRatioNextX96), test.sqrtNextX96) + uassert.Equal(t, amountIn, test.amountIn) + uassert.Equal(t, amountOut, test.amountOut) + uassert.Equal(t, feeAmount, test.feeAmount) + }) + } +} + +// encodePriceSqrt calculates the sqrt((reserve1 << 192) / reserve0) +func encodePriceSqrt(reserve1, reserve0 string) *u256.Uint { + reserve1Uint := u256.MustFromDecimal(reserve1) + reserve0Uint := u256.MustFromDecimal(reserve0) + + if reserve0Uint.IsZero() { + panic("division by zero") + } + + // numerator = reserve1 * (2^192) + two192 := new(u256.Uint).Lsh(u256.NewUint(1), 192) + numerator := new(u256.Uint).Mul(reserve1Uint, two192) + + // ratioX192 = numerator / reserve0 + ratioX192 := new(u256.Uint).Div(numerator, reserve0Uint) + + // Return sqrt(ratioX192) + return sqrt(ratioX192) +} + +// sqrt computes the integer square root of a u256.Uint +func sqrt(x *u256.Uint) *u256.Uint { + if x.IsZero() { + return u256.NewUint(0) + } + + z := new(u256.Uint).Set(x) + y := new(u256.Uint).Rsh(z, 1) // Initial guess is x / 2 + + for y.Cmp(z) < 0 { + z.Set(y) + temp := new(u256.Uint).Div(x, z) + y.Add(z, temp).Rsh(y, 1) + } + return z +} From 099f1cb15ad3d0863e728d01735d69aabe1471c3 Mon Sep 17 00:00:00 2001 From: n3wbie Date: Tue, 10 Dec 2024 20:05:41 +0900 Subject: [PATCH 2/9] test: define test helper function --- _deploy/p/gnoswap/pool/swap_math_test.gno | 34 +++++++++++++---------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/_deploy/p/gnoswap/pool/swap_math_test.gno b/_deploy/p/gnoswap/pool/swap_math_test.gno index 356664824..587b83909 100644 --- a/_deploy/p/gnoswap/pool/swap_math_test.gno +++ b/_deploy/p/gnoswap/pool/swap_math_test.gno @@ -22,12 +22,12 @@ func TestSwapMathComputeSwapStepStr(t *testing.T) { }{ { name: "exact amount in that gets capped at price target in one for zero", - currentX96: encodePriceSqrt("1", "1"), - targetX96: encodePriceSqrt("101", "100"), + currentX96: encodePriceSqrt(t, "1", "1"), + targetX96: encodePriceSqrt(t, "101", "100"), liquidity: u256.MustFromDecimal("2000000000000000000"), amountRemaining: i256.MustFromDecimal("1000000000000000000"), feePips: 600, - sqrtNextX96: encodePriceSqrt("101", "100"), + sqrtNextX96: encodePriceSqrt(t, "101", "100"), chkSqrtNextX96: func(sqrtRatioNextX96, priceTarget *u256.Uint) { uassert.True(t, sqrtRatioNextX96.Eq(priceTarget)) }, @@ -37,12 +37,12 @@ func TestSwapMathComputeSwapStepStr(t *testing.T) { }, { name: "exact amount out that gets capped at price target in one for zero", - currentX96: encodePriceSqrt("1", "1"), - targetX96: encodePriceSqrt("101", "100"), + currentX96: encodePriceSqrt(t, "1", "1"), + targetX96: encodePriceSqrt(t, "101", "100"), liquidity: u256.MustFromDecimal("2000000000000000000"), amountRemaining: i256.MustFromDecimal("-1000000000000000000"), feePips: 600, - sqrtNextX96: encodePriceSqrt("101", "100"), + sqrtNextX96: encodePriceSqrt(t, "101", "100"), chkSqrtNextX96: func(sqrtRatioNextX96, priceTarget *u256.Uint) { uassert.True(t, sqrtRatioNextX96.Eq(priceTarget)) }, @@ -52,11 +52,11 @@ func TestSwapMathComputeSwapStepStr(t *testing.T) { }, { name: "exact amount in that is fully spent in one for zero", - currentX96: encodePriceSqrt("1", "1"), - targetX96: encodePriceSqrt("1000", "100"), + currentX96: encodePriceSqrt(t, "1", "1"), + targetX96: encodePriceSqrt(t, "1000", "100"), liquidity: u256.MustFromDecimal("2000000000000000000"), amountRemaining: i256.MustFromDecimal("1000000000000000000"), - sqrtNextX96: encodePriceSqrt("1000", "100"), + sqrtNextX96: encodePriceSqrt(t, "1000", "100"), feePips: 600, chkSqrtNextX96: func(sqrtRatioNextX96, priceTarget *u256.Uint) { uassert.True(t, sqrtRatioNextX96.Lte(priceTarget)) @@ -67,12 +67,12 @@ func TestSwapMathComputeSwapStepStr(t *testing.T) { }, { name: "exact amount out that is fully received in one for zero", - currentX96: encodePriceSqrt("1", "1"), - targetX96: encodePriceSqrt("1000", "100"), + currentX96: encodePriceSqrt(t, "1", "1"), + targetX96: encodePriceSqrt(t, "1000", "100"), liquidity: u256.MustFromDecimal("2000000000000000000"), amountRemaining: i256.MustFromDecimal("-1000000000000000000"), feePips: 600, - sqrtNextX96: encodePriceSqrt("1000", "100"), + sqrtNextX96: encodePriceSqrt(t, "1000", "100"), chkSqrtNextX96: func(sqrtRatioNextX96, priceTarget *u256.Uint) { uassert.True(t, sqrtRatioNextX96.Lt(priceTarget)) }, @@ -169,7 +169,9 @@ func TestSwapMathComputeSwapStepStr(t *testing.T) { } // encodePriceSqrt calculates the sqrt((reserve1 << 192) / reserve0) -func encodePriceSqrt(reserve1, reserve0 string) *u256.Uint { +func encodePriceSqrt(t *testing.T, reserve1, reserve0 string) *u256.Uint { + t.Helper() + reserve1Uint := u256.MustFromDecimal(reserve1) reserve0Uint := u256.MustFromDecimal(reserve0) @@ -185,11 +187,13 @@ func encodePriceSqrt(reserve1, reserve0 string) *u256.Uint { ratioX192 := new(u256.Uint).Div(numerator, reserve0Uint) // Return sqrt(ratioX192) - return sqrt(ratioX192) + return sqrt(t, ratioX192) } // sqrt computes the integer square root of a u256.Uint -func sqrt(x *u256.Uint) *u256.Uint { +func sqrt(t *testing.T, x *u256.Uint) *u256.Uint { + t.Helper() + if x.IsZero() { return u256.NewUint(0) } From 5654380f6ea9f8dda54fd0d80cb31a317d7ec5c3 Mon Sep 17 00:00:00 2001 From: n3wbie Date: Tue, 10 Dec 2024 20:09:43 +0900 Subject: [PATCH 3/9] test: remove unnecessary initialization --- _deploy/p/gnoswap/pool/swap_math_test.gno | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/_deploy/p/gnoswap/pool/swap_math_test.gno b/_deploy/p/gnoswap/pool/swap_math_test.gno index 587b83909..7455b7e64 100644 --- a/_deploy/p/gnoswap/pool/swap_math_test.gno +++ b/_deploy/p/gnoswap/pool/swap_math_test.gno @@ -201,9 +201,10 @@ func sqrt(t *testing.T, x *u256.Uint) *u256.Uint { z := new(u256.Uint).Set(x) y := new(u256.Uint).Rsh(z, 1) // Initial guess is x / 2 + temp := new(u256.Uint) for y.Cmp(z) < 0 { z.Set(y) - temp := new(u256.Uint).Div(x, z) + temp.Div(x, z) y.Add(z, temp).Rsh(y, 1) } return z From ce2e2923ce5eb80ac2e24c6817fd8b67e3a9490b Mon Sep 17 00:00:00 2001 From: Blake <104744707+r3v4s@users.noreply.github.com> Date: Sun, 15 Dec 2024 14:10:20 +0900 Subject: [PATCH 4/9] chore: remove `Pool` prefix from pool's receiver getter (#423) --- pool/getter.gno | 58 +++++++++++++++---------------- pool/pool.gno | 34 +++++++++--------- pool/position.gno | 10 +++--- pool/tick.gno | 16 ++++----- position/_RPC_api.gno | 16 ++++----- position/liquidity_management.gno | 2 +- position/position.gno | 10 +++--- 7 files changed, 73 insertions(+), 73 deletions(-) diff --git a/pool/getter.gno b/pool/getter.gno index d71ecbf1e..f962846fa 100644 --- a/pool/getter.gno +++ b/pool/getter.gno @@ -11,119 +11,119 @@ func PoolGetPoolList() []string { } func PoolGetToken0Path(poolPath string) string { - return mustGetPool(poolPath).PoolGetToken0Path() + return mustGetPool(poolPath).GetToken0Path() } func PoolGetToken1Path(poolPath string) string { - return mustGetPool(poolPath).PoolGetToken1Path() + return mustGetPool(poolPath).GetToken1Path() } func PoolGetFee(poolPath string) uint32 { - return mustGetPool(poolPath).PoolGetFee() + return mustGetPool(poolPath).GetFee() } func PoolGetBalanceToken0(poolPath string) string { - return mustGetPool(poolPath).PoolGetBalanceToken0().ToString() + return mustGetPool(poolPath).GetBalanceToken0().ToString() } func PoolGetBalanceToken1(poolPath string) string { - return mustGetPool(poolPath).PoolGetBalanceToken1().ToString() + return mustGetPool(poolPath).GetBalanceToken1().ToString() } func PoolGetTickSpacing(poolPath string) int32 { - return mustGetPool(poolPath).PoolGetTickSpacing() + return mustGetPool(poolPath).GetTickSpacing() } func PoolGetMaxLiquidityPerTick(poolPath string) string { - return mustGetPool(poolPath).PoolGetMaxLiquidityPerTick().ToString() + return mustGetPool(poolPath).GetMaxLiquidityPerTick().ToString() } func PoolGetSlot0SqrtPriceX96(poolPath string) string { - return mustGetPool(poolPath).PoolGetSlot0SqrtPriceX96().ToString() + return mustGetPool(poolPath).GetSlot0SqrtPriceX96().ToString() } func PoolGetSlot0Tick(poolPath string) int32 { - return mustGetPool(poolPath).PoolGetSlot0Tick() + return mustGetPool(poolPath).GetSlot0Tick() } func PoolGetSlot0FeeProtocol(poolPath string) uint8 { - return mustGetPool(poolPath).PoolGetSlot0FeeProtocol() + return mustGetPool(poolPath).GetSlot0FeeProtocol() } func PoolGetSlot0Unlocked(poolPath string) bool { - return mustGetPool(poolPath).PoolGetSlot0Unlocked() + return mustGetPool(poolPath).GetSlot0Unlocked() } func PoolGetFeeGrowthGlobal0X128(poolPath string) string { - return mustGetPool(poolPath).PoolGetFeeGrowthGlobal0X128().ToString() + return mustGetPool(poolPath).GetFeeGrowthGlobal0X128().ToString() } func PoolGetFeeGrowthGlobal1X128(poolPath string) string { - return mustGetPool(poolPath).PoolGetFeeGrowthGlobal1X128().ToString() + return mustGetPool(poolPath).GetFeeGrowthGlobal1X128().ToString() } func PoolGetProtocolFeesToken0(poolPath string) string { - return mustGetPool(poolPath).PoolGetProtocolFeesToken0().ToString() + return mustGetPool(poolPath).GetProtocolFeesToken0().ToString() } func PoolGetProtocolFeesToken1(poolPath string) string { - return mustGetPool(poolPath).PoolGetProtocolFeesToken1().ToString() + return mustGetPool(poolPath).GetProtocolFeesToken1().ToString() } func PoolGetLiquidity(poolPath string) string { - return mustGetPool(poolPath).PoolGetLiquidity().ToString() + return mustGetPool(poolPath).GetLiquidity().ToString() } // position func PoolGetPositionLiquidity(poolPath, key string) string { - return mustGetPool(poolPath).PoolGetPositionLiquidity(key).ToString() + return mustGetPool(poolPath).GetPositionLiquidity(key).ToString() } func PoolGetPositionFeeGrowthInside0LastX128(poolPath, key string) string { - return mustGetPool(poolPath).PoolGetPositionFeeGrowthInside0LastX128(key).ToString() + return mustGetPool(poolPath).GetPositionFeeGrowthInside0LastX128(key).ToString() } func PoolGetPositionFeeGrowthInside1LastX128(poolPath, key string) string { - return mustGetPool(poolPath).PoolGetPositionFeeGrowthInside1LastX128(key).ToString() + return mustGetPool(poolPath).GetPositionFeeGrowthInside1LastX128(key).ToString() } func PoolGetPositionTokensOwed0(poolPath, key string) string { - return mustGetPool(poolPath).PoolGetPositionTokensOwed0(key).ToString() + return mustGetPool(poolPath).GetPositionTokensOwed0(key).ToString() } func PoolGetPositionTokensOwed1(poolPath, key string) string { - return mustGetPool(poolPath).PoolGetPositionTokensOwed1(key).ToString() + return mustGetPool(poolPath).GetPositionTokensOwed1(key).ToString() } // tick func PoolGetTickLiquidityGross(poolPath string, tick int32) string { - return mustGetPool(poolPath).PoolGetTickLiquidityGross(tick).ToString() + return mustGetPool(poolPath).GetTickLiquidityGross(tick).ToString() } func PoolGetTickLiquidityNet(poolPath string, tick int32) string { - return mustGetPool(poolPath).PoolGetTickLiquidityNet(tick).ToString() + return mustGetPool(poolPath).GetTickLiquidityNet(tick).ToString() } func PoolGetTickFeeGrowthOutside0X128(poolPath string, tick int32) string { - return mustGetPool(poolPath).PoolGetTickFeeGrowthOutside0X128(tick).ToString() + return mustGetPool(poolPath).GetTickFeeGrowthOutside0X128(tick).ToString() } func PoolGetTickFeeGrowthOutside1X128(poolPath string, tick int32) string { - return mustGetPool(poolPath).PoolGetTickFeeGrowthOutside1X128(tick).ToString() + return mustGetPool(poolPath).GetTickFeeGrowthOutside1X128(tick).ToString() } func PoolGetTickCumulativeOutside(poolPath string, tick int32) int64 { - return mustGetPool(poolPath).PoolGetTickCumulativeOutside(tick) + return mustGetPool(poolPath).GetTickCumulativeOutside(tick) } func PoolGetTickSecondsPerLiquidityOutsideX128(poolPath string, tick int32) string { - return mustGetPool(poolPath).PoolGetTickSecondsPerLiquidityOutsideX128(tick).ToString() + return mustGetPool(poolPath).GetTickSecondsPerLiquidityOutsideX128(tick).ToString() } func PoolGetTickSecondsOutside(poolPath string, tick int32) uint32 { - return mustGetPool(poolPath).PoolGetTickSecondsOutside(tick) + return mustGetPool(poolPath).GetTickSecondsOutside(tick) } func PoolGetTickInitialized(poolPath string, tick int32) bool { - return mustGetPool(poolPath).PoolGetTickInitialized(tick) + return mustGetPool(poolPath).GetTickInitialized(tick) } diff --git a/pool/pool.gno b/pool/pool.gno index 4816c6b11..cfb655f0e 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -985,71 +985,71 @@ func checkAmountRange(amount *u256.Uint) (uint64, error) { } // receiver getters -func (p *Pool) PoolGetToken0Path() string { +func (p *Pool) GetToken0Path() string { return p.token0Path } -func (p *Pool) PoolGetToken1Path() string { +func (p *Pool) GetToken1Path() string { return p.token1Path } -func (p *Pool) PoolGetFee() uint32 { +func (p *Pool) GetFee() uint32 { return p.fee } -func (p *Pool) PoolGetBalanceToken0() *u256.Uint { +func (p *Pool) GetBalanceToken0() *u256.Uint { return p.balances.token0 } -func (p *Pool) PoolGetBalanceToken1() *u256.Uint { +func (p *Pool) GetBalanceToken1() *u256.Uint { return p.balances.token1 } -func (p *Pool) PoolGetTickSpacing() int32 { +func (p *Pool) GetTickSpacing() int32 { return p.tickSpacing } -func (p *Pool) PoolGetMaxLiquidityPerTick() *u256.Uint { +func (p *Pool) GetMaxLiquidityPerTick() *u256.Uint { return p.maxLiquidityPerTick } -func (p *Pool) PoolGetSlot0() Slot0 { +func (p *Pool) GetSlot0() Slot0 { return p.slot0 } -func (p *Pool) PoolGetSlot0SqrtPriceX96() *u256.Uint { +func (p *Pool) GetSlot0SqrtPriceX96() *u256.Uint { return p.slot0.sqrtPriceX96 } -func (p *Pool) PoolGetSlot0Tick() int32 { +func (p *Pool) GetSlot0Tick() int32 { return p.slot0.tick } -func (p *Pool) PoolGetSlot0FeeProtocol() uint8 { +func (p *Pool) GetSlot0FeeProtocol() uint8 { return p.slot0.feeProtocol } -func (p *Pool) PoolGetSlot0Unlocked() bool { +func (p *Pool) GetSlot0Unlocked() bool { return p.slot0.unlocked } -func (p *Pool) PoolGetFeeGrowthGlobal0X128() *u256.Uint { +func (p *Pool) GetFeeGrowthGlobal0X128() *u256.Uint { return p.feeGrowthGlobal0X128 } -func (p *Pool) PoolGetFeeGrowthGlobal1X128() *u256.Uint { +func (p *Pool) GetFeeGrowthGlobal1X128() *u256.Uint { return p.feeGrowthGlobal1X128 } -func (p *Pool) PoolGetProtocolFeesToken0() *u256.Uint { +func (p *Pool) GetProtocolFeesToken0() *u256.Uint { return p.protocolFees.token0 } -func (p *Pool) PoolGetProtocolFeesToken1() *u256.Uint { +func (p *Pool) GetProtocolFeesToken1() *u256.Uint { return p.protocolFees.token1 } -func (p *Pool) PoolGetLiquidity() *u256.Uint { +func (p *Pool) GetLiquidity() *u256.Uint { return p.liquidity } diff --git a/pool/position.gno b/pool/position.gno index 9b020ef5a..64d37e9d5 100644 --- a/pool/position.gno +++ b/pool/position.gno @@ -107,23 +107,23 @@ func positionUpdate( // receiver getters -func (p *Pool) PoolGetPositionLiquidity(key string) *u256.Uint { +func (p *Pool) GetPositionLiquidity(key string) *u256.Uint { return p.mustGetPosition(key).liquidity } -func (p *Pool) PoolGetPositionFeeGrowthInside0LastX128(key string) *u256.Uint { +func (p *Pool) GetPositionFeeGrowthInside0LastX128(key string) *u256.Uint { return p.mustGetPosition(key).feeGrowthInside0LastX128 } -func (p *Pool) PoolGetPositionFeeGrowthInside1LastX128(key string) *u256.Uint { +func (p *Pool) GetPositionFeeGrowthInside1LastX128(key string) *u256.Uint { return p.mustGetPosition(key).feeGrowthInside1LastX128 } -func (p *Pool) PoolGetPositionTokensOwed0(key string) *u256.Uint { +func (p *Pool) GetPositionTokensOwed0(key string) *u256.Uint { return p.mustGetPosition(key).tokensOwed0 } -func (p *Pool) PoolGetPositionTokensOwed1(key string) *u256.Uint { +func (p *Pool) GetPositionTokensOwed1(key string) *u256.Uint { return p.mustGetPosition(key).tokensOwed1 } diff --git a/pool/tick.gno b/pool/tick.gno index da19e939a..d325826fb 100644 --- a/pool/tick.gno +++ b/pool/tick.gno @@ -144,35 +144,35 @@ func getFeeGrowthAboveX128( } // receiver getters -func (p *Pool) PoolGetTickLiquidityGross(tick int32) *u256.Uint { +func (p *Pool) GetTickLiquidityGross(tick int32) *u256.Uint { return p.mustGetTick(tick).liquidityGross } -func (p *Pool) PoolGetTickLiquidityNet(tick int32) *i256.Int { +func (p *Pool) GetTickLiquidityNet(tick int32) *i256.Int { return p.mustGetTick(tick).liquidityNet } -func (p *Pool) PoolGetTickFeeGrowthOutside0X128(tick int32) *u256.Uint { +func (p *Pool) GetTickFeeGrowthOutside0X128(tick int32) *u256.Uint { return p.mustGetTick(tick).feeGrowthOutside0X128 } -func (p *Pool) PoolGetTickFeeGrowthOutside1X128(tick int32) *u256.Uint { +func (p *Pool) GetTickFeeGrowthOutside1X128(tick int32) *u256.Uint { return p.mustGetTick(tick).feeGrowthOutside1X128 } -func (p *Pool) PoolGetTickCumulativeOutside(tick int32) int64 { +func (p *Pool) GetTickCumulativeOutside(tick int32) int64 { return p.mustGetTick(tick).tickCumulativeOutside } -func (p *Pool) PoolGetTickSecondsPerLiquidityOutsideX128(tick int32) *u256.Uint { +func (p *Pool) GetTickSecondsPerLiquidityOutsideX128(tick int32) *u256.Uint { return p.mustGetTick(tick).secondsPerLiquidityOutsideX128 } -func (p *Pool) PoolGetTickSecondsOutside(tick int32) uint32 { +func (p *Pool) GetTickSecondsOutside(tick int32) uint32 { return p.mustGetTick(tick).secondsOutside } -func (p *Pool) PoolGetTickInitialized(tick int32) bool { +func (p *Pool) GetTickInitialized(tick int32) bool { return p.mustGetTick(tick).initialized } diff --git a/position/_RPC_api.gno b/position/_RPC_api.gno index 5abf9a1b6..b42e871d1 100644 --- a/position/_RPC_api.gno +++ b/position/_RPC_api.gno @@ -390,7 +390,7 @@ func rpcMakePosition(lpTokenId uint64) RpcPosition { burned := isBurned(lpTokenId) pool := pl.GetPoolFromPoolPath(position.poolKey) - currentX96 := pool.PoolGetSlot0SqrtPriceX96() + currentX96 := pool.GetSlot0SqrtPriceX96() lowerX96 := common.TickMathGetSqrtRatioAtTick(position.tickLower) upperX96 := common.TickMathGetSqrtRatioAtTick(position.tickUpper) @@ -439,24 +439,24 @@ func unclaimedFee(tokenId uint64) (*i256.Int, *i256.Int) { poolKey := positions[tokenId].poolKey pool := pl.GetPoolFromPoolPath(poolKey) - currentTick := pool.PoolGetSlot0Tick() + currentTick := pool.GetSlot0Tick() - _feeGrowthGlobal0X128 := pool.PoolGetFeeGrowthGlobal0X128() // u256 + _feeGrowthGlobal0X128 := pool.GetFeeGrowthGlobal0X128() // u256 feeGrowthGlobal0X128 := i256.FromUint256(_feeGrowthGlobal0X128) // i256 - _feeGrowthGlobal1X128 := pool.PoolGetFeeGrowthGlobal1X128() // u256 + _feeGrowthGlobal1X128 := pool.GetFeeGrowthGlobal1X128() // u256 feeGrowthGlobal1X128 := i256.FromUint256(_feeGrowthGlobal1X128) // i256 - _tickUpperFeeGrowthOutside0X128 := pool.PoolGetTickFeeGrowthOutside0X128(tickUpper) // u256 + _tickUpperFeeGrowthOutside0X128 := pool.GetTickFeeGrowthOutside0X128(tickUpper) // u256 tickUpperFeeGrowthOutside0X128 := i256.FromUint256(_tickUpperFeeGrowthOutside0X128) // i256 - _tickUpperFeeGrowthOutside1X128 := pool.PoolGetTickFeeGrowthOutside1X128(tickUpper) // u256 + _tickUpperFeeGrowthOutside1X128 := pool.GetTickFeeGrowthOutside1X128(tickUpper) // u256 tickUpperFeeGrowthOutside1X128 := i256.FromUint256(_tickUpperFeeGrowthOutside1X128) // i256 - _tickLowerFeeGrowthOutside0X128 := pool.PoolGetTickFeeGrowthOutside0X128(tickLower) // u256 + _tickLowerFeeGrowthOutside0X128 := pool.GetTickFeeGrowthOutside0X128(tickLower) // u256 tickLowerFeeGrowthOutside0X128 := i256.FromUint256(_tickLowerFeeGrowthOutside0X128) // i256 - _tickLowerFeeGrowthOutside1X128 := pool.PoolGetTickFeeGrowthOutside1X128(tickLower) // u256 + _tickLowerFeeGrowthOutside1X128 := pool.GetTickFeeGrowthOutside1X128(tickLower) // u256 tickLowerFeeGrowthOutside1X128 := i256.FromUint256(_tickLowerFeeGrowthOutside1X128) // i256 _feeGrowthInside0LastX128 := positions[tokenId].feeGrowthInside0LastX128 // u256 diff --git a/position/liquidity_management.gno b/position/liquidity_management.gno index 417b6a4a0..8b74794ca 100644 --- a/position/liquidity_management.gno +++ b/position/liquidity_management.gno @@ -16,7 +16,7 @@ import ( func addLiquidity(params AddLiquidityParams) (*u256.Uint, *u256.Uint, *u256.Uint) { pool := pl.GetPoolFromPoolPath(params.poolKey) - sqrtPriceX96 := pool.PoolGetSlot0SqrtPriceX96() + sqrtPriceX96 := pool.GetSlot0SqrtPriceX96() sqrtRatioAX96 := common.TickMathGetSqrtRatioAtTick(params.tickLower) sqrtRatioBX96 := common.TickMathGetSqrtRatioAtTick(params.tickUpper) diff --git a/position/position.gno b/position/position.gno index 7e5e37a2b..d4ffb1a88 100644 --- a/position/position.gno +++ b/position/position.gno @@ -189,7 +189,7 @@ func mint(params MintParams) (uint64, *u256.Uint, *u256.Uint, *u256.Uint) { nextId++ positionKey := positionKeyCompute(GetOrigPkgAddr(), params.tickLower, params.tickUpper) - _feeGrowthInside0LastX128, _feeGrowthInside1LastX128 := pool.PoolGetPositionFeeGrowthInside0LastX128(positionKey), pool.PoolGetPositionFeeGrowthInside1LastX128(positionKey) + _feeGrowthInside0LastX128, _feeGrowthInside1LastX128 := pool.GetPositionFeeGrowthInside0LastX128(positionKey), pool.GetPositionFeeGrowthInside1LastX128(positionKey) feeGrowthInside0LastX128 := u256.MustFromDecimal(_feeGrowthInside0LastX128.ToString()) feeGrowthInside1LastX128 := u256.MustFromDecimal(_feeGrowthInside1LastX128.ToString()) @@ -320,7 +320,7 @@ func increaseLiquidity(params IncreaseLiquidityParams) (uint64, *u256.Uint, *u25 pool := pl.GetPoolFromPoolPath(position.poolKey) positionKey := positionKeyCompute(GetOrigPkgAddr(), position.tickLower, position.tickUpper) - _feeGrowthInside0LastX128, _feeGrowthInside1LastX128 := pool.PoolGetPositionFeeGrowthInside0LastX128(positionKey), pool.PoolGetPositionFeeGrowthInside1LastX128(positionKey) + _feeGrowthInside0LastX128, _feeGrowthInside1LastX128 := pool.GetPositionFeeGrowthInside0LastX128(positionKey), pool.GetPositionFeeGrowthInside1LastX128(positionKey) feeGrowthInside0LastX128 := u256.MustFromDecimal(_feeGrowthInside0LastX128.ToString()) feeGrowthInside1LastX128 := u256.MustFromDecimal(_feeGrowthInside1LastX128.ToString()) @@ -446,7 +446,7 @@ func decreaseLiquidity(params DecreaseLiquidityParams) (uint64, *u256.Uint, *u25 verifyBurnedAmounts(burnedAmount0, burnedAmount1, params.amount0Min, params.amount1Min) positionKey := positionKeyCompute(GetOrigPkgAddr(), position.tickLower, position.tickUpper) - _feeGrowthInside0LastX128, _feeGrowthInside1LastX128 := pool.PoolGetPositionFeeGrowthInside0LastX128(positionKey), pool.PoolGetPositionFeeGrowthInside1LastX128(positionKey) + _feeGrowthInside0LastX128, _feeGrowthInside1LastX128 := pool.GetPositionFeeGrowthInside0LastX128(positionKey), pool.GetPositionFeeGrowthInside1LastX128(positionKey) feeGrowthInside0LastX128 := u256.MustFromDecimal(_feeGrowthInside0LastX128.ToString()) feeGrowthInside1LastX128 := u256.MustFromDecimal(_feeGrowthInside1LastX128.ToString()) @@ -592,7 +592,7 @@ func Reposition( pool := pl.GetPoolFromPoolPath(position.poolKey) positionKey := positionKeyCompute(GetOrigPkgAddr(), tickLower, tickUpper) - _feeGrowthInside0LastX128, _feeGrowthInside1LastX128 := pool.PoolGetPositionFeeGrowthInside0LastX128(positionKey), pool.PoolGetPositionFeeGrowthInside1LastX128(positionKey) + _feeGrowthInside0LastX128, _feeGrowthInside1LastX128 := pool.GetPositionFeeGrowthInside0LastX128(positionKey), pool.GetPositionFeeGrowthInside1LastX128(positionKey) feeGrowthInside0LastX128 := u256.MustFromDecimal(_feeGrowthInside0LastX128.ToString()) feeGrowthInside1LastX128 := u256.MustFromDecimal(_feeGrowthInside1LastX128.ToString()) @@ -676,7 +676,7 @@ func CollectFee(tokenId uint64, unwrapResult bool) (uint64, string, string, stri positionKey := positionKeyCompute(GetOrigPkgAddr(), position.tickLower, position.tickUpper) pool := pl.GetPoolFromPoolPath(position.poolKey) - _feeGrowthInside0LastX128, _feeGrowthInside1LastX128 := pool.PoolGetPositionFeeGrowthInside0LastX128(positionKey), pool.PoolGetPositionFeeGrowthInside1LastX128(positionKey) + _feeGrowthInside0LastX128, _feeGrowthInside1LastX128 := pool.GetPositionFeeGrowthInside0LastX128(positionKey), pool.GetPositionFeeGrowthInside1LastX128(positionKey) feeGrowthInside0LastX128 := u256.MustFromDecimal(_feeGrowthInside0LastX128.ToString()) feeGrowthInside1LastX128 := u256.MustFromDecimal(_feeGrowthInside1LastX128.ToString()) From ff6eb546833eb035f22f61c7afc2db021cd21e22 Mon Sep 17 00:00:00 2001 From: 0xTopaz Date: Sun, 15 Dec 2024 20:16:32 +0900 Subject: [PATCH 5/9] refactor: Use Clone data in function calls to protect original data --- _deploy/p/gnoswap/pool/swap_math.gno | 127 ++++++++++++++-------- _deploy/p/gnoswap/pool/swap_math_test.gno | 62 +++++++++++ 2 files changed, 145 insertions(+), 44 deletions(-) diff --git a/_deploy/p/gnoswap/pool/swap_math.gno b/_deploy/p/gnoswap/pool/swap_math.gno index 6150981f3..81b9c2d28 100644 --- a/_deploy/p/gnoswap/pool/swap_math.gno +++ b/_deploy/p/gnoswap/pool/swap_math.gno @@ -5,14 +5,34 @@ import ( u256 "gno.land/p/gnoswap/uint256" ) +// SwapMathComputeSwapStepStr computes the next sqrt price, amount in, amount out, and fee amount +// Computes the result of swapping some amount in, or amount out, given the parameters of the swap +// The fee, plus the amount in, will never exceed the amount remaining if the swap's `amountSpecified` is positive +// +// input: +// - sqrtRatioCurrentX96: the current sqrt price of the pool +// - sqrtRatioTargetX96: The price that cannot be exceeded, from which the direction of the swap is inferred +// - liquidity: The usable liquidity of the pool +// - amountRemaining: How much input or output amount is remaining to be swapped in/out +// - feePips: The fee taken from the input amount, expressed in hundredths of a bip +// +// output: +// - sqrtRatioNextX96: The price after swapping the amount in/out, not to exceed the price target +// - amountIn: The amount to be swapped in, of either token0 or token1, based on the direction of the swap +// - amountOut: The amount to be received, of either token0 or token1, based on the direction of the swap +// - feeAmount: The amount of input that will be taken as a fee func SwapMathComputeSwapStepStr( - sqrtRatioCurrentX96 *u256.Uint, // uint160 - sqrtRatioTargetX96 *u256.Uint, // uint160 - liquidity *u256.Uint, // uint128 - amountRemaining *i256.Int, // int256 + sqrtRatioCurrentX96 *u256.Uint, + sqrtRatioTargetX96 *u256.Uint, + liquidity *u256.Uint, + amountRemaining *i256.Int, feePips uint64, -) (string, string, string, string) { // (sqrtRatioNextX96, amountIn, amountOut, feeAmount *u256.Uint) - isToken1Expensive := sqrtRatioCurrentX96.Gte(sqrtRatioTargetX96) +) (string, string, string, string) { + if sqrtRatioCurrentX96 == nil || sqrtRatioTargetX96 == nil || liquidity == nil || amountRemaining == nil { + panic("SwapMathComputeSwapStepStr: invalid input") + } + + zeroForOne := sqrtRatioCurrentX96.Gte(sqrtRatioTargetX96) // POSTIVIE == EXACT_IN => Estimated AmountOut // NEGATIVE == EXACT_OUT => Estimated AmountIn @@ -25,75 +45,94 @@ func SwapMathComputeSwapStepStr( if exactIn { amountRemainingLessFee := u256.MulDiv(amountRemaining.Abs(), u256.NewUint(1000000-feePips), u256.NewUint(1000000)) - - if isToken1Expensive { - amountIn = sqrtPriceMathGetAmount0DeltaHelper(sqrtRatioTargetX96, sqrtRatioCurrentX96, liquidity, true) + if zeroForOne { + amountIn = sqrtPriceMathGetAmount0DeltaHelper( + sqrtRatioTargetX96.Clone(), + sqrtRatioCurrentX96.Clone(), + liquidity.Clone(), + true) } else { - amountIn = sqrtPriceMathGetAmount1DeltaHelper(sqrtRatioCurrentX96, sqrtRatioTargetX96, liquidity, true) + amountIn = sqrtPriceMathGetAmount1DeltaHelper( + sqrtRatioCurrentX96.Clone(), + sqrtRatioTargetX96.Clone(), + liquidity.Clone(), + true) } if amountRemainingLessFee.Gte(amountIn) { - sqrtRatioNextX96 = sqrtRatioTargetX96 + sqrtRatioNextX96 = sqrtRatioTargetX96.Clone() } else { sqrtRatioNextX96 = sqrtPriceMathGetNextSqrtPriceFromInput( - sqrtRatioCurrentX96, - liquidity, - amountRemainingLessFee, - isToken1Expensive, + sqrtRatioCurrentX96.Clone(), + liquidity.Clone(), + amountRemainingLessFee.Clone(), + zeroForOne, ) } - } else { - if isToken1Expensive { - amountOut = sqrtPriceMathGetAmount1DeltaHelper(sqrtRatioTargetX96, sqrtRatioCurrentX96, liquidity, false) + if zeroForOne { + amountOut = sqrtPriceMathGetAmount1DeltaHelper( + sqrtRatioTargetX96.Clone(), + sqrtRatioCurrentX96.Clone(), + liquidity.Clone(), + false) } else { - amountOut = sqrtPriceMathGetAmount0DeltaHelper(sqrtRatioCurrentX96, sqrtRatioTargetX96, liquidity, false) + amountOut = sqrtPriceMathGetAmount0DeltaHelper( + sqrtRatioCurrentX96.Clone(), + sqrtRatioTargetX96.Clone(), + liquidity.Clone(), + false) } if amountRemaining.Abs().Gte(amountOut) { - sqrtRatioNextX96 = sqrtRatioTargetX96 + sqrtRatioNextX96 = sqrtRatioTargetX96.Clone() } else { sqrtRatioNextX96 = sqrtPriceMathGetNextSqrtPriceFromOutput( - sqrtRatioCurrentX96, - liquidity, + sqrtRatioCurrentX96.Clone(), + liquidity.Clone(), amountRemaining.Abs(), - isToken1Expensive, + zeroForOne, ) } } - max := sqrtRatioTargetX96.Eq(sqrtRatioNextX96) + isMax := sqrtRatioTargetX96.Eq(sqrtRatioNextX96) - if isToken1Expensive { - if max && exactIn { - amountIn = amountIn - } else { - amountIn = sqrtPriceMathGetAmount0DeltaHelper(sqrtRatioNextX96, sqrtRatioCurrentX96, liquidity, true) + if zeroForOne { + if !(isMax && exactIn) { + amountIn = sqrtPriceMathGetAmount0DeltaHelper( + sqrtRatioNextX96.Clone(), + sqrtRatioCurrentX96.Clone(), + liquidity.Clone(), + true) } - - if max && !exactIn { - amountOut = amountOut - } else { - amountOut = sqrtPriceMathGetAmount1DeltaHelper(sqrtRatioNextX96, sqrtRatioCurrentX96, liquidity, false) + if !(isMax && !exactIn) { + amountOut = sqrtPriceMathGetAmount1DeltaHelper( + sqrtRatioNextX96.Clone(), + sqrtRatioCurrentX96.Clone(), + liquidity.Clone(), + false) } } else { - if max && exactIn { - amountIn = amountIn - } else { - amountIn = sqrtPriceMathGetAmount1DeltaHelper(sqrtRatioCurrentX96, sqrtRatioNextX96, liquidity, true) + if !(isMax && exactIn) { + amountIn = sqrtPriceMathGetAmount1DeltaHelper( + sqrtRatioCurrentX96.Clone(), + sqrtRatioNextX96.Clone(), + liquidity.Clone(), + true) } - - if max && !exactIn { - amountOut = amountOut - } else { - amountOut = sqrtPriceMathGetAmount0DeltaHelper(sqrtRatioCurrentX96, sqrtRatioNextX96, liquidity, false) + if !(isMax && !exactIn) { + amountOut = sqrtPriceMathGetAmount0DeltaHelper( + sqrtRatioCurrentX96.Clone(), + sqrtRatioNextX96.Clone(), + liquidity.Clone(), + false) } } if !exactIn && amountOut.Gt(amountRemaining.Abs()) { amountOut = amountRemaining.Abs() } - if exactIn && !(sqrtRatioNextX96.Eq(sqrtRatioTargetX96)) { feeAmount = new(u256.Uint).Sub(amountRemaining.Abs(), amountIn) } else { diff --git a/_deploy/p/gnoswap/pool/swap_math_test.gno b/_deploy/p/gnoswap/pool/swap_math_test.gno index 7455b7e64..505a8e716 100644 --- a/_deploy/p/gnoswap/pool/swap_math_test.gno +++ b/_deploy/p/gnoswap/pool/swap_math_test.gno @@ -168,6 +168,68 @@ func TestSwapMathComputeSwapStepStr(t *testing.T) { } } +func TestSwapMathComputeSwapStepStrFail(t *testing.T) { + tests := []struct { + name string + currentX96, targetX96 *u256.Uint + liquidity *u256.Uint + amountRemaining *i256.Int + feePips uint64 + sqrtNextX96 *u256.Uint + chkSqrtNextX96 func(sqrtRatioNextX96, priceTarget *u256.Uint) + amountIn, amountOut, feeAmount string + shouldPanic bool + expectedMessage string + }{ + { + name: "input parameter is nil", + currentX96: nil, + targetX96: nil, + liquidity: nil, + amountRemaining: nil, + feePips: 600, + sqrtNextX96: encodePriceSqrt(t, "101", "100"), + chkSqrtNextX96: func(sqrtRatioNextX96, priceTarget *u256.Uint) { + uassert.True(t, sqrtRatioNextX96.Eq(priceTarget)) + }, + amountIn: "9975124224178055", + amountOut: "9925619580021728", + feeAmount: "5988667735148", + shouldPanic: true, + expectedMessage: "SwapMathComputeSwapStepStr: invalid input", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + if test.shouldPanic { + if errMsg, ok := r.(string); ok { + uassert.Equal(t, test.expectedMessage, errMsg) + } else { + t.Errorf("expected a panic with message, got: %v", r) + } + } else { + t.Errorf("unexpected panic: %v", r) + } + } else { + if test.shouldPanic { + t.Errorf("expected a panic, but none occurred") + } + } + }() + + SwapMathComputeSwapStepStr( + test.currentX96, + test.targetX96, + test.liquidity, + test.amountRemaining, + test.feePips) + }) + } +} + // encodePriceSqrt calculates the sqrt((reserve1 << 192) / reserve0) func encodePriceSqrt(t *testing.T, reserve1, reserve0 string) *u256.Uint { t.Helper() From 5aadc0cc5bab7227120fcfe8a01de8cab8c2f3ed Mon Sep 17 00:00:00 2001 From: Blake <104744707+r3v4s@users.noreply.github.com> Date: Mon, 16 Dec 2024 09:36:01 +0900 Subject: [PATCH 6/9] GSW-1838 test: sqrt_price_math unitest (#425) * refactor: sqrt_price_math - added testcase * refactor: add test code and comments - fix test fail issue * refactor: Overflow check for type conversion to int256 in delta amount calculation --------- Co-authored-by: 0xTopaz Co-authored-by: 0xTopaz <60733299+onlyhyde@users.noreply.github.com> --- .../pool/__TEST_sqrt_price_math_test.gnoA | 690 ------------------ _deploy/p/gnoswap/pool/consts.gno | 1 + _deploy/p/gnoswap/pool/gno.mod | 6 - _deploy/p/gnoswap/pool/sqrt_price_math.gno | 330 +++++++-- .../p/gnoswap/pool/sqrt_price_math_test.gno | 643 ++++++++++++++++ 5 files changed, 908 insertions(+), 762 deletions(-) delete mode 100644 _deploy/p/gnoswap/pool/__TEST_sqrt_price_math_test.gnoA create mode 100644 _deploy/p/gnoswap/pool/sqrt_price_math_test.gno diff --git a/_deploy/p/gnoswap/pool/__TEST_sqrt_price_math_test.gnoA b/_deploy/p/gnoswap/pool/__TEST_sqrt_price_math_test.gnoA deleted file mode 100644 index b3da60894..000000000 --- a/_deploy/p/gnoswap/pool/__TEST_sqrt_price_math_test.gnoA +++ /dev/null @@ -1,690 +0,0 @@ -package pool - -import ( - "testing" - - "gno.land/r/gnoswap/v1/consts" - - i256 "gno.land/p/gnoswap/int256" - u256 "gno.land/p/gnoswap/uint256" -) - -func TestGetNextSqrtPriceFromInput_1(t *testing.T) { - // fails if price is zero - sqrtPX96 := u256.Zero() - liquidity := u256.Zero() - amountIn := u256.MustFromDecimal("1000000000000000000") // 1e18 - zeroForOne := false - amountIn.Div(amountIn, u256.NewUint(10)) - - shouldPanic( - t, - func() { - sqrtPriceMathGetNextSqrtPriceFromInput(sqrtPX96, liquidity, amountIn, zeroForOne) - }, - ) -} - -func TestGetNextSqrtPriceFromInput_2(t *testing.T) { - // fails if liquidity is zero - sqrtPX96 := u256.One() - liquidity := u256.Zero() - amountIn := u256.MustFromDecimal("1000000000000000000") // 1e18 - zeroForOne := true - amountIn.Div(amountIn, u256.NewUint(10)) - - shouldPanic( - t, - func() { - sqrtPriceMathGetNextSqrtPriceFromInput(sqrtPX96, liquidity, amountIn, zeroForOne) - }, - ) -} - -func TestGetNextSqrtPriceFromInput_3(t *testing.T) { - // fails if input amount overflows the price - sqrtPX96 := u256.NewUint(2) - sqrtPX96.Exp(sqrtPX96, u256.NewUint(160)) - sqrtPX96.Sub(sqrtPX96, u256.One()) - - liquidity := u256.NewUint(1024) - amountIn := u256.NewUint(1024) - zeroForOne := false - - shouldPanic( - t, - func() { - sqrtPriceMathGetNextSqrtPriceFromInput(sqrtPX96, liquidity, amountIn, zeroForOne) - }, - ) -} - -func TestGetNextSqrtPriceFromInput_4(t *testing.T) { - // any input amount cannot underflow the price - sqrtPX96 := u256.One() - liquidity := u256.One() - amountIn := u256.NewUint(2) - amountIn.Exp(amountIn, u256.NewUint(225)) - zeroForOne := true - - expected := u256.One() - - got := sqrtPriceMathGetNextSqrtPriceFromInput(sqrtPX96, liquidity, amountIn, zeroForOne) - - rst := got.Eq(expected) - - if !rst { - t.Errorf("any input amount cannot underflow the price") - } -} - -func TestGetNextSqrtPriceFromInput_5(t *testing.T) { - // returns input price if amount in is zero and zeroForOne := true - sqrtPX96 := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - liquidity := u256.MustFromDecimal("1000000000000000000") - liquidity.Div(liquidity, u256.NewUint(10)) - amountIn := u256.Zero() - zeroForOne := true - - expected := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - - got := sqrtPriceMathGetNextSqrtPriceFromInput(sqrtPX96, liquidity, amountIn, zeroForOne) - - rst := got.Eq(expected) - - if !rst { - t.Errorf("returns input price if amount in is zero and zeroForOne := true") - } -} - -func TestGetNextSqrtPriceFromInput_6(t *testing.T) { - // returns input price if amount in is zero and zeroForOne := false - sqrtPX96 := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - liquidity := u256.MustFromDecimal("1000000000000000000") - liquidity.Div(liquidity, u256.NewUint(10)) - amountIn := u256.Zero() - zeroForOne := false - - expected := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - - got := sqrtPriceMathGetNextSqrtPriceFromInput(sqrtPX96, liquidity, amountIn, zeroForOne) - - rst := got.Eq(expected) - - if !rst { - t.Errorf("returns input price if amount in is zero and zeroForOne := false") - } -} - -func TestGetNextSqrtPriceFromInput_7(t *testing.T) { - var maxAmountNoOverflow *u256.Uint - var a *u256.Uint - - // returns the minimum price for max inputs - sqrtPX96 := u256.NewUint(2) - sqrtPX96.Exp(sqrtPX96, u256.NewUint(160)) - sqrtPX96.Sub(sqrtPX96, u256.NewUint(1)) - - liquidity := u256.MustFromDecimal(consts.MAX_UINT128) - a = u256.MustFromDecimal(consts.MAX_UINT128) - maxAmountNoOverflow = u256.MustFromDecimal(consts.MAX_UINT256) - a.Lsh(a, 96) - a.Div(a, sqrtPX96) - maxAmountNoOverflow.Sub(maxAmountNoOverflow, a) - - zeroForOne := true - - expected := u256.One() - - got := sqrtPriceMathGetNextSqrtPriceFromInput(sqrtPX96, liquidity, maxAmountNoOverflow, zeroForOne) - - rst := got.Eq(expected) - - if !rst { - t.Errorf("returns the minimum price for max inputs") - } -} - -func TestGetNextSqrtPriceFromInput_8(t *testing.T) { - // input amount of 0.1 token1 - sqrtPX96 := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - liquidity := u256.MustFromDecimal("1000000000000000000") - amountIn := u256.MustFromDecimal("1000000000000000000") - amountIn.Div(amountIn, u256.NewUint(10)) - zeroForOne := false - - expected := u256.MustFromDecimal("87150978765690771352898345369") - - got := sqrtPriceMathGetNextSqrtPriceFromInput(sqrtPX96, liquidity, amountIn, zeroForOne) - - rst := got.Eq(expected) - - if !rst { - t.Errorf("input amount of 0.1 token1") - } -} - -func TestGetNextSqrtPriceFromInput_9(t *testing.T) { - // input amount of 0.1 token1 - sqrtPX96 := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - liquidity := u256.MustFromDecimal("1000000000000000000") - amountIn := u256.MustFromDecimal("1000000000000000000") - amountIn.Div(amountIn, u256.NewUint(10)) - zeroForOne := true - - expected := u256.MustFromDecimal("72025602285694852357767227579") - - got := sqrtPriceMathGetNextSqrtPriceFromInput(sqrtPX96, liquidity, amountIn, zeroForOne) - - rst := got.Eq(expected) - - if !rst { - t.Errorf("input amount of 0.1 token1") - } -} - -func TestGetNextSqrtPriceFromInput_10(t *testing.T) { - // amountIn > type(uint96).max and zeroForOne := true - sqrtPX96 := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - liquidity := u256.MustFromDecimal("10000000000000000000") // 10e18 - amountIn := u256.MustFromDecimal("2") - amountIn.Exp(amountIn, u256.NewUint(100)) - zeroForOne := true - - expected := u256.MustFromDecimal("624999999995069620") - - got := sqrtPriceMathGetNextSqrtPriceFromInput(sqrtPX96, liquidity, amountIn, zeroForOne) - - rst := got.Eq(expected) - - if !rst { - t.Errorf("amountIn > type(uint96).max and zeroForOne := true") - } -} - -func TestGetNextSqrtPriceFromInput_11(t *testing.T) { - // can return 1 with enough amountIn and zeroForOne := true - sqrtPX96 := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - liquidity := u256.MustFromDecimal("1") - amountIn := u256.MustFromDecimal(consts.MAX_UINT256) - amountIn.Div(amountIn, u256.NewUint(2)) - zeroForOne := true - - expected := u256.MustFromDecimal("1") - - got := sqrtPriceMathGetNextSqrtPriceFromInput(sqrtPX96, liquidity, amountIn, zeroForOne) - - rst := got.Eq(expected) - - if !rst { - t.Errorf("can return 1 with enough amountIn and zeroForOne := true") - } -} - -func TestGetNextSqrtPriceFromOutput_1(t *testing.T) { - // fails if price is zero - sqrtPX96 := u256.Zero() - liquidity := u256.Zero() - amountOut := u256.MustFromDecimal("1000000000000000000") // 1e18 - amountOut.Div(amountOut, u256.NewUint(10)) - zeroForOne := false - - shouldPanic( - t, - func() { - sqrtPriceMathGetNextSqrtPriceFromOutput(sqrtPX96, liquidity, amountOut, zeroForOne) - }, - ) -} - -func TestGetNextSqrtPriceFromOutput_2(t *testing.T) { - // fails if liquidity is zero - sqrtPX96 := u256.One() - liquidity := u256.Zero() - amountOut := u256.MustFromDecimal("1000000000000000000") // 1e18 - amountOut.Div(amountOut, u256.NewUint(10)) - zeroForOne := true - - shouldPanic( - t, - func() { - sqrtPriceMathGetNextSqrtPriceFromOutput(sqrtPX96, liquidity, amountOut, zeroForOne) - }, - ) -} - -func TestGetNextSqrtPriceFromOutput_3(t *testing.T) { - // fails if output amount is exactly the virtual reserves of token0 - sqrtPX96 := u256.MustFromDecimal("20282409603651670423947251286016") - liquidity := u256.NewUint(1024) - amountOut := u256.MustFromDecimal("4") - zeroForOne := false - - shouldPanic( - t, - func() { - sqrtPriceMathGetNextSqrtPriceFromOutput(sqrtPX96, liquidity, amountOut, zeroForOne) - }, - ) -} - -func TestGetNextSqrtPriceFromOutput_4_1(t *testing.T) { - // fails if output amount is greater than virtual reserves of token0 - sqrtPX96 := u256.MustFromDecimal("20282409603651670423947251286016") - liquidity := u256.NewUint(1024) - amountOut := u256.MustFromDecimal("5") - zeroForOne := false - - shouldPanic( - t, - func() { - sqrtPriceMathGetNextSqrtPriceFromOutput(sqrtPX96, liquidity, amountOut, zeroForOne) - }, - ) -} - -func TestGetNextSqrtPriceFromOutput_4_2(t *testing.T) { - // fails if output amount is greater than virtual reserves of token1 - sqrtPX96 := u256.MustFromDecimal("20282409603651670423947251286016") - liquidity := u256.NewUint(1024) - amountOut := u256.MustFromDecimal("262145") - zeroForOne := true - - shouldPanic( - t, - func() { - sqrtPriceMathGetNextSqrtPriceFromOutput(sqrtPX96, liquidity, amountOut, zeroForOne) - }, - ) -} - -func TestGetNextSqrtPriceFromOutput_5(t *testing.T) { - // fails if output amount is exactly the virtual reserves of token1 - sqrtPX96 := u256.MustFromDecimal("20282409603651670423947251286016") - liquidity := u256.NewUint(1024) - amountOut := u256.MustFromDecimal("262144") - zeroForOne := true - - shouldPanic( - t, - func() { - sqrtPriceMathGetNextSqrtPriceFromOutput(sqrtPX96, liquidity, amountOut, zeroForOne) - }, - ) -} - -func TestGetNextSqrtPriceFromOutput_6(t *testing.T) { - // succeeds if output amount is just less than the virtual reserves of token1 - sqrtPX96 := u256.MustFromDecimal("20282409603651670423947251286016") - liquidity := u256.NewUint(1024) - amountOut := u256.MustFromDecimal("262143") - zeroForOne := true - - expected := u256.MustFromDecimal("77371252455336267181195264") - - got := sqrtPriceMathGetNextSqrtPriceFromOutput(sqrtPX96, liquidity, amountOut, zeroForOne) - - rst := got.Eq(expected) - - if !rst { - t.Errorf("The result should be eq to 77371252455336267181195264") - } -} - -func TestGetNextSqrtPriceFromOutput_7(t *testing.T) { - // puzzling echidna test - sqrtPX96 := u256.MustFromDecimal("20282409603651670423947251286016") - liquidity := u256.NewUint(1024) - amountOut := u256.MustFromDecimal("4") - zeroForOne := false - - shouldPanic( - t, - func() { - sqrtPriceMathGetNextSqrtPriceFromOutput(sqrtPX96, liquidity, amountOut, zeroForOne) - }, - ) -} - -func TestGetNextSqrtPriceFromOutput_8(t *testing.T) { - // returns input price if amount in is zero and zeroForOne := true - sqrtPX96 := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - liquidity := u256.MustFromDecimal("1000000000000000000") - liquidity.Div(liquidity, u256.NewUint(10)) - amountOut := u256.Zero() - zeroForOne := true - - expected := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - - got := sqrtPriceMathGetNextSqrtPriceFromOutput(sqrtPX96, liquidity, amountOut, zeroForOne) - - rst := got.Eq(expected) - - if !rst { - t.Errorf("The result should be eq to 79228162514264337593543950336") - } -} - -func TestGetNextSqrtPriceFromOutput_9(t *testing.T) { - // returns input price if amount in is zero and zeroForOne := false - - sqrtPX96 := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - liquidity := u256.MustFromDecimal("1000000000000000000") - liquidity.Div(liquidity, u256.NewUint(10)) - amountOut := u256.Zero() - zeroForOne := false - expected := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - - got := sqrtPriceMathGetNextSqrtPriceFromOutput(sqrtPX96, liquidity, amountOut, zeroForOne) - - rst := got.Eq(expected) - - if !rst { - t.Errorf("The result should be eq to 79228162514264337593543950336") - } -} - -func TestGetNextSqrtPriceFromOutput_10(t *testing.T) { - // output amount of 0.1 token1 - sqrtPX96 := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - liquidity := u256.MustFromDecimal("1000000000000000000") - amountOut := u256.MustFromDecimal("1000000000000000000") - amountOut.Div(amountOut, u256.NewUint(10)) - zeroForOne := false - expected := u256.MustFromDecimal("88031291682515930659493278152") - - got := sqrtPriceMathGetNextSqrtPriceFromOutput(sqrtPX96, liquidity, amountOut, zeroForOne) - - rst := got.Eq(expected) - - if !rst { - t.Errorf("The result should be eq to 88031291682515930659493278152") - } -} - -func TestGetNextSqrtPriceFromOutput_11(t *testing.T) { - // output amount of 0.1 token1 - sqrtPX96 := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - liquidity := u256.MustFromDecimal("1000000000000000000") - amountOut := u256.MustFromDecimal("1000000000000000000") - amountOut.Div(amountOut, u256.NewUint(10)) - zeroForOne := true - expected := u256.MustFromDecimal("71305346262837903834189555302") - - got := sqrtPriceMathGetNextSqrtPriceFromOutput(sqrtPX96, liquidity, amountOut, zeroForOne) - - rst := got.Eq(expected) - - if !rst { - t.Errorf("The result should be eq to 71305346262837903834189555302") - } -} - -func TestGetNextSqrtPriceFromOutput_12(t *testing.T) { - // reverts if amountOut is impossible in zero for one direction - sqrtPX96 := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - liquidity := u256.One() - amountOut := u256.MustFromDecimal(consts.MAX_UINT256) - zeroForOne := true - - shouldPanic( - t, - func() { - sqrtPriceMathGetNextSqrtPriceFromOutput(sqrtPX96, liquidity, amountOut, zeroForOne) - }, - ) -} - -func TestGetNextSqrtPriceFromOutput_13(t *testing.T) { - // reverts if amountOut is impossible in one for zero direction - sqrtPX96 := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - liquidity := u256.One() - amountOut := u256.MustFromDecimal(consts.MAX_UINT256) - zeroForOne := false - - shouldPanic( - t, - func() { - sqrtPriceMathGetNextSqrtPriceFromOutput(sqrtPX96, liquidity, amountOut, zeroForOne) - }, - ) -} - -func TestSqrtPriceMathGetAmount0DeltaStr_1(t *testing.T) { - // returns 0 if liquidity is 0 - - sqrtRatioAX96 := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - sqrtRatioBX96 := u256.MustFromDecimal("112045541949572279837463876454") // encodePriceSqrt(2, 1) = 112045541949572279837463876454 - liquidity_i256 := i256.Zero() - - got_string := SqrtPriceMathGetAmount0DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity_i256) - - if got_string != "0" { - t.Errorf("return value should be eq to 0") - } -} - -func TestSqrtPriceMathGetAmount0DeltaHelper_1(t *testing.T) { - // returns 0 if liquidity is 0 - sqrtRatioAX96 := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - sqrtRatioBX96 := u256.MustFromDecimal("112045541949572279837463876454") // encodePriceSqrt(2, 1) = 112045541949572279837463876454 - liquidity := u256.Zero() - roundUp := true - - expected := u256.Zero() - - got := sqrtPriceMathGetAmount0DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity, roundUp) - - rst := got.Eq(expected) - if !rst { - t.Errorf("return value should be eq to 0") - } -} - -func TestSqrtPriceMathGetAmount0DeltaStr_2(t *testing.T) { - // returns 0 if prices are equal - - sqrtRatioAX96 := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - sqrtRatioBX96 := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - liquidity_i256 := i256.Zero() - - got_string := SqrtPriceMathGetAmount0DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity_i256) - - if got_string != "0" { - t.Errorf("return value should be eq to 0") - } -} - -func TestSqrtPriceMathGetAmount0DeltaStr_3(t *testing.T) { - // return value should be eq to 90909090909090910 - - sqrtRatioAX96 := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - sqrtRatioBX96 := u256.MustFromDecimal("87150978765690771352898345369") // encodePriceSqrt(121, 100) = 87150978765690771352898345369 - liquidity_i256 := i256.MustFromDecimal("1000000000000000000") - - got_string := SqrtPriceMathGetAmount0DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity_i256) - - if got_string != "90909090909090910" { - t.Errorf("return value should be eq to 90909090909090910") - } -} - -func TestSqrtPriceMathGetAmount0DeltaStr_4(t *testing.T) { - // return value should be eq to 90909090909090910 - - sqrtRatioAX96 := u256.MustFromDecimal("2787593149816327892691964784081045188247552") // encodePriceSqrt(BigNumber.from(2).pow(90), 1) = 2787593149816327892691964784081045188247552 - sqrtRatioBX96 := u256.MustFromDecimal("22300745198530623141535718272648361505980416") // encodePriceSqrt(BigNumber.from(2).pow(96), 1) = 22300745198530623141535718272648361505980416 - liquidity_i256 := i256.MustFromDecimal("1000000000000000000") - - got_string := SqrtPriceMathGetAmount0DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity_i256) - if got_string == "0" { - t.Errorf("The result should not return 0") - } -} - -func TestSqrtPriceMathGetAmount0DeltaHelper_2(t *testing.T) { - // returns 0 if prices are equal - - sqrtRatioAX96 := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - sqrtRatioBX96 := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - liquidity := u256.Zero() - roundUp := true - - expected := u256.Zero() - - got := sqrtPriceMathGetAmount0DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity, roundUp) - - rst := got.Eq(expected) - if !rst { - t.Errorf("return value should be eq to 0") - } -} - -func TestSqrtPriceMathGetAmount0DeltaHelper_3(t *testing.T) { - sqrtRatioAX96 := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - sqrtRatioBX96 := u256.MustFromDecimal("87150978765690771352898345369") // encodePriceSqrt(121, 100) = 79228162514264337593543950336 - liquidity := u256.MustFromDecimal("1000000000000000000") - roundUp := true - - expected := u256.MustFromDecimal("90909090909090910") - - got := sqrtPriceMathGetAmount0DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity, roundUp) - - rst := got.Eq(expected) - - if !rst { - t.Errorf("The result should be eq to 90909090909090910") - } -} - -func TestSqrtPriceMathGetAmount0DeltaHelper_4(t *testing.T) { - // the sub between the result of roundup and rounddown should be eq to 1 - var got2 *u256.Uint - sqrtRatioAX96 := u256.MustFromDecimal("112045541949572279837463876454") // encodePriceSqrt(2, 1) = 112045541949572279837463876454 - sqrtRatioBX96 := u256.MustFromDecimal("87150978765690771352898345369") // encodePriceSqrt(121, 100) = 87150978765690771352898345369 - liquidity := u256.MustFromDecimal("1000000000000000000") - roundUp := true - - got := sqrtPriceMathGetAmount0DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity, roundUp) - - roundUp = false - - got2 = sqrtPriceMathGetAmount0DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity, roundUp) - - got.Sub(got, got2) - - rst := got.Eq(u256.One()) - - if !rst { - t.Errorf("the sub between the result of roundup and rounddown should be eq to 1") - } -} - -func TestSqrtPriceMathGetAmount0DeltaHelper_5(t *testing.T) { - // works for prices that overflow - - var got2 *u256.Uint - sqrtRatioAX96 := u256.MustFromDecimal("2787593149816327892691964784081045188247552") // encodePriceSqrt(BigNumber.from(2).pow(90), 1) = 2787593149816327892691964784081045188247552 - sqrtRatioBX96 := u256.MustFromDecimal("22300745198530623141535718272648361505980416") // encodePriceSqrt(BigNumber.from(2).pow(96), 1) = 22300745198530623141535718272648361505980416 - liquidity := u256.MustFromDecimal("1000000000000000000") - roundUp := true - - got := sqrtPriceMathGetAmount0DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity, roundUp) - - roundUp = false - - got2 = sqrtPriceMathGetAmount0DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity, roundUp) - - // println(got.ToString()) - // println(got2.ToString()) - got.Sub(got, got2) - - rst := got.Eq(u256.One()) - - if !rst { - t.Errorf("the sub between the result of roundup and rounddown should be eq to 1") - } -} - -func TestSqrtPriceMathGetAmount1DeltaHelper_1(t *testing.T) { - // returns 0 if liquidity is 0 - - var got2 *u256.Uint - sqrtRatioAX96 := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - sqrtRatioBX96 := u256.MustFromDecimal("112045541949572279837463876454") // encodePriceSqrt(2, 1) = 112045541949572279837463876454 - liquidity := u256.MustFromDecimal("0") - roundUp := true - - rst := sqrtPriceMathGetAmount1DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity, roundUp) - shouldEQ(t, rst.ToString(), "0") -} - -func TestSqrtPriceMathGetAmount1DeltaHelper_2(t *testing.T) { - // returns 0 if prices are equal - - sqrtRatioAX96 := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - sqrtRatioBX96 := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - liquidity := u256.MustFromDecimal("1") - roundUp := true - - expected := u256.Zero() - - got := sqrtPriceMathGetAmount1DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity, roundUp) - - rst := got.Eq(expected) - - if !rst { - t.Errorf("returns 0 if prices are equal") - } -} - -func TestSqrtPriceMathGetAmount1DeltaHelper_3(t *testing.T) { - var got2 *u256.Uint - // returns 0.1 amount1 for price of 1 to 1.21 - sqrtRatioAX96 := u256.MustFromDecimal("79228162514264337593543950336") // encodePriceSqrt(1, 1) = 79228162514264337593543950336 - sqrtRatioBX96 := u256.MustFromDecimal("87150978765690771352898345369") // encodePriceSqrt(121, 100) = 87150978765690771352898345369 - liquidity := u256.MustFromDecimal("1000000000000000000") - roundUp := true - - expected := u256.MustFromDecimal("100000000000000000") // 0.1e18 - - got := sqrtPriceMathGetAmount1DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity, roundUp) - - rst := got.Eq(expected) - - if !rst { - t.Errorf("the result should be eq to expected") - } - roundUp = false - - got2 = sqrtPriceMathGetAmount1DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity, roundUp) - got.Sub(got, got2) - - rst = got.Eq(u256.One()) - if !rst { - t.Errorf("the sub between the result of roundup and rounddown should be eq to 1") - } -} - -func TestSwapComputation(t *testing.T) { - // sqrtP * sqrtQ overflows - - sqrtPX96 := u256.MustFromDecimal("1025574284609383690408304870162715216695788925244") - liquidity := u256.MustFromDecimal("50015962439936049619261659728067971248") - amountIn := u256.NewUint(406) - zeroForOne := true - - expected := u256.MustFromDecimal("1025574284609383582644711336373707553698163132913") - - got := sqrtPriceMathGetNextSqrtPriceFromInput(sqrtPX96, liquidity, amountIn, zeroForOne) - rst := got.Eq(expected) - if !rst { - t.Errorf("The result should eq to expected") - } - - got = sqrtPriceMathGetAmount0DeltaHelper(expected, sqrtPX96, liquidity, true) - rst = got.Eq(u256.NewUint(406)) - if !rst { - t.Errorf("The result should eq to 406") - } -} diff --git a/_deploy/p/gnoswap/pool/consts.gno b/_deploy/p/gnoswap/pool/consts.gno index 17acbe4ef..e8b0b33e1 100644 --- a/_deploy/p/gnoswap/pool/consts.gno +++ b/_deploy/p/gnoswap/pool/consts.gno @@ -8,6 +8,7 @@ const ( MAX_UINT128 string = "340282366920938463463374607431768211455" MAX_UINT160 string = "1461501637330902918203684832716283019655932542975" MAX_UINT256 string = "115792089237316195423570985008687907853269984665640564039457584007913129639935" + MAX_INT256 string = "57896044618658097711785492504343953926634992332820282019728792003956564819967" Q64 string = "18446744073709551616" // 2 ** 64 Q96 string = "79228162514264337593543950336" // 2 ** 96 diff --git a/_deploy/p/gnoswap/pool/gno.mod b/_deploy/p/gnoswap/pool/gno.mod index 23776be3a..ea092a3fa 100644 --- a/_deploy/p/gnoswap/pool/gno.mod +++ b/_deploy/p/gnoswap/pool/gno.mod @@ -1,7 +1 @@ module gno.land/p/gnoswap/pool - -require ( - gno.land/p/gnoswap/int256 v0.0.0-latest - gno.land/p/gnoswap/uint256 v0.0.0-latest - gno.land/r/gnoswap/v1/consts v0.0.0-latest -) diff --git a/_deploy/p/gnoswap/pool/sqrt_price_math.gno b/_deploy/p/gnoswap/pool/sqrt_price_math.gno index 175f9224d..29b16feaf 100644 --- a/_deploy/p/gnoswap/pool/sqrt_price_math.gno +++ b/_deploy/p/gnoswap/pool/sqrt_price_math.gno @@ -5,17 +5,45 @@ import ( u256 "gno.land/p/gnoswap/uint256" ) +var ( + Q96_RESOLUTION = uint(96) + Q160_RESOLUTION = uint(160) +) + +// sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp calculates the next square root price +// based on the amount of token0 added or removed from the pool. +// NOTE: Always rounds up, because in the exact output case (increasing price) we need to move the price at least +// far enough to get the desired output amount, and in the exact input case (decreasing price) we need to move the +// price less in order to not send too much output. +// The most precise formula for this is liquidity * sqrtPX96 / (liquidity +- amount * sqrtPX96), +// if this is impossible because of overflow, we calculate liquidity / (liquidity / sqrtPX96 +- amount). +// +// Parameters: +// - sqrtPX96: The current square root price as a Q96 fixed-point number (uint160). +// - liquidity: The pool's active liquidity as a Q128 fixed-point number (uint128). +// - amount: The amount of token0 to be added or removed from the pool (uint256). +// - add: A boolean indicating whether the amount of token0 is being added (true) or removed (false). +// +// Returns: +// - The price after adding or removing amount, depending on add +// +// Notes: +// - When `add` is true, the function calculates the new square root price after adding `amount` of token0. +// - When `add` is false, the function calculates the new square root price after removing `amount` of token0. +// - The function uses high-precision math (MulDivRoundingUp, DivRoundingUp) to handle division rounding issues. +// - The function validates input conditions and panics if the state is invalid. func sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp( - sqrtPX96 *u256.Uint, // uint160 - liquidity *u256.Uint, // uint128 - amount *u256.Uint, // uint256 + sqrtPX96 *u256.Uint, + liquidity *u256.Uint, + amount *u256.Uint, add bool, -) *u256.Uint { // uint160 +) *u256.Uint { + // we short circuit amount == 0 because the result is otherwise not guaranteed to equal the input price if amount.IsZero() { return sqrtPX96 } - numerator1 := new(u256.Uint).Lsh(liquidity, 96) + numerator1 := new(u256.Uint).Lsh(liquidity, Q96_RESOLUTION) product := new(u256.Uint).Mul(amount, sqrtPX96) if add { @@ -27,22 +55,48 @@ func sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp( } } - div := new(u256.Uint).Div(numerator1, sqrtPX96) - add := new(u256.Uint).Add(div, amount) - return u256.DivRoundingUp(numerator1, add) + divValue := new(u256.Uint).Div(numerator1, sqrtPX96) + addValue := new(u256.Uint).Add(divValue, amount) + return u256.DivRoundingUp(numerator1, addValue) } else { cond1 := new(u256.Uint).Div(product, amount).Eq(sqrtPX96) cond2 := numerator1.Gt(product) if !(cond1 && cond2) { - panic("pool_sqrt price math #1") + panic("invalid pool sqrt price calculation: product/amount != sqrtPX96 or numerator1 <= product") } denominator := new(u256.Uint).Sub(numerator1, product) - return u256.MulDivRoundingUp(numerator1, sqrtPX96, denominator) + nextSqrtPrice := u256.MulDivRoundingUp(numerator1, sqrtPX96, denominator) + max160 := u256.MustFromDecimal(MAX_UINT160) + if nextSqrtPrice.Gt(max160) { + panic("nextSqrtPrice overflows uint160") + } + return nextSqrtPrice } } +// sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown calculates the next square root price +// based on the amount of token1 added or removed from the pool, with rounding down. +// NOTE: Always rounds down, because in the exact output case (decreasing price) we need to move the price at least +// far enough to get the desired output amount, and in the exact input case (increasing price) we need to move the +// price less in order to not send too much output. +// The formula we compute is within <1 wei of the lossless version: sqrtPX96 +- amount / liquidity +// +// Parameters: +// - sqrtPX96: The current square root price as a Q96 fixed-point number (uint160). +// - liquidity: The pool's active liquidity as a Q128 fixed-point number (uint128). +// - amount: The amount of token1 to be added or removed from the pool (uint256). +// - add: A boolean indicating whether the amount of token1 is being added (true) or removed (false). +// +// Returns: +// - The next square root price as a Q96 fixed-point number (uint160). +// +// Notes: +// - When `add` is true, the function calculates the new square root price after adding `amount` of token1. +// - When `add` is false, the function calculates the new square root price after removing `amount` of token1. +// - The function uses high-precision math (MulDiv and DivRoundingUp) to handle division and prevent precision loss. +// - The function validates input conditions and panics if the state is invalid. func sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown( sqrtPX96 *u256.Uint, // uint160 liquidity *u256.Uint, // uint1288 @@ -50,151 +104,295 @@ func sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown( add bool, ) *u256.Uint { // uint160 quotient := u256.Zero() + max160 := u256.MustFromDecimal(MAX_UINT160) + // if we're adding (subtracting), rounding down requires rounding the quotient down (up) + // in both cases, avoid a mulDiv for most inputs if add { if amount.Lte(u256.MustFromDecimal(MAX_UINT160)) { - value1 := new(u256.Uint).Lsh(amount, 96) - quotient = new(u256.Uint).Div(value1, liquidity) + value := new(u256.Uint).Lsh(amount, Q96_RESOLUTION) + quotient = new(u256.Uint).Div(value, liquidity) } else { quotient = u256.MulDiv(amount, u256.MustFromDecimal(Q96), liquidity) } res := new(u256.Uint).Add(sqrtPX96, quotient) - max160 := u256.MustFromDecimal("1461501637330902918203684832716283019655932542975") - if res.Gt(max160) { - panic("sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown sqrtPx96 + quotient overflow uint160") + panic("sqrtPx96 + quotient overflow uint160") } return res - } else { if amount.Lte(u256.MustFromDecimal(MAX_UINT160)) { - value1 := new(u256.Uint).Lsh(amount, 96) - quotient = u256.DivRoundingUp(value1, liquidity) + value := new(u256.Uint).Lsh(amount, Q96_RESOLUTION) + quotient = u256.DivRoundingUp(value, liquidity) } else { quotient = u256.MulDivRoundingUp(amount, u256.MustFromDecimal(Q96), liquidity) } if !(sqrtPX96.Gt(quotient)) { - panic("pool_sqrt price math #2") + panic("sqrt price exceeds calculated quotient") } - return new(u256.Uint).Sub(sqrtPX96, quotient) + res := new(u256.Uint).Sub(sqrtPX96, quotient) + if res.Gt(max160) { + mask := new(u256.Uint).Lsh(u256.One(), Q160_RESOLUTION) + mask = mask.Sub(mask, u256.One()) + res = res.And(res, mask) + } + return res } } +// sqrtPriceMathGetNextSqrtPriceFromInput calculates the next square root price +// based on the amount of token0 or token1 added to the pool. +// NOTE: Always rounds up, because in the exact output case (increasing price) we need to move the price at least +// far enough to get the desired output amount, and in the exact input case (decreasing price) we need to move the +// price less in order to not send too much output. +// The most precise formula for this is liquidity * sqrtPX96 / (liquidity +- amount * sqrtPX96), +// if this is impossible because of overflow, we calculate liquidity / (liquidity / sqrtPX96 +- amount). +// +// Parameters: +// - sqrtPX96: The current square root price as a Q96 fixed-point number (uint160). +// - liquidity: The pool's active liquidity as a Q128 fixed-point number (uint128). +// - amountIn: The amount of token0 or token1 to be added to the pool (uint256). +// - zeroForOne: A boolean indicating whether the amount is being added to token0 (true) or token1 (false). +// +// Returns: +// - The price after adding amountIn, depending on zeroForOne func sqrtPriceMathGetNextSqrtPriceFromInput( - sqrtPX96 *u256.Uint, // uint160 - liquidity *u256.Uint, // uint128 - amountIn *u256.Uint, // uint256 - zeroForOne bool, // bool -) *u256.Uint { // uint160 + sqrtPX96 *u256.Uint, + liquidity *u256.Uint, + amountIn *u256.Uint, + zeroForOne bool, +) *u256.Uint { if sqrtPX96.IsZero() { panic("sqrtPX96 should not be zero") } if liquidity.IsZero() { - panic("pool_sqrtPriceMathGetNextSqrtPriceFromInput_liquidity should not be zero") + panic("liquidity should not be zero") } if zeroForOne { return sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp(sqrtPX96, liquidity, amountIn, true) + } else { + return sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown(sqrtPX96, liquidity, amountIn, true) } - return sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown(sqrtPX96, liquidity, amountIn, true) } +// sqrtPriceMathGetNextSqrtPriceFromOutput calculates the next square root price +// based on the amount of token0 or token1 removed from the pool. +// +// NOTE: +// - For zeroForOne == true (Token0 -> Token1): The calculation uses rounding down. +// - For zeroForOne == false (Token1 -> Token0): The calculation uses rounding up. +// +// The most precise formula for this is: +// - liquidity * sqrtPX96 / (liquidity ± amount * sqrtPX96) +// If overflow occurs, it falls back to: +// - liquidity / (liquidity / sqrtPX96 ± amount) +// +// Parameters: +// - sqrtPX96: The current square root price as a Q96 fixed-point number (uint160). +// - liquidity: The pool's active liquidity as a Q128 fixed-point number (uint128). +// - amountOut: The amount of token0 or token1 to be removed from the pool (uint256). +// - zeroForOne: A boolean indicating whether the amount is being removed from token0 (true) or token1 (false). +// +// Returns: +// - The price after removing amountOut, depending on zeroForOne. +// +// Notes: +// - Rounding direction depends on the swap direction (zeroForOne). +// - Relies on helper functions: +// - `sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown` for Token0 -> Token1. +// - `sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp` for Token1 -> Token0. func sqrtPriceMathGetNextSqrtPriceFromOutput( - sqrtPX96 *u256.Uint, // uint160 - liquidity *u256.Uint, // uint128 - amountOut *u256.Uint, // uint256 - zeroForOne bool, // bool -) *u256.Uint { // uint160 + sqrtPX96 *u256.Uint, + liquidity *u256.Uint, + amountOut *u256.Uint, + zeroForOne bool, +) *u256.Uint { if sqrtPX96.IsZero() { - panic("pool_sqrtPriceMathGetNextSqrtPriceFromOutput_sqrtPX96 should not be zero") + panic("sqrtPX96 should not be zero") } if liquidity.IsZero() { - panic("pool_sqrtPriceMathGetNextSqrtPriceFromOutput_liquidity should not be zero") + panic("liquidity should not be zero") } if zeroForOne { return sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown(sqrtPX96, liquidity, amountOut, false) + } else { + return sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp(sqrtPX96, liquidity, amountOut, false) } - - return sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp(sqrtPX96, liquidity, amountOut, false) } +// sqrtPriceMathGetAmount0DeltaHelper calculates the absolute difference between the amounts of token0 in two +// liquidity ranges defined by the square root prices sqrtRatioAX96 and sqrtRatioBX96. The difference is +// calculated relative to the range [sqrtRatioAX96, sqrtRatioBX96]. +// +// If sqrtRatioAX96 > sqrtRatioBX96, their values are swapped to ensure sqrtRatioAX96 is the lower bound. +// +// Parameters: +// - sqrtRatioAX96: The lower bound of the range as a Q96 fixed-point number (uint160). +// - sqrtRatioBX96: The upper bound of the range as a Q96 fixed-point number (uint160). +// - liquidity: The pool's active liquidity as a Q128 fixed-point number (uint128). +// - roundUp: A boolean indicating whether the result should be rounded up (true) or down (false). +// +// Returns: +// - The absolute difference between the amounts of token0 in the two ranges as a uint256. +// +// Notes: +// - If sqrtRatioAX96 is zero or negative, the function panics. +// - The result is calculated using high-precision fixed-point arithmetic. +// - Rounding is applied based on the roundUp parameter. func sqrtPriceMathGetAmount0DeltaHelper( - sqrtRatioAX96 *u256.Uint, // uint160 - sqrtRatioBX96 *u256.Uint, // uint160 - liquidity *u256.Uint, // uint160 + sqrtRatioAX96 *u256.Uint, + sqrtRatioBX96 *u256.Uint, + liquidity *u256.Uint, roundUp bool, -) *u256.Uint { // uint256 +) *u256.Uint { if sqrtRatioAX96.Gt(sqrtRatioBX96) { sqrtRatioAX96, sqrtRatioBX96 = sqrtRatioBX96, sqrtRatioAX96 } - numerator1 := new(u256.Uint).Lsh(liquidity, 96) + numerator1 := new(u256.Uint).Lsh(liquidity, Q96_RESOLUTION) numerator2 := new(u256.Uint).Sub(sqrtRatioBX96, sqrtRatioAX96) if !(sqrtRatioAX96.Gt(u256.Zero())) { - panic("pool_sqrt price math #3") + panic("sqrtRatioAX96 must be greater than zero") } if roundUp { - value1 := u256.MulDivRoundingUp(numerator1, numerator2, sqrtRatioBX96) - return u256.DivRoundingUp(value1, sqrtRatioAX96) + value := u256.MulDivRoundingUp(numerator1, numerator2, sqrtRatioBX96) + return u256.DivRoundingUp(value, sqrtRatioAX96) } else { - value1 := u256.MulDiv(numerator1, numerator2, sqrtRatioBX96) - return new(u256.Uint).Div(value1, sqrtRatioAX96) + value := u256.MulDiv(numerator1, numerator2, sqrtRatioBX96) + return new(u256.Uint).Div(value, sqrtRatioAX96) } } +// sqrtPriceMathGetAmount1DeltaHelper calculates the absolute difference between the amounts of token1 in two +// liquidity ranges defined by the square root prices sqrtRatioAX96 and sqrtRatioBX96. The difference is +// calculated relative to the range [sqrtRatioAX96, sqrtRatioBX96]. +// +// If sqrtRatioAX96 > sqrtRatioBX96, their values are swapped to ensure sqrtRatioAX96 is the lower bound. +// +// Parameters: +// - sqrtRatioAX96: The lower bound of the range as a Q96 fixed-point number (uint160). +// - sqrtRatioBX96: The upper bound of the range as a Q96 fixed-point number (uint160). +// - liquidity: The pool's active liquidity as a Q128 fixed-point number (uint128). +// - roundUp: A boolean indicating whether the result should be rounded up (true) or down (false). +// +// Returns: +// - The absolute difference between the amounts of token1 in the two ranges as a uint256. +// +// Notes: +// - Rounding is applied based on the roundUp parameter. +// - The function swaps sqrtRatioAX96 and sqrtRatioBX96 if sqrtRatioAX96 > sqrtRatioBX96. func sqrtPriceMathGetAmount1DeltaHelper( - sqrtRatioAX96 *u256.Uint, // uint160 - sqrtRatioBX96 *u256.Uint, // uint160 - liquidity *u256.Uint, // uint160 + sqrtRatioAX96 *u256.Uint, + sqrtRatioBX96 *u256.Uint, + liquidity *u256.Uint, roundUp bool, -) *u256.Uint { // uint256 +) *u256.Uint { if sqrtRatioAX96.Gt(sqrtRatioBX96) { sqrtRatioAX96, sqrtRatioBX96 = sqrtRatioBX96, sqrtRatioAX96 } + diff := new(u256.Uint).Sub(sqrtRatioBX96, sqrtRatioAX96) if roundUp { - diff := new(u256.Uint).Sub(sqrtRatioBX96, sqrtRatioAX96) return u256.MulDivRoundingUp(liquidity, diff, u256.MustFromDecimal(Q96)) } else { - diff := new(u256.Uint).Sub(sqrtRatioBX96, sqrtRatioAX96) return u256.MulDiv(liquidity, diff, u256.MustFromDecimal(Q96)) } } +// SqrtPriceMathGetAmount0DeltaStr calculates the difference in the amount of token0 +// within a specified liquidity range defined by two square root prices (sqrtRatioAX96 and sqrtRatioBX96). +// This function returns the result as a string representation of an int256 value. +// +// If the liquidity is negative, the result is also negative. +// +// Parameters: +// - sqrtRatioAX96: The lower bound of the range as a Q96 fixed-point number (uint160). +// - sqrtRatioBX96: The upper bound of the range as a Q96 fixed-point number (uint160). +// - liquidity: The pool's active liquidity as a signed Q128 fixed-point number (int128). +// +// Returns: +// - A string representation of the int256 value representing the difference in token0 amounts +// within the specified range. The value is negative if the liquidity is negative. +// +// Notes: +// - This function relies on the helper function `sqrtPriceMathGetAmount0DeltaHelper` to perform the core calculation. +// - The helper function calculates the absolute difference between token0 amounts within the range. +// - If the computed result exceeds the maximum allowable value for int256 (2**255 - 1), the function will panic +// with an appropriate overflow error. +// - The rounding behavior of the result is controlled by the `roundUp` parameter passed to the helper function: +// - For negative liquidity, rounding is always down. +// - For positive liquidity, rounding is always up. func SqrtPriceMathGetAmount0DeltaStr( - sqrtRatioAX96 *u256.Uint, // uint160 - sqrtRatioBX96 *u256.Uint, // uint160 - liquidity *i256.Int, // int128 -) string { // int256 + sqrtRatioAX96 *u256.Uint, + sqrtRatioBX96 *u256.Uint, + liquidity *i256.Int, +) string { if liquidity.IsNeg() { u := sqrtPriceMathGetAmount0DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), false) + if u.Gt(u256.MustFromDecimal(MAX_INT256)) { + // if u > (2**255 - 1), cannot cast to int256 + panic("SqrtPriceMathGetAmount0DeltaStr: overflow") + } i := i256.FromUint256(u) return i256.Zero().Neg(i).ToString() + } else { + u := sqrtPriceMathGetAmount0DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), true) + if u.Gt(u256.MustFromDecimal(MAX_INT256)) { + // if u > (2**255 - 1), cannot cast to int256 + panic("SqrtPriceMathGetAmount0DeltaStr: overflow") + } + return i256.FromUint256(u).ToString() } - - u := sqrtPriceMathGetAmount0DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), true) - return i256.FromUint256(u).ToString() } +// SqrtPriceMathGetAmount1DeltaStr calculates the difference in the amount of token1 +// within a specified liquidity range defined by two square root prices (sqrtRatioAX96 and sqrtRatioBX96). +// This function returns the result as a string representation of an int256 value. +// +// If the liquidity is negative, the result is also negative. +// +// Parameters: +// - sqrtRatioAX96: The lower bound of the range as a Q96 fixed-point number (uint160). +// - sqrtRatioBX96: The upper bound of the range as a Q96 fixed-point number (uint160). +// - liquidity: The pool's active liquidity as a signed Q128 fixed-point number (int128). +// +// Returns: +// - A string representation of the int256 value representing the difference in token1 amounts +// within the specified range. The value is negative if the liquidity is negative. +// +// Notes: +// - This function relies on the helper function `sqrtPriceMathGetAmount1DeltaHelper` to perform the core calculation. +// - The rounding behavior of the result is controlled by the `roundUp` parameter passed to the helper function: +// - For negative liquidity, rounding is always down. +// - For positive liquidity, rounding is always up. func SqrtPriceMathGetAmount1DeltaStr( - sqrtRatioAX96 *u256.Uint, // uint160 - sqrtRatioBX96 *u256.Uint, // uint160 - liquidity *i256.Int, // int128 -) string { // int256 + sqrtRatioAX96 *u256.Uint, + sqrtRatioBX96 *u256.Uint, + liquidity *i256.Int, +) string { if liquidity.IsNeg() { u := sqrtPriceMathGetAmount1DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), false) + if u.Gt(u256.MustFromDecimal(MAX_INT256)) { + // if u > (2**255 - 1), cannot cast to int256 + panic("SqrtPriceMathGetAmount1DeltaStr: overflow") + } i := i256.FromUint256(u) return i256.Zero().Neg(i).ToString() + } else { + u := sqrtPriceMathGetAmount1DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), true) + if u.Gt(u256.MustFromDecimal(MAX_INT256)) { + // if u > (2**255 - 1), cannot cast to int256 + panic("SqrtPriceMathGetAmount1DeltaStr: overflow") + } + return i256.FromUint256(u).ToString() } - - u := sqrtPriceMathGetAmount1DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), true) - return i256.FromUint256(u).ToString() } diff --git a/_deploy/p/gnoswap/pool/sqrt_price_math_test.gno b/_deploy/p/gnoswap/pool/sqrt_price_math_test.gno new file mode 100644 index 000000000..24f450d51 --- /dev/null +++ b/_deploy/p/gnoswap/pool/sqrt_price_math_test.gno @@ -0,0 +1,643 @@ +package pool + +import ( + "testing" + + "gno.land/p/demo/uassert" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +func TestSqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp(t *testing.T) { + t.Run("zero amount returns same price", func(t *testing.T) { + sqrtPX96 := u256.MustFromDecimal("1000000") + liquidity := u256.MustFromDecimal("2000000") + amount := u256.Zero() + + result := sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp( + sqrtPX96, + liquidity, + amount, + true, + ) + + if !result.Eq(sqrtPX96) { + t.Errorf("Expected %s, got %s", sqrtPX96.ToString(), result.ToString()) + } + }) + + t.Run("remove token0", func(t *testing.T) { + sqrtPX96 := u256.MustFromDecimal("1000000") + liquidity := u256.MustFromDecimal("2000000") + amount := u256.MustFromDecimal("500000") + + result := sqrtPriceMathGetNextSqrtPriceFromAmount0RoundingUp( + sqrtPX96, + liquidity, + amount, + false, + ) + + if result.Lte(sqrtPX96) { + t.Error("Price should increase when removing token0") + } + }) +} + +func TestSqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown(t *testing.T) { + t.Run("add token1 small amount", func(t *testing.T) { + sqrtPX96 := u256.MustFromDecimal("1000000") + liquidity := u256.MustFromDecimal("2000000") + amount := u256.MustFromDecimal("100000") + + result := sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown( + sqrtPX96, + liquidity, + amount, + true, + ) + + if result.Lte(sqrtPX96) { + t.Error("Price should increase when adding token1") + } + }) +} + +func TestSqrtPriceMathGetAmount0DeltaStr(t *testing.T) { + t.Run("positive liquidity", func(t *testing.T) { + ratioA := u256.MustFromDecimal("1000000") + ratioB := u256.MustFromDecimal("2000000") + liquidity := i256.FromUint256(u256.MustFromDecimal("5000000")) + + result := SqrtPriceMathGetAmount0DeltaStr(ratioA, ratioB, liquidity) + + if result[0] == '-' { + t.Error("Result should be positive for positive liquidity") + } + }) + + t.Run("negative liquidity", func(t *testing.T) { + ratioA := u256.MustFromDecimal("1000000") + ratioB := u256.MustFromDecimal("2000000") + liquidity := i256.Zero().Neg(i256.FromUint256(u256.MustFromDecimal("5000000"))) + + result := SqrtPriceMathGetAmount0DeltaStr(ratioA, ratioB, liquidity) + + if result[0] != '-' { + t.Error("Result should be negative for negative liquidity") + } + }) + + t.Run("panic overflow when getting amount0 with positive liquidity", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic for overflow amount0") + } else { + uassert.Equal(t, "SqrtPriceMathGetAmount0DeltaStr: overflow", r) + } + }() + + // Inputs to trigger panic + sqrtRatioAX96 := u256.MustFromDecimal("1") // very low value + sqrtRatioBX96 := u256.MustFromDecimal("340282366920938463463374607431768211455") // very high value(2^128-1) + liquidity := i256.FromUint256(u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935")) + + SqrtPriceMathGetAmount0DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity) + }) + + t.Run("panic overflow when getting amount0 with negative liquidity", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic for overflow amount0") + } else { + uassert.Equal(t, "SqrtPriceMathGetAmount0DeltaStr: overflow", r) + } + }() + + // Inputs to trigger panic + sqrtRatioAX96 := u256.MustFromDecimal("1") // very low value + sqrtRatioBX96 := u256.MustFromDecimal("340282366920938463463374607431768211455") // very high value(2^128-1) + liquidity := i256.FromUint256(u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935")) + liquidity = liquidity.Neg(liquidity) // Make liquidity negative + + SqrtPriceMathGetAmount0DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity) + }) +} + +func TestSqrtPriceMathGetAmount1DeltaStr(t *testing.T) { + t.Run("positive liquidity", func(t *testing.T) { + ratioA := u256.MustFromDecimal("1000000") + ratioB := u256.MustFromDecimal("2000000") + liquidity := i256.FromUint256(u256.MustFromDecimal("5000000")) + + result := SqrtPriceMathGetAmount1DeltaStr(ratioA, ratioB, liquidity) + + if result[0] == '-' { + t.Error("Result should be positive for positive liquidity") + } + }) + + t.Run("negative liquidity", func(t *testing.T) { + ratioA := u256.MustFromDecimal("1000000") + ratioB := u256.MustFromDecimal("2000000") + liquidity := i256.Zero().Neg(i256.FromUint256(u256.MustFromDecimal("5000000"))) + + result := SqrtPriceMathGetAmount0DeltaStr(ratioA, ratioB, liquidity) + + if result[0] != '-' { + t.Error("Result should be negative for negative liquidity") + } + }) + + t.Run("panic overflow when getting amount1 with positive liquidity", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic for overflow amount1") + } else { + uassert.Equal(t, "SqrtPriceMathGetAmount1DeltaStr: overflow", r) + } + }() + + // Inputs to trigger panic + sqrtRatioAX96 := u256.MustFromDecimal("1") // very low value + sqrtRatioBX96 := u256.MustFromDecimal("79228162514264337593543950335") // slightly below Q96 + liquidity := i256.FromUint256(u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935")) + + SqrtPriceMathGetAmount1DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity) + }) + + t.Run("panic overflow when getting amount1 with negative liquidity", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic for overflow amount1") + } else { + uassert.Equal(t, "SqrtPriceMathGetAmount1DeltaStr: overflow", r) + } + }() + + // Inputs to trigger panic + sqrtRatioAX96 := u256.MustFromDecimal("1") // very low value + sqrtRatioBX96 := u256.MustFromDecimal("79228162514264337593543950335") // slightly below Q96 + liquidity := i256.FromUint256(u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935")) + liquidity = liquidity.Neg(liquidity) // Make liquidity negative + + SqrtPriceMathGetAmount1DeltaStr(sqrtRatioAX96, sqrtRatioBX96, liquidity) + }) +} + +func TestSqrtPriceMathGetNextSqrtPriceFromInput(t *testing.T) { + tests := []struct { + name string + sqrtPriceX96 *u256.Uint + liquidity *u256.Uint + amountIn *u256.Uint + zeroForOne bool + shouldPanic bool + panicMsg string + expectedSqrtPriceX96 string + }{ + { + name: "fails if price is zero", + sqrtPriceX96: u256.Zero(), + liquidity: u256.Zero(), + amountIn: u256.MustFromDecimal("100000000000000000"), + zeroForOne: false, + shouldPanic: true, + panicMsg: "sqrtPX96 should not be zero", + }, + { + name: "fails if liquidity is zero", + sqrtPriceX96: u256.One(), + liquidity: u256.Zero(), + amountIn: u256.MustFromDecimal("100000000000000000"), + zeroForOne: true, + shouldPanic: true, + panicMsg: "liquidity should not be zero", + }, + { + name: "fails if input amount overflows the price", + sqrtPriceX96: u256.MustFromDecimal("1461501637330902918203684832716283019655932542975"), // 2^160 - 1 + liquidity: u256.MustFromDecimal("1024"), + amountIn: u256.MustFromDecimal("1024"), + zeroForOne: false, + shouldPanic: true, + panicMsg: "sqrtPx96 + quotient overflow uint160", + }, + { + name: "any input amount cannot underflow the price", + sqrtPriceX96: u256.MustFromDecimal("1"), + liquidity: u256.MustFromDecimal("1"), + amountIn: u256.MustFromDecimal("57896044618658097711785492504343953926634992332820282019728792003956564819968"), // 2^255 + zeroForOne: true, + expectedSqrtPriceX96: "1", + }, + { + name: "returns input price if amount in is zero and zeroForOne = true", + sqrtPriceX96: u256.MustFromDecimal("79228162514264337593543950336"), + liquidity: u256.MustFromDecimal("100000000000000000"), + amountIn: u256.Zero(), + zeroForOne: true, + expectedSqrtPriceX96: "79228162514264337593543950336", + }, + { + name: "returns input price if amount in is zero and zeroForOne = false", + sqrtPriceX96: u256.MustFromDecimal("79228162514264337593543950336"), + liquidity: u256.MustFromDecimal("100000000000000000"), + amountIn: u256.Zero(), + zeroForOne: false, + expectedSqrtPriceX96: "79228162514264337593543950336", + }, + { + name: "returns the minimum price for max inputs", + sqrtPriceX96: u256.MustFromDecimal("1461501637330902918203684832716283019655932542975"), // 2^160 - 1 + liquidity: u256.MustFromDecimal("340282366920938463463374607431768211455"), // 2^128 - 1 + amountIn: u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039439137263839420088320"), + zeroForOne: true, + expectedSqrtPriceX96: "1", + }, + { + name: "input amount of 0.1 token1", + sqrtPriceX96: encodePriceSqrt("1", "1"), + liquidity: u256.MustFromDecimal("1000000000000000000"), + amountIn: u256.MustFromDecimal("100000000000000000"), + zeroForOne: false, + expectedSqrtPriceX96: "87150978765690771352898345369", + }, + { + name: "input amount of 0.1 token0", + sqrtPriceX96: encodePriceSqrt("1", "1"), + liquidity: u256.MustFromDecimal("1000000000000000000"), + amountIn: u256.MustFromDecimal("100000000000000000"), + zeroForOne: true, + expectedSqrtPriceX96: "72025602285694852357767227579", + }, + { + name: "amountIn > type(uint96).max and zeroForOne = true", + sqrtPriceX96: encodePriceSqrt("1", "1"), + liquidity: u256.MustFromDecimal("10000000000000000000"), + amountIn: u256.MustFromDecimal("1267650600228229401496703205376"), // 2^128 - 1 + zeroForOne: true, + expectedSqrtPriceX96: "624999999995069620", + }, + { + name: "can return 1 with enough amountIn and zeroForOne = true", + sqrtPriceX96: encodePriceSqrt("1", "1"), + liquidity: u256.One(), + amountIn: u256.MustFromDecimal("57896044618658097711785492504343953926634992332820282019728792003956564819967"), + zeroForOne: true, + expectedSqrtPriceX96: "1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.shouldPanic { + uassert.PanicsWithMessage(t, tt.panicMsg, func() { + sqrtPriceMathGetNextSqrtPriceFromInput(tt.sqrtPriceX96, tt.liquidity, tt.amountIn, tt.zeroForOne) + }) + } else { + actual := sqrtPriceMathGetNextSqrtPriceFromInput(tt.sqrtPriceX96, tt.liquidity, tt.amountIn, tt.zeroForOne) + uassert.Equal(t, tt.expectedSqrtPriceX96, actual.ToString()) + } + }) + } +} + +func TestSqrtPriceMathGetNextSqrtPriceFromInput2(t *testing.T) { + t.Run("zero sqrtPX96 should panic", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic for zero sqrtPX96") + } + }() + + sqrtPriceMathGetNextSqrtPriceFromInput( + u256.Zero(), + u256.MustFromDecimal("1000000"), + u256.MustFromDecimal("500000"), + true, + ) + }) + + t.Run("zero liquidity should panic", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic for zero liquidity") + } + }() + + sqrtPriceMathGetNextSqrtPriceFromInput( + u256.MustFromDecimal("1000000"), + u256.Zero(), + u256.MustFromDecimal("500000"), + true, + ) + }) +} + +func TestSqrtPriceMathGetNextSqrtPriceFromOutput(t *testing.T) { + tests := []struct { + name string + sqrtPriceX96 *u256.Uint + liquidity *u256.Uint + amountOut *u256.Uint + zeroForOne bool + shouldPanic bool + expectedSqrtPriceX96 string + }{ + { + name: "fails if price is zero", + sqrtPriceX96: u256.Zero(), + liquidity: u256.Zero(), + amountOut: u256.MustFromDecimal("100000000000000000"), + zeroForOne: false, + shouldPanic: true, + }, + { + name: "fails if liquidity is zero", + sqrtPriceX96: u256.One(), + liquidity: u256.Zero(), + amountOut: u256.MustFromDecimal("100000000000000000"), + zeroForOne: true, + shouldPanic: true, + }, + { + name: "fails if output amount is exactly the virtual reserves of token0", + sqrtPriceX96: u256.MustFromDecimal("20282409603651670423947251286016"), + liquidity: u256.MustFromDecimal("1024"), + amountOut: u256.NewUint(4), + zeroForOne: false, + shouldPanic: true, + }, + { + name: "fails if output amount is greater than virtual reserves of token0", + sqrtPriceX96: u256.MustFromDecimal("20282409603651670423947251286016"), + liquidity: u256.MustFromDecimal("1024"), + amountOut: u256.NewUint(5), + zeroForOne: false, + shouldPanic: true, + }, + { + name: "fails if output amount is greater than virtual reserves of token1", + sqrtPriceX96: u256.MustFromDecimal("20282409603651670423947251286016"), + liquidity: u256.MustFromDecimal("1024"), + amountOut: u256.NewUint(262145), + zeroForOne: true, + shouldPanic: true, + }, + { + name: "fails if output amount is exactly the virtual reserves of token1", + sqrtPriceX96: u256.MustFromDecimal("20282409603651670423947251286016"), + liquidity: u256.MustFromDecimal("1024"), + amountOut: u256.NewUint(262144), + zeroForOne: true, + shouldPanic: true, + }, + { + name: "succeeds if output amount is just less than the virtual reserves of token1", + sqrtPriceX96: u256.MustFromDecimal("20282409603651670423947251286016"), + liquidity: u256.MustFromDecimal("1024"), + amountOut: u256.NewUint(262143), + zeroForOne: true, + expectedSqrtPriceX96: "77371252455336267181195264", + }, + { + name: "puzzling echidna test", + sqrtPriceX96: u256.MustFromDecimal("20282409603651670423947251286016"), + liquidity: u256.MustFromDecimal("1024"), + amountOut: u256.NewUint(4), + zeroForOne: false, + shouldPanic: true, + }, + { + name: "returns input price if amount in is zero and zeroForOne = true", + sqrtPriceX96: encodePriceSqrt("1", "1"), + liquidity: u256.MustFromDecimal("100000000000000000"), + amountOut: u256.Zero(), + zeroForOne: true, + expectedSqrtPriceX96: encodePriceSqrt("1", "1").ToString(), + }, + { + name: "returns input price if amount in is zero and zeroForOne = false", + sqrtPriceX96: encodePriceSqrt("1", "1"), + liquidity: u256.MustFromDecimal("100000000000000000"), + amountOut: u256.Zero(), + zeroForOne: false, + expectedSqrtPriceX96: encodePriceSqrt("1", "1").ToString(), + }, + { + name: "output amount of 0.1 token1, zeroForOne = false", + sqrtPriceX96: encodePriceSqrt("1", "1"), + liquidity: u256.MustFromDecimal("1000000000000000000"), + amountOut: u256.MustFromDecimal("100000000000000000"), + zeroForOne: false, + expectedSqrtPriceX96: "88031291682515930659493278152", + }, + { + name: "output amount of 0.1 token1, zeroForOne = true", + sqrtPriceX96: encodePriceSqrt("1", "1"), + liquidity: u256.MustFromDecimal("1000000000000000000"), + amountOut: u256.MustFromDecimal("100000000000000000"), + zeroForOne: true, + expectedSqrtPriceX96: "71305346262837903834189555302", + }, + { + name: "reverts if amountOut is impossible in zero for one direction", + sqrtPriceX96: encodePriceSqrt("1", "1"), + liquidity: u256.NewUint(1), + amountOut: u256.MustFromDecimal(MAX_UINT256), + zeroForOne: true, + shouldPanic: true, + }, + { + name: "reverts if amountOut is impossible in one for zero direction", + sqrtPriceX96: encodePriceSqrt("1", "1"), + liquidity: u256.NewUint(1), + amountOut: u256.MustFromDecimal(MAX_UINT256), + zeroForOne: false, + shouldPanic: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.shouldPanic { + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic for %s", tt.name) + } + }() + sqrtPriceMathGetNextSqrtPriceFromOutput(tt.sqrtPriceX96, tt.liquidity, tt.amountOut, tt.zeroForOne) + + } else { + actual := sqrtPriceMathGetNextSqrtPriceFromOutput(tt.sqrtPriceX96, tt.liquidity, tt.amountOut, tt.zeroForOne) + uassert.Equal(t, tt.expectedSqrtPriceX96, actual.ToString()) + } + }) + } +} + +func TestSqrtPriceMathGetAmount0DeltaHelper(t *testing.T) { + tests := []struct { + name string + sqrtRatioAX96, sqrtRatioBX96, liquidity *u256.Uint + roundUp bool + expectedAmount0Delta string + }{ + { + name: "returns 0 if liquidity is 0", + sqrtRatioAX96: encodePriceSqrt("1", "1"), + sqrtRatioBX96: encodePriceSqrt("2", "1"), + liquidity: u256.Zero(), + roundUp: true, + expectedAmount0Delta: "0", + }, + { + name: "returns 0 if prices are equal", + sqrtRatioAX96: encodePriceSqrt("1", "1"), + sqrtRatioBX96: encodePriceSqrt("1", "1"), + liquidity: u256.Zero(), + roundUp: true, + expectedAmount0Delta: "0", + }, + { + name: "returns 0.1 amount1 for price of 1 to 1.21, roundUp = true", + sqrtRatioAX96: encodePriceSqrt("1", "1"), + sqrtRatioBX96: encodePriceSqrt("121", "100"), + liquidity: u256.MustFromDecimal("1000000000000000000"), + roundUp: true, + expectedAmount0Delta: "90909090909090910", + }, + { + name: "returns 0.1 amount1 for price of 1 to 1.21, roundUp = false", + sqrtRatioAX96: encodePriceSqrt("1", "1"), + sqrtRatioBX96: encodePriceSqrt("121", "100"), + liquidity: u256.MustFromDecimal("1000000000000000000"), + roundUp: false, + expectedAmount0Delta: "90909090909090909", + }, + { + name: "works for prices that overflow, roundUp = true", + sqrtRatioAX96: u256.MustFromDecimal("43556142965880123323311949751266331066368"), + sqrtRatioBX96: u256.MustFromDecimal("22300745198530623141535718272648361505980416"), + liquidity: u256.MustFromDecimal("1000000000000000000"), + roundUp: true, + expectedAmount0Delta: "1815437", + }, + { + name: "works for prices that overflow, roundUp = false", + sqrtRatioAX96: u256.MustFromDecimal("43556142965880123323311949751266331066368"), + sqrtRatioBX96: u256.MustFromDecimal("22300745198530623141535718272648361505980416"), + liquidity: u256.MustFromDecimal("1000000000000000000"), + roundUp: false, + expectedAmount0Delta: "1815436", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := sqrtPriceMathGetAmount0DeltaHelper(tt.sqrtRatioAX96, tt.sqrtRatioBX96, tt.liquidity, tt.roundUp) + uassert.Equal(t, tt.expectedAmount0Delta, actual.ToString()) + }) + } +} + +func TestSqrtPriceMathGetAmount1DeltaHelper(t *testing.T) { + tests := []struct { + name string + sqrtRatioAX96, sqrtRatioBX96, liquidity *u256.Uint + roundUp bool + expectedAmount1Delta string + }{ + { + name: "returns 0 if liquidity is 0", + sqrtRatioAX96: encodePriceSqrt("1", "1"), + sqrtRatioBX96: encodePriceSqrt("2", "1"), + liquidity: u256.Zero(), + roundUp: true, + expectedAmount1Delta: "0", + }, + { + name: "returns 0 if prices are equal", + sqrtRatioAX96: encodePriceSqrt("1", "1"), + sqrtRatioBX96: encodePriceSqrt("1", "1"), + liquidity: u256.Zero(), + roundUp: true, + expectedAmount1Delta: "0", + }, + { + name: "returns 0.1 amount1 for price of 1 to 1.21, roundUp = true", + sqrtRatioAX96: encodePriceSqrt("1", "1"), + sqrtRatioBX96: encodePriceSqrt("121", "100"), + liquidity: u256.MustFromDecimal("1000000000000000000"), + roundUp: true, + expectedAmount1Delta: "100000000000000000", + }, + { + name: "returns 0.1 amount1 for price of 1 to 1.21, roundUp = false", + sqrtRatioAX96: encodePriceSqrt("1", "1"), + sqrtRatioBX96: encodePriceSqrt("121", "100"), + liquidity: u256.MustFromDecimal("1000000000000000000"), + roundUp: false, + expectedAmount1Delta: "99999999999999999", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := sqrtPriceMathGetAmount1DeltaHelper(tt.sqrtRatioAX96, tt.sqrtRatioBX96, tt.liquidity, tt.roundUp) + uassert.Equal(t, tt.expectedAmount1Delta, actual.ToString()) + }) + } +} + +func TestSwapComputation_SqrtP_SqrtQ_Mul_Overflow(t *testing.T) { + sqrtP := u256.MustFromDecimal("1025574284609383690408304870162715216695788925244") + liquidity := u256.MustFromDecimal("50015962439936049619261659728067971248") + amountIn := u256.MustFromDecimal("406") + zeroForOne := true + + sqrtQ := sqrtPriceMathGetNextSqrtPriceFromInput(sqrtP, liquidity, amountIn, zeroForOne) + uassert.Equal(t, "1025574284609383582644711336373707553698163132913", sqrtQ.ToString()) + + amount0Delta := sqrtPriceMathGetAmount0DeltaHelper(sqrtQ, sqrtP, liquidity, true) + uassert.Equal(t, "406", amount0Delta.ToString()) +} + +// encodePriceSqrt calculates the sqrt((reserve1 << 192) / reserve0) +func encodePriceSqrt(reserve1, reserve0 string) *u256.Uint { + reserve1Uint := u256.MustFromDecimal(reserve1) + reserve0Uint := u256.MustFromDecimal(reserve0) + + if reserve0Uint.IsZero() { + panic("division by zero") + } + + // numerator = reserve1 * (2^192) + two192 := new(u256.Uint).Lsh(u256.NewUint(1), 192) + numerator := new(u256.Uint).Mul(reserve1Uint, two192) + + // ratioX192 = numerator / reserve0 + ratioX192 := new(u256.Uint).Div(numerator, reserve0Uint) + + // Return sqrt(ratioX192) + return sqrt(ratioX192) +} + +// sqrt computes the integer square root of a u256.Uint +func sqrt(x *u256.Uint) *u256.Uint { + if x.IsZero() { + return u256.NewUint(0) + } + + z := new(u256.Uint).Set(x) + y := new(u256.Uint).Rsh(z, 1) // Initial guess is x / 2 + + for y.Cmp(z) < 0 { + z.Set(y) + temp := new(u256.Uint).Div(x, z) + y.Add(z, temp).Rsh(y, 1) + } + return z +} From a94f232eb44f5e350ec7316f3c0cde61a1587b03 Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Mon, 16 Dec 2024 12:51:40 +0900 Subject: [PATCH 7/9] refactor: dry swap (#421) * remove duplicate functions * GSW-1838 refactor: Using the new constructor to protect raw data in sqrtRatio calculations * GSW-1838 refactor: need to lock pool.slot0 to prevent re-entry * refactor: Using clone data to protect original data * refactor: remove unused import * fix: sqrtRatio calculation default value issue * refactor: swap and swap math * refactor: computeSwapStep - edge case test (amount : zero and over liquidity) - EXACT IN / EXACT OUT Case * refactor: Separate pool transfer-related test code * refactor: Rename a receiver function param and liquidity math bug fix - Fix absolute value before checking if delta value is negative to fix handling error issue if negative. - Rename the param in the receiver function from pool to p. * refactor: Changed large numbers to const for readability & removed some comments * refactor: Remove unnecessary function usage * refactor: Fix error messages to make their meaning clear --------- Co-authored-by: 0xTopaz Co-authored-by: 0xTopaz <60733299+onlyhyde@users.noreply.github.com> Co-authored-by: Blake <104744707+r3v4s@users.noreply.github.com> --- _deploy/p/gnoswap/pool/sqrt_price_math.gno | 2 +- _deploy/r/gnoswap/common/tick_math.gno | 22 +- _deploy/r/gnoswap/consts/consts.gno | 3 +- pool/_RPC_dry.gno | 208 ----- pool/_helper_test.gno | 157 ++++ pool/liquidity_math.gno | 41 +- pool/liquidity_math_test.gno | 8 +- pool/pool.gno | 742 ++-------------- pool/pool_test.gno | 563 ------------ pool/pool_transfer.gno | 171 ++++ pool/pool_transfer_test.gno | 298 +++++++ pool/position_modify.gno | 46 +- pool/swap.gno | 524 +++++++++++ pool/swap_test.gno | 954 +++++++++++++++++++++ pool/tick.gno | 85 +- pool/tick_bitmap.gno | 20 +- pool/type.gno | 145 +++- pool/utils.gno | 14 + 18 files changed, 2448 insertions(+), 1555 deletions(-) delete mode 100644 pool/_RPC_dry.gno create mode 100644 pool/pool_transfer.gno create mode 100644 pool/pool_transfer_test.gno create mode 100644 pool/swap.gno create mode 100644 pool/swap_test.gno diff --git a/_deploy/p/gnoswap/pool/sqrt_price_math.gno b/_deploy/p/gnoswap/pool/sqrt_price_math.gno index 29b16feaf..34a37d655 100644 --- a/_deploy/p/gnoswap/pool/sqrt_price_math.gno +++ b/_deploy/p/gnoswap/pool/sqrt_price_math.gno @@ -118,7 +118,7 @@ func sqrtPriceMathGetNextSqrtPriceFromAmount1RoundingDown( res := new(u256.Uint).Add(sqrtPX96, quotient) if res.Gt(max160) { - panic("sqrtPx96 + quotient overflow uint160") + panic("GetNextSqrtPriceFromAmount1RoundingDown sqrtPx96 + quotient overflow uint160") } return res } else { diff --git a/_deploy/r/gnoswap/common/tick_math.gno b/_deploy/r/gnoswap/common/tick_math.gno index 40bca145a..2a0bc689a 100644 --- a/_deploy/r/gnoswap/common/tick_math.gno +++ b/_deploy/r/gnoswap/common/tick_math.gno @@ -43,21 +43,28 @@ var binaryLogConsts = [8]*u256.Uint{ var ( shift1By32Left = u256.MustFromDecimal("4294967296") // (1 << 32) + maxTick = int32(887272) ) func TickMathGetSqrtRatioAtTick(tick int32) *u256.Uint { // uint160 sqrtPriceX96 absTick := abs(tick) - if absTick > 887272 { // MAX_TICK + if absTick > maxTick { panic(addDetailToError( errOutOfRange, - ufmt.Sprintf("tick_math.gno__TickMathGetSqrtRatioAtTick() || tick is out of range (larger than 887272), tick: %d", tick), + ufmt.Sprintf("tick is out of range (larger than 887272), tick: %d", tick), )) } - ratio := u256.MustFromDecimal("340282366920938463463374607431768211456") // consts.Q128 + var initialBit int32 = 0x1 + var ratio *u256.Uint + if (absTick & initialBit) != 0 { + ratio = tickRatioMap[initialBit] + } else { + ratio = u256.MustFromDecimal("340282366920938463463374607431768211456") // consts.Q128 + } for mask, value := range tickRatioMap { - if absTick&mask != 0 { + if (mask != initialBit) && absTick&mask != 0 { // ratio = (ratio * value) >> 128 ratio = ratio.Mul(ratio, value) ratio = ratio.Rsh(ratio, 128) @@ -65,12 +72,11 @@ func TickMathGetSqrtRatioAtTick(tick int32) *u256.Uint { // uint160 sqrtPriceX96 } if tick > 0 { - _maxUint256 := u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935") // consts.MAX_UINT256 - _tmp := new(u256.Uint).Div(_maxUint256, ratio) - ratio = _tmp.Clone() + maxUint256 := u256.MustFromDecimal("115792089237316195423570985008687907853269984665640564039457584007913129639935") // consts.MAX_UINT256 + ratio = new(u256.Uint).Div(maxUint256, ratio) } - shifted := ratio.Rsh(ratio, 32).Clone() // ratio >> 32 + shifted := new(u256.Uint).Rsh(ratio, 32) // ratio >> 32 remainder := ratio.Mod(ratio, shift1By32Left) // ratio % (1 << 32) var adj *u256.Uint diff --git a/_deploy/r/gnoswap/consts/consts.gno b/_deploy/r/gnoswap/consts/consts.gno index 10e47e85d..9f9228e9f 100644 --- a/_deploy/r/gnoswap/consts/consts.gno +++ b/_deploy/r/gnoswap/consts/consts.gno @@ -90,9 +90,8 @@ const ( UINT64_MAX uint64 = 18446744073709551615 MAX_UINT128 string = "340282366920938463463374607431768211455" - MAX_UINT160 string = "1461501637330902918203684832716283019655932542975" - + MAX_INT256 string = "57896044618658097711785492504343953926634992332820282019728792003956564819968" MAX_UINT256 string = "115792089237316195423570985008687907853269984665640564039457584007913129639935" // Tick Related diff --git a/pool/_RPC_dry.gno b/pool/_RPC_dry.gno deleted file mode 100644 index 5ffe30c86..000000000 --- a/pool/_RPC_dry.gno +++ /dev/null @@ -1,208 +0,0 @@ -package pool - -import ( - "gno.land/r/gnoswap/v1/common" - - "gno.land/r/gnoswap/v1/consts" - - plp "gno.land/p/gnoswap/pool" // pool package - - i256 "gno.land/p/gnoswap/int256" - u256 "gno.land/p/gnoswap/uint256" -) - -// DrySwap simulates a swap and returns the amount0, amount1 that would be received and a boolean indicating if the swap is possible -func DrySwap( - token0Path string, - token1Path string, - fee uint32, - zeroForOne bool, - _amountSpecified string, - _sqrtPriceLimitX96 string, -) (string, string, bool) { - - if _amountSpecified == "0" { - return "0", "0", false - } - - amountSpecified := i256.MustFromDecimal(_amountSpecified) - sqrtPriceLimitX96 := u256.MustFromDecimal(_sqrtPriceLimitX96) - - pool := GetPool(token0Path, token1Path, fee) - slot0Start := pool.slot0 - - var feeProtocol uint8 - var feeGrowthGlobalX128 *u256.Uint - - if zeroForOne { - minSqrtRatio := u256.MustFromDecimal(consts.MIN_SQRT_RATIO) - - cond1 := sqrtPriceLimitX96.Lt(slot0Start.sqrtPriceX96) - cond2 := sqrtPriceLimitX96.Gt(minSqrtRatio) - if !(cond1 && cond2) { - return "0", "0", false - } - - feeProtocol = slot0Start.feeProtocol % 16 - feeGrowthGlobalX128 = pool.feeGrowthGlobal0X128 - } else { - maxSqrtRatio := u256.MustFromDecimal(consts.MAX_SQRT_RATIO) - - cond1 := sqrtPriceLimitX96.Gt(slot0Start.sqrtPriceX96) - cond2 := sqrtPriceLimitX96.Lt(maxSqrtRatio) - if !(cond1 && cond2) { - return "0", "0", false - } - - feeProtocol = slot0Start.feeProtocol / 16 - feeGrowthGlobalX128 = pool.feeGrowthGlobal1X128 - } - - slot0 := pool.slot0 - slot0.unlocked = false - - cache := newSwapCache(feeProtocol, pool.liquidity) - state := newSwapState(amountSpecified.Clone(), feeGrowthGlobalX128.Clone(), cache.liquidityStart.Clone(), slot0) - - exactInput := amountSpecified.Gt(i256.Zero()) - - // continue swapping as long as we haven't used the entire input/output and haven't reached the price limit - for !(state.amountSpecifiedRemaining.IsZero()) && !(state.sqrtPriceX96.Eq(sqrtPriceLimitX96)) { - var step StepComputations - step.sqrtPriceStartX96 = state.sqrtPriceX96 - - step.tickNext, step.initialized = pool.tickBitmapNextInitializedTickWithInOneWord( - state.tick, - pool.tickSpacing, - zeroForOne, - ) - - // ensure that we do not overshoot the min/max tick, as the tick bitmap is not aware of these bounds - if step.tickNext < consts.MIN_TICK { - step.tickNext = consts.MIN_TICK - } else if step.tickNext > consts.MAX_TICK { - step.tickNext = consts.MAX_TICK - } - - // get the price for the next tick - step.sqrtPriceNextX96 = common.TickMathGetSqrtRatioAtTick(step.tickNext) - - isLower := step.sqrtPriceNextX96.Lt(sqrtPriceLimitX96) - isHigher := step.sqrtPriceNextX96.Gt(sqrtPriceLimitX96) - - var sqrtRatioTargetX96 *u256.Uint - if (zeroForOne && isLower) || (!zeroForOne && isHigher) { - sqrtRatioTargetX96 = sqrtPriceLimitX96 - } else { - sqrtRatioTargetX96 = step.sqrtPriceNextX96 - } - - _sqrtPriceX96Str, _amountInStr, _amountOutStr, _feeAmountStr := plp.SwapMathComputeSwapStepStr( - state.sqrtPriceX96, - sqrtRatioTargetX96, - state.liquidity, - state.amountSpecifiedRemaining, - uint64(pool.fee), - ) - state.sqrtPriceX96 = u256.MustFromDecimal(_sqrtPriceX96Str) - step.amountIn = u256.MustFromDecimal(_amountInStr) - step.amountOut = u256.MustFromDecimal(_amountOutStr) - step.feeAmount = u256.MustFromDecimal(_feeAmountStr) - - amountInWithFee := i256.FromUint256(new(u256.Uint).Add(step.amountIn, step.feeAmount)) - if exactInput { - state.amountSpecifiedRemaining = i256.Zero().Sub(state.amountSpecifiedRemaining, amountInWithFee) - state.amountCalculated = i256.Zero().Sub(state.amountCalculated, i256.FromUint256(step.amountOut)) - } else { - state.amountSpecifiedRemaining = i256.Zero().Add(state.amountSpecifiedRemaining, i256.FromUint256(step.amountOut)) - state.amountCalculated = i256.Zero().Add(state.amountCalculated, amountInWithFee) - } - - // if the protocol fee is on, calculate how much is owed, decrement feeAmount, and increment protocolFee - if cache.feeProtocol > 0 { - delta := new(u256.Uint).Div(step.feeAmount, u256.NewUint(uint64(cache.feeProtocol))) - step.feeAmount = new(u256.Uint).Sub(step.feeAmount, delta) - state.protocolFee = new(u256.Uint).Add(state.protocolFee, delta) - } - - // update global fee tracker - if state.liquidity.Gt(u256.Zero()) { - // OBS if `DrySwap()` update its state, next ACTUAL `Swap()` gets affect - - // value1 := new(u256.Uint).Mul(step.feeAmount, u256.MustFromDecimal(consts.Q128)) - // value2 := new(u256.Uint).Div(value1, state.liquidity) - - // state.feeGrowthGlobalX128 = new(u256.Uint).Add(state.feeGrowthGlobalX128, value2) - } - - // shift tick if we reached the next price - if state.sqrtPriceX96.Eq(step.sqrtPriceNextX96) { - // if the tick is initialized, run the tick transition - if step.initialized { - var fee0, fee1 *u256.Uint - - // check for the placeholder value, which we replace with the actual value the first time the swap crosses an initialized tick - if zeroForOne { - fee0 = state.feeGrowthGlobalX128 - fee1 = pool.feeGrowthGlobal1X128 - } else { - fee0 = pool.feeGrowthGlobal0X128 - fee1 = state.feeGrowthGlobalX128 - } - - liquidityNet := pool.tickCross( - step.tickNext, - fee0, - fee1, - ) - - // if we're moving leftward, we interpret liquidityNet as the opposite sign - if zeroForOne { - liquidityNet = i256.Zero().Neg(liquidityNet) - } - - state.liquidity = liquidityMathAddDelta(state.liquidity, liquidityNet) - } - - if zeroForOne { - state.tick = step.tickNext - 1 - } else { - state.tick = step.tickNext - } - } else if !(state.sqrtPriceX96.Eq(step.sqrtPriceStartX96)) { - // recompute unless we're on a lower tick boundary (i.e. already transitioned ticks), and haven't moved - state.tick = common.TickMathGetTickAtSqrtRatio(state.sqrtPriceX96) - } - } - // END LOOP - - var amount0, amount1 *i256.Int - if zeroForOne == exactInput { - amount0 = i256.Zero().Sub(amountSpecified, state.amountSpecifiedRemaining) - amount1 = state.amountCalculated - } else { - amount0 = state.amountCalculated - amount1 = i256.Zero().Sub(amountSpecified, state.amountSpecifiedRemaining) - } - - pool.slot0.unlocked = true - - if zeroForOne { - if pool.balances.token1.Lt(amount1.Abs()) { - // NOT ENOUGH BALANCE for output token1 - return "0", "0", false - } - } else { - if pool.balances.token0.Lt(amount0.Abs()) { - // NOT ENOUGH BALANCE for output token0 - return "0", "0", false - } - } - - // JUST NOT ENOUGH BALANCE - if amount0.IsZero() || amount1.IsZero() { - return "0", "0", false - } - - return amount0.ToString(), amount1.ToString(), true -} diff --git a/pool/_helper_test.gno b/pool/_helper_test.gno index 992d99cc3..ea719ff2c 100644 --- a/pool/_helper_test.gno +++ b/pool/_helper_test.gno @@ -37,6 +37,8 @@ const ( fee3000 uint32 = 3000 maxApprove uint64 = 18446744073709551615 max_timeout int64 = 9999999999 + + maxSqrtPriceLimitX96 string = "1461446703485210103287273052203988822378723970341" ) const ( @@ -50,6 +52,7 @@ var ( alice = pusers.AddressOrName(testutils.TestAddress("alice")) pool = pusers.AddressOrName(consts.POOL_ADDR) protocolFee = pusers.AddressOrName(consts.PROTOCOL_FEE_ADDR) + router = pusers.AddressOrName(consts.ROUTER_ADDR) adminRealm = std.NewUserRealm(users.Resolve(admin)) posRealm = std.NewCodeRealm(consts.POSITION_PATH) @@ -215,6 +218,160 @@ func MintPosition(t *testing.T, caller) } +func MintPositionAll(t *testing.T, caller std.Address) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(caller)) + TokenApprove(t, gnsPath, pusers.AddressOrName(caller), pool, maxApprove) + TokenApprove(t, gnsPath, pusers.AddressOrName(caller), router, maxApprove) + TokenApprove(t, wugnotPath, pusers.AddressOrName(caller), pool, maxApprove) + TokenApprove(t, wugnotPath, pusers.AddressOrName(caller), router, maxApprove) + + params := []struct { + tickLower int32 + tickUpper int32 + liquidity uint64 + zeroToOne bool + }{ + { + tickLower: -300, + tickUpper: -240, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -240, + tickUpper: -180, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -180, + tickUpper: -120, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -120, + tickUpper: -60, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -60, + tickUpper: 0, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: 0, + tickUpper: 60, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 60, + tickUpper: 120, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 120, + tickUpper: 180, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 180, + tickUpper: 240, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 240, + tickUpper: 300, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: -360, + tickUpper: -300, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -420, + tickUpper: -360, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -480, + tickUpper: -420, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -540, + tickUpper: -480, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: -600, + tickUpper: -540, + liquidity: 10, + zeroToOne: true, + }, + { + tickLower: 300, + tickUpper: 360, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 360, + tickUpper: 420, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 420, + tickUpper: 480, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 480, + tickUpper: 540, + liquidity: 10, + zeroToOne: false, + }, + { + tickLower: 540, + tickUpper: 600, + liquidity: 10, + zeroToOne: false, + }, + } + + for _, p := range params { + MintPosition(t, + wugnotPath, + gnsPath, + fee3000, + p.tickLower, + p.tickUpper, + "100", + "100", + "0", + "0", + max_timeout, + caller, + caller) + } + +} + func wugnotApprove(t *testing.T, owner, spender pusers.AddressOrName, amount uint64) { t.Helper() std.TestSetRealm(std.NewUserRealm(users.Resolve(owner))) diff --git a/pool/liquidity_math.gno b/pool/liquidity_math.gno index 165acf834..47aaea2cc 100644 --- a/pool/liquidity_math.gno +++ b/pool/liquidity_math.gno @@ -7,39 +7,52 @@ import ( "gno.land/p/demo/ufmt" ) -// liquidityMathAddDelta Calculate the new liquidity with delta liquidity. -// If delta liquidity is negative, it will subtract the absolute value of delta liquidity from the current liquidity. -// If delta liquidity is positive, it will add the absolute value of delta liquidity to the current liquidity. -// inputs: -// - x: current liquidity -// - y: delta liquidity -// Returns the new liquidity. +// liquidityMathAddDelta calculates the new liquidity by applying the delta liquidity to the current liquidity. +// If delta liquidity is negative, it subtracts the absolute value of delta liquidity from the current liquidity. +// If delta liquidity is positive, it adds the absolute value of delta liquidity to the current liquidity. +// +// Parameters: +// - x: The current liquidity as a uint256 value. +// - y: The delta liquidity as a signed int256 value. +// +// Returns: +// - The new liquidity as a uint256 value. +// +// Notes: +// - If `x` or `y` is nil, the function panics with an appropriate error message. +// - If `y` is negative, its absolute value is subtracted from `x`. +// - The result must be less than `x`. Otherwise, the function panics to prevent underflow. +// +// - If `y` is positive, it is added to `x`. +// - The result must be greater than or equal to `x`. Otherwise, the function panics to prevent overflow. +// +// - The function ensures correctness by validating the results of the arithmetic operations. func liquidityMathAddDelta(x *u256.Uint, y *i256.Int) *u256.Uint { if x == nil || y == nil { panic(addDetailToError( errInvalidInput, - ufmt.Sprintf("liquidity_math.gno__liquidityMathAddDelta() || x or y is nil"), + "x or y is nil", )) } - absDelta := y.Abs() var z *u256.Uint // Subtract or add based on the sign of y if y.Lt(i256.Zero()) { + absDelta := y.Abs() z = new(u256.Uint).Sub(x, absDelta) - if z.Gte(x) { // z must be < x + if z.Gte(x) { panic(addDetailToError( errLiquidityCalculation, - ufmt.Sprintf("liquidity_math.gno__liquidityMathAddDelta() || LS(z must be < x) (x: %s, y: %s, z:%s)", x.ToString(), y.ToString(), z.ToString()), + ufmt.Sprintf("Less than Condition(z must be < x) (x: %s, y: %s, z:%s)", x.ToString(), y.ToString(), z.ToString()), )) } } else { - z = new(u256.Uint).Add(x, absDelta) - if z.Lt(x) { // z must be >= x + z = new(u256.Uint).Add(x, y.Abs()) + if z.Lt(x) { panic(addDetailToError( errLiquidityCalculation, - ufmt.Sprintf("liquidity_math.gno__liquidityMathAddDelta() || LA(z must be >= x) (x: %s, y: %s, z:%s)", x.ToString(), y.ToString(), z.ToString()), + ufmt.Sprintf("Less than or Equal Condition(z must be >= x) (x: %s, y: %s, z:%s)", x.ToString(), y.ToString(), z.ToString()), )) } } diff --git a/pool/liquidity_math_test.gno b/pool/liquidity_math_test.gno index 794dc2b97..fc2f379b2 100644 --- a/pool/liquidity_math_test.gno +++ b/pool/liquidity_math_test.gno @@ -22,7 +22,7 @@ func TestLiquidityMathAddDelta(t *testing.T) { y = i256.MustFromDecimal("100") liquidityMathAddDelta(nil, y) }, - wantPanic: addDetailToError(errInvalidInput, ufmt.Sprintf("liquidity_math.gno__liquidityMathAddDelta() || x or y is nil")), + wantPanic: addDetailToError(errInvalidInput, "x or y is nil"), }, { name: "y is nil", @@ -31,7 +31,7 @@ func TestLiquidityMathAddDelta(t *testing.T) { x = u256.MustFromDecimal("100") liquidityMathAddDelta(x, nil) }, - wantPanic: addDetailToError(errInvalidInput, ufmt.Sprintf("liquidity_math.gno__liquidityMathAddDelta() || x or y is nil")), + wantPanic: addDetailToError(errInvalidInput, "x or y is nil"), }, { name: "underflow panic with sub delta", @@ -42,7 +42,7 @@ func TestLiquidityMathAddDelta(t *testing.T) { }, wantPanic: addDetailToError( errLiquidityCalculation, - ufmt.Sprintf("liquidity_math.gno__liquidityMathAddDelta() || LS(z must be < x) (x: 0, y: -100, z:115792089237316195423570985008687907853269984665640564039457584007913129639836)")), + ufmt.Sprintf("Less than Condition(z must be < x) (x: 0, y: -100, z:115792089237316195423570985008687907853269984665640564039457584007913129639836)")), }, { name: "overflow panic with add delta", @@ -53,7 +53,7 @@ func TestLiquidityMathAddDelta(t *testing.T) { }, wantPanic: addDetailToError( errLiquidityCalculation, - ufmt.Sprintf("liquidity_math.gno__liquidityMathAddDelta() || LA(z must be >= x) (x: 115792089237316195423570985008687907853269984665640564039457584007913129639935, y: 100, z:99)")), + ufmt.Sprintf("Less than or Equal Condition(z must be >= x) (x: 115792089237316195423570985008687907853269984665640564039457584007913129639935, y: 100, z:99)")), }, } diff --git a/pool/pool.gno b/pool/pool.gno index cfb655f0e..db4df9b8e 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -5,8 +5,6 @@ import ( "gno.land/p/demo/ufmt" - plp "gno.land/p/gnoswap/pool" - "gno.land/r/gnoswap/v1/common" "gno.land/r/gnoswap/v1/consts" @@ -33,17 +31,14 @@ func Mint( if err := common.PositionOnly(caller); err != nil { panic(addDetailToError( errNoPermission, - ufmt.Sprintf("pool.gno__Mint() || only position(%s) can call pool mint(), called from %s", consts.POSITION_ADDR, caller.String()), + ufmt.Sprintf("only position(%s) can call pool mint(), called from %s", consts.POSITION_ADDR, caller.String()), )) } } liquidityAmount := u256.MustFromDecimal(_liquidityAmount) if liquidityAmount.IsZero() { - panic(addDetailToError( - errZeroLiquidity, - ufmt.Sprintf("pool.gno__Mint() || liquidityAmount == 0"), - )) + panic(errZeroLiquidity) } pool := GetPool(token0Path, token1Path, fee) @@ -79,7 +74,7 @@ func Burn( if err := common.PositionOnly(caller); err != nil { panic(addDetailToError( errNoPermission, - ufmt.Sprintf("pool.gno__Burn() || only position(%s) can call pool burn(), called from %s", consts.POSITION_ADDR, caller.String()), + ufmt.Sprintf("only position(%s) can call pool burn(), called from %s", consts.POSITION_ADDR, caller.String()), )) } } @@ -124,7 +119,7 @@ func Collect( if err := common.PositionOnly(caller); err != nil { panic(addDetailToError( errNoPermission, - ufmt.Sprintf("pool.gno__Collect() || only position(%s) can call pool collect(), called from %s", consts.POSITION_ADDR, caller.String()), + ufmt.Sprintf("only position(%s) can call pool collect(), called from %s", consts.POSITION_ADDR, caller.String()), )) } } @@ -136,7 +131,7 @@ func Collect( if !exist { panic(addDetailToError( errDataNotFound, - ufmt.Sprintf("pool.gno__Collect() || positionKey(%s) does not exist", positionKey), + ufmt.Sprintf("positionKey(%s) does not exist", positionKey), )) } @@ -163,7 +158,19 @@ func Collect( return amount0.ToString(), amount1.ToString() } -// collectToken handles the collection of a single token type (token0 or token1) +// collectToken handles the collection of tokens (either token0 or token1) from a position. +// It calculates the actual amount that can be collected based on three constraints: +// the requested amount, tokens owed, and available pool balance. +// +// Parameters: +// - amountReq: amount requested to collect +// - tokensOwed: amount of tokens owed to the position +// - poolBalance: current balance of tokens in the pool +// +// Returns: +// - amount: actual amount that will be collected (minimum of the three inputs) +// - newTokensOwed: remaining tokens owed after collection +// - newPoolBalance: remaining pool balance after collection func collectToken( amountReq, tokensOwed, poolBalance *u256.Uint, ) (amount, newTokensOwed, newPoolBalance *u256.Uint) { @@ -178,465 +185,6 @@ func collectToken( return amount, newTokensOwed, newPoolBalance } -// SwapResult encapsulates all state changes that occur as a result of a swap -// This type ensure all state transitions are atomic and can be applied at once. -type SwapResult struct { - Amount0 *i256.Int - Amount1 *i256.Int - NewSqrtPrice *u256.Uint - NewTick int32 - NewLiquidity *u256.Uint - NewProtocolFees ProtocolFees - FeeGrowthGlobal0X128 *u256.Uint - FeeGrowthGlobal1X128 *u256.Uint - SwapFee *u256.Uint -} - -// SwapComputation encapsulates pure computation logic for swap -type SwapComputation struct { - AmountSpecified *i256.Int - SqrtPriceLimitX96 *u256.Uint - ZeroForOne bool - ExactInput bool - InitialState SwapState - Cache SwapCache -} - -// Swap swaps token0 for token1, or token1 for token0 -// Returns swapped amount0, amount1 in string -// ref: https://docs.gnoswap.io/contracts/pool/pool.gno#swap -func Swap( - token0Path string, - token1Path string, - fee uint32, - recipient std.Address, - zeroForOne bool, - amountSpecified string, - sqrtPriceLimitX96 string, - payer std.Address, // router -) (string, string) { - common.IsHalted() - if common.GetLimitCaller() { - caller := std.PrevRealm().Addr() - if err := common.RouterOnly(caller); err != nil { - panic(addDetailToError( - errNoPermission, - ufmt.Sprintf("pool.gno__Swap() || only router(%s) can call pool swap(), called from %s", consts.ROUTER_ADDR, caller.String()), - )) - } - } - - if amountSpecified == "0" { - panic(addDetailToError( - errInvalidSwapAmount, - ufmt.Sprintf("pool.gno__Swap() || amountSpecified == 0"), - )) - } - - pool := GetPool(token0Path, token1Path, fee) - - slot0Start := pool.slot0 - if !slot0Start.unlocked { - panic(errLockedPool) - } - - slot0Start.unlocked = false - defer func() { slot0Start.unlocked = true }() - - amounts := i256.MustFromDecimal(amountSpecified) - sqrtPriceLimit := u256.MustFromDecimal(sqrtPriceLimitX96) - - validatePriceLimits(pool, zeroForOne, sqrtPriceLimit) - - feeGrowthGlobalX128 := getFeeGrowthGlobal(pool, zeroForOne) - feeProtocol := getFeeProtocol(slot0Start, zeroForOne) - cache := newSwapCache(feeProtocol, pool.liquidity) - - state := newSwapState(amounts, feeGrowthGlobalX128, cache.liquidityStart, pool.slot0) - - comp := SwapComputation{ - AmountSpecified: amounts, - SqrtPriceLimitX96: sqrtPriceLimit, - ZeroForOne: zeroForOne, - ExactInput: amounts.Gt(i256.Zero()), - InitialState: state, - Cache: cache, - } - - result, err := computeSwap(pool, comp) - if err != nil { - panic(err) - } - - applySwapResult(pool, result) - - // actual swap - pool.swapTransfers(zeroForOne, payer, recipient, result.Amount0, result.Amount1) - - prevAddr, prevRealm := getPrev() - - std.Emit( - "Swap", - "prevAddr", prevAddr, - "prevRealm", prevRealm, - "poolPath", GetPoolPath(token0Path, token1Path, fee), - "zeroForOne", ufmt.Sprintf("%t", zeroForOne), - "amountSpecified", amountSpecified, - "sqrtPriceLimitX96", sqrtPriceLimitX96, - "payer", payer.String(), - "recipient", recipient.String(), - "internal_amount0", result.Amount0.ToString(), - "internal_amount1", result.Amount1.ToString(), - "internal_protocolFee0", pool.protocolFees.token0.ToString(), - "internal_protocolFee1", pool.protocolFees.token1.ToString(), - "internal_swapFee", result.SwapFee.ToString(), - "internal_sqrtPriceX96", pool.slot0.sqrtPriceX96.ToString(), - ) - - return result.Amount0.ToString(), result.Amount1.ToString() -} - -// computeSwap performs the core swap computation without modifying pool state -// The function follows these state transitions: -// 1. Initial State: Provided by `SwapComputation.InitialState` -// 2. Stepping State: For each step: -// - Compute next tick and price target -// - Calculate amounts and fees -// - Update state (remaining amount, fees, liquidity) -// - Handle tick transitions if necessary -// -// 3. Final State: Aggregated in SwapResult -// -// The computation continues until either: -// - The entire amount is consumed (`amountSpecifiedRemaining` = 0) -// - The price limit is reached (`sqrtPriceX96` = `sqrtPriceLimitX96`) -// -// Returns an error if the computation fails at any step -func computeSwap(pool *Pool, comp SwapComputation) (*SwapResult, error) { - state := comp.InitialState - swapFee := u256.Zero() - - var newFee *u256.Uint - var err error - - // Compute swap steps until completion - for shouldContinueSwap(state, comp.SqrtPriceLimitX96) { - state, newFee, err = computeSwapStep(state, pool, comp.ZeroForOne, comp.SqrtPriceLimitX96, comp.ExactInput, comp.Cache, swapFee) - if err != nil { - return nil, err - } - swapFee = newFee - } - - // Calculate final amounts - amount0 := state.amountCalculated - amount1 := i256.Zero().Sub(comp.AmountSpecified, state.amountSpecifiedRemaining) - if comp.ZeroForOne == comp.ExactInput { - amount0, amount1 = amount1, amount0 - } - - // Prepare result - result := &SwapResult{ - Amount0: amount0, - Amount1: amount1, - NewSqrtPrice: state.sqrtPriceX96, - NewTick: state.tick, - NewLiquidity: state.liquidity, - NewProtocolFees: ProtocolFees{ - token0: pool.protocolFees.token0, - token1: pool.protocolFees.token1, - }, - FeeGrowthGlobal0X128: pool.feeGrowthGlobal0X128, - FeeGrowthGlobal1X128: pool.feeGrowthGlobal1X128, - SwapFee: swapFee, - } - - // Update protocol fees if necessary - if comp.ZeroForOne { - if state.protocolFee.Gt(u256.Zero()) { - result.NewProtocolFees.token0 = new(u256.Uint).Add(result.NewProtocolFees.token0, state.protocolFee) - } - result.FeeGrowthGlobal0X128 = state.feeGrowthGlobalX128 - } else { - if state.protocolFee.Gt(u256.Zero()) { - result.NewProtocolFees.token1 = new(u256.Uint).Add(result.NewProtocolFees.token1, state.protocolFee) - } - result.FeeGrowthGlobal1X128 = state.feeGrowthGlobalX128 - } - - return result, nil -} - -// applySwapResult updates pool state with computed results. -// All state changes are applied at once to maintain consistency -func applySwapResult(pool *Pool, result *SwapResult) { - pool.slot0.sqrtPriceX96 = result.NewSqrtPrice - pool.slot0.tick = result.NewTick - pool.liquidity = result.NewLiquidity - pool.protocolFees = result.NewProtocolFees - pool.feeGrowthGlobal0X128 = result.FeeGrowthGlobal0X128 - pool.feeGrowthGlobal1X128 = result.FeeGrowthGlobal1X128 -} - -// validatePriceLimits ensures the provided price limit is valid for the swap direction -// The function enforces that: -// For zeroForOne (selling token0): -// - Price limit must be below current price -// - Price limit must be above MIN_SQRT_RATIO -// -// For !zeroForOne (selling token1): -// - Price limit must be above current price -// - Price limit must be below MAX_SQRT_RATIO -func validatePriceLimits(pool *Pool, zeroForOne bool, sqrtPriceLimitX96 *u256.Uint) { - if zeroForOne { - minSqrtRatio := u256.MustFromDecimal(consts.MIN_SQRT_RATIO) - - cond1 := sqrtPriceLimitX96.Lt(pool.slot0.sqrtPriceX96) - cond2 := sqrtPriceLimitX96.Gt(minSqrtRatio) - if !(cond1 && cond2) { - panic(addDetailToError( - errPriceOutOfRange, - ufmt.Sprintf("pool.gno__Swap() || sqrtPriceLimitX96(%s) < slot0Start.sqrtPriceX96(%s) && sqrtPriceLimitX96(%s) > consts.MIN_SQRT_RATIO(%s)", - sqrtPriceLimitX96.ToString(), - pool.slot0.sqrtPriceX96.ToString(), - sqrtPriceLimitX96.ToString(), - consts.MIN_SQRT_RATIO), - )) - } - } else { - maxSqrtRatio := u256.MustFromDecimal(consts.MAX_SQRT_RATIO) - - cond1 := sqrtPriceLimitX96.Gt(pool.slot0.sqrtPriceX96) - cond2 := sqrtPriceLimitX96.Lt(maxSqrtRatio) - if !(cond1 && cond2) { - panic(addDetailToError( - errPriceOutOfRange, - ufmt.Sprintf("pool.gno__Swap() || sqrtPriceLimitX96(%s) > slot0Start.sqrtPriceX96(%s) && sqrtPriceLimitX96(%s) < consts.MAX_SQRT_RATIO(%s)", - sqrtPriceLimitX96.ToString(), - pool.slot0.sqrtPriceX96.ToString(), - sqrtPriceLimitX96.ToString(), - consts.MAX_SQRT_RATIO), - )) - } - } -} - -// getFeeProtocol returns the appropriate fee protocol based on zero for one -func getFeeProtocol(slot0 Slot0, zeroForOne bool) uint8 { - if zeroForOne { - return slot0.feeProtocol % 16 - } - return slot0.feeProtocol / 16 -} - -// getFeeGrowthGlobal returns the appropriate fee growth global based on zero for one -func getFeeGrowthGlobal(pool *Pool, zeroForOne bool) *u256.Uint { - if zeroForOne { - return pool.feeGrowthGlobal0X128 - } - return pool.feeGrowthGlobal1X128 -} - -func shouldContinueSwap(state SwapState, sqrtPriceLimitX96 *u256.Uint) bool { - return !(state.amountSpecifiedRemaining.IsZero()) && !(state.sqrtPriceX96.Eq(sqrtPriceLimitX96)) -} - -// computeSwapStep executes a single step of swap and returns new state -func computeSwapStep( - state SwapState, - pool *Pool, - zeroForOne bool, - sqrtPriceLimitX96 *u256.Uint, - exactInput bool, - cache SwapCache, - swapFee *u256.Uint, -) (SwapState, *u256.Uint, error) { - step := computeSwapStepInit(state, pool, zeroForOne) - - // determining the price target for this step - sqrtRatioTargetX96 := computeTargetSqrtRatio(step, sqrtPriceLimitX96, zeroForOne) - - // computing the amounts to be swapped at this step - var newState SwapState - var err error - - newState, step = computeAmounts(state, sqrtRatioTargetX96, pool, step) - newState = updateAmounts(step, newState, exactInput) - - // if the protocol fee is on, calculate how much is owed, - // decrement fee amount, and increment protocol fee - if cache.feeProtocol > 0 { - newState, err = updateFeeProtocol(step, cache.feeProtocol, newState) - if err != nil { - return state, nil, err - } - } - - // update global fee tracker - if newState.liquidity.Gt(u256.Zero()) { - update := u256.MulDiv(step.feeAmount, u256.MustFromDecimal(consts.Q128), newState.liquidity) - newState.SetFeeGrowthGlobalX128(new(u256.Uint).Add(newState.feeGrowthGlobalX128, update)) - } - - // handling tick transitions - if newState.sqrtPriceX96.Eq(step.sqrtPriceNextX96) { - newState = tickTransition(step, zeroForOne, newState, pool) - } - - if newState.sqrtPriceX96.Neq(step.sqrtPriceStartX96) { - newState.SetTick(common.TickMathGetTickAtSqrtRatio(newState.sqrtPriceX96)) - } - - newSwapFee := new(u256.Uint).Add(swapFee, step.feeAmount) - - return newState, newSwapFee, nil -} - -// updateFeeProtocol calculates and updates protocol fees for the current step. -func updateFeeProtocol(step StepComputations, feeProtocol uint8, state SwapState) (SwapState, error) { - delta := step.feeAmount - delta.Div(delta, u256.NewUint(uint64(feeProtocol))) - - newFeeAmount, overflow := new(u256.Uint).SubOverflow(step.feeAmount, delta) - if overflow { - return state, errUnderflow - } - step.feeAmount = newFeeAmount - state.protocolFee.Add(state.protocolFee, delta) - - return state, nil -} - -// computeSwapStepInit initializes the computation for a single swap step. -func computeSwapStepInit(state SwapState, pool *Pool, zeroForOne bool) StepComputations { - var step StepComputations - step.sqrtPriceStartX96 = state.sqrtPriceX96 - tickNext, initialized := pool.tickBitmapNextInitializedTickWithInOneWord( - state.tick, - pool.tickSpacing, - zeroForOne, - ) - - step.tickNext = tickNext - step.initialized = initialized - - // prevent overshoot the min/max tick - step.clampTickNext() - - // get the price for the next tick - step.sqrtPriceNextX96 = common.TickMathGetSqrtRatioAtTick(step.tickNext) - return step -} - -// computeTargetSqrtRatio determines the target sqrt price for the current swap step. -func computeTargetSqrtRatio(step StepComputations, sqrtPriceLimitX96 *u256.Uint, zeroForOne bool) *u256.Uint { - if shouldUsePriceLimit(step.sqrtPriceNextX96, sqrtPriceLimitX96, zeroForOne) { - return sqrtPriceLimitX96 - } - return step.sqrtPriceNextX96 -} - -// shouldUsePriceLimit returns true if the price limit should be used instead of the next tick price -func shouldUsePriceLimit(sqrtPriceNext, sqrtPriceLimit *u256.Uint, zeroForOne bool) bool { - isLower := sqrtPriceNext.Lt(sqrtPriceLimit) - isHigher := sqrtPriceNext.Gt(sqrtPriceLimit) - if zeroForOne { - return isLower - } - return isHigher -} - -// computeAmounts calculates the input and output amounts for the current swap step. -func computeAmounts(state SwapState, sqrtRatioTargetX96 *u256.Uint, pool *Pool, step StepComputations) (SwapState, StepComputations) { - sqrtPriceX96Str, amountInStr, amountOutStr, feeAmountStr := plp.SwapMathComputeSwapStepStr( - state.sqrtPriceX96, - sqrtRatioTargetX96, - state.liquidity, - state.amountSpecifiedRemaining, - uint64(pool.fee), - ) - - step.amountIn = u256.MustFromDecimal(amountInStr) - step.amountOut = u256.MustFromDecimal(amountOutStr) - step.feeAmount = u256.MustFromDecimal(feeAmountStr) - - state.SetSqrtPriceX96(sqrtPriceX96Str) - - return state, step -} - -// updateAmounts calculates new remaining and calculated amounts based on the swap step -// For exact input swaps: -// - Decrements remaining input amount by (amountIn + feeAmount) -// - Decrements calculated amount by amountOut -// -// For exact output swaps: -// - Increments remaining output amount by amountOut -// - Increments calculated amount by (amountIn + feeAmount) -func updateAmounts(step StepComputations, state SwapState, exactInput bool) SwapState { - amountInWithFee := i256.FromUint256(new(u256.Uint).Add(step.amountIn, step.feeAmount)) - if exactInput { - state.amountSpecifiedRemaining = i256.Zero().Sub(state.amountSpecifiedRemaining, amountInWithFee) - state.amountCalculated = i256.Zero().Sub(state.amountCalculated, i256.FromUint256(step.amountOut)) - return state - } - state.amountSpecifiedRemaining = i256.Zero().Add(state.amountSpecifiedRemaining, i256.FromUint256(step.amountOut)) - state.amountCalculated = i256.Zero().Add(state.amountCalculated, amountInWithFee) - - return state -} - -// tickTransition handles the transition between price ticks during a swap -func tickTransition(step StepComputations, zeroForOne bool, state SwapState, pool *Pool) SwapState { - // ensure existing state to keep immutability - newState := state - - if step.initialized { - var fee0, fee1 *u256.Uint - - if zeroForOne { - fee0 = state.feeGrowthGlobalX128 - fee1 = pool.feeGrowthGlobal1X128 - } else { - fee0 = pool.feeGrowthGlobal0X128 - fee1 = state.feeGrowthGlobalX128 - } - - liquidityNet := pool.tickCross(step.tickNext, fee0, fee1) - - if zeroForOne { - liquidityNet = i256.Zero().Neg(liquidityNet) - } - - newState.liquidity = liquidityMathAddDelta(state.liquidity, liquidityNet) - } - - if zeroForOne { - newState.tick = step.tickNext - 1 - } else { - newState.tick = step.tickNext - } - - return newState -} - -func (pool *Pool) swapTransfers(zeroForOne bool, payer, recipient std.Address, amount0, amount1 *i256.Int) { - var targetTokenPath string - var amount *i256.Int - - if zeroForOne { - targetTokenPath = pool.token0Path - amount = amount0 - } else { - targetTokenPath = pool.token1Path - amount = amount1 - } - - // payer -> POOL -> recipient - pool.transferFromAndVerify(payer, consts.POOL_ADDR, targetTokenPath, amount.Abs(), zeroForOne) - pool.transferAndVerify(recipient, targetTokenPath, amount, !zeroForOne) -} - // SetFeeProtocolByAdmin sets the fee protocol for all pools // Also it will be applied to new created pools func SetFeeProtocolByAdmin( @@ -684,16 +232,44 @@ func SetFeeProtocol(feeProtocol0, feeProtocol1 uint8) { ) } +// setFeeProtocol updates the protocol fee configuration for all existing pools and sets +// the default for new pools. This is an internal function called by both `admin` and `governance` +// protocol fee management functions. +// +// The protocol fee is stored as a single `uint8` value where: +// - Lower 4 bits store feeProtocol0 (for token0) +// - Upper 4 bits store feeProtocol1 (for token1) +// +// This compact representation allows storing both fee values in a single byte. +// +// Parameters (must be 0 or between 4 and 10 inclusive): +// - feeProtocol0: protocol fee for token0 +// - feeProtocol1: protocol fee for token1 +// +// Returns: +// - newFee (uint8): the combined fee protocol value +// +// Example: +// If feeProtocol0 = 4 and feeProtocol1 = 5: +// +// newFee = 4 + (5 << 4) +// // Results in: 0x54 (84 in decimal) +// // Binary: 0101 0100 +// // ^^^^ ^^^^ +// // fee1=5 fee0=4 func setFeeProtocol(feeProtocol0, feeProtocol1 uint8) uint8 { common.IsHalted() if err := validateFeeProtocol(feeProtocol0, feeProtocol1); err != nil { panic(addDetailToError( err, - ufmt.Sprintf("pool.gno__setFeeProtocol() || expected (feeProtocol0(%d) == 0 || (feeProtocol0(%d) >= 4 && feeProtocol0(%d) <= 10)) && (feeProtocol1(%d) == 0 || (feeProtocol1(%d) >= 4 && feeProtocol1(%d) <= 10))", feeProtocol0, feeProtocol0, feeProtocol0, feeProtocol1, feeProtocol1, feeProtocol1), + ufmt.Sprintf("expected (feeProtocol0(%d) == 0 || (feeProtocol0(%d) >= 4 && feeProtocol0(%d) <= 10)) && (feeProtocol1(%d) == 0 || (feeProtocol1(%d) >= 4 && feeProtocol1(%d) <= 10))", feeProtocol0, feeProtocol0, feeProtocol0, feeProtocol1, feeProtocol1, feeProtocol1), )) } + // combine both protocol fee into a single byte: + // - feePrtocol0 occupies the lower 4 bits + // - feeProtocol1 is shifted the lower 4 positions to occupy the upper 4 bits newFee := feeProtocol0 + (feeProtocol1 << 4) // ( << 4 ) = ( * 16 ) // iterate all pool @@ -714,6 +290,8 @@ func validateFeeProtocol(feeProtocol0, feeProtocol1 uint8) error { return nil } +// isValidFeeProtocolValue checks if a fee protocol value is within acceptable range. +// valid values are either 0 or between 4 and 10 inclusive. func isValidFeeProtocolValue(value uint8) bool { return value == 0 || (value >= 4 && value <= 10) } @@ -837,227 +415,29 @@ func collectProtocol( return amount0.ToString(), amount1.ToString() } -func (pool *Pool) saveProtocolFees(amount0, amount1 *u256.Uint) (*u256.Uint, *u256.Uint) { +// saveProtocolFees updates the protocol fee balances after collection. +// +// Parameters: +// - amount0: amount of token0 fees to collect +// - amount1: amount of token1 fees to collect +// +// Returns the adjusted amounts that will actually be collected for both tokens. +func (p *Pool) saveProtocolFees(amount0, amount1 *u256.Uint) (*u256.Uint, *u256.Uint) { cond01 := amount0.Gt(u256.Zero()) - cond02 := amount0.Eq(pool.protocolFees.token0) + cond02 := amount0.Eq(p.protocolFees.token0) if cond01 && cond02 { amount0 = new(u256.Uint).Sub(amount0, u256.One()) } cond11 := amount1.Gt(u256.Zero()) - cond12 := amount1.Eq(pool.protocolFees.token1) + cond12 := amount1.Eq(p.protocolFees.token1) if cond11 && cond12 { amount1 = new(u256.Uint).Sub(amount1, u256.One()) } - pool.protocolFees.token0 = new(u256.Uint).Sub(pool.protocolFees.token0, amount0) - pool.protocolFees.token1 = new(u256.Uint).Sub(pool.protocolFees.token1, amount1) + p.protocolFees.token0 = new(u256.Uint).Sub(p.protocolFees.token0, amount0) + p.protocolFees.token1 = new(u256.Uint).Sub(p.protocolFees.token1, amount1) // return rest fee return amount0, amount1 } - -func (pool *Pool) transferAndVerify( - to std.Address, - tokenPath string, - amount *i256.Int, - isToken0 bool, -) { - if amount.Sign() != -1 { - panic(addDetailToError( - errMustBeNegative, - ufmt.Sprintf("pool.gno__transferAndVerify() || amount(%s) must be negative", amount.ToString()), - )) - } - - absAmount := amount.Abs() - - token0 := pool.balances.token0 - token1 := pool.balances.token1 - - if err := validatePoolBalance(token0, token1, absAmount, isToken0); err != nil { - panic(err) - } - amountUint64, err := checkAmountRange(absAmount) - if err != nil { - panic(err) - } - - token := common.GetTokenTeller(tokenPath) - checkTransferError(token.Transfer(to, amountUint64)) - - newBalance, err := updatePoolBalance(token0, token1, absAmount, isToken0) - if err != nil { - panic(err) - } - - if isToken0 { - pool.balances.token0 = newBalance - } else { - pool.balances.token1 = newBalance - } -} - -func validatePoolBalance(token0, token1, amount *u256.Uint, isToken0 bool) error { - if isToken0 { - if token0.Lt(amount) { - return ufmt.Errorf( - "%s || token0(%s) >= amount(%s)", - errTransferFailed.Error(), token0.ToString(), amount.ToString(), - ) - } - return nil - } - if token1.Lt(amount) { - return ufmt.Errorf( - "%s || token1(%s) >= amount(%s)", - errTransferFailed.Error(), token1.ToString(), amount.ToString(), - ) - } - return nil -} - -func updatePoolBalance( - token0, token1, amount *u256.Uint, - isToken0 bool, -) (*u256.Uint, error) { - var overflow bool - var newBalance *u256.Uint - - if isToken0 { - newBalance, overflow = new(u256.Uint).SubOverflow(token0, amount) - if isBalanceOverflowOrNegative(overflow, newBalance) { - return nil, ufmt.Errorf( - "%s || cannot decrease, token0(%s) - amount(%s)", - errTransferFailed.Error(), token0.ToString(), amount.ToString(), - ) - } - return newBalance, nil - } - - newBalance, overflow = new(u256.Uint).SubOverflow(token1, amount) - if isBalanceOverflowOrNegative(overflow, newBalance) { - return nil, ufmt.Errorf( - "%s || cannot decrease, token1(%s) - amount(%s)", - errTransferFailed.Error(), token1.ToString(), amount.ToString(), - ) - } - return newBalance, nil -} - -func isBalanceOverflowOrNegative(overflow bool, newBalance *u256.Uint) bool { - return overflow || newBalance.Lt(u256.Zero()) -} - -func (pool *Pool) transferFromAndVerify( - from, to std.Address, - tokenPath string, - amount *u256.Uint, - isToken0 bool, -) { - absAmount := amount - amountUint64, err := checkAmountRange(absAmount) - if err != nil { - panic(err) - } - - token := common.GetTokenTeller(tokenPath) - checkTransferError(token.TransferFrom(from, to, amountUint64)) - - // update pool balances - if isToken0 { - pool.balances.token0 = new(u256.Uint).Add(pool.balances.token0, absAmount) - } else { - pool.balances.token1 = new(u256.Uint).Add(pool.balances.token1, absAmount) - } -} - -func checkAmountRange(amount *u256.Uint) (uint64, error) { - res, overflow := amount.Uint64WithOverflow() - if overflow { - return 0, ufmt.Errorf( - "%s || amount(%s) overflows uint64 range", - errOutOfRange.Error(), amount.ToString(), - ) - } - - return res, nil -} - -// receiver getters -func (p *Pool) GetToken0Path() string { - return p.token0Path -} - -func (p *Pool) GetToken1Path() string { - return p.token1Path -} - -func (p *Pool) GetFee() uint32 { - return p.fee -} - -func (p *Pool) GetBalanceToken0() *u256.Uint { - return p.balances.token0 -} - -func (p *Pool) GetBalanceToken1() *u256.Uint { - return p.balances.token1 -} - -func (p *Pool) GetTickSpacing() int32 { - return p.tickSpacing -} - -func (p *Pool) GetMaxLiquidityPerTick() *u256.Uint { - return p.maxLiquidityPerTick -} - -func (p *Pool) GetSlot0() Slot0 { - return p.slot0 -} - -func (p *Pool) GetSlot0SqrtPriceX96() *u256.Uint { - return p.slot0.sqrtPriceX96 -} - -func (p *Pool) GetSlot0Tick() int32 { - return p.slot0.tick -} - -func (p *Pool) GetSlot0FeeProtocol() uint8 { - return p.slot0.feeProtocol -} - -func (p *Pool) GetSlot0Unlocked() bool { - return p.slot0.unlocked -} - -func (p *Pool) GetFeeGrowthGlobal0X128() *u256.Uint { - return p.feeGrowthGlobal0X128 -} - -func (p *Pool) GetFeeGrowthGlobal1X128() *u256.Uint { - return p.feeGrowthGlobal1X128 -} - -func (p *Pool) GetProtocolFeesToken0() *u256.Uint { - return p.protocolFees.token0 -} - -func (p *Pool) GetProtocolFeesToken1() *u256.Uint { - return p.protocolFees.token1 -} - -func (p *Pool) GetLiquidity() *u256.Uint { - return p.liquidity -} - -func mustGetPool(poolPath string) *Pool { - pool, exist := pools[poolPath] - if !exist { - panic(addDetailToError(errDataNotFound, - ufmt.Sprintf("poolPath(%s) does not exist", poolPath))) - } - return pool -} diff --git a/pool/pool_test.gno b/pool/pool_test.gno index c6ef82d46..b58d95ac3 100644 --- a/pool/pool_test.gno +++ b/pool/pool_test.gno @@ -140,566 +140,3 @@ func TestBurn(t *testing.T) { } } -func TestSaveProtocolFees(t *testing.T) { - tests := []struct { - name string - pool *Pool - amount0 *u256.Uint - amount1 *u256.Uint - want0 *u256.Uint - want1 *u256.Uint - wantFee0 *u256.Uint - wantFee1 *u256.Uint - }{ - { - name: "normal fee deduction", - pool: &Pool{ - protocolFees: ProtocolFees{ - token0: u256.NewUint(1000), - token1: u256.NewUint(2000), - }, - }, - amount0: u256.NewUint(500), - amount1: u256.NewUint(1000), - want0: u256.NewUint(500), - want1: u256.NewUint(1000), - wantFee0: u256.NewUint(500), - wantFee1: u256.NewUint(1000), - }, - { - name: "exact fee deduction (1 deduction)", - pool: &Pool{ - protocolFees: ProtocolFees{ - token0: u256.NewUint(1000), - token1: u256.NewUint(2000), - }, - }, - amount0: u256.NewUint(1000), - amount1: u256.NewUint(2000), - want0: u256.NewUint(999), - want1: u256.NewUint(1999), - wantFee0: u256.NewUint(1), - wantFee1: u256.NewUint(1), - }, - { - name: "0 fee deduction", - pool: &Pool{ - protocolFees: ProtocolFees{ - token0: u256.NewUint(1000), - token1: u256.NewUint(2000), - }, - }, - amount0: u256.NewUint(0), - amount1: u256.NewUint(0), - want0: u256.NewUint(0), - want1: u256.NewUint(0), - wantFee0: u256.NewUint(1000), - wantFee1: u256.NewUint(2000), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got0, got1 := tt.pool.saveProtocolFees(tt.amount0, tt.amount1) - - uassert.Equal(t, got0.ToString(), tt.want0.ToString()) - uassert.Equal(t, got1.ToString(), tt.want1.ToString()) - uassert.Equal(t, tt.pool.protocolFees.token0.ToString(), tt.wantFee0.ToString()) - uassert.Equal(t, tt.pool.protocolFees.token1.ToString(), tt.wantFee1.ToString()) - }) - } -} - -func TestTransferAndVerify(t *testing.T) { - // Setup common test data - pool := &Pool{ - balances: Balances{ - token0: u256.NewUint(1000), - token1: u256.NewUint(1000), - }, - } - - t.Run("validatePoolBalance", func(t *testing.T) { - tests := []struct { - name string - amount *u256.Uint - isToken0 bool - expectedError bool - }{ - { - name: "must success for negative amount", - amount: u256.NewUint(500), - isToken0: true, - expectedError: false, - }, - { - name: "must panic for insufficient token0 balance", - amount: u256.NewUint(1500), - isToken0: true, - expectedError: true, - }, - { - name: "must success for negative amount", - amount: u256.NewUint(500), - isToken0: false, - expectedError: false, - }, - { - name: "must panic for insufficient token1 balance", - amount: u256.NewUint(1500), - isToken0: false, - expectedError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - token0 := pool.balances.token0 - token1 := pool.balances.token1 - - err := validatePoolBalance(token0, token1, tt.amount, tt.isToken0) - if err != nil { - if !tt.expectedError { - t.Errorf("unexpected error: %v", err) - } - } - }) - } - }) -} - -func TestUpdatePoolBalance(t *testing.T) { - tests := []struct { - name string - initialToken0 *u256.Uint - initialToken1 *u256.Uint - amount *u256.Uint - isToken0 bool - expectedBal *u256.Uint - expectErr bool - }{ - { - name: "normal token0 decrease", - initialToken0: u256.NewUint(1000), - initialToken1: u256.NewUint(2000), - amount: u256.NewUint(300), - isToken0: true, - expectedBal: u256.NewUint(700), - expectErr: false, - }, - { - name: "normal token1 decrease", - initialToken0: u256.NewUint(1000), - initialToken1: u256.NewUint(2000), - amount: u256.NewUint(500), - isToken0: false, - expectedBal: u256.NewUint(1500), - expectErr: false, - }, - { - name: "insufficient token0 balance", - initialToken0: u256.NewUint(100), - initialToken1: u256.NewUint(2000), - amount: u256.NewUint(200), - isToken0: true, - expectedBal: nil, - expectErr: true, - }, - { - name: "insufficient token1 balance", - initialToken0: u256.NewUint(1000), - initialToken1: u256.NewUint(100), - amount: u256.NewUint(200), - isToken0: false, - expectedBal: nil, - expectErr: true, - }, - { - name: "zero value handling", - initialToken0: u256.NewUint(1000), - initialToken1: u256.NewUint(2000), - amount: u256.NewUint(0), - isToken0: true, - expectedBal: u256.NewUint(1000), - expectErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - pool := &Pool{ - balances: Balances{ - token0: tt.initialToken0, - token1: tt.initialToken1, - }, - } - - newBal, err := updatePoolBalance(tt.initialToken0, tt.initialToken1, tt.amount, tt.isToken0) - - if tt.expectErr { - if err == nil { - t.Errorf("%s: expected error but no error", tt.name) - } - return - } - if err != nil { - t.Errorf("%s: unexpected error: %v", tt.name, err) - return - } - - if !newBal.Eq(tt.expectedBal) { - t.Errorf("%s: balance mismatch, expected: %s, actual: %s", - tt.name, - tt.expectedBal.ToString(), - newBal.ToString(), - ) - } - }) - } -} - -func TestShouldContinueSwap(t *testing.T) { - tests := []struct { - name string - state SwapState - sqrtPriceLimitX96 *u256.Uint - expected bool - }{ - { - name: "Should continue - amount remaining and price not at limit", - state: SwapState{ - amountSpecifiedRemaining: i256.MustFromDecimal("1000"), - sqrtPriceX96: u256.MustFromDecimal("1000000"), - }, - sqrtPriceLimitX96: u256.MustFromDecimal("900000"), - expected: true, - }, - { - name: "Should stop - no amount remaining", - state: SwapState{ - amountSpecifiedRemaining: i256.Zero(), - sqrtPriceX96: u256.MustFromDecimal("1000000"), - }, - sqrtPriceLimitX96: u256.MustFromDecimal("900000"), - expected: false, - }, - { - name: "Should stop - price at limit", - state: SwapState{ - amountSpecifiedRemaining: i256.MustFromDecimal("1000"), - sqrtPriceX96: u256.MustFromDecimal("900000"), - }, - sqrtPriceLimitX96: u256.MustFromDecimal("900000"), - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := shouldContinueSwap(tt.state, tt.sqrtPriceLimitX96) - uassert.Equal(t, tt.expected, result) - }) - } -} - -func TestUpdateAmounts(t *testing.T) { - tests := []struct { - name string - step StepComputations - state SwapState - exactInput bool - expectedState SwapState - }{ - { - name: "Exact input update", - step: StepComputations{ - amountIn: u256.MustFromDecimal("100"), - amountOut: u256.MustFromDecimal("97"), - feeAmount: u256.MustFromDecimal("3"), - }, - state: SwapState{ - amountSpecifiedRemaining: i256.MustFromDecimal("1000"), - amountCalculated: i256.Zero(), - }, - exactInput: true, - expectedState: SwapState{ - amountSpecifiedRemaining: i256.MustFromDecimal("897"), // 1000 - (100 + 3) - amountCalculated: i256.MustFromDecimal("-97"), - }, - }, - { - name: "Exact output update", - step: StepComputations{ - amountIn: u256.MustFromDecimal("100"), - amountOut: u256.MustFromDecimal("97"), - feeAmount: u256.MustFromDecimal("3"), - }, - state: SwapState{ - amountSpecifiedRemaining: i256.MustFromDecimal("-1000"), - amountCalculated: i256.Zero(), - }, - exactInput: false, - expectedState: SwapState{ - amountSpecifiedRemaining: i256.MustFromDecimal("-903"), // -1000 + 97 - amountCalculated: i256.MustFromDecimal("103"), // 100 + 3 - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := updateAmounts(tt.step, tt.state, tt.exactInput) - - uassert.True(t, tt.expectedState.amountSpecifiedRemaining.Eq(result.amountSpecifiedRemaining)) - uassert.True(t, tt.expectedState.amountCalculated.Eq(result.amountCalculated)) - }) - } -} - -func TestComputeSwap(t *testing.T) { - mockPool := &Pool{ - token0Path: "token0", - token1Path: "token1", - fee: 3000, // 0.3% - tickSpacing: 60, - slot0: Slot0{ - sqrtPriceX96: u256.MustFromDecimal("1000000000000000000"), // 1.0 - tick: 0, - feeProtocol: 0, - unlocked: true, - }, - liquidity: u256.MustFromDecimal("1000000000000000000"), // 1.0 - protocolFees: ProtocolFees{ - token0: u256.Zero(), - token1: u256.Zero(), - }, - feeGrowthGlobal0X128: u256.Zero(), - feeGrowthGlobal1X128: u256.Zero(), - tickBitmaps: make(TickBitmaps), - ticks: make(Ticks), - positions: make(Positions), - } - - wordPos, _ := tickBitmapPosition(0) - // TODO: use avl - mockPool.tickBitmaps[wordPos] = u256.NewUint(1) - - t.Run("basic swap", func(t *testing.T) { - comp := SwapComputation{ - AmountSpecified: i256.MustFromDecimal("1000000"), // 1.0 token - SqrtPriceLimitX96: u256.MustFromDecimal("1100000000000000000"), // 1.1 - ZeroForOne: true, - ExactInput: true, - InitialState: SwapState{ - amountSpecifiedRemaining: i256.MustFromDecimal("1000000"), - amountCalculated: i256.Zero(), - sqrtPriceX96: mockPool.slot0.sqrtPriceX96, - tick: mockPool.slot0.tick, - feeGrowthGlobalX128: mockPool.feeGrowthGlobal0X128, - protocolFee: u256.Zero(), - liquidity: mockPool.liquidity, - }, - Cache: SwapCache{ - feeProtocol: 0, - liquidityStart: mockPool.liquidity, - }, - } - - result, err := computeSwap(mockPool, comp) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - - if result.Amount0.IsZero() { - t.Error("expected non-zero amount0") - } - if result.Amount1.IsZero() { - t.Error("expected non-zero amount1") - } - if result.SwapFee.IsZero() { - t.Error("expected non-zero swap fee") - } - }) - - t.Run("swap with zero liquidity", func(t *testing.T) { - mockPoolZeroLiq := *mockPool - mockPoolZeroLiq.liquidity = u256.Zero() - - comp := SwapComputation{ - AmountSpecified: i256.MustFromDecimal("1000000"), - SqrtPriceLimitX96: u256.MustFromDecimal("1100000000000000000"), - ZeroForOne: true, - ExactInput: true, - InitialState: SwapState{ - amountSpecifiedRemaining: i256.MustFromDecimal("1000000"), - amountCalculated: i256.Zero(), - sqrtPriceX96: mockPoolZeroLiq.slot0.sqrtPriceX96, - tick: mockPoolZeroLiq.slot0.tick, - feeGrowthGlobalX128: mockPoolZeroLiq.feeGrowthGlobal0X128, - protocolFee: u256.Zero(), - liquidity: mockPoolZeroLiq.liquidity, - }, - Cache: SwapCache{ - feeProtocol: 0, - liquidityStart: mockPoolZeroLiq.liquidity, - }, - } - - result, err := computeSwap(&mockPoolZeroLiq, comp) - if err != nil { - t.Fatalf("expected no error, got %v", err) - } - - if !result.Amount0.IsZero() || !result.Amount1.IsZero() { - t.Error("expected zero amounts for zero liquidity") - } - }) -} - -func TestTransferFromAndVerify(t *testing.T) { - tests := []struct { - name string - pool *Pool - from std.Address - to std.Address - tokenPath string - amount *i256.Int - isToken0 bool - expectedBal0 *u256.Uint - expectedBal1 *u256.Uint - }{ - { - name: "normal token0 transfer", - pool: &Pool{ - balances: Balances{ - token0: u256.NewUint(1000), - token1: u256.NewUint(2000), - }, - }, - from: testutils.TestAddress("from_addr"), - to: testutils.TestAddress("to_addr"), - tokenPath: "gno.land/r/onbloc/bar", - amount: i256.NewInt(500), - isToken0: true, - expectedBal0: u256.NewUint(1500), // 1000 + 500 - expectedBal1: u256.NewUint(2000), // unchanged - }, - { - name: "normal token1 transfer", - pool: &Pool{ - balances: Balances{ - token0: u256.NewUint(1000), - token1: u256.NewUint(2000), - }, - }, - from: testutils.TestAddress("from_addr"), - to: testutils.TestAddress("to_addr"), - tokenPath: "gno.land/r/onbloc/foo", - amount: i256.NewInt(800), - isToken0: false, - expectedBal0: u256.NewUint(1000), // unchanged - expectedBal1: u256.NewUint(2800), // 2000 + 800 - }, - { - name: "zero value transfer", - pool: &Pool{ - balances: Balances{ - token0: u256.NewUint(1000), - token1: u256.NewUint(2000), - }, - }, - from: testutils.TestAddress("from_addr"), - to: testutils.TestAddress("to_addr"), - tokenPath: "gno.land/r/onbloc/bar", - amount: i256.NewInt(0), - isToken0: true, - expectedBal0: u256.NewUint(1000), // unchanged - expectedBal1: u256.NewUint(2000), // unchanged - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - - oldCheckTransferError := checkTransferError - defer func() { - checkTransferError = oldCheckTransferError - }() - - checkTransferError = func(err error) { - return - } - - tt.pool.transferFromAndVerify(tt.from, tt.to, tt.tokenPath, u256.MustFromDecimal(tt.amount.ToString()), tt.isToken0) - - if !tt.pool.balances.token0.Eq(tt.expectedBal0) { - t.Errorf("token0 balance mismatch: expected %s, got %s", - tt.expectedBal0.ToString(), - tt.pool.balances.token0.ToString()) - } - - if !tt.pool.balances.token1.Eq(tt.expectedBal1) { - t.Errorf("token1 balance mismatch: expected %s, got %s", - tt.expectedBal1.ToString(), - tt.pool.balances.token1.ToString()) - } - }) - } - - t.Run("negative value handling", func(t *testing.T) { - pool := &Pool{ - balances: Balances{ - token0: u256.NewUint(1000), - token1: u256.NewUint(2000), - }, - } - - oldCheckTransferError := checkTransferError - defer func() { checkTransferError = oldCheckTransferError }() - - checkTransferError = func(err error) { - return - } - - negativeAmount := i256.NewInt(-500) - pool.transferFromAndVerify( - testutils.TestAddress("from_addr"), - testutils.TestAddress("to_addr"), - "gno.land/r/onbloc/qux", - u256.MustFromDecimal(negativeAmount.Abs().ToString()), - true, - ) - - expectedBal := u256.NewUint(1500) // 1000 + 500 (absolute value) - if !pool.balances.token0.Eq(expectedBal) { - t.Errorf("negative amount handling failed: expected %s, got %s", - expectedBal.ToString(), - pool.balances.token0.ToString()) - } - }) - - t.Run("uint64 overflow value", func(t *testing.T) { - pool := &Pool{ - balances: Balances{ - token0: u256.NewUint(1000), - token1: u256.NewUint(2000), - }, - } - - hugeAmount := i256.FromUint256(u256.MustFromDecimal("18446744073709551616")) // 2^64 - - defer func() { - if r := recover(); r == nil { - t.Error("expected panic for amount exceeding uint64 range") - } - }() - - pool.transferFromAndVerify( - testutils.TestAddress("from_addr"), - testutils.TestAddress("to_addr"), - "gno.land/r/onbloc/qux", - u256.MustFromDecimal(hugeAmount.ToString()), - true, - ) - }) -} diff --git a/pool/pool_transfer.gno b/pool/pool_transfer.gno new file mode 100644 index 000000000..201ae2cfb --- /dev/null +++ b/pool/pool_transfer.gno @@ -0,0 +1,171 @@ +package pool + +import ( + "std" + + "gno.land/p/demo/ufmt" + + "gno.land/r/gnoswap/v1/common" + + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" +) + +// transferAndVerify performs a token transfer out of the pool while ensuring +// the pool has sufficient balance and updating internal accounting. +// This function is typically used during swaps and liquidity removals. +// +// Important requirements: +// - The amount must be negative (representing an outflow from the pool) +// - The pool must have sufficient balance for the transfer +// - The transfer amount must fit within uint64 range +// +// Parameters: +// - to: destination address for the transfer +// - tokenPath: path identifier of the token to transfer +// - amount: amount to transfer (must be negative) +// - isToken0: true if transferring token0, false for token1 +// +// The function will: +// 1. Validate the amount is negative +// 2. Check pool has sufficient balance +// 3. Execute the transfer +// 4. Update pool's internal balance +// +// Panics if any validation fails or if the transfer fails +func (p *Pool) transferAndVerify( + to std.Address, + tokenPath string, + amount *i256.Int, + isToken0 bool, +) { + if amount.Sign() != -1 { + panic(ufmt.Sprintf( + "%v. got: %s", errMustBeNegative, amount.ToString(), + )) + } + + absAmount := amount.Abs() + + token0 := p.balances.token0 + token1 := p.balances.token1 + + if err := validatePoolBalance(token0, token1, absAmount, isToken0); err != nil { + panic(err) + } + amountUint64, err := safeConvertToUint64(absAmount) + if err != nil { + panic(err) + } + + token := common.GetTokenTeller(tokenPath) + checkTransferError(token.Transfer(to, amountUint64)) + + newBalance, err := updatePoolBalance(token0, token1, absAmount, isToken0) + if err != nil { + panic(err) + } + + if isToken0 { + p.balances.token0 = newBalance + } else { + p.balances.token1 = newBalance + } +} + +// transferFromAndVerify performs a token transfer into the pool using transferFrom +// while updating the pool's internal accounting. This function is typically used +// during swaps and liquidity additions. +// +// The function assumes the sender has approved the pool to spend their tokens. +// +// Parameters: +// - from: source address for the transfer +// - to: destination address (typically the pool) +// - tokenPath: path identifier of the token to transfer +// - amount: amount to transfer (must be positive) +// - isToken0: true if transferring token0, false for token1 +// +// The function will: +// 1. Convert amount to uint64 (must fit) +// 2. Execute the transferFrom +// 3. Update pool's internal balance +// +// Panics if the amount conversion fails or if the transfer fails +func (p *Pool) transferFromAndVerify( + from, to std.Address, + tokenPath string, + amount *u256.Uint, + isToken0 bool, +) { + absAmount := amount + amountUint64, err := safeConvertToUint64(absAmount) + if err != nil { + panic(err) + } + + token := common.GetTokenTeller(tokenPath) + checkTransferError(token.TransferFrom(from, to, amountUint64)) + + // update pool balances + if isToken0 { + p.balances.token0 = new(u256.Uint).Add(p.balances.token0, absAmount) + } else { + p.balances.token1 = new(u256.Uint).Add(p.balances.token1, absAmount) + } +} + +// validatePoolBalance checks if the pool has sufficient balance of either token0 and token1 +// before proceeding with a transfer. This prevents the pool won't go into a negative balance. +func validatePoolBalance(token0, token1, amount *u256.Uint, isToken0 bool) error { + if isToken0 { + if token0.Lt(amount) { + return ufmt.Errorf( + "%v. token0(%s) >= amount(%s)", + errTransferFailed, token0.ToString(), amount.ToString(), + ) + } + return nil + } + if token1.Lt(amount) { + return ufmt.Errorf( + "%v. token1(%s) >= amount(%s)", + errTransferFailed, token1.ToString(), amount.ToString(), + ) + } + return nil +} + +// updatePoolBalance calculates the new balance after a transfer and validate. +// It ensures the resulting balance won't be negative or overflow. +func updatePoolBalance( + token0, token1, amount *u256.Uint, + isToken0 bool, +) (*u256.Uint, error) { + var overflow bool + var newBalance *u256.Uint + + if isToken0 { + newBalance, overflow = new(u256.Uint).SubOverflow(token0, amount) + if isBalanceOverflowOrNegative(overflow, newBalance) { + return nil, ufmt.Errorf( + "%v. cannot decrease, token0(%s) - amount(%s)", + errTransferFailed, token0.ToString(), amount.ToString(), + ) + } + return newBalance, nil + } + + newBalance, overflow = new(u256.Uint).SubOverflow(token1, amount) + if isBalanceOverflowOrNegative(overflow, newBalance) { + return nil, ufmt.Errorf( + "%v. cannot decrease, token1(%s) - amount(%s)", + errTransferFailed, token1.ToString(), amount.ToString(), + ) + } + return newBalance, nil +} + +func isBalanceOverflowOrNegative(overflow bool, newBalance *u256.Uint) bool { + return overflow || newBalance.Lt(u256.Zero()) +} diff --git a/pool/pool_transfer_test.gno b/pool/pool_transfer_test.gno new file mode 100644 index 000000000..562825b9a --- /dev/null +++ b/pool/pool_transfer_test.gno @@ -0,0 +1,298 @@ +package pool + +import ( + "std" + "testing" + + "gno.land/p/demo/testutils" + pusers "gno.land/p/demo/users" + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/r/gnoswap/v1/consts" +) + +func TestTransferAndVerify(t *testing.T) { + // Setup common test data + pool := &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(1000), + }, + } + + t.Run("validatePoolBalance", func(t *testing.T) { + tests := []struct { + name string + amount *u256.Uint + isToken0 bool + expectedError bool + }{ + { + name: "must success for negative amount", + amount: u256.NewUint(500), + isToken0: true, + expectedError: false, + }, + { + name: "must panic for insufficient token0 balance", + amount: u256.NewUint(1500), + isToken0: true, + expectedError: true, + }, + { + name: "must success for negative amount", + amount: u256.NewUint(500), + isToken0: false, + expectedError: false, + }, + { + name: "must panic for insufficient token1 balance", + amount: u256.NewUint(1500), + isToken0: false, + expectedError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token0 := pool.balances.token0 + token1 := pool.balances.token1 + + err := validatePoolBalance(token0, token1, tt.amount, tt.isToken0) + if err != nil { + if !tt.expectedError { + t.Errorf("unexpected error: %v", err) + } + } + }) + } + }) +} + +func TestTransferFromAndVerify(t *testing.T) { + tests := []struct { + name string + pool *Pool + from std.Address + to std.Address + tokenPath string + amount *i256.Int + isToken0 bool + expectedBal0 *u256.Uint + expectedBal1 *u256.Uint + }{ + { + name: "normal token0 transfer", + pool: &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + from: testutils.TestAddress("from_addr"), + to: testutils.TestAddress("to_addr"), + tokenPath: fooPath, + amount: i256.NewInt(500), + isToken0: true, + expectedBal0: u256.NewUint(1500), // 1000 + 500 + expectedBal1: u256.NewUint(2000), // unchanged + }, + { + name: "normal token1 transfer", + pool: &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + from: testutils.TestAddress("from_addr"), + to: testutils.TestAddress("to_addr"), + tokenPath: fooPath, + amount: i256.NewInt(800), + isToken0: false, + expectedBal0: u256.NewUint(1000), // unchanged + expectedBal1: u256.NewUint(2800), // 2000 + 800 + }, + { + name: "zero value transfer", + pool: &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + from: testutils.TestAddress("from_addr"), + to: testutils.TestAddress("to_addr"), + tokenPath: fooPath, + amount: i256.NewInt(0), + isToken0: true, + expectedBal0: u256.NewUint(1000), // unchanged + expectedBal1: u256.NewUint(2000), // unchanged + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + TokenFaucet(t, fooPath, pusers.AddressOrName(tt.from)) + TokenApprove(t, fooPath, pusers.AddressOrName(tt.from), pool, u256.MustFromDecimal(tt.amount.ToString()).Uint64()) + + tt.pool.transferFromAndVerify(tt.from, tt.to, tt.tokenPath, u256.MustFromDecimal(tt.amount.ToString()), tt.isToken0) + + if !tt.pool.balances.token0.Eq(tt.expectedBal0) { + t.Errorf("token0 balance mismatch: expected %s, got %s", + tt.expectedBal0.ToString(), + tt.pool.balances.token0.ToString()) + } + + if !tt.pool.balances.token1.Eq(tt.expectedBal1) { + t.Errorf("token1 balance mismatch: expected %s, got %s", + tt.expectedBal1.ToString(), + tt.pool.balances.token1.ToString()) + } + }) + } + + t.Run("negative value handling", func(t *testing.T) { + pool := &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + } + + negativeAmount := i256.NewInt(-500) + + TokenFaucet(t, fooPath, pusers.AddressOrName(testutils.TestAddress("from_addr"))) + TokenApprove(t, fooPath, pusers.AddressOrName(testutils.TestAddress("from_addr")), pusers.AddressOrName(consts.POOL_ADDR), u256.MustFromDecimal(negativeAmount.Abs().ToString()).Uint64()) + pool.transferFromAndVerify( + testutils.TestAddress("from_addr"), + testutils.TestAddress("to_addr"), + fooPath, + u256.MustFromDecimal(negativeAmount.Abs().ToString()), + true, + ) + + expectedBal := u256.NewUint(1500) // 1000 + 500 (absolute value) + if !pool.balances.token0.Eq(expectedBal) { + t.Errorf("negative amount handling failed: expected %s, got %s", + expectedBal.ToString(), + pool.balances.token0.ToString()) + } + }) + + t.Run("uint64 overflow value", func(t *testing.T) { + pool := &Pool{ + balances: Balances{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + } + + hugeAmount := i256.FromUint256(u256.MustFromDecimal("18446744073709551616")) // 2^64 + + defer func() { + if r := recover(); r == nil { + t.Error("expected panic for amount exceeding uint64 range") + } + }() + + pool.transferFromAndVerify( + testutils.TestAddress("from_addr"), + testutils.TestAddress("to_addr"), + fooPath, + u256.MustFromDecimal(hugeAmount.ToString()), + true, + ) + }) +} + +func TestUpdatePoolBalance(t *testing.T) { + tests := []struct { + name string + initialToken0 *u256.Uint + initialToken1 *u256.Uint + amount *u256.Uint + isToken0 bool + expectedBal *u256.Uint + expectErr bool + }{ + { + name: "normal token0 decrease", + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(300), + isToken0: true, + expectedBal: u256.NewUint(700), + expectErr: false, + }, + { + name: "normal token1 decrease", + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(500), + isToken0: false, + expectedBal: u256.NewUint(1500), + expectErr: false, + }, + { + name: "insufficient token0 balance", + initialToken0: u256.NewUint(100), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(200), + isToken0: true, + expectedBal: nil, + expectErr: true, + }, + { + name: "insufficient token1 balance", + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(100), + amount: u256.NewUint(200), + isToken0: false, + expectedBal: nil, + expectErr: true, + }, + { + name: "zero value handling", + initialToken0: u256.NewUint(1000), + initialToken1: u256.NewUint(2000), + amount: u256.NewUint(0), + isToken0: true, + expectedBal: u256.NewUint(1000), + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pool := &Pool{ + balances: Balances{ + token0: tt.initialToken0, + token1: tt.initialToken1, + }, + } + + newBal, err := updatePoolBalance(tt.initialToken0, tt.initialToken1, tt.amount, tt.isToken0) + + if tt.expectErr { + if err == nil { + t.Errorf("%s: expected error but no error", tt.name) + } + return + } + if err != nil { + t.Errorf("%s: unexpected error: %v", tt.name, err) + return + } + + if !newBal.Eq(tt.expectedBal) { + t.Errorf("%s: balance mismatch, expected: %s, actual: %s", + tt.name, + tt.expectedBal.ToString(), + newBal.ToString(), + ) + } + }) + } +} diff --git a/pool/position_modify.gno b/pool/position_modify.gno index 0400b126a..c6c574005 100644 --- a/pool/position_modify.gno +++ b/pool/position_modify.gno @@ -4,14 +4,29 @@ import ( "gno.land/r/gnoswap/v1/common" i256 "gno.land/p/gnoswap/int256" - u256 "gno.land/p/gnoswap/uint256" plp "gno.land/p/gnoswap/pool" + u256 "gno.land/p/gnoswap/uint256" ) -// modifyPosition updates a position in the pool and calculates the amount of tokens to be added or removed. -// Returns positionInfo, amount0, amount1 -func (pool *Pool) modifyPosition(params ModifyPositionParams) (PositionInfo, *u256.Uint, *u256.Uint) { - position := pool.updatePosition(params) +// modifyPosition updates a position in the pool and calculates the amount of tokens +// needed (for minting) or returned (for burning). The calculation depends on the current +// price (tick) relative to the position's price range. +// +// The function handles three cases: +// 1. Current price below range (tick < tickLower): only token0 is used/returned +// 2. Current price in range (tickLower <= tick < tickUpper): both tokens are used/returned +// 3. Current price above range (tick >= tickUpper): only token1 is used/returned +// +// Parameters: +// - params: ModifyPositionParams containing owner, tickLower, tickUpper, and liquidityDelta +// +// Returns: +// - PositionInfo: updated position information +// - *u256.Uint: amount of token0 needed/returned +// - *u256.Uint: amount of token1 needed/returned +func (p *Pool) modifyPosition(params ModifyPositionParams) (PositionInfo, *u256.Uint, *u256.Uint) { + // update position state + position := p.updatePosition(params) liqDelta := params.liquidityDelta if liqDelta.IsZero() { @@ -20,25 +35,36 @@ func (pool *Pool) modifyPosition(params ModifyPositionParams) (PositionInfo, *u2 amount0, amount1 := i256.Zero(), i256.Zero() - tick := pool.slot0.tick + // get current state and price bounds + tick := p.slot0.tick + // covert ticks to sqrt price to use in amount calculations + // price = 1.0001^tick, but we use sqrtPriceX96 sqrtRatioLower := common.TickMathGetSqrtRatioAtTick(params.tickLower) sqrtRatioUpper := common.TickMathGetSqrtRatioAtTick(params.tickUpper) - sqrtPriceX96 := pool.slot0.sqrtPriceX96 + sqrtPriceX96 := p.slot0.sqrtPriceX96 - // calculate amount0, amount1 based on the current tick position + // calculate token amounts based on current price position relative to range switch { case tick < params.tickLower: + // case 1 + // full range between lower and upper tick is used for token0 amount0 = calculateToken0Amount(sqrtRatioLower, sqrtRatioUpper, liqDelta) case tick < params.tickUpper: - liquidityBefore := pool.liquidity + // case 2 + liquidityBefore := p.liquidity + // token0 used from current price to upper tick amount0 = calculateToken0Amount(sqrtPriceX96, sqrtRatioUpper, liqDelta) + // token1 used from lower tick to current price amount1 = calculateToken1Amount(sqrtRatioLower, sqrtPriceX96, liqDelta) - pool.liquidity = liquidityMathAddDelta(liquidityBefore, liqDelta) + // update pool's active liquidity since price is in range + p.liquidity = liquidityMathAddDelta(liquidityBefore, liqDelta) default: + // case 3 + // full range between lower and upper tick is used for token1 amount1 = calculateToken1Amount(sqrtRatioLower, sqrtRatioUpper, liqDelta) } diff --git a/pool/swap.gno b/pool/swap.gno new file mode 100644 index 000000000..47c5e2d67 --- /dev/null +++ b/pool/swap.gno @@ -0,0 +1,524 @@ +package pool + +import ( + "std" + + "gno.land/p/demo/ufmt" + "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/consts" + + i256 "gno.land/p/gnoswap/int256" + plp "gno.land/p/gnoswap/pool" // pool package + u256 "gno.land/p/gnoswap/uint256" +) + +// SwapResult encapsulates all state changes that occur as a result of a swap +// This type ensure all state transitions are atomic and can be applied at once. +type SwapResult struct { + Amount0 *i256.Int + Amount1 *i256.Int + NewSqrtPrice *u256.Uint + NewTick int32 + NewLiquidity *u256.Uint + NewProtocolFees ProtocolFees + FeeGrowthGlobal0X128 *u256.Uint + FeeGrowthGlobal1X128 *u256.Uint +} + +// SwapComputation encapsulates pure computation logic for swap +type SwapComputation struct { + AmountSpecified *i256.Int + SqrtPriceLimitX96 *u256.Uint + ZeroForOne bool + ExactInput bool + InitialState SwapState + Cache SwapCache +} + +var ( + fixedPointQ128 = u256.MustFromDecimal(consts.Q128) +) + +// Swap swaps token0 for token1, or token1 for token0 +// Returns swapped amount0, amount1 in string +// ref: https://docs.gnoswap.io/contracts/pool/pool.gno#swap +func Swap( + token0Path string, + token1Path string, + fee uint32, + recipient std.Address, + zeroForOne bool, + amountSpecified string, + sqrtPriceLimitX96 string, + payer std.Address, // router +) (string, string) { + common.IsHalted() + if common.GetLimitCaller() { + caller := std.PrevRealm().Addr() + if err := common.RouterOnly(caller); err != nil { + panic(addDetailToError( + errNoPermission, + ufmt.Sprintf("only router(%s) can call pool swap(), called from %s", consts.ROUTER_ADDR, caller.String()), + )) + } + } + + if amountSpecified == "0" { + panic(addDetailToError( + errInvalidSwapAmount, + ufmt.Sprintf("amountSpecified == 0"), + )) + } + + pool := GetPool(token0Path, token1Path, fee) + + slot0Start := pool.slot0 + if !slot0Start.unlocked { + panic(errLockedPool) + } + pool.slot0.unlocked = false + defer func() { pool.slot0.unlocked = true }() + + sqrtPriceLimit := u256.MustFromDecimal(sqrtPriceLimitX96) + validatePriceLimits(slot0Start, zeroForOne, sqrtPriceLimit) + + amounts := i256.MustFromDecimal(amountSpecified) + feeGrowthGlobalX128 := getFeeGrowthGlobal(pool, zeroForOne) + feeProtocol := getFeeProtocol(slot0Start, zeroForOne) + cache := newSwapCache(feeProtocol, pool.liquidity.Clone()) + state := newSwapState(amounts, feeGrowthGlobalX128, cache.liquidityStart, slot0Start) + + comp := SwapComputation{ + AmountSpecified: amounts, + SqrtPriceLimitX96: sqrtPriceLimit, + ZeroForOne: zeroForOne, + ExactInput: amounts.Gt(i256.Zero()), + InitialState: state, + Cache: cache, + } + + result, err := computeSwap(pool, comp) + if err != nil { + panic(err) + } + + applySwapResult(pool, result) + + // actual swap + pool.swapTransfers(zeroForOne, payer, recipient, result.Amount0, result.Amount1) + + prevAddr, prevPkgPath := getPrev() + + std.Emit( + "Swap", + "prevAddr", prevAddr, + "prevRealm", prevPkgPath, + "poolPath", GetPoolPath(token0Path, token1Path, fee), + "zeroForOne", ufmt.Sprintf("%t", zeroForOne), + "amountSpecified", amountSpecified, + "sqrtPriceLimitX96", sqrtPriceLimitX96, + "payer", payer.String(), + "recipient", recipient.String(), + "internal_amount0", result.Amount0.ToString(), + "internal_amount1", result.Amount1.ToString(), + "internal_protocolFee0", pool.protocolFees.token0.ToString(), + "internal_protocolFee1", pool.protocolFees.token1.ToString(), + "internal_sqrtPriceX96", pool.slot0.sqrtPriceX96.ToString(), + ) + + return result.Amount0.ToString(), result.Amount1.ToString() +} + +// DrySwap simulates a swap and returns the amount0, amount1 that would be received and a boolean indicating if the swap is possible +func DrySwap( + token0Path string, + token1Path string, + fee uint32, + zeroForOne bool, + amountSpecified string, + sqrtPriceLimitX96 string, +) (string, string, bool) { + if amountSpecified == "0" { + return "0", "0", false + } + + pool := GetPool(token0Path, token1Path, fee) + + slot0Start := pool.slot0 + sqrtPriceLimit := u256.MustFromDecimal(sqrtPriceLimitX96) + validatePriceLimits(slot0Start, zeroForOne, sqrtPriceLimit) + + amounts := i256.MustFromDecimal(amountSpecified) + feeGrowthGlobalX128 := getFeeGrowthGlobal(pool, zeroForOne) + feeProtocol := getFeeProtocol(slot0Start, zeroForOne) + cache := newSwapCache(feeProtocol, pool.liquidity.Clone()) + state := newSwapState(amounts, feeGrowthGlobalX128, cache.liquidityStart, slot0Start) + + comp := SwapComputation{ + AmountSpecified: amounts, + SqrtPriceLimitX96: sqrtPriceLimit, + ZeroForOne: zeroForOne, + ExactInput: amounts.Gt(i256.Zero()), + InitialState: state, + Cache: cache, + } + + result, err := computeSwap(pool, comp) + if err != nil { + return "0", "0", false + } + + if zeroForOne { + if pool.balances.token1.Lt(result.Amount1.Abs()) { + return "0", "0", false + } + } else { + if pool.balances.token0.Lt(result.Amount0.Abs()) { + return "0", "0", false + } + } + + // Validate non-zero amounts + if result.Amount0.IsZero() || result.Amount1.IsZero() { + return "0", "0", false + } + + return result.Amount0.ToString(), result.Amount1.ToString(), true +} + +// computeSwap performs the core swap computation without modifying pool state +// The function follows these state transitions: +// 1. Initial State: Provided by `SwapComputation.InitialState` +// 2. Stepping State: For each step: +// - Compute next tick and price target +// - Calculate amounts and fees +// - Update state (remaining amount, fees, liquidity) +// - Handle tick transitions if necessary +// +// 3. Final State: Aggregated in SwapResult +// +// The computation continues until either: +// - The entire amount is consumed (`amountSpecifiedRemaining` = 0) +// - The price limit is reached (`sqrtPriceX96` = `sqrtPriceLimitX96`) +// +// Returns an error if the computation fails at any step +func computeSwap(pool *Pool, comp SwapComputation) (*SwapResult, error) { + state := comp.InitialState + var err error + + // Compute swap steps until completion + for shouldContinueSwap(state, comp.SqrtPriceLimitX96) { + state, err = computeSwapStep(state, pool, comp.ZeroForOne, comp.SqrtPriceLimitX96, comp.ExactInput, comp.Cache) + if err != nil { + return nil, err + } + } + + // Calculate final amounts + amount0 := state.amountCalculated + amount1 := i256.Zero().Sub(comp.AmountSpecified, state.amountSpecifiedRemaining) + if comp.ZeroForOne == comp.ExactInput { + amount0, amount1 = amount1, amount0 + } + + // Prepare result + result := &SwapResult{ + Amount0: amount0, + Amount1: amount1, + NewSqrtPrice: state.sqrtPriceX96, + NewTick: state.tick, + NewLiquidity: state.liquidity, + NewProtocolFees: ProtocolFees{ + token0: pool.protocolFees.token0, + token1: pool.protocolFees.token1, + }, + FeeGrowthGlobal0X128: pool.feeGrowthGlobal0X128, + FeeGrowthGlobal1X128: pool.feeGrowthGlobal1X128, + } + + // Update protocol fees if necessary + if comp.ZeroForOne { + if state.protocolFee.Gt(u256.Zero()) { + result.NewProtocolFees.token0 = new(u256.Uint).Add(result.NewProtocolFees.token0, state.protocolFee) + } + result.FeeGrowthGlobal0X128 = state.feeGrowthGlobalX128 + } else { + if state.protocolFee.Gt(u256.Zero()) { + result.NewProtocolFees.token1 = new(u256.Uint).Add(result.NewProtocolFees.token1, state.protocolFee) + } + result.FeeGrowthGlobal1X128 = state.feeGrowthGlobalX128 + } + + return result, nil +} + +// applySwapResult updates pool state with computed results. +// All state changes are applied at once to maintain consistency +func applySwapResult(pool *Pool, result *SwapResult) { + pool.slot0.sqrtPriceX96 = result.NewSqrtPrice + pool.slot0.tick = result.NewTick + pool.liquidity = result.NewLiquidity + pool.protocolFees = result.NewProtocolFees + pool.feeGrowthGlobal0X128 = result.FeeGrowthGlobal0X128 + pool.feeGrowthGlobal1X128 = result.FeeGrowthGlobal1X128 +} + +// validatePriceLimits ensures the provided price limit is valid for the swap direction +// The function enforces that: +// For zeroForOne (selling token0): +// - Price limit must be below current price +// - Price limit must be above MIN_SQRT_RATIO +// +// For !zeroForOne (selling token1): +// - Price limit must be above current price +// - Price limit must be below MAX_SQRT_RATIO +func validatePriceLimits(slot0 Slot0, zeroForOne bool, sqrtPriceLimitX96 *u256.Uint) { + if zeroForOne { + minSqrtRatio := u256.MustFromDecimal(consts.MIN_SQRT_RATIO) + + cond1 := sqrtPriceLimitX96.Lt(slot0.sqrtPriceX96) + cond2 := sqrtPriceLimitX96.Gt(minSqrtRatio) + if !(cond1 && cond2) { + panic(addDetailToError( + errPriceOutOfRange, + ufmt.Sprintf("sqrtPriceLimitX96(%s) < slot0Start.sqrtPriceX96(%s) && sqrtPriceLimitX96(%s) > consts.MIN_SQRT_RATIO(%s)", + sqrtPriceLimitX96.ToString(), + slot0.sqrtPriceX96.ToString(), + sqrtPriceLimitX96.ToString(), + consts.MIN_SQRT_RATIO), + )) + } + } else { + maxSqrtRatio := u256.MustFromDecimal(consts.MAX_SQRT_RATIO) + + cond1 := sqrtPriceLimitX96.Gt(slot0.sqrtPriceX96) + cond2 := sqrtPriceLimitX96.Lt(maxSqrtRatio) + if !(cond1 && cond2) { + panic(addDetailToError( + errPriceOutOfRange, + ufmt.Sprintf("sqrtPriceLimitX96(%s) > slot0Start.sqrtPriceX96(%s) && sqrtPriceLimitX96(%s) < consts.MAX_SQRT_RATIO(%s)", + sqrtPriceLimitX96.ToString(), + slot0.sqrtPriceX96.ToString(), + sqrtPriceLimitX96.ToString(), + consts.MAX_SQRT_RATIO), + )) + } + } +} + +// getFeeProtocol returns the appropriate fee protocol based on zero for one +func getFeeProtocol(slot0 Slot0, zeroForOne bool) uint8 { + if zeroForOne { + return slot0.feeProtocol % 16 + } + return slot0.feeProtocol / 16 +} + +// getFeeGrowthGlobal returns the appropriate fee growth global based on zero for one +func getFeeGrowthGlobal(pool *Pool, zeroForOne bool) *u256.Uint { + if zeroForOne { + return pool.feeGrowthGlobal0X128.Clone() + } + return pool.feeGrowthGlobal1X128.Clone() +} + +func shouldContinueSwap(state SwapState, sqrtPriceLimitX96 *u256.Uint) bool { + return !(state.amountSpecifiedRemaining.IsZero()) && !(state.sqrtPriceX96.Eq(sqrtPriceLimitX96)) +} + +// computeSwapStep executes a single step of swap and returns new state +func computeSwapStep( + state SwapState, + pool *Pool, + zeroForOne bool, + sqrtPriceLimitX96 *u256.Uint, + exactInput bool, + cache SwapCache, +) (SwapState, error) { + step := computeSwapStepInit(state, pool, zeroForOne) + + // determining the price target for this step + sqrtRatioTargetX96 := computeTargetSqrtRatio(step, sqrtPriceLimitX96, zeroForOne) + + // computing the amounts to be swapped at this step + var newState SwapState + var err error + + newState, step = computeAmounts(state, sqrtRatioTargetX96, pool, step) + newState = updateAmounts(step, newState, exactInput) + + // if the protocol fee is on, calculate how much is owed, + // decrement fee amount, and increment protocol fee + if cache.feeProtocol > 0 { + newState, step, err = updateFeeProtocol(step, cache.feeProtocol, newState) + if err != nil { + return state, err + } + } + + // update global fee tracker + if newState.liquidity.Gt(u256.Zero()) { + update := u256.MulDiv(step.feeAmount, fixedPointQ128, newState.liquidity) + newState.SetFeeGrowthGlobalX128(new(u256.Uint).Add(newState.feeGrowthGlobalX128, update)) + } + + // handling tick transitions + if newState.sqrtPriceX96.Eq(step.sqrtPriceNextX96) { + newState = tickTransition(step, zeroForOne, newState, pool) + } else if newState.sqrtPriceX96.Neq(step.sqrtPriceStartX96) { + newState.SetTick(common.TickMathGetTickAtSqrtRatio(newState.sqrtPriceX96)) + } + + return newState, nil +} + +// updateFeeProtocol calculates and updates protocol fees for the current step. +func updateFeeProtocol(step StepComputations, feeProtocol uint8, state SwapState) (SwapState, StepComputations, error) { + delta := step.feeAmount + delta.Div(delta, u256.NewUint(uint64(feeProtocol))) + + newFeeAmount, overflow := new(u256.Uint).SubOverflow(step.feeAmount, delta) + if overflow { + return state, step, errUnderflow + } + step.feeAmount = newFeeAmount + state.protocolFee.Add(state.protocolFee, delta) + + return state, step, nil +} + +// computeSwapStepInit initializes the computation for a single swap step. +func computeSwapStepInit(state SwapState, pool *Pool, zeroForOne bool) StepComputations { + var step StepComputations + step.sqrtPriceStartX96 = state.sqrtPriceX96 + tickNext, initialized := pool.tickBitmapNextInitializedTickWithInOneWord( + state.tick, + pool.tickSpacing, + zeroForOne, + ) + + step.tickNext = tickNext + step.initialized = initialized + + // prevent overshoot the min/max tick + step.clampTickNext() + + // get the price for the next tick + step.sqrtPriceNextX96 = common.TickMathGetSqrtRatioAtTick(step.tickNext) + return step +} + +// computeTargetSqrtRatio determines the target sqrt price for the current swap step. +func computeTargetSqrtRatio(step StepComputations, sqrtPriceLimitX96 *u256.Uint, zeroForOne bool) *u256.Uint { + if shouldUsePriceLimit(step.sqrtPriceNextX96, sqrtPriceLimitX96, zeroForOne) { + return sqrtPriceLimitX96 + } + return step.sqrtPriceNextX96 +} + +// shouldUsePriceLimit returns true if the price limit should be used instead of the next tick price +func shouldUsePriceLimit(sqrtPriceNext, sqrtPriceLimit *u256.Uint, zeroForOne bool) bool { + isLower := sqrtPriceNext.Lt(sqrtPriceLimit) + isHigher := sqrtPriceNext.Gt(sqrtPriceLimit) + if zeroForOne { + return isLower + } + return isHigher +} + +// computeAmounts calculates the input and output amounts for the current swap step. +func computeAmounts(state SwapState, sqrtRatioTargetX96 *u256.Uint, pool *Pool, step StepComputations) (SwapState, StepComputations) { + sqrtPriceX96Str, amountInStr, amountOutStr, feeAmountStr := plp.SwapMathComputeSwapStepStr( + state.sqrtPriceX96, + sqrtRatioTargetX96, + state.liquidity, + state.amountSpecifiedRemaining, + uint64(pool.fee), + ) + + step.amountIn = u256.MustFromDecimal(amountInStr) + step.amountOut = u256.MustFromDecimal(amountOutStr) + step.feeAmount = u256.MustFromDecimal(feeAmountStr) + + state.SetSqrtPriceX96(sqrtPriceX96Str) + + return state, step +} + +// updateAmounts calculates new remaining and calculated amounts based on the swap step +// For exact input swaps: +// - Decrements remaining input amount by (amountIn + feeAmount) +// - Decrements calculated amount by amountOut +// +// For exact output swaps: +// - Increments remaining output amount by amountOut +// - Increments calculated amount by (amountIn + feeAmount) +func updateAmounts(step StepComputations, state SwapState, exactInput bool) SwapState { + amountInWithFeeU256 := new(u256.Uint).Add(step.amountIn, step.feeAmount) + if amountInWithFeeU256.Gt(u256.MustFromDecimal(consts.MAX_INT256)) { + panic("amountIn + feeAmount overflows int256") + } + amountInWithFee := i256.FromUint256(amountInWithFeeU256) + if step.amountOut.Gt(u256.MustFromDecimal(consts.MAX_INT256)) { + panic("amountOut overflows int256") + } + + if exactInput { + state.amountSpecifiedRemaining = i256.Zero().Sub(state.amountSpecifiedRemaining, amountInWithFee) + state.amountCalculated = i256.Zero().Sub(state.amountCalculated, i256.FromUint256(step.amountOut)) + return state + } else { + state.amountSpecifiedRemaining = i256.Zero().Add(state.amountSpecifiedRemaining, i256.FromUint256(step.amountOut)) + state.amountCalculated = i256.Zero().Add(state.amountCalculated, amountInWithFee) + } + + return state +} + +// tickTransition handles the transition between price ticks during a swap +func tickTransition(step StepComputations, zeroForOne bool, state SwapState, pool *Pool) SwapState { + // ensure existing state to keep immutability + newState := state + + if step.initialized { + fee0, fee1 := u256.Zero(), u256.Zero() + + if zeroForOne { + fee0 = state.feeGrowthGlobalX128 + fee1 = pool.feeGrowthGlobal1X128 + } else { + fee0 = pool.feeGrowthGlobal0X128 + fee1 = state.feeGrowthGlobalX128 + } + + liquidityNet := pool.tickCross(step.tickNext, fee0, fee1) + + if zeroForOne { + liquidityNet = i256.Zero().Neg(liquidityNet) + } + + newState.liquidity = liquidityMathAddDelta(state.liquidity, liquidityNet) + } + + if zeroForOne { + newState.tick = step.tickNext - 1 + } else { + newState.tick = step.tickNext + } + + return newState +} + +func (p *Pool) swapTransfers(zeroForOne bool, payer, recipient std.Address, amount0, amount1 *i256.Int) { + if zeroForOne { + // payer > POOL + p.transferFromAndVerify(payer, consts.POOL_ADDR, p.token0Path, amount0.Abs(), true) + // POOL > recipient + p.transferAndVerify(recipient, p.token1Path, amount1, false) + } else { + // payer > POOL + p.transferFromAndVerify(payer, consts.POOL_ADDR, p.token1Path, amount1.Abs(), false) + // POOL > recipient + p.transferAndVerify(recipient, p.token0Path, amount0, true) + } +} diff --git a/pool/swap_test.gno b/pool/swap_test.gno new file mode 100644 index 000000000..3b73a5139 --- /dev/null +++ b/pool/swap_test.gno @@ -0,0 +1,954 @@ +package pool + +import ( + "std" + "testing" + + "gno.land/p/demo/uassert" + "gno.land/r/demo/users" + + pusers "gno.land/p/demo/users" + i256 "gno.land/p/gnoswap/int256" + u256 "gno.land/p/gnoswap/uint256" + + "gno.land/r/gnoswap/v1/consts" +) + +func TestSaveProtocolFees(t *testing.T) { + tests := []struct { + name string + pool *Pool + amount0 *u256.Uint + amount1 *u256.Uint + want0 *u256.Uint + want1 *u256.Uint + wantFee0 *u256.Uint + wantFee1 *u256.Uint + }{ + { + name: "normal fee deduction", + pool: &Pool{ + protocolFees: ProtocolFees{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + amount0: u256.NewUint(500), + amount1: u256.NewUint(1000), + want0: u256.NewUint(500), + want1: u256.NewUint(1000), + wantFee0: u256.NewUint(500), + wantFee1: u256.NewUint(1000), + }, + { + name: "exact fee deduction (1 deduction)", + pool: &Pool{ + protocolFees: ProtocolFees{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + amount0: u256.NewUint(1000), + amount1: u256.NewUint(2000), + want0: u256.NewUint(999), + want1: u256.NewUint(1999), + wantFee0: u256.NewUint(1), + wantFee1: u256.NewUint(1), + }, + { + name: "0 fee deduction", + pool: &Pool{ + protocolFees: ProtocolFees{ + token0: u256.NewUint(1000), + token1: u256.NewUint(2000), + }, + }, + amount0: u256.NewUint(0), + amount1: u256.NewUint(0), + want0: u256.NewUint(0), + want1: u256.NewUint(0), + wantFee0: u256.NewUint(1000), + wantFee1: u256.NewUint(2000), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got0, got1 := tt.pool.saveProtocolFees(tt.amount0, tt.amount1) + + uassert.Equal(t, got0.ToString(), tt.want0.ToString()) + uassert.Equal(t, got1.ToString(), tt.want1.ToString()) + uassert.Equal(t, tt.pool.protocolFees.token0.ToString(), tt.wantFee0.ToString()) + uassert.Equal(t, tt.pool.protocolFees.token1.ToString(), tt.wantFee1.ToString()) + }) + } +} + +func TestShouldContinueSwap(t *testing.T) { + tests := []struct { + name string + state SwapState + sqrtPriceLimitX96 *u256.Uint + expected bool + }{ + { + name: "Should continue - amount remaining and price not at limit", + state: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000"), + sqrtPriceX96: u256.MustFromDecimal("1000000"), + }, + sqrtPriceLimitX96: u256.MustFromDecimal("900000"), + expected: true, + }, + { + name: "Should stop - no amount remaining", + state: SwapState{ + amountSpecifiedRemaining: i256.Zero(), + sqrtPriceX96: u256.MustFromDecimal("1000000"), + }, + sqrtPriceLimitX96: u256.MustFromDecimal("900000"), + expected: false, + }, + { + name: "Should stop - price at limit", + state: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000"), + sqrtPriceX96: u256.MustFromDecimal("900000"), + }, + sqrtPriceLimitX96: u256.MustFromDecimal("900000"), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := shouldContinueSwap(tt.state, tt.sqrtPriceLimitX96) + uassert.Equal(t, tt.expected, result) + }) + } +} + +func TestUpdateAmounts(t *testing.T) { + tests := []struct { + name string + step StepComputations + state SwapState + exactInput bool + expectedState SwapState + }{ + { + name: "Exact input update", + step: StepComputations{ + amountIn: u256.MustFromDecimal("100"), + amountOut: u256.MustFromDecimal("97"), + feeAmount: u256.MustFromDecimal("3"), + }, + state: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000"), + amountCalculated: i256.Zero(), + }, + exactInput: true, + expectedState: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("897"), // 1000 - (100 + 3) + amountCalculated: i256.MustFromDecimal("-97"), + }, + }, + { + name: "Exact output update", + step: StepComputations{ + amountIn: u256.MustFromDecimal("100"), + amountOut: u256.MustFromDecimal("97"), + feeAmount: u256.MustFromDecimal("3"), + }, + state: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("-1000"), + amountCalculated: i256.Zero(), + }, + exactInput: false, + expectedState: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("-903"), // -1000 + 97 + amountCalculated: i256.MustFromDecimal("103"), // 100 + 3 + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := updateAmounts(tt.step, tt.state, tt.exactInput) + + uassert.True(t, tt.expectedState.amountSpecifiedRemaining.Eq(result.amountSpecifiedRemaining)) + uassert.True(t, tt.expectedState.amountCalculated.Eq(result.amountCalculated)) + }) + } +} + +func TestComputeSwap(t *testing.T) { + mockPool := &Pool{ + token0Path: "token0", + token1Path: "token1", + fee: 3000, // 0.3% + tickSpacing: 60, + slot0: Slot0{ + sqrtPriceX96: u256.MustFromDecimal("1000000000000000000"), // 1.0 + tick: 0, + feeProtocol: 0, + unlocked: true, + }, + liquidity: u256.MustFromDecimal("1000000000000000000"), // 1.0 + protocolFees: ProtocolFees{ + token0: u256.Zero(), + token1: u256.Zero(), + }, + feeGrowthGlobal0X128: u256.Zero(), + feeGrowthGlobal1X128: u256.Zero(), + tickBitmaps: make(TickBitmaps), + ticks: make(Ticks), + positions: make(Positions), + } + + wordPos, _ := tickBitmapPosition(0) + // TODO: use avl + mockPool.tickBitmaps[wordPos] = u256.NewUint(1) + + t.Run("basic swap", func(t *testing.T) { + comp := SwapComputation{ + AmountSpecified: i256.MustFromDecimal("1000000"), // 1.0 token + SqrtPriceLimitX96: u256.MustFromDecimal("1100000000000000000"), // 1.1 + ZeroForOne: true, + ExactInput: true, + InitialState: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000000"), + amountCalculated: i256.Zero(), + sqrtPriceX96: mockPool.slot0.sqrtPriceX96, + tick: mockPool.slot0.tick, + feeGrowthGlobalX128: mockPool.feeGrowthGlobal0X128, + protocolFee: u256.Zero(), + liquidity: mockPool.liquidity, + }, + Cache: SwapCache{ + feeProtocol: 0, + liquidityStart: mockPool.liquidity, + }, + } + + result, err := computeSwap(mockPool, comp) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if result.Amount0.IsZero() { + t.Error("expected non-zero amount0") + } + if result.Amount1.IsZero() { + t.Error("expected non-zero amount1") + } + }) + + t.Run("swap with zero liquidity", func(t *testing.T) { + mockPoolZeroLiq := *mockPool + mockPoolZeroLiq.liquidity = u256.Zero() + + comp := SwapComputation{ + AmountSpecified: i256.MustFromDecimal("1000000"), + SqrtPriceLimitX96: u256.MustFromDecimal("1100000000000000000"), + ZeroForOne: true, + ExactInput: true, + InitialState: SwapState{ + amountSpecifiedRemaining: i256.MustFromDecimal("1000000"), + amountCalculated: i256.Zero(), + sqrtPriceX96: mockPoolZeroLiq.slot0.sqrtPriceX96, + tick: mockPoolZeroLiq.slot0.tick, + feeGrowthGlobalX128: mockPoolZeroLiq.feeGrowthGlobal0X128, + protocolFee: u256.Zero(), + liquidity: mockPoolZeroLiq.liquidity, + }, + Cache: SwapCache{ + feeProtocol: 0, + liquidityStart: mockPoolZeroLiq.liquidity, + }, + } + + result, err := computeSwap(&mockPoolZeroLiq, comp) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if !result.Amount0.IsZero() || !result.Amount1.IsZero() { + t.Error("expected zero amounts for zero liquidity") + } + }) +} + +func TestSwap_Failures(t *testing.T) { + t.Skip() + const addr = pusers.AddressOrName(consts.ROUTER_ADDR) + + tests := []struct { + name string + setupFn func(t *testing.T) + token0Path string + token1Path string + fee uint32 + recipient std.Address + zeroForOne bool + amountSpecified string + sqrtPriceLimitX96 string + payer std.Address + expectedAmount0 string + expectedAmount1 string + expectError bool + }{ + { + name: "locked pool", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + pool := GetPool(wugnotPath, gnsPath, fee3000) + pool.slot0.unlocked = false + }, + token0Path: wugnotPath, + token1Path: gnsPath, + fee: fee3000, + recipient: users.Resolve(addr), + zeroForOne: true, + amountSpecified: "100", + sqrtPriceLimitX96: "79228162514264337593543950336", + payer: users.Resolve(addr), + expectError: true, + }, + { + name: "zero amount", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + }, + token0Path: wugnotPath, + token1Path: gnsPath, + fee: fee3000, + recipient: users.Resolve(alice), + zeroForOne: true, + amountSpecified: "0", + sqrtPriceLimitX96: "79228162514264337593543950336", + payer: users.Resolve(alice), + expectError: true, + }, + { + name: "zero liquidity", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + pool := GetPool(wugnotPath, gnsPath, fee3000) + pool.liquidity = u256.Zero() + }, + token0Path: wugnotPath, + token1Path: gnsPath, + fee: fee3000, + recipient: users.Resolve(alice), + zeroForOne: true, + amountSpecified: "100", + sqrtPriceLimitX96: "79228162514264337593543950336", + payer: users.Resolve(alice), + expectedAmount0: "0", + expectedAmount1: "0", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetObject(t) + burnTokens(t) + + if tt.setupFn != nil { + tt.setupFn(t) + } + + std.TestSetOrigCaller(tt.payer) + + if tt.expectError { + defer func() { + if r := recover(); r == nil { + t.Errorf("error should be occurred but not occurred") + } + }() + } + + amount0, amount1 := Swap( + tt.token0Path, + tt.token1Path, + tt.fee, + tt.recipient, + tt.zeroForOne, + tt.amountSpecified, + tt.sqrtPriceLimitX96, + tt.payer, + ) + + if !tt.expectError { + uassert.Equal(t, amount0, tt.expectedAmount0) + uassert.Equal(t, amount1, tt.expectedAmount1) + } + }) + } +} + +func TestDrySwap_Failures(t *testing.T) { + mockPool := &Pool{ + token0Path: "token0", + token1Path: "token1", + fee: 3000, + tickSpacing: 60, + slot0: Slot0{ + sqrtPriceX96: u256.MustFromDecimal("1000000000000000000"), // 1.0 + tick: 0, + feeProtocol: 0, + unlocked: true, + }, + liquidity: u256.MustFromDecimal("1000000000000000000"), // 1.0 + balances: Balances{ + token0: u256.MustFromDecimal("1000000000"), + token1: u256.MustFromDecimal("1000000000"), + }, + protocolFees: ProtocolFees{ + token0: u256.Zero(), + token1: u256.Zero(), + }, + feeGrowthGlobal0X128: u256.Zero(), + feeGrowthGlobal1X128: u256.Zero(), + tickBitmaps: make(TickBitmaps), + ticks: make(Ticks), + positions: make(Positions), + } + + originalGetPool := GetPool + defer func() { + GetPool = originalGetPool + }() + GetPool = func(token0Path, token1Path string, fee uint32) *Pool { + return mockPool + } + + tests := []struct { + name string + token0Path string + token1Path string + fee uint32 + zeroForOne bool + amountSpecified string + sqrtPriceLimitX96 string + expectAmount0 string + expectAmount1 string + expectSuccess bool + }{ + { + name: "zero amount token0 to token1", + token0Path: "token0", + token1Path: "token1", + fee: 3000, + zeroForOne: true, + amountSpecified: "0", + sqrtPriceLimitX96: "900000000000000000", + expectAmount0: "0", + expectAmount1: "0", + expectSuccess: false, + }, + { + name: "insufficient balance", + token0Path: "token0", + token1Path: "token1", + fee: 3000, + zeroForOne: false, + amountSpecified: "2000000000", + sqrtPriceLimitX96: "1100000000000000000", + expectAmount0: "0", + expectAmount1: "0", + expectSuccess: false, + }, + { + name: "insufficient balance token1 to token0", + token0Path: "token0", + token1Path: "token1", + fee: 3000, + zeroForOne: false, + amountSpecified: "3000000000", + sqrtPriceLimitX96: "1100000000000000000", + expectAmount0: "0", + expectAmount1: "0", + expectSuccess: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + amount0, amount1, success := DrySwap( + tt.token0Path, + tt.token1Path, + tt.fee, + tt.zeroForOne, + tt.amountSpecified, + tt.sqrtPriceLimitX96, + ) + + uassert.Equal(t, success, tt.expectSuccess) + uassert.Equal(t, amount0, tt.expectAmount0) + uassert.Equal(t, amount1, tt.expectAmount1) + }) + } +} + +func TestSwapAndDrySwapComparison(t *testing.T) { + const addr = pusers.AddressOrName(consts.ROUTER_ADDR) + + tests := []struct { + name string + setupFn func(t *testing.T) + token0Path string + token1Path string + fee uint32 + zeroForOne bool + amountSpecified string + sqrtPriceLimitX96 string + }{ + { + name: "normal swap", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, uint64(1000)) + }, + token0Path: wugnotPath, + token1Path: gnsPath, + fee: fee3000, + zeroForOne: false, + amountSpecified: "100", + sqrtPriceLimitX96: maxSqrtPriceLimitX96, + }, + { + name: "swap - request to swap amount over total liquidty", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, users.Resolve(admin)) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + }, + token0Path: wugnotPath, + token1Path: gnsPath, + fee: fee3000, + zeroForOne: false, + amountSpecified: "2000000000", + sqrtPriceLimitX96: maxSqrtPriceLimitX96, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetObject(t) + burnTokens(t) + + if tt.setupFn != nil { + tt.setupFn(t) + } + + dryAmount0, dryAmount1, drySuccess := DrySwap( + tt.token0Path, + tt.token1Path, + tt.fee, + tt.zeroForOne, + tt.amountSpecified, + tt.sqrtPriceLimitX96, + ) + + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + tt.token0Path, + tt.token1Path, + tt.fee, + users.Resolve(addr), + tt.zeroForOne, + tt.amountSpecified, + tt.sqrtPriceLimitX96, + users.Resolve(addr), + ) + + if !drySuccess { + t.Error("DrySwap failed but actual Swap succeeded") + } + + uassert.NotEqual(t, dryAmount0, "0", "amount0 should not be zero") + uassert.NotEqual(t, dryAmount1, "0", "amount1 should not be zero") + uassert.NotEqual(t, actualAmount0, "0", "amount0 should not be zero") + uassert.NotEqual(t, actualAmount1, "0", "amount1 should not be zero") + + uassert.Equal(t, dryAmount0, actualAmount0, + "Amount0 mismatch between DrySwap and actual Swap") + uassert.Equal(t, dryAmount1, actualAmount1, + "Amount1 mismatch between DrySwap and actual Swap") + }) + } +} + +func TestSwapAndDrySwapComparison_amount_zero(t *testing.T) { + const addr = pusers.AddressOrName(consts.ROUTER_ADDR) + + tests := []struct { + name string + setupFn func(t *testing.T) + action func(t *testing.T) + shouldPanic bool + expected string + }{ + { + name: "zero amount swap - zeroForOne = false", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, uint64(1000)) + }, + action: func(t *testing.T) { + dryAmount0, dryAmount1, drySuccess := DrySwap( + wugnotPath, + gnsPath, + fee3000, + false, + "0", + maxSqrtPriceLimitX96, + ) + uassert.Equal(t, "0", dryAmount0) + uassert.Equal(t, "0", dryAmount1) + uassert.Equal(t, false, drySuccess) + }, + shouldPanic: false, + expected: "[GNOSWAP-POOL-014] invalid swap amount || amountSpecified == 0", + }, + { + name: "zero amount swap - zeroForOne = true", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, uint64(1000)) + }, + action: func(t *testing.T) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + users.Resolve(addr), + true, + "0", + maxSqrtPriceLimitX96, + users.Resolve(addr), + ) + }, + shouldPanic: true, + expected: "[GNOSWAP-POOL-014] invalid swap amount || amountSpecified == 0", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + r := recover() + if r == nil { + if tt.shouldPanic { + t.Errorf(">>> %s: expected panic but got none", tt.name) + return + } + } else { + switch r.(type) { + case string: + if r.(string) != tt.expected { + t.Errorf(">>> %s: got panic %v, want %v", tt.name, r, tt.expected) + } + case error: + if r.(error).Error() != tt.expected { + t.Errorf(">>> %s: got panic %v, want %v", tt.name, r.(error).Error(), tt.expected) + } + default: + t.Errorf(">>> %s: got panic %v, want %v", tt.name, r, tt.expected) + } + } + }() + + resetObject(t) + burnTokens(t) + + if tt.setupFn != nil { + tt.setupFn(t) + } + + if tt.shouldPanic { + tt.action(t) + } else { + tt.action(t) + } + }) + } +} + +func TestSwap_amount_over_liquidity(t *testing.T) { + const addr = pusers.AddressOrName(consts.ROUTER_ADDR) + + tests := []struct { + name string + setupFn func(t *testing.T) + action func(t *testing.T) (string, string) + shouldPanic bool + expected []string + }{ + { + name: "amount over liquidity - zeroForOne = false, token0:20000", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, users.Resolve(admin)) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + }, + action: func(t *testing.T) (string, string) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + users.Resolve(alice), + false, + "20000", + consts.MAX_PRICE, + users.Resolve(addr), + ) + return actualAmount0, actualAmount1 + }, + shouldPanic: false, + expected: []string{"-1989", "2404"}, + }, + { + name: "amount over liquidity - zeroForOne = false, token1:-20000", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, users.Resolve(admin)) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + }, + action: func(t *testing.T) (string, string) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + users.Resolve(alice), + false, + "-20000", + consts.MAX_PRICE, + users.Resolve(addr), + ) + return actualAmount0, actualAmount1 + }, + shouldPanic: false, + expected: []string{"-1989", "2404"}, + }, + { + name: "amount over liquidity - zeroForOne = true, token0:20000", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, users.Resolve(admin)) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + TokenApprove(t, wugnotPath, addr, pool, maxApprove) + }, + action: func(t *testing.T) (string, string) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + users.Resolve(alice), + true, + "20000", + consts.MIN_PRICE, + users.Resolve(addr), + ) + return actualAmount0, actualAmount1 + }, + shouldPanic: false, + expected: []string{"1045", "-990"}, + }, + { + name: "amount over liquidity - zeroForOne = true, token1:-20000", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, users.Resolve(admin)) + TokenFaucet(t, gnsPath, addr) + TokenFaucet(t, wugnotPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + TokenApprove(t, wugnotPath, addr, pool, maxApprove) + }, + action: func(t *testing.T) (string, string) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + users.Resolve(alice), + true, + "-20000", + consts.MIN_PRICE, + users.Resolve(addr), + ) + return actualAmount0, actualAmount1 + }, + shouldPanic: false, + expected: []string{"1045", "-990"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetObject(t) + burnTokens(t) + + if tt.setupFn != nil { + tt.setupFn(t) + } + + if tt.shouldPanic { + tt.action(t) + } else { + amount0, amount1 := tt.action(t) + uassert.Equal(t, tt.expected[0], amount0) + uassert.Equal(t, tt.expected[1], amount1) + } + }) + } +} + +func TestSwap_EXACTIN_OUT(t *testing.T) { + const addr = pusers.AddressOrName(consts.ROUTER_ADDR) + + tests := []struct { + name string + setupFn func(t *testing.T) + action func(t *testing.T) (string, string) + shouldPanic bool + expected []string + }{ + { + name: "EXACT IN - zeroForOne = false, token1:200", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, users.Resolve(admin)) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + }, + action: func(t *testing.T) (string, string) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + users.Resolve(alice), + false, + "200", + consts.MAX_PRICE, + users.Resolve(addr), + ) + return actualAmount0, actualAmount1 + }, + shouldPanic: false, + expected: []string{"-195", "200"}, + }, + { + name: "EXACT OUT - zeroForOne = false, token0:-200", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, users.Resolve(admin)) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + }, + action: func(t *testing.T) (string, string) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + users.Resolve(alice), + false, + "-200", + consts.MAX_PRICE, + users.Resolve(addr), + ) + return actualAmount0, actualAmount1 + }, + shouldPanic: false, + expected: []string{"-200", "208"}, + }, + { + name: "EXACT IN - zeroForOne = true, token0:200", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, users.Resolve(admin)) + TokenFaucet(t, gnsPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + TokenApprove(t, wugnotPath, addr, pool, maxApprove) + }, + action: func(t *testing.T) (string, string) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + users.Resolve(alice), + true, + "200", + consts.MIN_PRICE, + users.Resolve(addr), + ) + return actualAmount0, actualAmount1 + }, + shouldPanic: false, + expected: []string{"200", "-195"}, + }, + { + name: "EXACT OUT - zeroForOne = true, token1:-200", + setupFn: func(t *testing.T) { + InitialisePoolTest(t) + MintPositionAll(t, users.Resolve(admin)) + TokenFaucet(t, gnsPath, addr) + TokenFaucet(t, wugnotPath, addr) + TokenApprove(t, gnsPath, addr, pool, maxApprove) + TokenApprove(t, wugnotPath, addr, pool, maxApprove) + }, + action: func(t *testing.T) (string, string) { + std.TestSetOrigCaller(consts.ROUTER_ADDR) + actualAmount0, actualAmount1 := Swap( + wugnotPath, + gnsPath, + fee3000, + users.Resolve(alice), + true, + "-200", + consts.MIN_PRICE, + users.Resolve(addr), + ) + return actualAmount0, actualAmount1 + }, + shouldPanic: false, + expected: []string{"208", "-200"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetObject(t) + burnTokens(t) + + if tt.setupFn != nil { + tt.setupFn(t) + } + + if tt.shouldPanic { + tt.action(t) + } else { + amount0, amount1 := tt.action(t) + uassert.Equal(t, tt.expected[0], amount0) + uassert.Equal(t, tt.expected[1], amount1) + } + }) + } +} diff --git a/pool/tick.gno b/pool/tick.gno index d325826fb..d0db635c2 100644 --- a/pool/tick.gno +++ b/pool/tick.gno @@ -19,17 +19,47 @@ func calculateMaxLiquidityPerTick(tickSpacing int32) *u256.Uint { return new(u256.Uint).Div(u256.MustFromDecimal(consts.MAX_UINT128), u256.NewUint(numTicks)) } +func getFeeGrowthBelowX128( + tickLower, tickCurrent int32, + feeGrowthGlobal0X128, feeGrowthGlobal1X128 *u256.Uint, + lowerTick TickInfo, +) (*u256.Uint, *u256.Uint) { + if tickCurrent >= tickLower { + return lowerTick.feeGrowthOutside0X128, lowerTick.feeGrowthOutside1X128 + } + + below0X128 := new(u256.Uint).Sub(feeGrowthGlobal0X128, lowerTick.feeGrowthOutside0X128) + below1X128 := new(u256.Uint).Sub(feeGrowthGlobal1X128, lowerTick.feeGrowthOutside1X128) + + return below0X128, below1X128 +} + +func getFeeGrowthAboveX128( + tickUpper, tickCurrent int32, + feeGrowthGlobal0X128, feeGrowthGlobal1X128 *u256.Uint, + upperTick TickInfo, +) (*u256.Uint, *u256.Uint) { + if tickCurrent < tickUpper { + return upperTick.feeGrowthOutside0X128, upperTick.feeGrowthOutside1X128 + } + + above0X128 := new(u256.Uint).Sub(feeGrowthGlobal0X128, upperTick.feeGrowthOutside0X128) + above1X128 := new(u256.Uint).Sub(feeGrowthGlobal1X128, upperTick.feeGrowthOutside1X128) + + return above0X128, above1X128 +} + // calculateFeeGrowthInside calculates the fee growth inside a tick range, // and returns the fee growth inside for both tokens. -func (pool *Pool) calculateFeeGrowthInside( +func (p *Pool) calculateFeeGrowthInside( tickLower int32, tickUpper int32, tickCurrent int32, feeGrowthGlobal0X128 *u256.Uint, feeGrowthGlobal1X128 *u256.Uint, ) (*u256.Uint, *u256.Uint) { - lower := pool.getTick(tickLower) - upper := pool.getTick(tickUpper) + lower := p.getTick(tickLower) + upper := p.getTick(tickUpper) feeGrowthBelow0X128, feeGrowthBelow1X128 := getFeeGrowthBelowX128(tickLower, tickCurrent, feeGrowthGlobal0X128, feeGrowthGlobal1X128, lower) feeGrowthAbove0X128, feeGrowthAbove1X128 := getFeeGrowthAboveX128(tickUpper, tickCurrent, feeGrowthGlobal0X128, feeGrowthGlobal1X128, upper) @@ -41,7 +71,7 @@ func (pool *Pool) calculateFeeGrowthInside( } // tickUpdate updates a tick's state and returns whether the tick was flipped. -func (pool *Pool) tickUpdate( +func (p *Pool) tickUpdate( tick int32, tickCurrent int32, liquidityDelta *i256.Int, // int128 @@ -54,7 +84,7 @@ func (pool *Pool) tickUpdate( feeGrowthGlobal0X128 = feeGrowthGlobal0X128.NilToZero() feeGrowthGlobal1X128 = feeGrowthGlobal1X128.NilToZero() - thisTick := pool.getTick(tick) + thisTick := p.getTick(tick) liquidityGrossBefore := thisTick.liquidityGross liquidityGrossAfter := liquidityMathAddDelta(liquidityGrossBefore, liquidityDelta) @@ -85,64 +115,35 @@ func (pool *Pool) tickUpdate( thisTick.liquidityNet = i256.Zero().Add(thisTick.liquidityNet, liquidityDelta) } - pool.ticks[tick] = thisTick + p.ticks[tick] = thisTick return flipped } // tickCross updates a tick's state when it is crossed and returns the liquidity net. -func (pool *Pool) tickCross( +func (p *Pool) tickCross( tick int32, feeGrowthGlobal0X128 *u256.Uint, feeGrowthGlobal1X128 *u256.Uint, ) *i256.Int { - thisTick := pool.getTick(tick) + thisTick := p.getTick(tick) thisTick.feeGrowthOutside0X128 = new(u256.Uint).Sub(feeGrowthGlobal0X128, thisTick.feeGrowthOutside0X128) thisTick.feeGrowthOutside1X128 = new(u256.Uint).Sub(feeGrowthGlobal1X128, thisTick.feeGrowthOutside1X128) - pool.ticks[tick] = thisTick + p.ticks[tick] = thisTick - return thisTick.liquidityNet + return thisTick.liquidityNet.Clone() } -func (pool *Pool) getTick(tick int32) TickInfo { - tickInfo := pool.ticks[tick] +// getTick returns a tick's state. +func (p *Pool) getTick(tick int32) TickInfo { + tickInfo := p.ticks[tick] tickInfo.init() return tickInfo } -func getFeeGrowthBelowX128( - tickLower, tickCurrent int32, - feeGrowthGlobal0X128, feeGrowthGlobal1X128 *u256.Uint, - lowerTick TickInfo, -) (*u256.Uint, *u256.Uint) { - if tickCurrent >= tickLower { - return lowerTick.feeGrowthOutside0X128, lowerTick.feeGrowthOutside1X128 - } - - below0X128 := new(u256.Uint).Sub(feeGrowthGlobal0X128, lowerTick.feeGrowthOutside0X128) - below1X128 := new(u256.Uint).Sub(feeGrowthGlobal1X128, lowerTick.feeGrowthOutside1X128) - - return below0X128, below1X128 -} - -func getFeeGrowthAboveX128( - tickUpper, tickCurrent int32, - feeGrowthGlobal0X128, feeGrowthGlobal1X128 *u256.Uint, - upperTick TickInfo, -) (*u256.Uint, *u256.Uint) { - if tickCurrent < tickUpper { - return upperTick.feeGrowthOutside0X128, upperTick.feeGrowthOutside1X128 - } - - above0X128 := new(u256.Uint).Sub(feeGrowthGlobal0X128, upperTick.feeGrowthOutside0X128) - above1X128 := new(u256.Uint).Sub(feeGrowthGlobal1X128, upperTick.feeGrowthOutside1X128) - - return above0X128, above1X128 -} - // receiver getters func (p *Pool) GetTickLiquidityGross(tick int32) *u256.Uint { return p.mustGetTick(tick).liquidityGross diff --git a/pool/tick_bitmap.gno b/pool/tick_bitmap.gno index 45b4ba9e9..5c4dbf0c8 100644 --- a/pool/tick_bitmap.gno +++ b/pool/tick_bitmap.gno @@ -18,7 +18,7 @@ func tickBitmapPosition(tick int32) (int16, uint8) { // tickBitmapFlipTick flips tthe bit corresponding to the given tick // in the pool's tick bitmap. -func (pool *Pool) tickBitmapFlipTick( +func (p *Pool) tickBitmapFlipTick( tick int32, tickSpacing int32, ) { @@ -32,12 +32,12 @@ func (pool *Pool) tickBitmapFlipTick( wordPos, bitPos := tickBitmapPosition(tick / tickSpacing) mask := new(u256.Uint).Lsh(u256.One(), uint(bitPos)) - pool.setTickBitmap(wordPos, new(u256.Uint).Xor(pool.getTickBitmap(wordPos), mask)) + p.setTickBitmap(wordPos, new(u256.Uint).Xor(p.getTickBitmap(wordPos), mask)) } // tickBitmapNextInitializedTickWithInOneWord finds the next initialized tick within // one word of the bitmap. -func (pool *Pool) tickBitmapNextInitializedTickWithInOneWord( +func (p *Pool) tickBitmapNextInitializedTickWithInOneWord( tick int32, tickSpacing int32, lte bool, @@ -49,7 +49,7 @@ func (pool *Pool) tickBitmapNextInitializedTickWithInOneWord( wordPos, bitPos := getWordAndBitPos(compress, lte) mask := getMaskBit(uint(bitPos), lte) - masked := new(u256.Uint).And(pool.getTickBitmap(wordPos), mask) + masked := new(u256.Uint).And(p.getTickBitmap(wordPos), mask) initialized := !(masked.IsZero()) nextTick := getNextTick(lte, initialized, compress, bitPos, tickSpacing, masked) @@ -58,17 +58,17 @@ func (pool *Pool) tickBitmapNextInitializedTickWithInOneWord( // getTickBitmap gets the tick bitmap for the given word position // if the tick bitmap is not initialized, initialize it to zero -func (pool *Pool) getTickBitmap(wordPos int16) *u256.Uint { - if pool.tickBitmaps[wordPos] == nil { - pool.tickBitmaps[wordPos] = u256.Zero() +func (p *Pool) getTickBitmap(wordPos int16) *u256.Uint { + if p.tickBitmaps[wordPos] == nil { + p.tickBitmaps[wordPos] = u256.Zero() } - return pool.tickBitmaps[wordPos] + return p.tickBitmaps[wordPos] } // setTickBitmap sets the tick bitmap for the given word position -func (pool *Pool) setTickBitmap(wordPos int16, bitmap *u256.Uint) { - pool.tickBitmaps[wordPos] = bitmap +func (p *Pool) setTickBitmap(wordPos int16, bitmap *u256.Uint) { + p.tickBitmaps[wordPos] = bitmap } // getWordAndBitPos gets tick's wordPos and bitPos depending on the swap direction diff --git a/pool/type.gno b/pool/type.gno index d2a6c5197..2c2fd3d47 100644 --- a/pool/type.gno +++ b/pool/type.gno @@ -3,6 +3,8 @@ package pool import ( "std" + "gno.land/p/demo/ufmt" + "gno.land/r/gnoswap/v1/common" "gno.land/r/gnoswap/v1/consts" @@ -57,16 +59,41 @@ func newProtocolFees() ProtocolFees { } } +// ModifyPositionParams repersents the parameters for modifying a liquidity position. +// This structure is used internally both `Mint` and `Burn` operation to manage +// the liquidity positions. type ModifyPositionParams struct { - owner std.Address // address that owns the position + // owner is the address that owns the position + owner std.Address + + // tickLower and atickUpper define the price range + // The actual price range is calculated as 1.0001^tick + // This allows for precision in price range while using integer math. - // the tick range of the position, bounds are included - tickLower int32 - tickUpper int32 + tickLower int32 // lower tick of the position + tickUpper int32 // upper tick of the position - liquidityDelta *i256.Int // any change in liquidity + // liquidityDelta represents the change in liquidity + // Positive for minting, negative for burning + liquidityDelta *i256.Int } +// newModifyPositionParams creates a new `ModifyPositionParams` instance. +// This is used to preare parameters for the `modifyPosition` function, +// which handles both minting and burning of liquidity positions. +// +// Parameters: +// - owner: address that will own (or owns) the position +// - tickLower: lower tick bound of the position +// - tickUpper: upper tick bound of the position +// - liquidityDelta: amount of liquidity to add (positive) or remove (negative) +// +// The tick parameters represent prices as powers of 1.0001: +// - actual_price = 1.0001^tick +// - For example, tick = 100 means price = 1.0001^100 +// +// Returns: +// - ModifyPositionParams: a new instance of ModifyPositionParams func newModifyPositionParams( owner std.Address, tickLower int32, @@ -81,6 +108,7 @@ func newModifyPositionParams( } } +// SwapCache holds data that remains constant throughout a swap. type SwapCache struct { feeProtocol uint8 // protocol fee for the input token liquidityStart *u256.Uint // liquidity at the beginning of the swap @@ -96,6 +124,9 @@ func newSwapCache( } } +// SwapState tracks the changing values during a swap. +// This type helps manage the state transiktions that occur as the swap progresses +// accross different price ranges. type SwapState struct { amountSpecifiedRemaining *i256.Int // amount remaining to be swapped in/out of the input/output token amountCalculated *i256.Int // amount already swapped out/in of the output/input token @@ -139,6 +170,9 @@ func (s *SwapState) SetProtocolFee(fee *u256.Uint) { s.protocolFee = fee } +// StepComputations holds intermediate values used during a single step of a swap. +// Each step represents movement from the current tick to the next initialized tick +// or the target price, whichever comes first. type StepComputations struct { sqrtPriceStartX96 *u256.Uint // price at the beginning of the step tickNext int32 // next tick to swap to from the current tick in the swap direction @@ -154,14 +188,14 @@ func (step *StepComputations) initSwapStep(state SwapState, pool *Pool, zeroForO step.sqrtPriceStartX96 = state.sqrtPriceX96 step.tickNext, step.initialized = pool.tickBitmapNextInitializedTickWithInOneWord( state.tick, - pool.tickSpacing, + pool.tickSpacing, zeroForOne, ) // prevent overshoot the min/max tick step.clampTickNext() - // get the price for the next tick + // get the price for the next tick step.sqrtPriceNextX96 = common.TickMathGetSqrtRatioAtTick(step.tickNext) } @@ -178,12 +212,18 @@ func (step *StepComputations) clampTickNext() { type PositionInfo struct { liquidity *u256.Uint // amount of liquidity owned by this position - // fee growth per unit of liquidity as of the last update to liquidity or fees owed + // Fee growth per unit of liquidity as of the last update + // Used to calculate uncollected fees for token0 feeGrowthInside0LastX128 *u256.Uint + + // Fee growth per unit of liquidity as of the last update + // Used to calculate uncollected fees for token1 feeGrowthInside1LastX128 *u256.Uint - // fees owed to the position owner in token0/token1 + // accumulated fees in token0 waiting to be collected tokensOwed0 *u256.Uint + + // accumulated fees in token1 waiting to be collected tokensOwed1 *u256.Uint } @@ -195,6 +235,8 @@ func (p *PositionInfo) init() { p.tokensOwed1 = p.tokensOwed1.NilToZero() } +// TickInfo stores information about a specific tick in the pool. +// TIcks represent discrete price points that can be used as boundaries for positions. type TickInfo struct { liquidityGross *u256.Uint // total position liquidity that references this tick liquidityNet *i256.Int // amount of net liquidity added (subtracted) when tick is crossed from left to right (right to left) @@ -225,9 +267,11 @@ func (t *TickInfo) init() { t.secondsPerLiquidityOutsideX128 = t.secondsPerLiquidityOutsideX128.NilToZero() } -type Ticks map[int32]TickInfo // tick => TickInfo -type TickBitmaps map[int16]*u256.Uint // tick(wordPos) => bitmap(tickWord ^ mask) -type Positions map[string]PositionInfo // positionKey => PositionInfo +type ( + Ticks map[int32]TickInfo // tick => TickInfo + TickBitmaps map[int16]*u256.Uint // tick(wordPos) => bitmap(tickWord ^ mask) + Positions map[string]PositionInfo // positionKey => PositionInfo +) // type Pool describes a single Pool's state // A pool is identificed with a unique key (token0, token1, fee), where token0 < token1 @@ -282,3 +326,80 @@ func newPool(poolInfo *createPoolParams) *Pool { positions: Positions{}, } } + +func (p *Pool) PoolGetToken0Path() string { + return p.token0Path +} + +func (p *Pool) PoolGetToken1Path() string { + return p.token1Path +} + +func (p *Pool) PoolGetFee() uint32 { + return p.fee +} + +func (p *Pool) PoolGetBalanceToken0() *u256.Uint { + return p.balances.token0 +} + +func (p *Pool) PoolGetBalanceToken1() *u256.Uint { + return p.balances.token1 +} + +func (p *Pool) PoolGetTickSpacing() int32 { + return p.tickSpacing +} + +func (p *Pool) PoolGetMaxLiquidityPerTick() *u256.Uint { + return p.maxLiquidityPerTick +} + +func (p *Pool) PoolGetSlot0() Slot0 { + return p.slot0 +} + +func (p *Pool) PoolGetSlot0SqrtPriceX96() *u256.Uint { + return p.slot0.sqrtPriceX96 +} + +func (p *Pool) PoolGetSlot0Tick() int32 { + return p.slot0.tick +} + +func (p *Pool) PoolGetSlot0FeeProtocol() uint8 { + return p.slot0.feeProtocol +} + +func (p *Pool) PoolGetSlot0Unlocked() bool { + return p.slot0.unlocked +} + +func (p *Pool) PoolGetFeeGrowthGlobal0X128() *u256.Uint { + return p.feeGrowthGlobal0X128 +} + +func (p *Pool) PoolGetFeeGrowthGlobal1X128() *u256.Uint { + return p.feeGrowthGlobal1X128 +} + +func (p *Pool) PoolGetProtocolFeesToken0() *u256.Uint { + return p.protocolFees.token0 +} + +func (p *Pool) PoolGetProtocolFeesToken1() *u256.Uint { + return p.protocolFees.token1 +} + +func (p *Pool) PoolGetLiquidity() *u256.Uint { + return p.liquidity +} + +func mustGetPool(poolPath string) *Pool { + pool, exist := pools[poolPath] + if !exist { + panic(addDetailToError(errDataNotFound, + ufmt.Sprintf("poolPath(%s) does not exist", poolPath))) + } + return pool +} diff --git a/pool/utils.gno b/pool/utils.gno index b8d627a9d..e45116217 100644 --- a/pool/utils.gno +++ b/pool/utils.gno @@ -3,11 +3,25 @@ package pool import ( "std" + "gno.land/p/demo/ufmt" + pusers "gno.land/p/demo/users" u256 "gno.land/p/gnoswap/uint256" ) +func safeConvertToUint64(value *u256.Uint) (uint64, error) { + res, overflow := value.Uint64WithOverflow() + if overflow { + return 0, ufmt.Errorf( + "%v: amount(%s) overflows uint64 range", + errOutOfRange, value.ToString(), + ) + } + + return res, nil +} + func a2u(addr std.Address) pusers.AddressOrName { if !addr.IsValid() { panic(addDetailToError( From 0cb7ce0a5569a36904620cca82a0c4215ba976a7 Mon Sep 17 00:00:00 2001 From: 0xTopaz <60733299+onlyhyde@users.noreply.github.com> Date: Mon, 16 Dec 2024 12:54:37 +0900 Subject: [PATCH 8/9] GSW-1839 refactor: integrated helper and test code (#432) - integrated helper with nft helper - add test helper code - add test code for helper - change file filename --- ..._receiver.gno.gno => _GET_no_receiver.gno} | 2 +- position/_RPC_api.gno | 12 +- position/_helper_test.gno | 101 +++-- position/errors.gno | 32 +- position/helper.gno | 113 ++++- position/helper_test.gno | 397 ++++++++++++++++++ position/liquidity_management.gno | 2 +- position/nft_helper.gno | 70 --- position/position.gno | 22 +- position/utils.gno | 27 +- 10 files changed, 625 insertions(+), 153 deletions(-) rename position/{_GET_no_receiver.gno.gno => _GET_no_receiver.gno} (97%) create mode 100644 position/helper_test.gno delete mode 100644 position/nft_helper.gno diff --git a/position/_GET_no_receiver.gno.gno b/position/_GET_no_receiver.gno similarity index 97% rename from position/_GET_no_receiver.gno.gno rename to position/_GET_no_receiver.gno index 7a29d05f7..ffd26f733 100644 --- a/position/_GET_no_receiver.gno.gno +++ b/position/_GET_no_receiver.gno @@ -68,5 +68,5 @@ func PositionIsInRange(tokenId uint64) bool { } func PositionGetPositionOwner(tokenId uint64) std.Address { - return gnft.OwnerOf(tid(tokenId)) + return gnft.OwnerOf(tokenIdFrom(tokenId)) } diff --git a/position/_RPC_api.gno b/position/_RPC_api.gno index b42e871d1..10df9acf2 100644 --- a/position/_RPC_api.gno +++ b/position/_RPC_api.gno @@ -80,7 +80,7 @@ func ApiGetPositions() string { _positionNode := json.ObjectNode("", map[string]*json.Node{ "lpTokenId": json.NumberNode("lpTokenId", float64(position.LpTokenId)), "burned": json.BoolNode("burned", position.Burned), - "owner": json.StringNode("owner", gnft.OwnerOf(tid(position.LpTokenId)).String()), + "owner": json.StringNode("owner", gnft.OwnerOf(tokenIdFrom(position.LpTokenId)).String()), "operator": json.StringNode("operator", position.Operator), "poolKey": json.StringNode("poolKey", position.PoolKey), "tickLower": json.NumberNode("tickLower", float64(position.TickLower)), @@ -140,7 +140,7 @@ func ApiGetPosition(lpTokenId uint64) string { _positionNode := json.ObjectNode("", map[string]*json.Node{ "lpTokenId": json.NumberNode("lpTokenId", float64(position.LpTokenId)), "burned": json.BoolNode("burned", position.Burned), - "owner": json.StringNode("owner", gnft.OwnerOf(tid(position.LpTokenId)).String()), + "owner": json.StringNode("owner", gnft.OwnerOf(tokenIdFrom(position.LpTokenId)).String()), "operator": json.StringNode("operator", position.Operator), "poolKey": json.StringNode("poolKey", position.PoolKey), "tickLower": json.NumberNode("tickLower", float64(position.TickLower)), @@ -203,7 +203,7 @@ func ApiGetPositionsByPoolPath(poolPath string) string { _positionNode := json.ObjectNode("", map[string]*json.Node{ "lpTokenId": json.NumberNode("lpTokenId", float64(position.LpTokenId)), "burned": json.BoolNode("burned", position.Burned), - "owner": json.StringNode("owner", gnft.OwnerOf(tid(position.LpTokenId)).String()), + "owner": json.StringNode("owner", gnft.OwnerOf(tokenIdFrom(position.LpTokenId)).String()), "operator": json.StringNode("operator", position.Operator), "poolKey": json.StringNode("poolKey", position.PoolKey), "tickLower": json.NumberNode("tickLower", float64(position.TickLower)), @@ -238,7 +238,7 @@ func ApiGetPositionsByAddress(address std.Address) string { rpcPositions := []RpcPosition{} for lpTokenId, position := range positions { - if !(position.operator == address || gnft.OwnerOf(tid(lpTokenId)) == address) { + if !(position.operator == address || gnft.OwnerOf(tokenIdFrom(lpTokenId)) == address) { continue } @@ -266,7 +266,7 @@ func ApiGetPositionsByAddress(address std.Address) string { _positionNode := json.ObjectNode("", map[string]*json.Node{ "lpTokenId": json.NumberNode("lpTokenId", float64(position.LpTokenId)), "burned": json.BoolNode("burned", position.Burned), - "owner": json.StringNode("owner", gnft.OwnerOf(tid(position.LpTokenId)).String()), + "owner": json.StringNode("owner", gnft.OwnerOf(tokenIdFrom(position.LpTokenId)).String()), "operator": json.StringNode("operator", position.Operator), "poolKey": json.StringNode("poolKey", position.PoolKey), "tickLower": json.NumberNode("tickLower", float64(position.TickLower)), @@ -410,7 +410,7 @@ func rpcMakePosition(lpTokenId uint64) RpcPosition { return RpcPosition{ LpTokenId: lpTokenId, Burned: burned, - Owner: gnft.OwnerOf(tid(lpTokenId)).String(), + Owner: gnft.OwnerOf(tokenIdFrom(lpTokenId)).String(), Operator: position.operator.String(), PoolKey: position.poolKey, TickLower: position.tickLower, diff --git a/position/_helper_test.gno b/position/_helper_test.gno index be52f83be..f998f2ab9 100644 --- a/position/_helper_test.gno +++ b/position/_helper_test.gno @@ -14,6 +14,7 @@ import ( "gno.land/r/gnoswap/v1/gnft" "gno.land/r/gnoswap/v1/gns" pl "gno.land/r/gnoswap/v1/pool" + sr "gno.land/r/gnoswap/v1/staker" "gno.land/r/onbloc/bar" "gno.land/r/onbloc/baz" "gno.land/r/onbloc/foo" @@ -37,6 +38,10 @@ const ( fee3000 uint32 = 3000 maxApprove uint64 = 18446744073709551615 max_timeout int64 = 9999999999 + + TIER_1 uint64 = 1 + TIER_2 uint64 = 2 + TIER_3 uint64 = 3 ) const ( @@ -165,6 +170,7 @@ func init() { var ( admin = pusers.AddressOrName(consts.ADMIN) alice = pusers.AddressOrName(testutils.TestAddress("alice")) + bob = pusers.AddressOrName(testutils.TestAddress("bob")) pool = pusers.AddressOrName(consts.POOL_ADDR) protocolFee = pusers.AddressOrName(consts.PROTOCOL_FEE_ADDR) adminRealm = std.NewUserRealm(users.Resolve(admin)) @@ -182,10 +188,7 @@ func InitialisePoolTest(t *testing.T) { std.TestSetOrigCaller(users.Resolve(admin)) TokenApprove(t, gnsPath, admin, pool, maxApprove) - poolPath := pl.GetPoolPath(wugnotPath, gnsPath, fee3000) - if !pl.DoesPoolPathExist(poolPath) { - pl.CreatePool(wugnotPath, gnsPath, fee3000, "79228162514264337593543950336") - } + CreatePool(t, wugnotPath, gnsPath, fee3000, "79228162514264337593543950336", users.Resolve(admin)) //2. create position std.TestSetOrigCaller(users.Resolve(alice)) @@ -300,6 +303,22 @@ func TokenApprove(t *testing.T, tokenPath string, owner, spender pusers.AddressO } } +func CreatePool(t *testing.T, + token0 string, + token1 string, + fee uint32, + sqrtPriceX96 string, + caller std.Address) { + t.Helper() + + std.TestSetRealm(std.NewUserRealm(caller)) + poolPath := pl.GetPoolPath(token0, token1, fee) + if !pl.DoesPoolPathExist(poolPath) { + pl.CreatePool(token0, token1, fee, sqrtPriceX96) + sr.SetPoolTierByAdmin(poolPath, TIER_1) + } +} + func MintPosition(t *testing.T, token0 string, token1 string, @@ -332,6 +351,54 @@ func MintPosition(t *testing.T, caller) } +func MakeMintPositionWithoutFee(t *testing.T) (uint64, string, string, string) { + t.Helper() + + // make actual data to test resetting not only position's state but also pool's state + std.TestSetRealm(adminRealm) + + // set pool create fee to 0 for testing + pl.SetPoolCreationFeeByAdmin(0) + CreatePool(t, barPath, fooPath, fee500, common.TickMathGetSqrtRatioAtTick(0).ToString(), users.Resolve(admin)) + + TokenApprove(t, barPath, admin, pool, consts.UINT64_MAX) + TokenApprove(t, fooPath, admin, pool, consts.UINT64_MAX) + + // mint position + return Mint( + barPath, + fooPath, + fee500, + -887270, + 887270, + "50000", + "50000", + "0", + "0", + max_timeout, + users.Resolve(admin), + users.Resolve(admin), + ) +} + +func LPTokenApprove(t *testing.T, owner, operator pusers.AddressOrName, tokenId uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(users.Resolve(owner))) + gnft.Approve(operator, tokenIdFrom(tokenId)) +} + +func LPTokenStake(t *testing.T, owner pusers.AddressOrName, tokenId uint64) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(users.Resolve(owner))) + sr.StakeToken(tokenId) +} + +func LPTokenUnStake(t *testing.T, owner pusers.AddressOrName, tokenId uint64, unwrap bool) { + t.Helper() + std.TestSetRealm(std.NewUserRealm(users.Resolve(owner))) + sr.UnstakeToken(tokenId, unwrap) +} + func wugnotApprove(t *testing.T, owner, spender pusers.AddressOrName, amount uint64) { t.Helper() std.TestSetRealm(std.NewUserRealm(users.Resolve(owner))) @@ -487,7 +554,7 @@ func burnAllNFT(t *testing.T) { std.TestSetRealm(std.NewCodeRealm(consts.POSITION_PATH)) for i := uint64(1); i <= gnft.TotalSupply(); i++ { - gnft.Burn(tid(i)) + gnft.Burn(tokenIdFrom(i)) } } @@ -495,29 +562,7 @@ func TestBeforeResetObject(t *testing.T) { // make actual data to test resetting not only position's state but also pool's state std.TestSetRealm(adminRealm) - // set pool create fee to 0 for testing - pl.SetPoolCreationFeeByAdmin(0) - pl.CreatePool(barPath, fooPath, fee500, common.TickMathGetSqrtRatioAtTick(0).ToString()) - - // mint position - bar.Approve(a2u(consts.POOL_ADDR), consts.UINT64_MAX) - foo.Approve(a2u(consts.POOL_ADDR), consts.UINT64_MAX) - - tokenId, liquidity, amount0, amount1 := Mint( - barPath, - fooPath, - fee500, - -887270, - 887270, - "50000", - "50000", - "0", - "0", - max_timeout, - users.Resolve(admin), - users.Resolve(admin), - ) - + tokenId, liquidity, amount0, amount1 := MakeMintPositionWithoutFee(t) uassert.Equal(t, tokenId, uint64(1), "tokenId should be 1") uassert.Equal(t, liquidity, "50000", "liquidity should be 50000") uassert.Equal(t, amount0, "50000", "amount0 should be 50000") diff --git a/position/errors.gno b/position/errors.gno index 9cb698faf..60ef43ceb 100644 --- a/position/errors.gno +++ b/position/errors.gno @@ -7,19 +7,31 @@ import ( ) var ( - errNoPermission = errors.New("[GNOSWAP-POSITION-001] caller has no permission") - errSlippage = errors.New("[GNOSWAP-POSITION-002] slippage failed") - errWrapUnwrap = errors.New("[GNOSWAP-POSITION-003] wrap, unwrap failed") - errOutOfRange = errors.New("[GNOSWAP-POSITION-004] out of range for numeric value") - errInvalidInput = errors.New("[GNOSWAP-POSITION-005] invalid input data") - errDataNotFound = errors.New("[GNOSWAP-POSITION-006] requested data not found") - errExpired = errors.New("[GNOSWAP-POSITION-007] transaction expired") - errWugnotMinimum = errors.New("[GNOSWAP-POSITION-008] can not wrap less than minimum amount") - errNotClear = errors.New("[GNOSWAP-POSITION-009] position is not clear") - errZeroLiquidity = errors.New("[GNOSWAP-POSITION-010] zero liquidity") + errNoPermission = errors.New("[GNOSWAP-POSITION-001] caller has no permission") + errSlippage = errors.New("[GNOSWAP-POSITION-002] slippage failed") + errWrapUnwrap = errors.New("[GNOSWAP-POSITION-003] wrap, unwrap failed") + errOutOfRange = errors.New("[GNOSWAP-POSITION-004] out of range for numeric value") + errInvalidInput = errors.New("[GNOSWAP-POSITION-005] invalid input data") + errDataNotFound = errors.New("[GNOSWAP-POSITION-006] requested data not found") + errExpired = errors.New("[GNOSWAP-POSITION-007] transaction expired") + errWugnotMinimum = errors.New("[GNOSWAP-POSITION-008] can not wrap less than minimum amount") + errNotClear = errors.New("[GNOSWAP-POSITION-009] position is not clear") + errZeroLiquidity = errors.New("[GNOSWAP-POSITION-010] zero liquidity") + errInvalidAddress = errors.New("[GNOSWAP-POSITION-011] invalid address") ) +// TODO: +// addDetailToError -> newErrorWithDetail func addDetailToError(err error, detail string) string { finalErr := ufmt.Errorf("%s || %s", err.Error(), detail) return finalErr.Error() } + +// newErrorWithDetail returns a new error with the given detail +// e.g. newErrorWithDetail(err, "detail") +// +// input: err error, detail string +// output: "err.Error() || detail" +func newErrorWithDetail(err error, detail string) string { + return ufmt.Errorf("%s || %s", err.Error(), detail).Error() +} diff --git a/position/helper.gno b/position/helper.gno index 848717628..532e0fdc9 100644 --- a/position/helper.gno +++ b/position/helper.gno @@ -1,21 +1,30 @@ package position import ( + "std" "strconv" "gno.land/p/demo/grc/grc721" + "gno.land/p/demo/ufmt" + "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/consts" + "gno.land/r/gnoswap/v1/gnft" ) +// nextId is the next tokenId to be minted func getNextId() uint64 { return nextId } -func tid(tokenId interface{}) grc721.TokenID { +// tokenIdFrom converts tokenId to grc721.TokenID type +// NOTE: input parameter tokenId can be string, int, uint64, or grc721.TokenID +// if tokenId is nil or not supported, it will panic +// if tokenId is not found, it will panic +// input: tokenId interface{} +// output: grc721.TokenID +func tokenIdFrom(tokenId interface{}) grc721.TokenID { if tokenId == nil { - panic(addDetailToError( - errDataNotFound, - "helper.gno__tid() || tokenId is nil", - )) + panic(newErrorWithDetail(errInvalidInput, "tokenId is nil")) } switch tokenId.(type) { @@ -28,9 +37,95 @@ func tid(tokenId interface{}) grc721.TokenID { case grc721.TokenID: return tokenId.(grc721.TokenID) default: - panic(addDetailToError( - errInvalidInput, - "helper.gno__tid() || unsupported tokenId type", - )) + panic(newErrorWithDetail(errInvalidInput, "unsupported tokenId type")) } } + +// exists checks whether tokenId exists +// If tokenId doesn't exist, return false, otherwise return true +// input: tokenId uint64 +// output: bool +func exists(tokenId uint64) bool { + return gnft.Exists(tokenIdFrom(tokenId)) +} + +// isOwner checks whether the caller is the owner of the tokenId +// If the caller is the owner of the tokenId, return true, otherwise return false +// input: tokenId uint64, addr std.Address +// output: bool +func isOwner(tokenId uint64, addr std.Address) bool { + owner := gnft.OwnerOf(tokenIdFrom(tokenId)) + if owner == addr { + return true + } + return false +} + +// isOperator checks whether the caller is the approved operator of the tokenId +// If the caller is the approved operator of the tokenId, return true, otherwise return false +// input: tokenId uint64, addr std.Address +// output: bool +func isOperator(tokenId uint64, addr std.Address) bool { + operator, ok := gnft.GetApproved(tokenIdFrom(tokenId)) + if ok && operator == addr { + return true + } + return false +} + +// isStaked checks whether tokenId is staked +// If tokenId is staked, owner of tokenId is staker contract +// If tokenId is staked, return true, otherwise return false +// input: tokenId grc721.TokenID +// output: bool +func isStaked(tokenId grc721.TokenID) bool { + exist := gnft.Exists(tokenId) + if exist { + owner := gnft.OwnerOf(tokenId) + if owner == consts.STAKER_ADDR { + return true + } + } + return false +} + +// isOwnerOrOperator checks whether the caller is the owner or approved operator of the tokenId +// If the caller is the owner or approved operator of the tokenId, return true, otherwise return false +// input: addr std.Address, tokenId uint64 +// output: bool +func isOwnerOrOperator(addr std.Address, tokenId uint64) bool { + assertOnlyValidAddress(addr) + if !exists(tokenId) { + return false + } + if isOwner(tokenId, addr) || isOperator(tokenId, addr) { + return true + } + if isStaked(tokenIdFrom(tokenId)) { + position, exist := positions[tokenId] + if exist && addr == position.operator { + return true + } + } + return false +} + +// splitOf divides poolKey into pToken0, pToken1, and pFee +// If poolKey is invalid, it will panic +// +// input: poolKey string +// output: +// - token0Path string +// - token1Path string +// - fee uint32 +func splitOf(poolKey string) (string, string, uint32) { + res, err := common.Split(poolKey, ":", 3) + if err != nil { + panic(newErrorWithDetail(errInvalidInput, ufmt.Sprintf("invalid poolKey(%s)", poolKey))) + } + + pToken0, pToken1, pFeeStr := res[0], res[1], res[2] + + pFee, _ := strconv.Atoi(pFeeStr) + return pToken0, pToken1, uint32(pFee) +} diff --git a/position/helper_test.gno b/position/helper_test.gno new file mode 100644 index 000000000..35be8cb56 --- /dev/null +++ b/position/helper_test.gno @@ -0,0 +1,397 @@ +package position + +import ( + "std" + "testing" + + "gno.land/p/demo/grc/grc721" + "gno.land/p/demo/uassert" + pusers "gno.land/p/demo/users" + "gno.land/r/demo/users" +) + +func TestGetNextId(t *testing.T) { + tests := []struct { + name string + newMint bool + expected uint64 + }{ + { + name: "Success - initial nextId", + newMint: false, + expected: 1, + }, + { + name: "Success - after mint", + newMint: true, + expected: 2, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.newMint { + MakeMintPositionWithoutFee(t) + } + got := getNextId() + uassert.Equal(t, tc.expected, got) + }) + } +} + +func TestTokenIdFrom(t *testing.T) { + + tests := []struct { + name string + input interface{} + expected string + shouldPanic bool + }{ + { + name: "Panic - nil", + input: nil, + expected: "[GNOSWAP-POSITION-005] invalid input data || tokenId is nil", + shouldPanic: true, + }, + { + name: "Panic - unsupported type", + input: float64(1), + expected: "[GNOSWAP-POSITION-005] invalid input data || unsupported tokenId type", + shouldPanic: true, + }, + { + name: "Success - string", + input: "1", + expected: "1", + shouldPanic: false, + }, + { + name: "Success - int", + input: int(1), + expected: "1", + shouldPanic: false, + }, + { + name: "Success - uint64", + input: uint64(1), + expected: "1", + shouldPanic: false, + }, + { + name: "Success - grc721.TokenID", + input: grc721.TokenID("1"), + expected: "1", + shouldPanic: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + defer func() { + r := recover() + if r == nil { + if tc.shouldPanic { + t.Errorf(">>> %s: expected panic but got none", tc.name) + return + } + } else { + switch r.(type) { + case string: + if r.(string) != tc.expected { + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r, tc.expected) + } + case error: + if r.(error).Error() != tc.expected { + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r.(error).Error(), tc.expected) + } + default: + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r, tc.expected) + } + } + }() + + if !tc.shouldPanic { + got := tokenIdFrom(tc.input) + uassert.Equal(t, tc.expected, string(got)) + } else { + tokenIdFrom(tc.input) + } + }) + } +} + +func TestExists(t *testing.T) { + tests := []struct { + name string + tokenId uint64 + expected bool + }{ + { + name: "Fail - not exists", + tokenId: 2, + expected: false, + }, + { + name: "Success - exists", + tokenId: 1, + expected: true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := exists(tc.tokenId) + uassert.Equal(t, tc.expected, got) + }) + } +} + +func TestIsOwner(t *testing.T) { + tests := []struct { + name string + tokenId uint64 + addr std.Address + expected bool + }{ + { + name: "Fail - is not owner", + tokenId: 1, + addr: users.Resolve(alice), + expected: false, + }, + { + name: "Success - is owner", + tokenId: 1, + addr: users.Resolve(admin), + expected: true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + MakeMintPositionWithoutFee(t) + got := isOwner(tc.tokenId, tc.addr) + uassert.Equal(t, tc.expected, got) + }) + } +} + +func TestIsOperator(t *testing.T) { + MakeMintPositionWithoutFee(t) + tests := []struct { + name string + tokenId uint64 + addr pusers.AddressOrName + expected bool + }{ + { + name: "Fail - is not operator", + tokenId: 1, + addr: alice, + expected: false, + }, + { + name: "Success - is operator", + tokenId: 1, + addr: bob, + expected: true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.expected { + LPTokenApprove(t, admin, tc.addr, tc.tokenId) + } + got := isOperator(tc.tokenId, users.Resolve(tc.addr)) + uassert.Equal(t, tc.expected, got) + }) + } +} + +func TestIsStaked(t *testing.T) { + MakeMintPositionWithoutFee(t) + tests := []struct { + name string + owner pusers.AddressOrName + operator pusers.AddressOrName + tokenId uint64 + expected bool + }{ + { + name: "Fail - is not staked", + owner: bob, + operator: alice, + tokenId: 1, + expected: false, + }, + { + name: "Fail - is not exist tokenId", + owner: admin, + operator: bob, + tokenId: 100, + expected: false, + }, + { + name: "Success - is staked", + owner: admin, + operator: admin, + tokenId: 1, + expected: true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.expected && tc.owner == tc.operator { + LPTokenStake(t, tc.owner, tc.tokenId) + } + got := isStaked(tokenIdFrom(tc.tokenId)) + uassert.Equal(t, tc.expected, got) + if tc.expected && tc.owner == tc.operator { + LPTokenUnStake(t, tc.owner, tc.tokenId, false) + } + }) + } +} + +func TestIsOwnerOrOperator(t *testing.T) { + MakeMintPositionWithoutFee(t) + tests := []struct { + name string + owner pusers.AddressOrName + operator pusers.AddressOrName + tokenId uint64 + expected bool + }{ + { + name: "Fail - is not owner or operator", + owner: admin, + operator: alice, + tokenId: 1, + expected: false, + }, + { + name: "Success - is operator", + owner: admin, + operator: bob, + tokenId: 1, + expected: true, + }, + { + name: "Success - is owner", + owner: admin, + operator: admin, + tokenId: 1, + expected: true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.expected && tc.owner != tc.operator { + LPTokenApprove(t, tc.owner, tc.operator, tc.tokenId) + } + var got bool + if tc.owner == tc.operator { + got = isOwnerOrOperator(users.Resolve(tc.owner), tc.tokenId) + } else { + got = isOwnerOrOperator(users.Resolve(tc.operator), tc.tokenId) + } + uassert.Equal(t, tc.expected, got) + }) + } +} + +func TestIsOwnerOrOperatorWithStake(t *testing.T) { + MakeMintPositionWithoutFee(t) + tests := []struct { + name string + owner pusers.AddressOrName + operator pusers.AddressOrName + tokenId uint64 + isStake bool + expected bool + }{ + { + name: "Fail - is not token staked", + owner: admin, + operator: alice, + tokenId: 1, + isStake: false, + expected: false, + }, + { + name: "Success - is token staked (position operator)", + owner: admin, + operator: admin, + tokenId: 1, + isStake: true, + expected: true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.isStake { + LPTokenStake(t, tc.owner, tc.tokenId) + } + got := isOwnerOrOperator(users.Resolve(tc.operator), tc.tokenId) + uassert.Equal(t, tc.expected, got) + }) + } +} + +func TestPoolKeyDivide(t *testing.T) { + tests := []struct { + name string + poolKey string + expectedPath0 string + expectedPath1 string + expectedFee uint32 + expectedError string + shouldPanic bool + }{ + { + name: "Fail - invalid poolKey", + poolKey: "gno.land/r/onbloc", + expectedError: "[GNOSWAP-POSITION-005] invalid input data || invalid poolKey(gno.land/r/onbloc)", + shouldPanic: true, + }, + { + name: "Success - split poolKey", + poolKey: "gno.land/r/gnoswap/v1/gns:gno.land/r/demo/wugnot:500", + expectedPath0: gnsPath, + expectedPath1: wugnotPath, + expectedFee: fee500, + shouldPanic: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + defer func() { + r := recover() + if r == nil { + if tc.shouldPanic { + t.Errorf(">>> %s: expected panic but got none", tc.name) + return + } + } else { + switch r.(type) { + case string: + if r.(string) != tc.expectedError { + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r, tc.expectedError) + } + case error: + if r.(error).Error() != tc.expectedError { + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r.(error).Error(), tc.expectedError) + } + default: + t.Errorf(">>> %s: got panic %v, want %v", tc.name, r, tc.expectedError) + } + } + }() + + if !tc.shouldPanic { + gotToken0, gotToken1, gotFee := splitOf(tc.poolKey) + uassert.Equal(t, tc.expectedPath0, gotToken0) + uassert.Equal(t, tc.expectedPath1, gotToken1) + uassert.Equal(t, tc.expectedFee, gotFee) + } else { + splitOf(tc.poolKey) + } + }) + } +} diff --git a/position/liquidity_management.gno b/position/liquidity_management.gno index 8b74794ca..be926f912 100644 --- a/position/liquidity_management.gno +++ b/position/liquidity_management.gno @@ -28,7 +28,7 @@ func addLiquidity(params AddLiquidityParams) (*u256.Uint, *u256.Uint, *u256.Uint params.amount1Desired, ) - pToken0, pToken1, pFee := poolKeyDivide(params.poolKey) + pToken0, pToken1, pFee := splitOf(params.poolKey) amount0, amount1 := pl.Mint( pToken0, pToken1, diff --git a/position/nft_helper.gno b/position/nft_helper.gno deleted file mode 100644 index ff8d0a76f..000000000 --- a/position/nft_helper.gno +++ /dev/null @@ -1,70 +0,0 @@ -package position - -import ( - "std" - - "gno.land/p/demo/ufmt" - "gno.land/r/gnoswap/v1/consts" - - "gno.land/r/gnoswap/v1/gnft" -) - -func exists(tokenId uint64) bool { - // non exist tokenId will panic - // use defer to catch the panic - defer func() { - if err := recover(); err != nil { - panic(addDetailToError( - errDataNotFound, - ufmt.Sprintf("nft_helper.gno__exists() || tokenId(%d) doesn't exist", tokenId), - )) - } - }() - - // exists method in grc721 is private - // we don't have much choice but to use ownerOf - owner := gnft.OwnerOf(tid(tokenId)) - if owner == consts.ZERO_ADDRESS { - panic(addDetailToError( - errDataNotFound, - ufmt.Sprintf("nft_helper.gno__exists() || tokenId(%d) doesn't exist__ZeroAddressOwner", tokenId), - )) - return false - } - - return true -} - -func isApprovedOrOwner(addr std.Address, tokenId uint64) bool { - tid := tid(tokenId) - - // check whether token exists - if !exists(tokenId) { - panic(addDetailToError( - errDataNotFound, - ufmt.Sprintf("nft_helper.gno__isApprovedOrOwner() || tokenId(%d) doesn't exist", tokenId), - )) - } - - // check owner first - owner := gnft.OwnerOf(tid) - if addr == owner { - return true - } - - // if not owner, check whether approved in position contract - position, exist := positions[tokenId] - if exist { - if addr == position.operator { - return true - } - } - - // if not owner, check whether approved in actual grc721 contract - operator, ok := gnft.GetApproved(tid) - if ok && addr == operator { - return true - } - - return false -} diff --git a/position/position.gno b/position/position.gno index d4ffb1a88..60cb0f62f 100644 --- a/position/position.gno +++ b/position/position.gno @@ -185,7 +185,7 @@ func mint(params MintParams) (uint64, *u256.Uint, *u256.Uint, *u256.Uint) { ) tokenId := nextId - gnft.Mint(a2u(params.mintTo), tid(tokenId)) // owner, tokenId + gnft.Mint(a2u(params.mintTo), tokenIdFrom(tokenId)) // owner, tokenId nextId++ positionKey := positionKeyCompute(GetOrigPkgAddr(), params.tickLower, params.tickUpper) @@ -240,7 +240,7 @@ func IncreaseLiquidity( // wrap if target pool has wugnot position := positions[tokenId] - pToken0, pToken1, _ := poolKeyDivide(position.poolKey) + pToken0, pToken1, _ := splitOf(position.poolKey) isToken0Wugnot := pToken0 == consts.WRAPPED_WUGNOT isToken1Wugnot := pToken1 == consts.WRAPPED_WUGNOT @@ -293,7 +293,7 @@ func increaseLiquidity(params IncreaseLiquidityParams) (uint64, *u256.Uint, *u25 // MUST BE OWNER TO INCREASE LIQUIDITY // can not be approved address ≈ staked position can't be modified - owner := gnft.OwnerOf(tid(params.tokenId)) + owner := gnft.OwnerOf(tokenIdFrom(params.tokenId)) caller := std.PrevRealm().Addr() if owner != caller { panic(addDetailToError( @@ -434,7 +434,7 @@ func decreaseLiquidity(params DecreaseLiquidityParams) (uint64, *u256.Uint, *u25 liquidityToRemove := calculateLiquidityToRemove(positionLiquidity, params.liquidityRatio) - pToken0, pToken1, pFee := poolKeyDivide(position.poolKey) + pToken0, pToken1, pFee := splitOf(position.poolKey) pool := pl.GetPoolFromPoolPath(position.poolKey) // BURN HERE @@ -537,7 +537,7 @@ func Reposition( // MUST BE OWNER TO REPOSITION // can not be approved address > staked position can't be modified - owner := gnft.OwnerOf(tid(tokenId)) + owner := gnft.OwnerOf(tokenIdFrom(tokenId)) caller := std.PrevRealm().Addr() if owner != caller { panic(addDetailToError( @@ -558,7 +558,7 @@ func Reposition( )) } - token0, token1, _ := poolKeyDivide(position.poolKey) + token0, token1, _ := splitOf(position.poolKey) // check if gnot pool token0IsNative := false token1IsNative := false @@ -663,7 +663,7 @@ func CollectFee(tokenId uint64, unwrapResult bool) (uint64, string, string, stri )) } - token0, token1, fee := poolKeyDivide(position.poolKey) + token0, token1, fee := splitOf(position.poolKey) pl.Burn( token0, @@ -726,7 +726,7 @@ func CollectFee(tokenId uint64, unwrapResult bool) (uint64, string, string, stri withoutFee0, withoutFee1 := pl.HandleWithdrawalFee(tokenId, token0, amount0, token1, amount1, position.poolKey, std.PrevRealm().Addr()) // UNWRAP - pToken0, pToken1, _ := poolKeyDivide(position.poolKey) + pToken0, pToken1, _ := splitOf(position.poolKey) if (pToken0 == consts.WUGNOT_PATH || pToken1 == consts.WUGNOT_PATH) && unwrapResult { userNewWugnot := wugnot.BalanceOf(a2u(std.PrevRealm().Addr())) unwrapAmount := userNewWugnot - userWugnot @@ -785,7 +785,7 @@ func burnNFT(tokenId uint64) { )) } delete(positions, tokenId) - gnft.Burn(tid(tokenId)) + gnft.Burn(tokenIdFrom(tokenId)) } func burnPosition(tokenId uint64) { @@ -802,7 +802,7 @@ func burnPosition(tokenId uint64) { } func isAuthorizedForToken(tokenId uint64) { - if !(isApprovedOrOwner(std.PrevRealm().Addr(), tokenId)) { + if !(isOwnerOrOperator(std.PrevRealm().Addr(), tokenId)) { panic(addDetailToError( errNoPermission, ufmt.Sprintf("position.gno__isAuthorizedForToken() || caller(%s) is not approved or owner of tokenId(%d)", std.PrevRealm().Addr(), tokenId), @@ -818,7 +818,7 @@ func verifyTokenIdAndOwnership(tokenId uint64) { )) } - owner := gnft.OwnerOf(tid(tokenId)) + owner := gnft.OwnerOf(tokenIdFrom(tokenId)) caller := std.PrevRealm().Addr() if owner != caller { panic(addDetailToError( diff --git a/position/utils.gno b/position/utils.gno index 81b644770..dedba8661 100644 --- a/position/utils.gno +++ b/position/utils.gno @@ -2,12 +2,10 @@ package position import ( "std" - "strconv" "time" "gno.land/p/demo/ufmt" pusers "gno.land/p/demo/users" - "gno.land/r/gnoswap/v1/common" ) func checkDeadline(deadline int64) { @@ -24,21 +22,6 @@ func a2u(addr std.Address) pusers.AddressOrName { return pusers.AddressOrName(addr) } -func poolKeyDivide(poolKey string) (string, string, uint32) { - res, err := common.Split(poolKey, ":", 3) - if err != nil { - panic(addDetailToError( - errInvalidInput, - ufmt.Sprintf("utils.gno__poolKeyDivide() || invalid poolKey(%s)", poolKey), - )) - } - - pToken0, pToken1, pFeeStr := res[0], res[1], res[2] - - pFee, _ := strconv.Atoi(pFeeStr) - return pToken0, pToken1, uint32(pFee) -} - func prevRealm() string { return std.PrevRealm().PkgPath() } @@ -51,3 +34,13 @@ func getPrev() (string, string) { prev := std.PrevRealm() return prev.Addr().String(), prev.PkgPath() } + +// assertOnlyValidAddress panics if the address is invalid. +func assertOnlyValidAddress(addr std.Address) { + if !addr.IsValid() { + panic(newErrorWithDetail( + errInvalidAddress, + ufmt.Sprintf("(%s)", addr), + )) + } +} From 61f36d43e0ef580d67c51fd947eeff601f0d8080 Mon Sep 17 00:00:00 2001 From: 0xTopaz <60733299+onlyhyde@users.noreply.github.com> Date: Mon, 16 Dec 2024 15:58:34 +0900 Subject: [PATCH 9/9] GSW-1839 Refactor/position contract utils (#433) * GSW-1839 refactor: integrated helper and test code - integrated helper with nft helper - add test helper code - add test code for helper - change file filename * GSW-1839 refactor: utils - add assert functions - refactor original util functions * Update position/utils_test.gno * test: Update to use the correct test values --------- Co-authored-by: Blake <104744707+r3v4s@users.noreply.github.com> --- position/position.gno | 20 +-- position/utils.gno | 95 +++++++++-- position/utils_test.gno | 344 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 434 insertions(+), 25 deletions(-) create mode 100644 position/utils_test.gno diff --git a/position/position.gno b/position/position.gno index 60cb0f62f..f33ffd0b6 100644 --- a/position/position.gno +++ b/position/position.gno @@ -105,12 +105,12 @@ func Mint( poolSqrtPriceX96 := pl.PoolGetSlot0SqrtPriceX96(poolPath) - prevAddr, prevRealm := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "Mint", "prevAddr", prevAddr, - "prevRealm", prevRealm, + "prevRealm", prevPkgPath, "tickLower", ufmt.Sprintf("%d", tickLower), "tickUpper", ufmt.Sprintf("%d", tickUpper), "poolPath", poolPath, @@ -265,12 +265,12 @@ func IncreaseLiquidity( poolSqrtPriceX96 := pl.PoolGetSlot0SqrtPriceX96(poolPath) - prevAddr, prevRealm := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "IncreaseLiquidity", "prevAddr", prevAddr, - "prevRealm", prevRealm, + "prevRealm", prevPkgPath, "lpTokenId", ufmt.Sprintf("%d", tokenId), "internal_poolPath", poolPath, "internal_liquidity", liquidity.ToString(), @@ -386,12 +386,12 @@ func DecreaseLiquidity( poolSqrtPriceX96 := pl.PoolGetSlot0SqrtPriceX96(poolPath) - prevAddr, prevRealm := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "DecreaseLiquidity", "prevAddr", prevAddr, - "prevRealm", prevRealm, + "prevRealm", prevPkgPath, "lpTokenId", ufmt.Sprintf("%d", tokenId), "liquidityRatio", ufmt.Sprintf("%d", liquidityRatio), "internal_poolPath", poolPath, @@ -615,12 +615,12 @@ func Reposition( poolSqrtPriceX96 := pl.PoolGetSlot0SqrtPriceX96(position.poolKey) - prevAddr, prevRealm := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "Reposition", "prevAddr", prevAddr, - "prevRealm", prevRealm, + "prevRealm", prevPkgPath, "lpTokenId", ufmt.Sprintf("%d", tokenId), "tickLower", ufmt.Sprintf("%d", tickLower), "tickUpper", ufmt.Sprintf("%d", tickUpper), @@ -736,12 +736,12 @@ func CollectFee(tokenId uint64, unwrapResult bool) (uint64, string, string, stri } } - prevAddr, prevRealm := getPrev() + prevAddr, prevPkgPath := getPrevAsString() std.Emit( "CollectSwapFee", "prevAddr", prevAddr, - "prevRealm", prevRealm, + "prevRealm", prevPkgPath, "lpTokenId", ufmt.Sprintf("%d", tokenId), "internal_fee0", withoutFee0, "internal_fee1", withoutFee1, diff --git a/position/utils.gno b/position/utils.gno index dedba8661..2121fb4fb 100644 --- a/position/utils.gno +++ b/position/utils.gno @@ -6,35 +6,86 @@ import ( "gno.land/p/demo/ufmt" pusers "gno.land/p/demo/users" + "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/consts" ) -func checkDeadline(deadline int64) { - now := time.Now().Unix() - if now > deadline { - panic(addDetailToError( - errExpired, - ufmt.Sprintf("utils.gno__checkDeadline() || transaction too old, now(%d) > deadline(%d)", now, deadline), - )) - } -} - +// a2u converts std.Address to pusers.AddressOrName. +// pusers is a package that contains the user-related functions. +// +// Input: +// - addr: the address to convert +// +// Output: +// - pusers.AddressOrName: the converted address func a2u(addr std.Address) pusers.AddressOrName { return pusers.AddressOrName(addr) } -func prevRealm() string { - return std.PrevRealm().PkgPath() +// derivePkgAddr derives the Realm address from it's pkgpath parameter +func derivePkgAddr(pkgPath string) std.Address { + return std.DerivePkgAddr(pkgPath) } -func isUserCall() bool { - return std.PrevRealm().IsUser() +// getOrigPkgAddr returns the original package address. +// In position contract, original package address is the position address. +func getOrigPkgAddr() std.Address { + return consts.POSITION_ADDR } -func getPrev() (string, string) { +// getPrevRealm returns object of the previous realm. +func getPrevRealm() std.Realm { + return std.PrevRealm() +} + +// getPrevAddr returns the address of the previous realm. +func getPrevAddr() std.Address { + return std.PrevRealm().Addr() +} + +// getPrev returns the address and package path of the previous realm. +func getPrevAsString() (string, string) { prev := std.PrevRealm() return prev.Addr().String(), prev.PkgPath() } +// isUserCall returns true if the caller is a user. +func isUserCall() bool { + return std.PrevRealm().IsUser() +} + +// checkDeadline checks if the deadline is expired. +// If the deadline is expired, it panics. +// The deadline is expired if the current time is greater than the deadline. +// Input: +// - deadline: the deadline to check +func checkDeadline(deadline int64) { + now := time.Now().Unix() + if now > deadline { + panic(newErrorWithDetail( + errExpired, + ufmt.Sprintf("transaction too old, now(%d) > deadline(%d)", now, deadline), + )) + } +} + +// assertOnlyUserOrStaker panics if the caller is not a user or staker. +func assertOnlyUserOrStaker(caller std.Realm) { + if !caller.IsUser() { + if err := common.StakerOnly(caller.Addr()); err != nil { + panic(newErrorWithDetail( + errNoPermission, + ufmt.Sprintf("from (%s)", caller.Addr()), + )) + } + } +} + +// assertOnlyNotHalted panics if the contract is halted. +func assertOnlyNotHalted() { + common.IsHalted() +} + // assertOnlyValidAddress panics if the address is invalid. func assertOnlyValidAddress(addr std.Address) { if !addr.IsValid() { @@ -44,3 +95,17 @@ func assertOnlyValidAddress(addr std.Address) { )) } } + +// assertOnlyValidAddress panics if the address is invalid or previous address is not +// different from the other address. +func assertOnlyValidAddressWith(prevAddr, otherAddr std.Address) { + assertOnlyValidAddress(prevAddr) + assertOnlyValidAddress(otherAddr) + + if prevAddr != otherAddr { + panic(newErrorWithDetail( + errInvalidAddress, + ufmt.Sprintf("(%s, %s)", prevAddr, otherAddr), + )) + } +} diff --git a/position/utils_test.gno b/position/utils_test.gno new file mode 100644 index 000000000..26051a271 --- /dev/null +++ b/position/utils_test.gno @@ -0,0 +1,344 @@ +package position + +import ( + "std" + "testing" + + "gno.land/p/demo/uassert" + pusers "gno.land/p/demo/users" + "gno.land/r/demo/users" + "gno.land/r/gnoswap/v1/common" + "gno.land/r/gnoswap/v1/consts" +) + +func TestA2u(t *testing.T) { + var ( + addr = std.Address("g1lmvrrrr4er2us84h2732sru76c9zl2nvknha8c") + ) + + tests := []struct { + name string + input std.Address + expected pusers.AddressOrName + }{ + { + name: "Success - a2u", + input: addr, + expected: pusers.AddressOrName(addr), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := a2u(tc.input) + uassert.Equal(t, users.Resolve(got).String(), users.Resolve(tc.expected).String()) + }) + } +} + +func TestDerivePkgAddr(t *testing.T) { + var ( + pkgPath = "gno.land/r/gnoswap/v1/position" + ) + tests := []struct { + name string + input string + expected string + }{ + { + name: "Success - derivePkgAddr", + input: pkgPath, + expected: "g1q646ctzhvn60v492x8ucvyqnrj2w30cwh6efk5", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := derivePkgAddr(tc.input) + uassert.Equal(t, got.String(), tc.expected) + }) + } +} + +func TestGetOrigPkgAddr(t *testing.T) { + tests := []struct { + name string + expected std.Address + }{ + { + name: "Success - getOrigPkgAddr", + expected: consts.POSITION_ADDR, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := getOrigPkgAddr() + uassert.Equal(t, got, tc.expected) + }) + } +} + +func TestGetPrevRealm(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + expected []string + }{ + { + name: "Success - prevRealm is User", + originCaller: consts.ADMIN, + expected: []string{"g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", ""}, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(std.Address(tc.originCaller)) + got := getPrevRealm() + uassert.Equal(t, got.Addr().String(), tc.expected[0]) + uassert.Equal(t, got.PkgPath(), tc.expected[1]) + }) + } +} + +func TestGetPrevAddr(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + expected std.Address + }{ + { + name: "Success - prev Address is User", + originCaller: consts.ADMIN, + expected: "g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(std.Address(tc.originCaller)) + got := getPrevAddr() + uassert.Equal(t, got.String(), tc.expected.String()) + }) + } +} + +func TestGetPrevAsString(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + expected []string + }{ + { + name: "Success - prev Realm of user info as string", + originCaller: consts.ADMIN, + expected: []string{"g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", ""}, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(std.Address(tc.originCaller)) + got1, got2 := getPrevAsString() + uassert.Equal(t, got1, tc.expected[0]) + uassert.Equal(t, got2, tc.expected[1]) + }) + } +} + +func TestIsUserCall(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + originPkgPath string + expected bool + }{ + { + name: "Success - User Call", + originCaller: consts.ADMIN, + expected: true, + }, + { + name: "Failure - Not User Call", + originCaller: consts.ROUTER_ADDR, + originPkgPath: consts.ROUTER_PATH, + expected: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(tc.originCaller) + if !tc.expected { + std.TestSetRealm(std.NewCodeRealm(tc.originPkgPath)) + } + got := isUserCall() + uassert.Equal(t, got, tc.expected) + }) + } +} + +func TestCheckDeadline(t *testing.T) { + tests := []struct { + name string + deadline int64 + now int64 + expected string + }{ + { + name: "Success - checkDeadline", + deadline: 1234567890 + 100, + now: 1234567890, + expected: "", + }, + { + name: "Failure - checkDeadline", + deadline: 1234567890 - 100, + now: 1234567890, + expected: "[GNOSWAP-POSITION-007] transaction expired || transaction too old, now(1234567890) > deadline(1234567790)", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.expected != "" { + uassert.PanicsWithMessage(t, tc.expected, func() { + checkDeadline(tc.deadline) + }) + } else { + uassert.NotPanics(t, func() { + checkDeadline(tc.deadline) + }) + } + }) + } +} + +func TestAssertOnlyUserOrStaker(t *testing.T) { + tests := []struct { + name string + originCaller std.Address + expected bool + }{ + { + name: "Failure - Not User or Staker", + originCaller: consts.ROUTER_ADDR, + expected: false, + }, + { + name: "Success - User Call", + originCaller: consts.ADMIN, + expected: true, + }, + { + name: "Success - Staker Call", + originCaller: consts.STAKER_ADDR, + expected: true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + std.TestSetOrigCaller(tc.originCaller) + assertOnlyUserOrStaker(std.PrevRealm()) + }) + } +} + +func TestAssertOnlyNotHalted(t *testing.T) { + tests := []struct { + name string + expected bool + panicMsg string + }{ + { + name: "Failure - Halted", + expected: false, + panicMsg: "[GNOSWAP-COMMON-002] halted || gnoswap halted", + }, + { + name: "Success - Not Halted", + expected: true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.expected { + uassert.NotPanics(t, func() { + assertOnlyNotHalted() + }) + } else { + std.TestSetRealm(std.NewUserRealm(users.Resolve(admin))) + common.SetHaltByAdmin(true) + uassert.PanicsWithMessage(t, tc.panicMsg, func() { + assertOnlyNotHalted() + }) + common.SetHaltByAdmin(false) + } + }) + } +} + +func TestAssertOnlyValidAddress(t *testing.T) { + tests := []struct { + name string + addr std.Address + expected bool + errorMsg string + }{ + { + name: "Success - valid address", + addr: consts.ADMIN, + expected: true, + }, + { + name: "Failure - invalid address", + addr: "g1lmvrrrr4er2us84h2732sru76c9zl2nvknha8", // invalid length + expected: false, + errorMsg: "[GNOSWAP-POSITION-011] invalid address || (g1lmvrrrr4er2us84h2732sru76c9zl2nvknha8)", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.expected { + uassert.NotPanics(t, func() { + assertOnlyValidAddress(tc.addr) + }) + } else { + uassert.PanicsWithMessage(t, tc.errorMsg, func() { + assertOnlyValidAddress(tc.addr) + }) + } + }) + } +} + +func TestAssertOnlyValidAddressWith(t *testing.T) { + tests := []struct { + name string + addr std.Address + other std.Address + expected bool + errorMsg string + }{ + { + name: "Success - validation address check to compare with other address", + addr: consts.ADMIN, + other: std.Address("g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d"), + expected: true, + }, + { + name: "Failure - two address is different", + addr: "g1lmvrrrr4er2us84h2732sru76c9zl2nvknha8", + other: "g17290cwvmrapvp869xfnhhawa8sm9edpufzat7d", + expected: false, + errorMsg: "[GNOSWAP-POSITION-011] invalid address || (g1lmvrrrr4er2us84h2732sru76c9zl2nvknha8)", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.expected { + uassert.NotPanics(t, func() { + assertOnlyValidAddressWith(tc.addr, tc.other) + }) + } else { + uassert.PanicsWithMessage(t, tc.errorMsg, func() { + assertOnlyValidAddressWith(tc.addr, tc.other) + }) + } + }) + } +}