From 3d106068b4ace7a75a329bbc5d137834ac143753 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C3=A1vid=20Barbora?= Date: Fri, 10 Jan 2025 15:36:52 +0100 Subject: [PATCH] add mint with referral funcionality --- src/BaseGen.sol | 27 +++++++++++++----- test/BaseGenReferral.t.sol | 58 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 7 deletions(-) create mode 100644 test/BaseGenReferral.t.sol diff --git a/src/BaseGen.sol b/src/BaseGen.sol index 95f29dc..7220d6d 100644 --- a/src/BaseGen.sol +++ b/src/BaseGen.sol @@ -7,6 +7,7 @@ import "@openzeppelin/contracts/token/ERC721/extensions/ERC721Royalty.sol"; import "@openzeppelin/contracts/access/Ownable.sol"; import "@openzeppelin/contracts/utils/Base64.sol"; import "@openzeppelin/contracts/utils/Strings.sol"; +import "./lib/Split.sol"; contract BaseGen is ERC721, ERC721Burnable, ERC721Royalty, Ownable { uint256 private _nextTokenId; @@ -47,7 +48,7 @@ contract BaseGen is ERC721, ERC721Burnable, ERC721Royalty, Ownable { _contractURI = contractURI_; _maxSupply = maxSupply_; _receiver = receiver_; - _setDefaultRoyalty(receiver_, 5e2); // + _setDefaultRoyalty(receiver_, 5e2); } function _baseURI() internal view override returns (string memory) { @@ -126,6 +127,10 @@ contract BaseGen is ERC721, ERC721Burnable, ERC721Royalty, Ownable { } function safeBatchMint(address to, uint256 quantity) public payable { + mintWithReferral(to, quantity, address(0)); + } + + function mintWithReferral(address to, uint256 quantity, address referrer) public payable { if (quantity == 0) { revert MintQuantityCannotBeZero(); } @@ -143,7 +148,7 @@ contract BaseGen is ERC721, ERC721Burnable, ERC721Royalty, Ownable { revert InsufficientFunds(totalCost, msg.value); } - _splitPayment(totalCost); + _splitPayment(msg.value, referrer); for (uint256 i = 0; i < quantity; i++) { _safeMint(to, tokenId++); @@ -153,13 +158,21 @@ contract BaseGen is ERC721, ERC721Burnable, ERC721Royalty, Ownable { } function _splitPayment(uint256 amount) private { + _splitPayment(amount, address(0)); + } + + function _splitPayment(uint256 amount, address referrer) private { + uint256 receiverAmount = amount / 2; if (_receiver != address(0)) { - uint256 splitAmount = amount / 2; - payable(_receiver).transfer(splitAmount); - payable(owner()).transfer(splitAmount); - } else { - payable(owner()).transfer(amount); + payable(_receiver).transfer(receiverAmount); + } + if (referrer != address(0) && referrer != _receiver) { + (, uint256 referrerAmount) = royaltyInfo(0, amount); + if (referrerAmount > 0 && referrerAmount < address(this).balance) { + payable(referrer).transfer(referrerAmount); + } } + payable(owner()).transfer(address(this).balance); } function tokenURI(uint256 tokenId) public view override(ERC721) returns (string memory) { diff --git a/test/BaseGenReferral.t.sol b/test/BaseGenReferral.t.sol new file mode 100644 index 0000000..97ab85c --- /dev/null +++ b/test/BaseGenReferral.t.sol @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.20; + +import "forge-std/Test.sol"; +import "../src/BaseGen.sol"; + +contract BaseGenReferralTest is Test { + BaseGen public instance; + uint256 public constant pricePerMint = 0.0015 ether; + uint256 public constant quantity = 10; + address public initialOwner; + address public receiver; + uint256 public maxSupply = 20; + + function setUp() public { + initialOwner = vm.addr(1); + receiver = vm.addr(2); + string memory name = "Generative"; + string memory symbol = "GEN"; + string memory contractURI = "ipfs://"; + string memory baseURI = "https://data.kodadot.xyz/base/"; // suffixed with /?hash= + instance = new BaseGen(initialOwner, name, symbol, contractURI, baseURI, maxSupply, receiver); + } + + function testMintWithReferral() public { + address tokenOwner = vm.addr(3); + address referrer = vm.addr(4); + + // Calculate the total cost for minting + uint256 totalCost = pricePerMint * quantity; + vm.deal(tokenOwner, totalCost); + + // Assert initial state + assertEq(instance.totalSupply(), 0); + + // Prank msg.sender as initialOwner to simulate the mint transaction + vm.prank(tokenOwner); + + // Send sufficient funds to mint + instance.mintWithReferral{value: totalCost}(tokenOwner, quantity, referrer); + + // Assert that totalSupply is updated correctly + assertEq(instance.totalSupply(), quantity); + + // Assert the owner of the newly minted tokens + for (uint256 i = 0; i < quantity; i++) { + assertEq(instance.ownerOf(i), tokenOwner); + } + + // Assert that the contract balance is transferred correctly + uint256 receiverBalance = totalCost / 2; + assertEq(receiver.balance, receiverBalance); + uint256 referrerBalance = totalCost / 20; + assertEq(referrer.balance, referrerBalance); + assertEq(initialOwner.balance, receiverBalance- referrerBalance); + } + +}