Skip to content

Commit 276ccaa

Browse files
authored
perf(StdAssertions): avoid vm call for trivial conditions (#693)
Avoid calling vm when the condition is trivial to check inline. This is a minor performance improvement as a couple of opcodes execute a lot faster than a full cheatcode pipeline (external call, EVM call bookkeeping, abi encoding/decoding, ...) that eventually does nothing. For example, [Uniswap's v4-core](https://github.com/Uniswap/v4-core/) spends ~28% of the entire `forge test` CPU time in a couple of trivial `assert` functions (uint256, bytes32, true, false), ~13% for `forge coverage`. This is without accounting for the actual CALL/abi coding etc. These numbers are skewed due to profiling overhead, however making this change does have ~5% overall test performance improvement, for no compilation time change.
1 parent 369dd01 commit 276ccaa

File tree

1 file changed

+132
-44
lines changed

1 file changed

+132
-44
lines changed

src/StdAssertions.sol

Lines changed: 132 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -49,35 +49,51 @@ abstract contract StdAssertions {
4949
}
5050

5151
function assertTrue(bool data) internal pure virtual {
52-
vm.assertTrue(data);
52+
if (!data) {
53+
vm.assertTrue(data);
54+
}
5355
}
5456

5557
function assertTrue(bool data, string memory err) internal pure virtual {
56-
vm.assertTrue(data, err);
58+
if (!data) {
59+
vm.assertTrue(data, err);
60+
}
5761
}
5862

5963
function assertFalse(bool data) internal pure virtual {
60-
vm.assertFalse(data);
64+
if (data) {
65+
vm.assertFalse(data);
66+
}
6167
}
6268

6369
function assertFalse(bool data, string memory err) internal pure virtual {
64-
vm.assertFalse(data, err);
70+
if (data) {
71+
vm.assertFalse(data, err);
72+
}
6573
}
6674

6775
function assertEq(bool left, bool right) internal pure virtual {
68-
vm.assertEq(left, right);
76+
if (left != right) {
77+
vm.assertEq(left, right);
78+
}
6979
}
7080

7181
function assertEq(bool left, bool right, string memory err) internal pure virtual {
72-
vm.assertEq(left, right, err);
82+
if (left != right) {
83+
vm.assertEq(left, right, err);
84+
}
7385
}
7486

7587
function assertEq(uint256 left, uint256 right) internal pure virtual {
76-
vm.assertEq(left, right);
88+
if (left != right) {
89+
vm.assertEq(left, right);
90+
}
7791
}
7892

7993
function assertEq(uint256 left, uint256 right, string memory err) internal pure virtual {
80-
vm.assertEq(left, right, err);
94+
if (left != right) {
95+
vm.assertEq(left, right, err);
96+
}
8197
}
8298

8399
function assertEqDecimal(uint256 left, uint256 right, uint256 decimals) internal pure virtual {
@@ -89,11 +105,15 @@ abstract contract StdAssertions {
89105
}
90106

91107
function assertEq(int256 left, int256 right) internal pure virtual {
92-
vm.assertEq(left, right);
108+
if (left != right) {
109+
vm.assertEq(left, right);
110+
}
93111
}
94112

95113
function assertEq(int256 left, int256 right, string memory err) internal pure virtual {
96-
vm.assertEq(left, right, err);
114+
if (left != right) {
115+
vm.assertEq(left, right, err);
116+
}
97117
}
98118

99119
function assertEqDecimal(int256 left, int256 right, uint256 decimals) internal pure virtual {
@@ -105,27 +125,39 @@ abstract contract StdAssertions {
105125
}
106126

107127
function assertEq(address left, address right) internal pure virtual {
108-
vm.assertEq(left, right);
128+
if (left != right) {
129+
vm.assertEq(left, right);
130+
}
109131
}
110132

111133
function assertEq(address left, address right, string memory err) internal pure virtual {
112-
vm.assertEq(left, right, err);
134+
if (left != right) {
135+
vm.assertEq(left, right, err);
136+
}
113137
}
114138

115139
function assertEq(bytes32 left, bytes32 right) internal pure virtual {
116-
vm.assertEq(left, right);
140+
if (left != right) {
141+
vm.assertEq(left, right);
142+
}
117143
}
118144

119145
function assertEq(bytes32 left, bytes32 right, string memory err) internal pure virtual {
120-
vm.assertEq(left, right, err);
146+
if (left != right) {
147+
vm.assertEq(left, right, err);
148+
}
121149
}
122150

123151
function assertEq32(bytes32 left, bytes32 right) internal pure virtual {
124-
assertEq(left, right);
152+
if (left != right) {
153+
vm.assertEq(left, right);
154+
}
125155
}
126156

127157
function assertEq32(bytes32 left, bytes32 right, string memory err) internal pure virtual {
128-
assertEq(left, right, err);
158+
if (left != right) {
159+
vm.assertEq(left, right, err);
160+
}
129161
}
130162

131163
function assertEq(string memory left, string memory right) internal pure virtual {
@@ -206,19 +238,27 @@ abstract contract StdAssertions {
206238
}
207239

208240
function assertNotEq(bool left, bool right) internal pure virtual {
209-
vm.assertNotEq(left, right);
241+
if (left == right) {
242+
vm.assertNotEq(left, right);
243+
}
210244
}
211245

212246
function assertNotEq(bool left, bool right, string memory err) internal pure virtual {
213-
vm.assertNotEq(left, right, err);
247+
if (left == right) {
248+
vm.assertNotEq(left, right, err);
249+
}
214250
}
215251

216252
function assertNotEq(uint256 left, uint256 right) internal pure virtual {
217-
vm.assertNotEq(left, right);
253+
if (left == right) {
254+
vm.assertNotEq(left, right);
255+
}
218256
}
219257

220258
function assertNotEq(uint256 left, uint256 right, string memory err) internal pure virtual {
221-
vm.assertNotEq(left, right, err);
259+
if (left == right) {
260+
vm.assertNotEq(left, right, err);
261+
}
222262
}
223263

224264
function assertNotEqDecimal(uint256 left, uint256 right, uint256 decimals) internal pure virtual {
@@ -234,11 +274,15 @@ abstract contract StdAssertions {
234274
}
235275

236276
function assertNotEq(int256 left, int256 right) internal pure virtual {
237-
vm.assertNotEq(left, right);
277+
if (left == right) {
278+
vm.assertNotEq(left, right);
279+
}
238280
}
239281

240282
function assertNotEq(int256 left, int256 right, string memory err) internal pure virtual {
241-
vm.assertNotEq(left, right, err);
283+
if (left == right) {
284+
vm.assertNotEq(left, right, err);
285+
}
242286
}
243287

244288
function assertNotEqDecimal(int256 left, int256 right, uint256 decimals) internal pure virtual {
@@ -250,27 +294,39 @@ abstract contract StdAssertions {
250294
}
251295

252296
function assertNotEq(address left, address right) internal pure virtual {
253-
vm.assertNotEq(left, right);
297+
if (left == right) {
298+
vm.assertNotEq(left, right);
299+
}
254300
}
255301

256302
function assertNotEq(address left, address right, string memory err) internal pure virtual {
257-
vm.assertNotEq(left, right, err);
303+
if (left == right) {
304+
vm.assertNotEq(left, right, err);
305+
}
258306
}
259307

260308
function assertNotEq(bytes32 left, bytes32 right) internal pure virtual {
261-
vm.assertNotEq(left, right);
309+
if (left == right) {
310+
vm.assertNotEq(left, right);
311+
}
262312
}
263313

264314
function assertNotEq(bytes32 left, bytes32 right, string memory err) internal pure virtual {
265-
vm.assertNotEq(left, right, err);
315+
if (left == right) {
316+
vm.assertNotEq(left, right, err);
317+
}
266318
}
267319

268320
function assertNotEq32(bytes32 left, bytes32 right) internal pure virtual {
269-
assertNotEq(left, right);
321+
if (left == right) {
322+
vm.assertNotEq(left, right);
323+
}
270324
}
271325

272326
function assertNotEq32(bytes32 left, bytes32 right, string memory err) internal pure virtual {
273-
assertNotEq(left, right, err);
327+
if (left == right) {
328+
vm.assertNotEq(left, right, err);
329+
}
274330
}
275331

276332
function assertNotEq(string memory left, string memory right) internal pure virtual {
@@ -346,11 +402,15 @@ abstract contract StdAssertions {
346402
}
347403

348404
function assertLt(uint256 left, uint256 right) internal pure virtual {
349-
vm.assertLt(left, right);
405+
if (!(left < right)) {
406+
vm.assertLt(left, right);
407+
}
350408
}
351409

352410
function assertLt(uint256 left, uint256 right, string memory err) internal pure virtual {
353-
vm.assertLt(left, right, err);
411+
if (!(left < right)) {
412+
vm.assertLt(left, right, err);
413+
}
354414
}
355415

356416
function assertLtDecimal(uint256 left, uint256 right, uint256 decimals) internal pure virtual {
@@ -362,11 +422,15 @@ abstract contract StdAssertions {
362422
}
363423

364424
function assertLt(int256 left, int256 right) internal pure virtual {
365-
vm.assertLt(left, right);
425+
if (!(left < right)) {
426+
vm.assertLt(left, right);
427+
}
366428
}
367429

368430
function assertLt(int256 left, int256 right, string memory err) internal pure virtual {
369-
vm.assertLt(left, right, err);
431+
if (!(left < right)) {
432+
vm.assertLt(left, right, err);
433+
}
370434
}
371435

372436
function assertLtDecimal(int256 left, int256 right, uint256 decimals) internal pure virtual {
@@ -378,11 +442,15 @@ abstract contract StdAssertions {
378442
}
379443

380444
function assertGt(uint256 left, uint256 right) internal pure virtual {
381-
vm.assertGt(left, right);
445+
if (!(left > right)) {
446+
vm.assertGt(left, right);
447+
}
382448
}
383449

384450
function assertGt(uint256 left, uint256 right, string memory err) internal pure virtual {
385-
vm.assertGt(left, right, err);
451+
if (!(left > right)) {
452+
vm.assertGt(left, right, err);
453+
}
386454
}
387455

388456
function assertGtDecimal(uint256 left, uint256 right, uint256 decimals) internal pure virtual {
@@ -394,11 +462,15 @@ abstract contract StdAssertions {
394462
}
395463

396464
function assertGt(int256 left, int256 right) internal pure virtual {
397-
vm.assertGt(left, right);
465+
if (!(left > right)) {
466+
vm.assertGt(left, right);
467+
}
398468
}
399469

400470
function assertGt(int256 left, int256 right, string memory err) internal pure virtual {
401-
vm.assertGt(left, right, err);
471+
if (!(left > right)) {
472+
vm.assertGt(left, right, err);
473+
}
402474
}
403475

404476
function assertGtDecimal(int256 left, int256 right, uint256 decimals) internal pure virtual {
@@ -410,11 +482,15 @@ abstract contract StdAssertions {
410482
}
411483

412484
function assertLe(uint256 left, uint256 right) internal pure virtual {
413-
vm.assertLe(left, right);
485+
if (!(left <= right)) {
486+
vm.assertLe(left, right);
487+
}
414488
}
415489

416490
function assertLe(uint256 left, uint256 right, string memory err) internal pure virtual {
417-
vm.assertLe(left, right, err);
491+
if (!(left <= right)) {
492+
vm.assertLe(left, right, err);
493+
}
418494
}
419495

420496
function assertLeDecimal(uint256 left, uint256 right, uint256 decimals) internal pure virtual {
@@ -426,11 +502,15 @@ abstract contract StdAssertions {
426502
}
427503

428504
function assertLe(int256 left, int256 right) internal pure virtual {
429-
vm.assertLe(left, right);
505+
if (!(left <= right)) {
506+
vm.assertLe(left, right);
507+
}
430508
}
431509

432510
function assertLe(int256 left, int256 right, string memory err) internal pure virtual {
433-
vm.assertLe(left, right, err);
511+
if (!(left <= right)) {
512+
vm.assertLe(left, right, err);
513+
}
434514
}
435515

436516
function assertLeDecimal(int256 left, int256 right, uint256 decimals) internal pure virtual {
@@ -442,11 +522,15 @@ abstract contract StdAssertions {
442522
}
443523

444524
function assertGe(uint256 left, uint256 right) internal pure virtual {
445-
vm.assertGe(left, right);
525+
if (!(left >= right)) {
526+
vm.assertGe(left, right);
527+
}
446528
}
447529

448530
function assertGe(uint256 left, uint256 right, string memory err) internal pure virtual {
449-
vm.assertGe(left, right, err);
531+
if (!(left >= right)) {
532+
vm.assertGe(left, right, err);
533+
}
450534
}
451535

452536
function assertGeDecimal(uint256 left, uint256 right, uint256 decimals) internal pure virtual {
@@ -458,11 +542,15 @@ abstract contract StdAssertions {
458542
}
459543

460544
function assertGe(int256 left, int256 right) internal pure virtual {
461-
vm.assertGe(left, right);
545+
if (!(left >= right)) {
546+
vm.assertGe(left, right);
547+
}
462548
}
463549

464550
function assertGe(int256 left, int256 right, string memory err) internal pure virtual {
465-
vm.assertGe(left, right, err);
551+
if (!(left >= right)) {
552+
vm.assertGe(left, right, err);
553+
}
466554
}
467555

468556
function assertGeDecimal(int256 left, int256 right, uint256 decimals) internal pure virtual {

0 commit comments

Comments
 (0)