1
1
package gnsmath
2
2
3
3
import (
4
+ "gno.land/p/gnoswap/consts"
4
5
i256 "gno.land/p/gnoswap/int256"
5
6
u256 "gno.land/p/gnoswap/uint256"
6
7
)
@@ -11,47 +12,60 @@ const (
11
12
)
12
13
13
14
var (
14
- q96 = u256.MustFromDecimal(Q96)
15
- max160 = u256.MustFromDecimal(MAX_UINT160)
15
+ q96 = u256.MustFromDecimal(consts. Q96)
16
+ max160 = u256.MustFromDecimal(consts. MAX_UINT160)
16
17
)
17
18
18
19
// getNextPriceAmount0Add calculates the next sqrt price when we are adding token0.
19
20
// Preserves the rounding-up logic. No in-place mutation of input arguments.
20
21
func getNextPriceAmount0Add(
21
- sqrtPX96 , liquidity, amount *u256.Uint,
22
+ currentSqrtPriceX96 , liquidity, amountToAdd *u256.Uint,
22
23
) *u256.Uint {
23
- numerator1 := new(u256.Uint).Lsh(liquidity, Q96_RESOLUTION)
24
- product := new(u256.Uint).Mul(amount, sqrtPX96)
25
-
26
- // overflow check
27
- if new(u256.Uint).Div(product, amount).Eq(sqrtPX96) {
28
- denominator := new(u256.Uint).Add(numerator1, product)
29
- if denominator.Gte(numerator1) {
30
- return u256.MulDivRoundingUp(numerator1, sqrtPX96, denominator)
24
+ // Shift liquidity left by Q96 bits to increase precision
25
+ liquidityShifted := new(u256.Uint).Lsh(liquidity, Q96_RESOLUTION)
26
+ // Multiply the amount to add by the current square root price
27
+ amountTimesSqrtPrice := new(u256.Uint).Mul(amountToAdd, currentSqrtPriceX96)
28
+
29
+ // Overflow check: Ensure (amountTimesSqrtPrice / amountToAdd) == currentSqrtPriceX96
30
+ if new(u256.Uint).Div(amountTimesSqrtPrice, amountToAdd).Eq(currentSqrtPriceX96) {
31
+ // Compute denominator: liquidityShifted + (amountToAdd * currentSqrtPriceX96)
32
+ denominator := new(u256.Uint).Add(liquidityShifted, amountTimesSqrtPrice)
33
+ if denominator.Gte(liquidityShifted) {
34
+ return u256.MulDivRoundingUp(liquidityShifted, currentSqrtPriceX96, denominator)
31
35
}
32
36
}
33
37
34
- divValue := new(u256.Uint).Div(numerator1, sqrtPX96)
35
- addValue := new(u256.Uint).Add(divValue, amount)
38
+ // Alternative computation path: (liquidityShifted / currentSqrtPriceX96) + amountToAdd
39
+ divValue := new(u256.Uint).Div(liquidityShifted, currentSqrtPriceX96)
40
+ addValue := new(u256.Uint).Add(divValue, amountToAdd)
36
41
37
- return u256.DivRoundingUp(numerator1, addValue)
42
+ // Compute the next square root price using division rounding up
43
+ return u256.DivRoundingUp(liquidityShifted, addValue)
38
44
}
39
45
40
46
// getNextPriceAmount0Remove calculates the next sqrt price when we are removing token0.
41
47
// Preserves the rounding-up logic. No in-place mutation of input arguments.
42
48
func getNextPriceAmount0Remove(
43
- sqrtPX96 , liquidity, amount *u256.Uint,
49
+ currentSqrtPriceX96 , liquidity, amountToRemove *u256.Uint,
44
50
) *u256.Uint {
45
- numerator1 := new(u256.Uint).Lsh(liquidity, Q96_RESOLUTION)
46
- product := new(u256.Uint).Mul(amount, sqrtPX96)
51
+ // Shift liquidity left by Q96 bits to increase precision
52
+ liquidityShifted := new(u256.Uint).Lsh(liquidity, Q96_RESOLUTION)
53
+
54
+ // Multiply the amount to remove by the current square root price
55
+ amountTimesSqrtPrice := new(u256.Uint).Mul(amountToRemove, currentSqrtPriceX96)
47
56
48
57
// Conditions must hold: product/amount == sqrtPX96 and numerator1 > product
49
- if !new(u256.Uint).Div(product, amount).Eq(sqrtPX96) || !numerator1.Gt(product) {
58
+ if !new(u256.Uint).Div(amountTimesSqrtPrice, amountToRemove).Eq(currentSqrtPriceX96) ||
59
+ !liquidityShifted.Gt(amountTimesSqrtPrice) {
50
60
panic(errInvalidPoolSqrtPrice)
51
61
}
52
62
53
- denominator := new(u256.Uint).Sub(numerator1, product)
54
- nextSqrtPrice := u256.MulDivRoundingUp(numerator1, sqrtPX96, denominator)
63
+ // Compute denominator: liquidityShifted - (amountToRemove * currentSqrtPriceX96)
64
+ denominator := new(u256.Uint).Sub(liquidityShifted, amountTimesSqrtPrice)
65
+ // Calculate next square root price: (liquidityShifted * currentSqrtPriceX96) / denominator (with rounding up)
66
+ nextSqrtPrice := u256.MulDivRoundingUp(liquidityShifted, currentSqrtPriceX96, denominator)
67
+
68
+ // Check for overflow
55
69
if nextSqrtPrice.Gt(max160) {
56
70
panic(errNextSqrtPriceOverflow)
57
71
}
@@ -95,11 +109,9 @@ func getNextSqrtPriceFromAmount0RoundingUp(
95
109
if add {
96
110
return getNextPriceAmount0Add(sqrtPX96, liquidity, amount)
97
111
}
98
-
99
112
return getNextPriceAmount0Remove(sqrtPX96, liquidity, amount)
100
113
}
101
114
102
-
103
115
// getNextPriceAmount1Add calculates the next sqrt price when adding token1.
104
116
// Preserves rounding-down logic for the final result.
105
117
func getNextPriceAmount1Add(
@@ -172,9 +184,16 @@ func getNextPriceAmount1Remove(
172
184
// - The function uses high-precision math (MulDiv and DivRoundingUp) to handle division and prevent precision loss.
173
185
// - The function validates input conditions and panics if the state is invalid.
174
186
func getNextSqrtPriceFromAmount1RoundingDown(
175
- sqrtPX96, liquidity, amount *u256.Uint,
187
+ sqrtPX96,
188
+ liquidity,
189
+ amount *u256.Uint,
176
190
add bool,
177
191
) *u256.Uint {
192
+ // Shortcut: if no amount, return original price
193
+ if amount.IsZero() {
194
+ return sqrtPX96
195
+ }
196
+
178
197
if add {
179
198
return getNextPriceAmount1Add(sqrtPX96, liquidity, amount)
180
199
}
@@ -367,7 +386,7 @@ func GetAmount0DeltaStr(
367
386
) string {
368
387
if liquidity.IsNeg() {
369
388
u := getAmount0DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), false)
370
- if u.Gt(u256.MustFromDecimal(MAX_INT256)) {
389
+ if u.Gt(u256.MustFromDecimal(consts. MAX_INT256)) {
371
390
// if u > (2**255 - 1), cannot cast to int256
372
391
panic(errAmount0DeltaOverflow)
373
392
}
@@ -376,7 +395,7 @@ func GetAmount0DeltaStr(
376
395
}
377
396
378
397
u := getAmount0DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), true)
379
- if u.Gt(u256.MustFromDecimal(MAX_INT256)) {
398
+ if u.Gt(u256.MustFromDecimal(consts. MAX_INT256)) {
380
399
// if u > (2**255 - 1), cannot cast to int256
381
400
panic(errAmount0DeltaOverflow)
382
401
}
@@ -410,7 +429,7 @@ func GetAmount1DeltaStr(
410
429
) string {
411
430
if liquidity.IsNeg() {
412
431
u := getAmount1DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), false)
413
- if u.Gt(u256.MustFromDecimal(MAX_INT256)) {
432
+ if u.Gt(u256.MustFromDecimal(consts. MAX_INT256)) {
414
433
// if u > (2**255 - 1), cannot cast to int256
415
434
panic(errAmount1DeltaOverflow)
416
435
}
@@ -419,7 +438,7 @@ func GetAmount1DeltaStr(
419
438
}
420
439
421
440
u := getAmount1DeltaHelper(sqrtRatioAX96, sqrtRatioBX96, liquidity.Abs(), true)
422
- if u.Gt(u256.MustFromDecimal(MAX_INT256)) {
441
+ if u.Gt(u256.MustFromDecimal(consts. MAX_INT256)) {
423
442
// if u > (2**255 - 1), cannot cast to int256
424
443
panic(errAmount1DeltaOverflow)
425
444
}
0 commit comments