Skip to content

Commit

Permalink
GSW-856 feat: mutex lock
Browse files Browse the repository at this point in the history
  • Loading branch information
r3v4s committed Feb 7, 2024
1 parent 8f2f0ed commit 3b334c8
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 46 deletions.
90 changes: 46 additions & 44 deletions pool/pool.gno
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,26 @@ import (
"std"

"gno.land/p/demo/ufmt"

"gno.land/r/demo/consts"

g "gno.land/r/demo/gov"
)

// only position contract can call this function
func Mint(
pToken0Path string,
pToken1Path string,
pFee uint16,
token0Path string,
token1Path string,
fee uint16,
recipient std.Address,
tickLower int32,
tickUpper int32,
liquidityAmount bigint,
) (bigint, bigint) {
require(PrevRealmPath() == "gno.land/r/demo/position", ufmt.Sprintf("[POOL] pool.gno__Mint() || PrevRealmPath(%s) == \"gno.land/r/demo/position\"", PrevRealmPath()))
require(PrevRealmPath() == consts.POSITION_PATH, ufmt.Sprintf("[POOL] pool.gno__Mint() || PrevRealmPath(%s) == \"gno.land/r/demo/position\"", PrevRealmPath()))

require(liquidityAmount > 0, ufmt.Sprintf("[POOL] pool.gno__Mint() || liquidityAmount(%d) > 0", liquidityAmount))

pool := GetPool(pToken0Path, pToken1Path, pFee)
pool := GetPool(token0Path, token1Path, fee)
_, amount0Int, amount1Int := pool.modifyPosition(
ModifyPositionParams{
recipient, // owner
Expand Down Expand Up @@ -87,18 +87,18 @@ func Mint(

// only position contract can call this function
func Burn(
pToken0Path string,
pToken1Path string,
pFee uint16,
token0Path string,
token1Path string,
fee uint16,
tickLower int32,
tickUpper int32,
amount bigint,
) (bigint, bigint) {
require(PrevRealmPath() == "gno.land/r/demo/position", ufmt.Sprintf("[POOL] pool.gno__Burn() || caller(%s) must be position contract", PrevRealmPath()))
require(PrevRealmPath() == consts.POSITION_PATH, ufmt.Sprintf("[POOL] pool.gno__Burn() || caller(%s) must be position contract", PrevRealmPath()))

requireUnsigned(amount, ufmt.Sprintf("[POOL] pool.gno__Burn() || amount(%d) >= 0", amount))

pool := GetPool(pToken0Path, pToken1Path, pFee)
pool := GetPool(token0Path, token1Path, fee)

position, amount0Int, amount1Int := pool.modifyPosition(
ModifyPositionParams{
Expand All @@ -121,73 +121,75 @@ func Burn(
key := positionGetKey(PrevRealmAddr(), tickLower, tickUpper)
pool.positions[key] = position

// actual token transfer happens in Collect()
return amount0, amount1
}

// only position contract can call this function
func Collect(
pToken0Path string,
pToken1Path string,
pFee uint16,
token0Path string,
token1Path string,
fee uint16,
recipient std.Address,
tickLower int32,
tickUpper int32,
amount0Requested bigint,
amount1Requested bigint,
) (bigint, bigint) {
require(PrevRealmPath() == "gno.land/r/demo/position", ufmt.Sprintf("[POOL] pool.gno__Collect() || caller(%s) must be position contract(gno.land/r/demo/position)", PrevRealmPath()))
require(PrevRealmPath() == consts.POSITION_PATH, ufmt.Sprintf("[POOL] pool.gno__Collect() || caller(%s) must be position contract(gno.land/r/demo/position)", PrevRealmPath()))

requireUnsigned(amount0Requested, ufmt.Sprintf("pool.gno__Collect() || amount0Requested(%d) >= 0", amount0Requested))
requireUnsigned(amount1Requested, ufmt.Sprintf("pool.gno__Collect() || amount1Requested(%d) >= 0", amount1Requested))

pool := GetPool(pToken0Path, pToken1Path, pFee)
pool := GetPool(token0Path, token1Path, fee)

key := positionGetKey(PrevRealmAddr(), tickLower, tickUpper)
position, exist := pool.positions[key]
require(exist, ufmt.Sprintf("[POOL] pool.gno__Collect() || position(%s) does not exist", key))

// Smallest of three: amount0Requested, position.tokensOwed0, pool.balances.token0
amount0 := min(amount0Requested, position.tokensOwed0)
amount0 = min(amount0, pool.balances.token0)
requireUnsigned(amount0, ufmt.Sprintf("[POOL] pool.gno__Collect() || amount0(%d) >= 0", amount0))

amount1 := min(amount1Requested, position.tokensOwed1)
requireUnsigned(amount1, ufmt.Sprintf("[POOL] pool.gno__Collect() || amount1(%d) >= 0", amount1))

require(pool.balances.token0 >= amount0, ufmt.Sprintf("[POOL] pool.gno__Collect() || pool.balances.token0(%d) >= amount0(%d)", pool.balances.token0, amount0))
// Update state first then transfer
position.tokensOwed0 -= amount0
pool.balances.token0 -= amount0
transferByRegisterCall(pool.token0Path, recipient, uint64(amount0))
requireUnsigned(pool.balances.token0, ufmt.Sprintf("[POOL] pool.gno__Burn() || pool.balances.token0(%d) >= 0", pool.balances.token0))

require(pool.balances.token1 >= amount1, ufmt.Sprintf("[POOL] pool.gno__Collect() || pool.balances.token1(%d) >= amount1(%d)", pool.balances.token1, amount1))
transferByRegisterCall(pool.token1Path, recipient, uint64(amount1))
// Smallest of three: amount0Requested, position.tokensOwed0, pool.balances.token0
amount1 := min(amount1Requested, position.tokensOwed1)
amount1 = min(amount1, pool.balances.token1)
requireUnsigned(amount1, ufmt.Sprintf("[POOL] pool.gno__Collect() || amount1(%d) >= 0", amount1))

// adjust position
position.tokensOwed0 -= amount0
// Update state first then transfer
position.tokensOwed1 -= amount1
pool.positions[key] = position

// adjust pool
pool.balances.token0 -= amount0
pool.balances.token1 -= amount1

requireUnsigned(pool.balances.token0, ufmt.Sprintf("[POOL] pool.gno__Burn() || pool.balances.token0(%d) >= 0", pool.balances.token0))
transferByRegisterCall(pool.token1Path, recipient, uint64(amount1))
requireUnsigned(pool.balances.token1, ufmt.Sprintf("[POOL] pool.gno__Burn() || pool.balances.token1(%d) >= 0", pool.balances.token1))

pool.positions[key] = position

return amount0, amount1
}

func Swap(
pToken0Path string,
pToken1Path string,
pFee uint16,
token0Path string,
token1Path string,
fee uint16,
recipient std.Address,
zeroForOne bool,
amountSpecified bigint,
sqrtPriceLimitX96 bigint,
payer std.Address, // router
) (bigint, bigint) {
require(PrevRealmPath() == "gno.land/r/demo/router", ufmt.Sprintf("[POOL] pool.gno__Swap() || caller(%s) must be router contract(gno.land/r/demo/router)", PrevRealmPath()))
require(PrevRealmPath() == consts.ROUTER_PATH, ufmt.Sprintf("[POOL] pool.gno__Swap() || caller(%s) must be router contract(gno.land/r/demo/router)", PrevRealmPath()))

// early panic
require(amountSpecified != 0, "[POOL] pool.gno__Swap() || amountSpecified can't be zero")

pool := GetPool(pToken0Path, pToken1Path, pFee)
pool := GetPool(token0Path, token1Path, fee)
slot0Start := pool.slot0
require(slot0Start.unlocked, "[POOL] pool.gno__Swap() || slot0 must be unlocked")

Expand Down Expand Up @@ -377,9 +379,9 @@ func Swap(
if zeroForOne {
// payer > pool
balance0Before := bigint(balanceOfByRegisterCall(pool.token0Path, GetOrigPkgAddr()))
ok := transferFromByRegisterCall(pool.token0Path, payer, consts.ADDR_POOL, uint64(amount0))
ok := transferFromByRegisterCall(pool.token0Path, payer, consts.POOL_ADDR, uint64(amount0))
if !ok {
panic("[POOL] pool.gno__Swap() || transferFromByRegisterCall(pool.token0Path, payer, ADDR_POOL, uint64(amount0)) failed")
panic("[POOL] pool.gno__Swap() || transferFromByRegisterCall(pool.token0Path, payer, POOL_ADDR, uint64(amount0)) failed")
}

require(
Expand All @@ -393,7 +395,7 @@ func Swap(
require(pool.balances.token0 >= 0, ufmt.Sprintf("[POOL] pool.gno__Swap() || pool.balances.token0(%d) >= 0__#1", pool.balances.token0))

if amount1 < 0 { // pool > recipient
require(pool.balances.token1 > (-amount1), ufmt.Sprintf("[POOL] pool.gno__Swap() || pool.balances.token1(%d) > (-1 * amount1)(%d)", pool.balances.token1, (-amount1)))
require(pool.balances.token1 > -amount1, ufmt.Sprintf("[POOL] pool.gno__Swap() || pool.balances.token1(%d) > (-1 * amount1)(%d)", pool.balances.token1, (-amount1)))
ok := transferByRegisterCall(pool.token1Path, recipient, uint64(-amount1))
if !ok {
panic("[POOL] pool.gno__Swap() || transferByRegisterCall(pool.token1Path, recipient, uint64(-amount1)) failed")
Expand All @@ -406,9 +408,9 @@ func Swap(
} else {
// payer > pool
balance1Before := bigint(balanceOfByRegisterCall(pool.token1Path, GetOrigPkgAddr()))
ok := transferFromByRegisterCall(pool.token1Path, payer, consts.ADDR_POOL, uint64(amount1))
ok := transferFromByRegisterCall(pool.token1Path, payer, consts.POOL_ADDR, uint64(amount1))
if !ok {
panic("[POOL] pool.gno__Swap() || transferFromByRegisterCall(pool.token1Path, payer, ADDR_POOL, uint64(amount1)) failed")
panic("[POOL] pool.gno__Swap() || transferFromByRegisterCall(pool.token1Path, payer, POOL_ADDR, uint64(amount1)) failed")
}

require(
Expand Down Expand Up @@ -460,9 +462,9 @@ func SetFeeProtocol(

// ADMIN
func CollectProtocol(
pToken0Path string,
pToken1Path string,
pFee uint16,
token0Path string,
token1Path string,
fee uint16,
recipient std.Address,
amount0Requested bigint,
amount1Requested bigint,
Expand All @@ -471,7 +473,7 @@ func CollectProtocol(
requireUnsigned(amount1Requested, ufmt.Sprintf("[POOL] pool.gno__CollectProtocol() || amount1Requested(%d) >= 0", amount1Requested))
require(isAdmin(PrevRealmAddr()), ufmt.Sprintf("[POOL] pool.gno__CollectProtocol() || caller(%s) must be admin", PrevRealmAddr()))

pool := GetPool(pToken0Path, pToken1Path, pFee)
pool := GetPool(token0Path, token1Path, fee)

amount0 := min(amount0Requested, pool.protocolFees.token0)
requireUnsigned(amount0, ufmt.Sprintf("[POOL] pool.gno__CollectProtocol() || amount0(%d) >= 0", amount0))
Expand Down
22 changes: 20 additions & 2 deletions pool/pool_register.gno
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (

var registered = []GRC20Pair{}

var locked bool // mutex flag

type GRC20Interface interface {
Transfer() func(to users.AddressOrName, amount uint64)
TransferFrom() func(from, to users.AddressOrName, amount uint64)
Expand Down Expand Up @@ -95,8 +97,16 @@ func transferByRegisterCall(pkgPath string, to std.Address, amount uint64) bool
return false
}

registered[i].igrc20.Transfer()(users.AddressOrName(to), amount)
if !locked {
locked = true
registered[i].igrc20.Transfer()(users.AddressOrName(to), amount)

defer func() {
locked = false
}()
} else {
panic("[POOl] pool_register.gno transferByRegisterCall: locked")
}
return true
}

Expand All @@ -108,8 +118,16 @@ func transferFromByRegisterCall(pkgPath string, from, to std.Address, amount uin
return false
}

registered[i].igrc20.TransferFrom()(users.AddressOrName(from), users.AddressOrName(to), amount)
if !locked {
locked = true
registered[i].igrc20.TransferFrom()(users.AddressOrName(from), users.AddressOrName(to), amount)

defer func() {
locked = false
}()
} else {
panic("[POOl] pool_register.gno transferFromByRegisterCall: locked")
}
return true
}

Expand Down

0 comments on commit 3b334c8

Please sign in to comment.