Skip to content

Commit 11097d6

Browse files
authored
fix(router): adds caller verification when the swap callback (#1005)
* fix(router): adds caller verification when the swap callback * test: add swap callback tests * test: assert tests
1 parent d193177 commit 11097d6

File tree

5 files changed

+239
-0
lines changed

5 files changed

+239
-0
lines changed

contract/r/gnoswap/router/v1/_helper_test.gno

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ var (
9898

9999
registeredTestStore = false
100100
registeredTestPool = false
101+
102+
mockInstance *routerV1
101103
)
102104

103105
func TokenFaucet(t *testing.T, tokenPath string, to address) {
@@ -470,6 +472,10 @@ func initRouterTest(t *testing.T) {
470472
registerPositionTest(t)
471473

472474
registeredTestStore = true
475+
476+
mockInstance = &routerV1{
477+
store: mock.NewMockRouterStore(),
478+
}
473479
}
474480

475481
func initPoolTest(t *testing.T) {
@@ -511,3 +517,18 @@ func registerPositionTest(t *testing.T) {
511517
testing.SetRealm(adminRealm)
512518
pn.UpgradeImpl(cross, positionMockPath)
513519
}
520+
521+
// mockInstanceSwapCallback is a helper function to call the SwapCallback function on the mock instance
522+
func mockInstanceSwapCallback(
523+
token0Path string,
524+
token1Path string,
525+
amount0Delta string,
526+
amount1Delta string,
527+
payer address,
528+
) error {
529+
// mock instance using a closure
530+
return func(cur realm) error {
531+
testing.SetRealm(testing.NewCodeRealm(routerPath))
532+
return mockInstance.SwapCallback(token0Path, token1Path, amount0Delta, amount1Delta, payer)
533+
}(cross)
534+
}

contract/r/gnoswap/router/v1/assert.gno

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,12 @@ func assertIsExistsPools(routePathArr string) {
103103
}
104104
}
105105
}
106+
107+
func assertIsRouterV1(caller address) {
108+
if caller != routerV1Addr {
109+
panic(makeErrorWithDetails(
110+
errInvalidInput,
111+
ufmt.Sprintf("caller %s is not router v1(%s)", caller, routerV1Addr),
112+
))
113+
}
114+
}

contract/r/gnoswap/router/v1/consts.gno

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,6 @@ var (
1616
routerAddr = chain.PackageAddress("gno.land/r/gnoswap/router")
1717
positionAddr = chain.PackageAddress("gno.land/r/gnoswap/position")
1818
protocolFeeAddr = chain.PackageAddress("gno.land/r/gnoswap/protocol_fee")
19+
20+
routerV1Addr = chain.PackageAddress("gno.land/r/gnoswap/router/v1")
1921
)

contract/r/gnoswap/router/v1/swap_callback.gno

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
package v1
22

33
import (
4+
"chain/runtime"
5+
46
i256 "gno.land/p/gnoswap/int256"
57
u256 "gno.land/p/gnoswap/uint256"
68

79
"gno.land/r/gnoswap/common"
10+
"gno.land/r/gnoswap/halt"
811
)
912

1013
// swapCallback implements the pool's SwapCallback interface.
@@ -15,11 +18,19 @@ import (
1518
// 1. Flash swaps (receive tokens before paying)
1619
// 2. Just-in-time token transfers
1720
// 3. Complex multi-hop swaps without intermediate transfers
21+
//
22+
// Only callable from the router v1 implementation contract.
23+
// It is only used when calling a pool swap function.
1824
func (r *routerV1) SwapCallback(
1925
token0Path, token1Path string,
2026
amount0Delta, amount1Delta string,
2127
payer address,
2228
) error {
29+
halt.AssertIsNotHaltedRouter()
30+
31+
caller := runtime.PreviousRealm().Address()
32+
assertIsRouterV1(caller)
33+
2334
var tokenToPay string
2435

2536
amountToPay := i256.Zero()
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
package v1
2+
3+
import (
4+
"testing"
5+
6+
"gno.land/p/nt/testutils"
7+
"gno.land/p/nt/uassert"
8+
)
9+
10+
// TestSwapCallback tests the SwapCallback function with various scenarios
11+
func TestSwapCallback(t *testing.T) {
12+
tests := []struct {
13+
name string
14+
setupFunc func(t *testing.T) address // returns payer address
15+
amount0Delta string
16+
amount1Delta string
17+
callerRealmPath string
18+
shouldError bool
19+
expectedErrorMsg string
20+
}{
21+
{
22+
name: "success with token0 payment",
23+
setupFunc: func(t *testing.T) address {
24+
initRouterTest(t)
25+
CreatePoolWithoutFee(t)
26+
MakeMintPositionWithoutFee(t)
27+
TokenFaucet(t, barPath, routerAddr)
28+
TokenFaucet(t, bazPath, routerAddr)
29+
return routerAddr
30+
},
31+
amount0Delta: "1000",
32+
amount1Delta: "0",
33+
callerRealmPath: "gno.land/r/gnoswap/router/v1",
34+
shouldError: false,
35+
},
36+
{
37+
name: "success with token1 payment",
38+
setupFunc: func(t *testing.T) address {
39+
initRouterTest(t)
40+
CreatePoolWithoutFee(t)
41+
MakeMintPositionWithoutFee(t)
42+
TokenFaucet(t, barPath, routerAddr)
43+
TokenFaucet(t, bazPath, routerAddr)
44+
return routerAddr
45+
},
46+
amount0Delta: "0",
47+
amount1Delta: "1000",
48+
callerRealmPath: "gno.land/r/gnoswap/router/v1",
49+
shouldError: false,
50+
},
51+
{
52+
name: "success with both deltas zero",
53+
setupFunc: func(t *testing.T) address {
54+
initRouterTest(t)
55+
CreatePoolWithoutFee(t)
56+
return routerAddr
57+
},
58+
amount0Delta: "0",
59+
amount1Delta: "0",
60+
callerRealmPath: "gno.land/r/gnoswap/router/v1",
61+
shouldError: false,
62+
},
63+
{
64+
name: "success with user as payer",
65+
setupFunc: func(t *testing.T) address {
66+
initRouterTest(t)
67+
CreatePoolWithoutFee(t)
68+
MakeMintPositionWithoutFee(t)
69+
testUser := testutils.TestAddress("testUser")
70+
TokenFaucet(t, barPath, testUser)
71+
TokenFaucet(t, bazPath, testUser)
72+
TokenApprove(t, barPath, testUser, routerAddr, maxApprove)
73+
TokenApprove(t, bazPath, testUser, routerAddr, maxApprove)
74+
return testUser
75+
},
76+
amount0Delta: "1000",
77+
amount1Delta: "0",
78+
callerRealmPath: "gno.land/r/gnoswap/router/v1",
79+
shouldError: false,
80+
},
81+
{
82+
name: "fail with invalid caller",
83+
setupFunc: func(t *testing.T) address {
84+
initRouterTest(t)
85+
CreatePoolWithoutFee(t)
86+
return routerAddr
87+
},
88+
amount0Delta: "1000",
89+
amount1Delta: "0",
90+
callerRealmPath: "gno.land/r/unauthorized/contract",
91+
shouldError: true,
92+
expectedErrorMsg: "[GNOSWAP-ROUTER-005]",
93+
},
94+
{
95+
name: "fail with insufficient balance",
96+
setupFunc: func(t *testing.T) address {
97+
initRouterTest(t)
98+
CreatePoolWithoutFee(t)
99+
MakeMintPositionWithoutFee(t)
100+
poorUser := testutils.TestAddress("poorUser")
101+
return poorUser
102+
},
103+
amount0Delta: "1000",
104+
amount1Delta: "0",
105+
callerRealmPath: "gno.land/r/gnoswap/router/v1",
106+
shouldError: true,
107+
expectedErrorMsg: "insufficient balance",
108+
},
109+
}
110+
111+
for _, tt := range tests {
112+
t.Run(tt.name, func(t *testing.T) {
113+
// Setup
114+
payer := tt.setupFunc(t)
115+
116+
// Set caller realm
117+
testing.SetRealm(testing.NewCodeRealm(tt.callerRealmPath))
118+
119+
// Execute
120+
if tt.shouldError {
121+
uassert.AbortsContains(t, tt.expectedErrorMsg, func() {
122+
mockInstanceSwapCallback(
123+
barPath,
124+
bazPath,
125+
tt.amount0Delta,
126+
tt.amount1Delta,
127+
payer,
128+
)
129+
})
130+
} else {
131+
err := mockInstanceSwapCallback(
132+
barPath,
133+
bazPath,
134+
tt.amount0Delta,
135+
tt.amount1Delta,
136+
payer,
137+
)
138+
uassert.NoError(t, err)
139+
}
140+
})
141+
}
142+
}
143+
144+
// TestSwapCallback_AssertIsRouterV1 tests the assertIsRouterV1 function with various caller types
145+
func TestSwapCallback_AssertIsRouterV1(t *testing.T) {
146+
tests := []struct {
147+
name string
148+
callerAddress address
149+
shouldError bool
150+
expectedErrorMsg string
151+
}{
152+
{
153+
name: "valid caller from router v1",
154+
callerAddress: routerV1Addr,
155+
shouldError: false,
156+
},
157+
{
158+
name: "invalid caller from unauthorized contract",
159+
callerAddress: testutils.TestAddress("unauthorized"),
160+
shouldError: true,
161+
expectedErrorMsg: "[GNOSWAP-ROUTER-005]",
162+
},
163+
{
164+
name: "invalid caller from user realm",
165+
callerAddress: testutils.TestAddress("maliciousUser"),
166+
shouldError: true,
167+
expectedErrorMsg: "[GNOSWAP-ROUTER-005]",
168+
},
169+
{
170+
name: "invalid caller from different contract",
171+
callerAddress: testutils.TestAddress("different"),
172+
shouldError: true,
173+
expectedErrorMsg: "[GNOSWAP-ROUTER-005]",
174+
},
175+
{
176+
name: "invalid caller from pool contract",
177+
callerAddress: poolAddr,
178+
shouldError: true,
179+
expectedErrorMsg: "[GNOSWAP-ROUTER-005]",
180+
},
181+
}
182+
183+
for _, tt := range tests {
184+
t.Run(tt.name, func(t *testing.T) {
185+
// Execute and verify
186+
if tt.shouldError {
187+
uassert.PanicsContains(t, tt.expectedErrorMsg, func() {
188+
assertIsRouterV1(tt.callerAddress)
189+
})
190+
} else {
191+
// Should not panic
192+
assertIsRouterV1(tt.callerAddress)
193+
}
194+
})
195+
}
196+
}

0 commit comments

Comments
 (0)