Skip to content

Commit a9bf99c

Browse files
committed
fix: allow partially collecting RAVs (TRST-M05)
Signed-off-by: Tomás Migone <[email protected]>
1 parent 86b998b commit a9bf99c

File tree

4 files changed

+162
-56
lines changed

4 files changed

+162
-56
lines changed

packages/horizon/contracts/interfaces/ITAPCollector.sol

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
pragma solidity 0.8.27;
33

44
import { IPaymentsCollector } from "./IPaymentsCollector.sol";
5+
import { IGraphPayments } from "./IGraphPayments.sol";
56

67
/**
78
* @title Interface for the {TAPCollector} contract
@@ -175,6 +176,13 @@ interface ITAPCollector is IPaymentsCollector {
175176
*/
176177
error TAPCollectorInconsistentRAVTokens(uint256 tokens, uint256 tokensCollected);
177178

179+
/**
180+
* Thrown when the attempting to collect more tokens than what it's owed
181+
* @param tokensToCollect The amount of tokens to collect
182+
* @param maxTokensToCollect The maximum amount of tokens to collect
183+
*/
184+
error TAPCollectorInvalidTokensToCollectAmount(uint256 tokensToCollect, uint256 maxTokensToCollect);
185+
178186
/**
179187
* @notice Authorize a signer to sign on behalf of the payer.
180188
* A signer can not be authorized for multiple payers even after revoking previous authorizations.
@@ -237,4 +245,21 @@ interface ITAPCollector is IPaymentsCollector {
237245
* @return The hash of the RAV.
238246
*/
239247
function encodeRAV(ReceiptAggregateVoucher calldata rav) external view returns (bytes32);
248+
249+
/**
250+
* @notice See {IPaymentsCollector.collect}
251+
* This variant adds the ability to partially collect a RAV by specifying the amount of tokens to collect.
252+
*
253+
* Requirements:
254+
* - The amount of tokens to collect must be less than or equal to the total amount of tokens in the RAV minus
255+
* the tokens already collected.
256+
* @param paymentType The payment type to collect
257+
* @param data Additional data required for the payment collection
258+
* @param tokensToCollect The amount of tokens to collect
259+
*/
260+
function collect(
261+
IGraphPayments.PaymentTypes paymentType,
262+
bytes calldata data,
263+
uint256 tokensToCollect
264+
) external returns (uint256);
240265
}

packages/horizon/contracts/payments/collectors/TAPCollector.sol

Lines changed: 70 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -125,37 +125,15 @@ contract TAPCollector is EIP712, GraphDirectory, ITAPCollector {
125125
* @notice REVERT: This function may revert if ECDSA.recover fails, check ECDSA library for details.
126126
*/
127127
function collect(IGraphPayments.PaymentTypes paymentType, bytes memory data) external override returns (uint256) {
128-
// Ensure caller is the RAV data service
129-
(SignedRAV memory signedRAV, uint256 dataServiceCut) = abi.decode(data, (SignedRAV, uint256));
130-
require(
131-
signedRAV.rav.dataService == msg.sender,
132-
TAPCollectorCallerNotDataService(msg.sender, signedRAV.rav.dataService)
133-
);
134-
135-
// Ensure RAV signer is authorized for a payer
136-
address signer = _recoverRAVSigner(signedRAV);
137-
require(
138-
authorizedSigners[signer].payer != address(0) && !authorizedSigners[signer].revoked,
139-
TAPCollectorInvalidRAVSigner()
140-
);
141-
142-
// Ensure RAV payer matches the authorized payer
143-
address payer = signedRAV.rav.payer;
144-
require(
145-
authorizedSigners[signer].payer == payer,
146-
TAPCollectorInvalidRAVPayer(authorizedSigners[signer].payer, payer)
147-
);
148-
149-
// Check the service provider has an active provision with the data service
150-
// This prevents an attack where the payer can deny the service provider from collecting payments
151-
// by using a signer as data service to syphon off the tokens in the escrow to an account they control
152-
uint256 tokensAvailable = _graphStaking().getProviderTokensAvailable(
153-
signedRAV.rav.serviceProvider,
154-
signedRAV.rav.dataService
155-
);
156-
require(tokensAvailable > 0, TAPCollectorUnauthorizedDataService(signedRAV.rav.dataService));
128+
return _collect(paymentType, data, 0);
129+
}
157130

158-
return _collect(paymentType, authorizedSigners[signer].payer, signedRAV, dataServiceCut);
131+
function collect(
132+
IGraphPayments.PaymentTypes paymentType,
133+
bytes memory data,
134+
uint256 tokensToCollect
135+
) external override returns (uint256) {
136+
return _collect(paymentType, data, tokensToCollect);
159137
}
160138

161139
/**
@@ -177,44 +155,87 @@ contract TAPCollector is EIP712, GraphDirectory, ITAPCollector {
177155
*/
178156
function _collect(
179157
IGraphPayments.PaymentTypes _paymentType,
180-
address _payer,
181-
SignedRAV memory _signedRAV,
182-
uint256 _dataServiceCut
158+
bytes memory _data,
159+
uint256 _tokensToCollect
183160
) private returns (uint256) {
184-
address dataService = _signedRAV.rav.dataService;
185-
address receiver = _signedRAV.rav.serviceProvider;
161+
// Ensure caller is the RAV data service
162+
(SignedRAV memory signedRAV, uint256 dataServiceCut) = abi.decode(_data, (SignedRAV, uint256));
163+
require(
164+
signedRAV.rav.dataService == msg.sender,
165+
TAPCollectorCallerNotDataService(msg.sender, signedRAV.rav.dataService)
166+
);
167+
168+
// Ensure RAV signer is authorized for a payer
169+
address signer = _recoverRAVSigner(signedRAV);
170+
require(
171+
authorizedSigners[signer].payer != address(0) && !authorizedSigners[signer].revoked,
172+
TAPCollectorInvalidRAVSigner()
173+
);
186174

187-
uint256 tokensRAV = _signedRAV.rav.valueAggregate;
188-
uint256 tokensAlreadyCollected = tokensCollected[dataService][receiver][_payer];
175+
// Ensure RAV payer matches the authorized payer
176+
address payer = authorizedSigners[signer].payer;
189177
require(
190-
tokensRAV > tokensAlreadyCollected,
191-
TAPCollectorInconsistentRAVTokens(tokensRAV, tokensAlreadyCollected)
178+
signedRAV.rav.payer == payer,
179+
TAPCollectorInvalidRAVPayer(payer, signedRAV.rav.payer)
192180
);
193181

194-
uint256 tokensToCollect = tokensRAV - tokensAlreadyCollected;
195-
uint256 tokensDataService = tokensToCollect.mulPPM(_dataServiceCut);
182+
address dataService = signedRAV.rav.dataService;
183+
address receiver = signedRAV.rav.serviceProvider;
184+
185+
// Check the service provider has an active provision with the data service
186+
// This prevents an attack where the payer can deny the service provider from collecting payments
187+
// by using a signer as data service to syphon off the tokens in the escrow to an account they control
188+
{
189+
uint256 tokensAvailable = _graphStaking().getProviderTokensAvailable(
190+
signedRAV.rav.serviceProvider,
191+
signedRAV.rav.dataService
192+
);
193+
require(tokensAvailable > 0, TAPCollectorUnauthorizedDataService(signedRAV.rav.dataService));
194+
}
195+
196+
uint256 tokensToCollect = 0;
197+
{
198+
uint256 tokensRAV = signedRAV.rav.valueAggregate;
199+
uint256 tokensAlreadyCollected = tokensCollected[dataService][receiver][payer];
200+
require(
201+
tokensRAV > tokensAlreadyCollected,
202+
TAPCollectorInconsistentRAVTokens(tokensRAV, tokensAlreadyCollected)
203+
);
204+
205+
if (_tokensToCollect == 0) {
206+
tokensToCollect = tokensRAV - tokensAlreadyCollected;
207+
} else {
208+
require(
209+
_tokensToCollect <= tokensRAV - tokensAlreadyCollected,
210+
TAPCollectorInvalidTokensToCollectAmount(_tokensToCollect, tokensRAV - tokensAlreadyCollected)
211+
);
212+
tokensToCollect = _tokensToCollect;
213+
}
214+
}
215+
216+
uint256 tokensDataService = tokensToCollect.mulPPM(dataServiceCut);
196217

197218
if (tokensToCollect > 0) {
198-
tokensCollected[dataService][receiver][_payer] = tokensRAV;
219+
tokensCollected[dataService][receiver][payer] += tokensToCollect;
199220
_graphPaymentsEscrow().collect(
200221
_paymentType,
201-
_payer,
222+
payer,
202223
receiver,
203224
tokensToCollect,
204225
dataService,
205226
tokensDataService
206227
);
207228
}
208229

209-
emit PaymentCollected(_paymentType, _payer, receiver, tokensToCollect, dataService, tokensDataService);
230+
emit PaymentCollected(_paymentType, payer, receiver, tokensToCollect, dataService, tokensDataService);
210231
emit RAVCollected(
211-
_payer,
232+
payer,
212233
dataService,
213234
receiver,
214-
_signedRAV.rav.timestampNs,
215-
_signedRAV.rav.valueAggregate,
216-
_signedRAV.rav.metadata,
217-
_signedRAV.signature
235+
signedRAV.rav.timestampNs,
236+
signedRAV.rav.valueAggregate,
237+
signedRAV.rav.metadata,
238+
signedRAV.signature
218239
);
219240
return tokensToCollect;
220241
}

packages/horizon/test/payments/tap-collector/TAPCollector.t.sol

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,20 @@ contract TAPCollectorTest is HorizonStakingSharedTest, PaymentsEscrowSharedTest
119119
}
120120

121121
function _collect(IGraphPayments.PaymentTypes _paymentType, bytes memory _data) internal {
122+
__collect(_paymentType, _data, 0);
123+
}
124+
125+
function _collect(IGraphPayments.PaymentTypes _paymentType, bytes memory _data, uint256 _tokensToCollect) internal {
126+
__collect(_paymentType, _data, _tokensToCollect);
127+
}
128+
129+
function __collect(IGraphPayments.PaymentTypes _paymentType, bytes memory _data, uint256 _tokensToCollect) internal {
122130
(ITAPCollector.SignedRAV memory signedRAV, uint256 dataServiceCut) = abi.decode(_data, (ITAPCollector.SignedRAV, uint256));
123131
bytes32 messageHash = tapCollector.encodeRAV(signedRAV.rav);
124132
address _signer = ECDSA.recover(messageHash, signedRAV.signature);
125133
(address _payer, , ) = tapCollector.authorizedSigners(_signer);
126134
uint256 tokensAlreadyCollected = tapCollector.tokensCollected(signedRAV.rav.dataService, signedRAV.rav.serviceProvider, _payer);
127-
uint256 tokensToCollect = signedRAV.rav.valueAggregate - tokensAlreadyCollected;
135+
uint256 tokensToCollect = _tokensToCollect == 0 ? signedRAV.rav.valueAggregate - tokensAlreadyCollected : _tokensToCollect;
128136
uint256 tokensDataService = tokensToCollect.mulPPM(dataServiceCut);
129137

130138
vm.expectEmit(address(tapCollector));
@@ -136,6 +144,7 @@ contract TAPCollectorTest is HorizonStakingSharedTest, PaymentsEscrowSharedTest
136144
signedRAV.rav.dataService,
137145
tokensDataService
138146
);
147+
vm.expectEmit(address(tapCollector));
139148
emit ITAPCollector.RAVCollected(
140149
_payer,
141150
signedRAV.rav.dataService,
@@ -145,11 +154,10 @@ contract TAPCollectorTest is HorizonStakingSharedTest, PaymentsEscrowSharedTest
145154
signedRAV.rav.metadata,
146155
signedRAV.signature
147156
);
148-
149-
uint256 tokensCollected = tapCollector.collect(_paymentType, _data);
150-
assertEq(tokensCollected, tokensToCollect);
157+
uint256 tokensCollected = _tokensToCollect == 0 ? tapCollector.collect(_paymentType, _data) : tapCollector.collect(_paymentType, _data, _tokensToCollect);
151158

152159
uint256 tokensCollectedAfter = tapCollector.tokensCollected(signedRAV.rav.dataService, signedRAV.rav.serviceProvider, _payer);
153-
assertEq(tokensCollectedAfter, signedRAV.rav.valueAggregate);
160+
assertEq(tokensCollected, tokensToCollect);
161+
assertEq(tokensCollectedAfter, _tokensToCollect == 0 ? signedRAV.rav.valueAggregate : tokensAlreadyCollected + _tokensToCollect);
154162
}
155163
}

packages/horizon/test/payments/tap-collector/collect/collect.t.sol

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,11 +203,12 @@ contract TAPCollectorCollectTest is TAPCollectorTest {
203203
tapCollector.collect(IGraphPayments.PaymentTypes.QueryFee, data);
204204
}
205205

206-
function testTAPCollector_Collect_RevertWhen_PayerMismatch(uint256 tokens) public useGateway useSigner {
206+
function testTAPCollector_Collect_RevertWhen_PayerMismatch(
207+
uint256 tokens
208+
) public useIndexer useProvisionDataService(users.verifier, 100, 0, 0) useGateway useSigner {
207209
tokens = bound(tokens, 1, type(uint128).max);
208210

209211
resetPrank(users.gateway);
210-
_approveCollector(address(tapCollector), tokens);
211212
_depositTokens(address(tapCollector), users.indexer, tokens);
212213

213214
(address anotherPayer, ) = makeAddrAndKey("anotherPayer");
@@ -340,4 +341,55 @@ contract TAPCollectorCollectTest is TAPCollectorTest {
340341
resetPrank(users.verifier);
341342
_collect(IGraphPayments.PaymentTypes.QueryFee, data);
342343
}
344+
345+
function testTAPCollector_CollectPartial(
346+
uint256 tokens,
347+
uint256 tokensToCollect
348+
) public useIndexer useProvisionDataService(users.verifier, 100, 0, 0) useGateway useSigner {
349+
tokens = bound(tokens, 1, type(uint128).max);
350+
tokensToCollect = bound(tokensToCollect, 1, tokens);
351+
352+
_depositTokens(address(tapCollector), users.indexer, tokens);
353+
354+
bytes memory data = _getQueryFeeEncodedData(
355+
signerPrivateKey,
356+
users.gateway,
357+
users.indexer,
358+
users.verifier,
359+
uint128(tokens)
360+
);
361+
362+
resetPrank(users.verifier);
363+
_collect(IGraphPayments.PaymentTypes.QueryFee, data, tokensToCollect);
364+
}
365+
366+
function testTAPCollector_CollectPartial_RevertWhen_AmountTooHigh(
367+
uint256 tokens,
368+
uint256 tokensToCollect
369+
) public useIndexer useProvisionDataService(users.verifier, 100, 0, 0) useGateway useSigner {
370+
tokens = bound(tokens, 1, type(uint128).max - 1);
371+
372+
_depositTokens(address(tapCollector), users.indexer, tokens);
373+
374+
bytes memory data = _getQueryFeeEncodedData(
375+
signerPrivateKey,
376+
users.gateway,
377+
users.indexer,
378+
users.verifier,
379+
uint128(tokens)
380+
);
381+
382+
resetPrank(users.verifier);
383+
uint256 tokensAlreadyCollected = tapCollector.tokensCollected(users.verifier, users.indexer, users.gateway);
384+
tokensToCollect = bound(tokensToCollect, tokens - tokensAlreadyCollected + 1, type(uint128).max);
385+
386+
vm.expectRevert(
387+
abi.encodeWithSelector(
388+
ITAPCollector.TAPCollectorInvalidTokensToCollectAmount.selector,
389+
tokensToCollect,
390+
tokens - tokensAlreadyCollected
391+
)
392+
);
393+
tapCollector.collect(IGraphPayments.PaymentTypes.QueryFee, data, tokensToCollect);
394+
}
343395
}

0 commit comments

Comments
 (0)