From 3b334c85dff38ef40b70bfb608289cbebe9f0517 Mon Sep 17 00:00:00 2001 From: n3wbie Date: Wed, 7 Feb 2024 18:31:43 +0900 Subject: [PATCH] GSW-856 feat: mutex lock --- pool/pool.gno | 90 +++++++++++++++++++++--------------------- pool/pool_register.gno | 22 ++++++++++- 2 files changed, 66 insertions(+), 46 deletions(-) diff --git a/pool/pool.gno b/pool/pool.gno index 1d626c7a3..fafe33b72 100644 --- a/pool/pool.gno +++ b/pool/pool.gno @@ -4,26 +4,26 @@ import ( "std" "gno.land/p/demo/ufmt" - "gno.land/r/demo/consts" + g "gno.land/r/demo/gov" ) // only position contract can call this function func Mint( - pToken0Path string, - pToken1Path string, - pFee uint16, + token0Path string, + token1Path string, + fee uint16, recipient std.Address, tickLower int32, tickUpper int32, liquidityAmount bigint, ) (bigint, bigint) { - require(PrevRealmPath() == "gno.land/r/demo/position", ufmt.Sprintf("[POOL] pool.gno__Mint() || PrevRealmPath(%s) == \"gno.land/r/demo/position\"", PrevRealmPath())) + require(PrevRealmPath() == consts.POSITION_PATH, ufmt.Sprintf("[POOL] pool.gno__Mint() || PrevRealmPath(%s) == \"gno.land/r/demo/position\"", PrevRealmPath())) require(liquidityAmount > 0, ufmt.Sprintf("[POOL] pool.gno__Mint() || liquidityAmount(%d) > 0", liquidityAmount)) - pool := GetPool(pToken0Path, pToken1Path, pFee) + pool := GetPool(token0Path, token1Path, fee) _, amount0Int, amount1Int := pool.modifyPosition( ModifyPositionParams{ recipient, // owner @@ -87,18 +87,18 @@ func Mint( // only position contract can call this function func Burn( - pToken0Path string, - pToken1Path string, - pFee uint16, + token0Path string, + token1Path string, + fee uint16, tickLower int32, tickUpper int32, amount bigint, ) (bigint, bigint) { - require(PrevRealmPath() == "gno.land/r/demo/position", ufmt.Sprintf("[POOL] pool.gno__Burn() || caller(%s) must be position contract", PrevRealmPath())) + require(PrevRealmPath() == consts.POSITION_PATH, ufmt.Sprintf("[POOL] pool.gno__Burn() || caller(%s) must be position contract", PrevRealmPath())) requireUnsigned(amount, ufmt.Sprintf("[POOL] pool.gno__Burn() || amount(%d) >= 0", amount)) - pool := GetPool(pToken0Path, pToken1Path, pFee) + pool := GetPool(token0Path, token1Path, fee) position, amount0Int, amount1Int := pool.modifyPosition( ModifyPositionParams{ @@ -121,73 +121,75 @@ func Burn( key := positionGetKey(PrevRealmAddr(), tickLower, tickUpper) pool.positions[key] = position + // actual token transfer happens in Collect() return amount0, amount1 } // only position contract can call this function func Collect( - pToken0Path string, - pToken1Path string, - pFee uint16, + token0Path string, + token1Path string, + fee uint16, recipient std.Address, tickLower int32, tickUpper int32, amount0Requested bigint, amount1Requested bigint, ) (bigint, bigint) { - require(PrevRealmPath() == "gno.land/r/demo/position", ufmt.Sprintf("[POOL] pool.gno__Collect() || caller(%s) must be position contract(gno.land/r/demo/position)", PrevRealmPath())) + require(PrevRealmPath() == consts.POSITION_PATH, ufmt.Sprintf("[POOL] pool.gno__Collect() || caller(%s) must be position contract(gno.land/r/demo/position)", PrevRealmPath())) requireUnsigned(amount0Requested, ufmt.Sprintf("pool.gno__Collect() || amount0Requested(%d) >= 0", amount0Requested)) requireUnsigned(amount1Requested, ufmt.Sprintf("pool.gno__Collect() || amount1Requested(%d) >= 0", amount1Requested)) - pool := GetPool(pToken0Path, pToken1Path, pFee) + pool := GetPool(token0Path, token1Path, fee) key := positionGetKey(PrevRealmAddr(), tickLower, tickUpper) position, exist := pool.positions[key] require(exist, ufmt.Sprintf("[POOL] pool.gno__Collect() || position(%s) does not exist", key)) + // Smallest of three: amount0Requested, position.tokensOwed0, pool.balances.token0 amount0 := min(amount0Requested, position.tokensOwed0) + amount0 = min(amount0, pool.balances.token0) requireUnsigned(amount0, ufmt.Sprintf("[POOL] pool.gno__Collect() || amount0(%d) >= 0", amount0)) - amount1 := min(amount1Requested, position.tokensOwed1) - requireUnsigned(amount1, ufmt.Sprintf("[POOL] pool.gno__Collect() || amount1(%d) >= 0", amount1)) - - require(pool.balances.token0 >= amount0, ufmt.Sprintf("[POOL] pool.gno__Collect() || pool.balances.token0(%d) >= amount0(%d)", pool.balances.token0, amount0)) + // Update state first then transfer + position.tokensOwed0 -= amount0 + pool.balances.token0 -= amount0 transferByRegisterCall(pool.token0Path, recipient, uint64(amount0)) + requireUnsigned(pool.balances.token0, ufmt.Sprintf("[POOL] pool.gno__Burn() || pool.balances.token0(%d) >= 0", pool.balances.token0)) - require(pool.balances.token1 >= amount1, ufmt.Sprintf("[POOL] pool.gno__Collect() || pool.balances.token1(%d) >= amount1(%d)", pool.balances.token1, amount1)) - transferByRegisterCall(pool.token1Path, recipient, uint64(amount1)) + // Smallest of three: amount0Requested, position.tokensOwed0, pool.balances.token0 + amount1 := min(amount1Requested, position.tokensOwed1) + amount1 = min(amount1, pool.balances.token1) + requireUnsigned(amount1, ufmt.Sprintf("[POOL] pool.gno__Collect() || amount1(%d) >= 0", amount1)) - // adjust position - position.tokensOwed0 -= amount0 + // Update state first then transfer position.tokensOwed1 -= amount1 - pool.positions[key] = position - - // adjust pool - pool.balances.token0 -= amount0 pool.balances.token1 -= amount1 - - requireUnsigned(pool.balances.token0, ufmt.Sprintf("[POOL] pool.gno__Burn() || pool.balances.token0(%d) >= 0", pool.balances.token0)) + transferByRegisterCall(pool.token1Path, recipient, uint64(amount1)) requireUnsigned(pool.balances.token1, ufmt.Sprintf("[POOL] pool.gno__Burn() || pool.balances.token1(%d) >= 0", pool.balances.token1)) + pool.positions[key] = position + return amount0, amount1 } func Swap( - pToken0Path string, - pToken1Path string, - pFee uint16, + token0Path string, + token1Path string, + fee uint16, recipient std.Address, zeroForOne bool, amountSpecified bigint, sqrtPriceLimitX96 bigint, payer std.Address, // router ) (bigint, bigint) { - require(PrevRealmPath() == "gno.land/r/demo/router", ufmt.Sprintf("[POOL] pool.gno__Swap() || caller(%s) must be router contract(gno.land/r/demo/router)", PrevRealmPath())) + require(PrevRealmPath() == consts.ROUTER_PATH, ufmt.Sprintf("[POOL] pool.gno__Swap() || caller(%s) must be router contract(gno.land/r/demo/router)", PrevRealmPath())) + // early panic require(amountSpecified != 0, "[POOL] pool.gno__Swap() || amountSpecified can't be zero") - pool := GetPool(pToken0Path, pToken1Path, pFee) + pool := GetPool(token0Path, token1Path, fee) slot0Start := pool.slot0 require(slot0Start.unlocked, "[POOL] pool.gno__Swap() || slot0 must be unlocked") @@ -377,9 +379,9 @@ func Swap( if zeroForOne { // payer > pool balance0Before := bigint(balanceOfByRegisterCall(pool.token0Path, GetOrigPkgAddr())) - ok := transferFromByRegisterCall(pool.token0Path, payer, consts.ADDR_POOL, uint64(amount0)) + ok := transferFromByRegisterCall(pool.token0Path, payer, consts.POOL_ADDR, uint64(amount0)) if !ok { - panic("[POOL] pool.gno__Swap() || transferFromByRegisterCall(pool.token0Path, payer, ADDR_POOL, uint64(amount0)) failed") + panic("[POOL] pool.gno__Swap() || transferFromByRegisterCall(pool.token0Path, payer, POOL_ADDR, uint64(amount0)) failed") } require( @@ -393,7 +395,7 @@ func Swap( require(pool.balances.token0 >= 0, ufmt.Sprintf("[POOL] pool.gno__Swap() || pool.balances.token0(%d) >= 0__#1", pool.balances.token0)) if amount1 < 0 { // pool > recipient - require(pool.balances.token1 > (-amount1), ufmt.Sprintf("[POOL] pool.gno__Swap() || pool.balances.token1(%d) > (-1 * amount1)(%d)", pool.balances.token1, (-amount1))) + require(pool.balances.token1 > -amount1, ufmt.Sprintf("[POOL] pool.gno__Swap() || pool.balances.token1(%d) > (-1 * amount1)(%d)", pool.balances.token1, (-amount1))) ok := transferByRegisterCall(pool.token1Path, recipient, uint64(-amount1)) if !ok { panic("[POOL] pool.gno__Swap() || transferByRegisterCall(pool.token1Path, recipient, uint64(-amount1)) failed") @@ -406,9 +408,9 @@ func Swap( } else { // payer > pool balance1Before := bigint(balanceOfByRegisterCall(pool.token1Path, GetOrigPkgAddr())) - ok := transferFromByRegisterCall(pool.token1Path, payer, consts.ADDR_POOL, uint64(amount1)) + ok := transferFromByRegisterCall(pool.token1Path, payer, consts.POOL_ADDR, uint64(amount1)) if !ok { - panic("[POOL] pool.gno__Swap() || transferFromByRegisterCall(pool.token1Path, payer, ADDR_POOL, uint64(amount1)) failed") + panic("[POOL] pool.gno__Swap() || transferFromByRegisterCall(pool.token1Path, payer, POOL_ADDR, uint64(amount1)) failed") } require( @@ -460,9 +462,9 @@ func SetFeeProtocol( // ADMIN func CollectProtocol( - pToken0Path string, - pToken1Path string, - pFee uint16, + token0Path string, + token1Path string, + fee uint16, recipient std.Address, amount0Requested bigint, amount1Requested bigint, @@ -471,7 +473,7 @@ func CollectProtocol( requireUnsigned(amount1Requested, ufmt.Sprintf("[POOL] pool.gno__CollectProtocol() || amount1Requested(%d) >= 0", amount1Requested)) require(isAdmin(PrevRealmAddr()), ufmt.Sprintf("[POOL] pool.gno__CollectProtocol() || caller(%s) must be admin", PrevRealmAddr())) - pool := GetPool(pToken0Path, pToken1Path, pFee) + pool := GetPool(token0Path, token1Path, fee) amount0 := min(amount0Requested, pool.protocolFees.token0) requireUnsigned(amount0, ufmt.Sprintf("[POOL] pool.gno__CollectProtocol() || amount0(%d) >= 0", amount0)) diff --git a/pool/pool_register.gno b/pool/pool_register.gno index 786928681..07daaa5f2 100644 --- a/pool/pool_register.gno +++ b/pool/pool_register.gno @@ -10,6 +10,8 @@ import ( var registered = []GRC20Pair{} +var locked bool // mutex flag + type GRC20Interface interface { Transfer() func(to users.AddressOrName, amount uint64) TransferFrom() func(from, to users.AddressOrName, amount uint64) @@ -95,8 +97,16 @@ func transferByRegisterCall(pkgPath string, to std.Address, amount uint64) bool return false } - registered[i].igrc20.Transfer()(users.AddressOrName(to), amount) + if !locked { + locked = true + registered[i].igrc20.Transfer()(users.AddressOrName(to), amount) + defer func() { + locked = false + }() + } else { + panic("[POOl] pool_register.gno transferByRegisterCall: locked") + } return true } @@ -108,8 +118,16 @@ func transferFromByRegisterCall(pkgPath string, from, to std.Address, amount uin return false } - registered[i].igrc20.TransferFrom()(users.AddressOrName(from), users.AddressOrName(to), amount) + if !locked { + locked = true + registered[i].igrc20.TransferFrom()(users.AddressOrName(from), users.AddressOrName(to), amount) + defer func() { + locked = false + }() + } else { + panic("[POOl] pool_register.gno transferFromByRegisterCall: locked") + } return true }