Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TokenGatedHook Contract for NFT Ownership Verification #167

Merged
merged 2 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions contracts/hooks/TokenGatedHook.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// SPDX-License-Identifier: BUSL-1.1
pragma solidity ^0.8.19;

import { HookResult } from "contracts/interfaces/hooks/base/IHook.sol";
import { SyncBaseHook } from "contracts/hooks/base/SyncBaseHook.sol";
import { Errors } from "contracts/lib/Errors.sol";
import { TokenGated } from "contracts/lib/hooks/TokenGated.sol";
import { ERC165Checker } from "@openzeppelin/contracts/utils/introspection/ERC165Checker.sol";
import { IERC721 } from "@openzeppelin/contracts/token/ERC721/IERC721.sol";

/// @title TokenGatedHook
/// @notice This contract is a hook that ensures the user is the owner of a specific NFT token.
/// @dev It extends SyncBaseHook and provides the implementation for validating the hook configuration and executing the hook.
contract TokenGatedHook is SyncBaseHook {
using ERC165Checker for address;

/// @notice Constructs the TokenGatedHook contract.
/// @param accessControl_ The address of the access control contract.
constructor(address accessControl_) SyncBaseHook(accessControl_) {}

/// @notice Validates the configuration for the hook.
/// @dev This function checks if the tokenAddress is a valid ERC721 contract.
/// @param hookConfig_ The configuration data for the hook.
function _validateConfig(bytes memory hookConfig_) internal view override {
TokenGated.Config memory config = abi.decode(hookConfig_, (TokenGated.Config));
address tokenAddress = config.tokenAddress;
if (tokenAddress == address(0)) {
revert Errors.ZeroAddress();
}
// Check if the configured token address is a valid ERC 721 contract
if (
!tokenAddress.supportsInterface(
type(IERC721).interfaceId
)
) {
revert Errors.UnsupportedInterface("IERC721");
}
}

/// @notice Executes token gated check in a synchronous manner.
/// @dev This function checks if the "tokenOwner" owns a token of the specified ERC721 token contract.
/// @param hookConfig_ The configuration of the hook.
/// @param hookParams_ The parameters for the hook.
/// @return hookData always return empty string as no return data from this hook.
function _executeSyncCall(
bytes memory hookConfig_,
bytes memory hookParams_
) internal virtual override returns (bytes memory) {
TokenGated.Config memory config = abi.decode(hookConfig_, (TokenGated.Config));
TokenGated.Params memory params = abi.decode(hookParams_, (TokenGated.Params));

if (params.tokenOwner == address(0)) {
revert Errors.ZeroAddress();
}
// check if tokenOwner own any required token
if (IERC721(config.tokenAddress).balanceOf(params.tokenOwner) == 0) {
revert Errors.TokenGatedHook_NotTokenOwner(config.tokenAddress, params.tokenOwner);
}

return "";
}
}
3 changes: 3 additions & 0 deletions contracts/lib/Errors.sol
Original file line number Diff line number Diff line change
Expand Up @@ -337,4 +337,7 @@ library Errors {

/// @notice Invalid async request ID.
error Hook_InvalidAsyncRequestId(bytes32 invalidRequestId);

/// @notice The address is not the owner of the token.
error TokenGatedHook_NotTokenOwner(address tokenAddress, address ownerAddress);
}
23 changes: 23 additions & 0 deletions contracts/lib/hooks/TokenGated.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// SPDX-License-Identifier: BUSL-1.1
pragma solidity ^0.8.19;

/// @title TokenGated
/// @notice This library defines the Config and Params structs used in the TokenGatedHook.
/// @dev The Config struct contains the tokenAddress field, and the Params struct contains the tokenOwner field.
library TokenGated {
/// @notice Defines the required configuration information for the TokenGatedHook.
/// @dev The Config struct contains a single field: tokenAddress.
struct Config {
/// @notice The address of the ERC721 token contract.
/// @dev This address is used to check if the tokenOwner owns a token of the specified ERC721 token contract.
address tokenAddress;
}

/// @notice Defines the required parameter information for executing the TokenGatedHook.
/// @dev The Params struct contains a single field: tokenOwner.
struct Params {
/// @notice The address of the token owner.
/// @dev This address is checked against the tokenAddress in the Config struct to ensure the owner has a token.
address tokenOwner;
}
}
229 changes: 229 additions & 0 deletions test/foundry/hooks/TestTokenGatedHook.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
// SPDX-License-Identifier: BUSL-1.1
pragma solidity ^0.8.19;

import "forge-std/Test.sol";

import { BaseTest } from "test/foundry/utils/BaseTest.sol";
import { TokenGatedHook } from "contracts/hooks/TokenGatedHook.sol";
import { HookResult } from "contracts/interfaces/hooks/base/IHook.sol";
import { MockSyncHook } from "test/foundry/mocks/MockSyncHook.sol";
import { Errors } from "contracts/lib/Errors.sol";
import { AccessControl } from "contracts/lib/AccessControl.sol";
import { Hook } from "contracts/lib/hooks/Hook.sol";
import { MockERC721 } from "test/foundry/mocks/MockERC721.sol";
import { MockERC721Receiver } from "test/foundry/mocks/MockERC721Receiver.sol";
import { TokenGated } from "contracts/lib/hooks/TokenGated.sol";

contract TestTokenGatedHook is BaseTest {
TokenGatedHook hook;
MockERC721 tokenContract;
MockERC721Receiver tokenOwner;

event SyncHookExecuted(
address indexed hookAddress,
HookResult indexed result,
bytes contextData,
bytes returnData
);

function setUp() public override {
super.setUp();

vm.prank(admin);
accessControl.grantRole(AccessControl.HOOK_CALLER_ROLE, address(this));

hook = new TokenGatedHook(address(accessControl));
tokenContract = new MockERC721();
tokenOwner = new MockERC721Receiver(MockERC721Receiver.onERC721Received.selector, false);
// Simulate user has ownership of the NFT
tokenContract.mint(address(tokenOwner), 1);
}

function test_tokenGatedHook_hasOwnership() public {
// create configuration of hook
TokenGated.Config memory hookConfig = TokenGated.Config({
tokenAddress: address(tokenContract)
});
bytes memory encodedConfig = abi.encode(hookConfig);
// Hook validating the configuration
hook.validateConfig(encodedConfig);

// create parameters of executing the hook
TokenGated.Params memory hookParams = TokenGated.Params({
tokenOwner: address(tokenOwner)
});
bytes memory encodedParams = abi.encode(hookParams);

// Create Hook execution context which has hook's config and current parameters
bytes memory context = _getExecutionContext(encodedConfig, encodedParams);

bytes memory expectedHookData = "";

HookResult result;
bytes memory hookData;

// Execute the sync hook
(result, hookData) = hook.executeSync(context);

// Check the result
assertEq(uint(result), uint(HookResult.Completed));

// Check the hook data
assertEq0(hookData, expectedHookData);
}

function test_tokenGatedHook_hasOwnershipVerifyEvent() public {
// create configuration of hook
TokenGated.Config memory hookConfig = TokenGated.Config({
tokenAddress: address(tokenContract)
});
bytes memory encodedConfig = abi.encode(hookConfig);
// Hook validating the configuration
hook.validateConfig(encodedConfig);

// create parameters of executing the hook
TokenGated.Params memory hookParams = TokenGated.Params({
tokenOwner: address(tokenOwner)
});
bytes memory encodedParams = abi.encode(hookParams);

// Create Hook execution context which has hook's config and current parameters
bytes memory context = _getExecutionContext(encodedConfig, encodedParams);

bytes memory expectedHookData = "";

vm.expectEmit(address(hook));
emit SyncHookExecuted(
address(hook),
HookResult.Completed,
context,
expectedHookData
);
// Execute the sync hook
hook.executeSync(context);
}

function test_tokenGatedHook_revert_hasNoOwnership() public {
MockERC721Receiver nonTokenOwner = new MockERC721Receiver(MockERC721Receiver.onERC721Received.selector, false);
// create configuration of hook
TokenGated.Config memory hookConfig = TokenGated.Config({
tokenAddress: address(tokenContract)
});
bytes memory encodedConfig = abi.encode(hookConfig);
// Hook validating the configuration
hook.validateConfig(encodedConfig);

// create parameters of executing the hook
TokenGated.Params memory hookParams = TokenGated.Params({
tokenOwner: address(nonTokenOwner)
});
bytes memory encodedParams = abi.encode(hookParams);

// Create Hook execution context which has hook's config and current parameters
bytes memory context = _getExecutionContext(encodedConfig, encodedParams);

// Try to execute the hook without token ownership
vm.expectRevert(
abi.encodeWithSelector(
Errors.TokenGatedHook_NotTokenOwner.selector,
address(tokenContract),
address(nonTokenOwner)
)
);
hook.executeSync(context);
}

function test_tokenGatedHook_revert_ZeroTokenAddress() public {
// create configuration of hook
TokenGated.Config memory hookConfig = TokenGated.Config({
// Invalid token address
tokenAddress: address(0)
});
bytes memory encodedConfig = abi.encode(hookConfig);

// create parameters of executing the hook
TokenGated.Params memory hookParams = TokenGated.Params({
tokenOwner: address(tokenOwner)
});
bytes memory encodedParams = abi.encode(hookParams);

// Create Hook execution context which has hook's config and current parameters
bytes memory context = _getExecutionContext(encodedConfig, encodedParams);

// Try to execute the hook with invalid token contract address
vm.expectRevert(Errors.ZeroAddress.selector);
hook.executeSync(context);
}

function test_tokenGatedHook_revert_NonERC721Address() public {
// create configuration of hook
TokenGated.Config memory hookConfig = TokenGated.Config({
// Invalid token address
tokenAddress: address(0x77777)
});
bytes memory encodedConfig = abi.encode(hookConfig);

// create parameters of executing the hook
TokenGated.Params memory hookParams = TokenGated.Params({
tokenOwner: address(tokenOwner)
});
bytes memory encodedParams = abi.encode(hookParams);

// Create Hook execution context which has hook's config and current parameters
bytes memory context = _getExecutionContext(encodedConfig, encodedParams);

// Try to execute the hook with invalid token contract address
vm.expectRevert(
abi.encodeWithSelector(
Errors.UnsupportedInterface.selector,
"IERC721"
)
);

hook.executeSync(context);
}

function test_syncBaseHook_revert_InvalidOwnerAddress() public {
// create configuration of hook
TokenGated.Config memory hookConfig = TokenGated.Config({
// Invalid token address
tokenAddress: address(tokenContract)
});
bytes memory encodedConfig = abi.encode(hookConfig);

// create parameters of executing the hook
TokenGated.Params memory hookParams = TokenGated.Params({
tokenOwner: address(0)
});
bytes memory encodedParams = abi.encode(hookParams);

// Create Hook execution context which has hook's config and current parameters
bytes memory context = _getExecutionContext(encodedConfig, encodedParams);

// Try to execute the hook with invalid contract address
vm.expectRevert(Errors.ZeroAddress.selector);

hook.executeSync(context);
}

function test_tokenGatedHook_revert_InvalidConfig() public {
// create configuration of hook
TokenGated.Config memory hookConfig = TokenGated.Config({
// Invalid token address
tokenAddress: address(0)
});
bytes memory encodedConfig = abi.encode(hookConfig);

vm.expectRevert(Errors.ZeroAddress.selector);
hook.validateConfig(encodedConfig);
}

function _getExecutionContext(bytes memory hookConfig_, bytes memory hookParams_) internal pure returns (bytes memory) {
Hook.ExecutionContext memory context = Hook.ExecutionContext({
config: hookConfig_,
params: hookParams_
});
return abi.encode(context);
}

}