diff --git a/target_chains/ethereum/contracts/contracts/pulse/scheduler/IScheduler.sol b/target_chains/ethereum/contracts/contracts/pulse/scheduler/IScheduler.sol new file mode 100644 index 0000000000..aae5638346 --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/pulse/scheduler/IScheduler.sol @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +import "@pythnetwork/pyth-sdk-solidity/IPyth.sol"; +import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol"; +import "./SchedulerEvents.sol"; +import "./SchedulerState.sol"; + +interface IScheduler is SchedulerEvents { + // CORE FUNCTIONS + + /** + * @notice Adds a new subscription + * @param subscriptionParams The parameters for the subscription + * @return subscriptionId The ID of the newly created subscription + */ + function addSubscription( + SchedulerState.SubscriptionParams calldata subscriptionParams + ) external returns (uint256 subscriptionId); + + /** + * @notice Gets a subscription's parameters and status + * @param subscriptionId The ID of the subscription + * @return params The subscription parameters + * @return status The subscription status + */ + function getSubscription( + uint256 subscriptionId + ) + external + view + returns ( + SchedulerState.SubscriptionParams memory params, + SchedulerState.SubscriptionStatus memory status + ); + + /** + * @notice Updates an existing subscription + * @param subscriptionId The ID of the subscription to update + * @param newSubscriptionParams The new parameters for the subscription + */ + function updateSubscription( + uint256 subscriptionId, + SchedulerState.SubscriptionParams calldata newSubscriptionParams + ) external; + + /** + * @notice Deactivates a subscription + * @param subscriptionId The ID of the subscription to deactivate + */ + function deactivateSubscription(uint256 subscriptionId) external; + + /** + * @notice Updates price feeds for a subscription. + * Verifies the updateData using the Pyth contract and validates that all feeds have the same timestamp. + * @param subscriptionId The ID of the subscription + * @param updateData The price update data from Pyth + * @param priceIds The IDs of the price feeds to update + */ + function updatePriceFeeds( + uint256 subscriptionId, + bytes[] calldata updateData, + bytes32[] calldata priceIds + ) external; + + /** + * @notice Gets the latest prices for a subscription + * @param subscriptionId The ID of the subscription + * @param priceIds Optional array of price IDs to retrieve. If empty, returns all price feeds for the subscription. + * @return The latest price feeds for the requested price IDs + */ + function getLatestPrices( + uint256 subscriptionId, + bytes32[] calldata priceIds + ) external view returns (PythStructs.PriceFeed[] memory); + + /** + * @notice Adds funds to a subscription's balance + * @param subscriptionId The ID of the subscription + */ + function addFunds(uint256 subscriptionId) external payable; + + /** + * @notice Withdraws funds from a subscription's balance + * @param subscriptionId The ID of the subscription + * @param amount The amount to withdraw + */ + function withdrawFunds(uint256 subscriptionId, uint256 amount) external; + + /** + * @notice Gets all active subscriptions with their parameters + * @dev This function has no access control to allow keepers to discover active subscriptions + * @return subscriptionIds Array of active subscription IDs + * @return subscriptionParams Array of subscription parameters for each active subscription + */ + function getActiveSubscriptions() + external + view + returns ( + uint256[] memory subscriptionIds, + SchedulerState.SubscriptionParams[] memory subscriptionParams + ); +} diff --git a/target_chains/ethereum/contracts/contracts/pulse/scheduler/Scheduler.sol b/target_chains/ethereum/contracts/contracts/pulse/scheduler/Scheduler.sol new file mode 100644 index 0000000000..8f6d5897e5 --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/pulse/scheduler/Scheduler.sol @@ -0,0 +1,482 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +import "@openzeppelin/contracts/utils/math/SafeCast.sol"; +import "@openzeppelin/contracts/utils/math/SignedMath.sol"; +import "@openzeppelin/contracts/utils/math/Math.sol"; +import "@pythnetwork/pyth-sdk-solidity/IPyth.sol"; +import "./IScheduler.sol"; +import "./SchedulerState.sol"; +import "./SchedulerErrors.sol"; + +abstract contract Scheduler is IScheduler, SchedulerState { + function _initialize(address admin, address pythAddress) internal { + require(admin != address(0), "admin is zero address"); + require(pythAddress != address(0), "pyth is zero address"); + + _state.pyth = pythAddress; + _state.subscriptionNumber = 1; + } + + function addSubscription( + SubscriptionParams calldata subscriptionParams + ) external override returns (uint256 subscriptionId) { + if (subscriptionParams.priceIds.length > MAX_PRICE_IDS) { + revert TooManyPriceIds( + subscriptionParams.priceIds.length, + MAX_PRICE_IDS + ); + } + + // Validate update criteria + if ( + !subscriptionParams.updateCriteria.updateOnHeartbeat && + !subscriptionParams.updateCriteria.updateOnDeviation + ) { + revert InvalidUpdateCriteria(); + } + + // Validate gas config + if ( + subscriptionParams.gasConfig.maxGasPrice == 0 || + subscriptionParams.gasConfig.maxGasLimit == 0 + ) { + revert InvalidGasConfig(); + } + + subscriptionId = _state.subscriptionNumber++; + + // Store the subscription parameters + _state.subscriptionParams[subscriptionId] = subscriptionParams; + + // Initialize subscription status + SubscriptionStatus storage status = _state.subscriptionStatuses[ + subscriptionId + ]; + status.priceLastUpdatedAt = 0; + status.balanceInWei = 0; + status.totalUpdates = 0; + status.totalSpent = 0; + status.isActive = true; + + // Map subscription ID to manager + _state.subscriptionManager[subscriptionId] = msg.sender; + + emit SubscriptionCreated(subscriptionId, msg.sender); + return subscriptionId; + } + + function getSubscription( + uint256 subscriptionId + ) + external + view + override + returns ( + SubscriptionParams memory params, + SubscriptionStatus memory status + ) + { + return ( + _state.subscriptionParams[subscriptionId], + _state.subscriptionStatuses[subscriptionId] + ); + } + + function updateSubscription( + uint256 subscriptionId, + SubscriptionParams calldata newSubscriptionParams + ) external override onlyManager(subscriptionId) { + if (!_state.subscriptionStatuses[subscriptionId].isActive) { + revert InactiveSubscription(); + } + + if (newSubscriptionParams.priceIds.length > MAX_PRICE_IDS) { + revert TooManyPriceIds( + newSubscriptionParams.priceIds.length, + MAX_PRICE_IDS + ); + } + + // Validate update criteria + if ( + !newSubscriptionParams.updateCriteria.updateOnHeartbeat && + !newSubscriptionParams.updateCriteria.updateOnDeviation + ) { + revert InvalidUpdateCriteria(); + } + + // Validate gas config + if ( + newSubscriptionParams.gasConfig.maxGasPrice == 0 || + newSubscriptionParams.gasConfig.maxGasLimit == 0 + ) { + revert InvalidGasConfig(); + } + + // Update subscription parameters + _state.subscriptionParams[subscriptionId] = newSubscriptionParams; + + emit SubscriptionUpdated(subscriptionId); + } + + function deactivateSubscription( + uint256 subscriptionId + ) external override onlyManager(subscriptionId) { + if (!_state.subscriptionStatuses[subscriptionId].isActive) { + revert InactiveSubscription(); + } + + _state.subscriptionStatuses[subscriptionId].isActive = false; + + emit SubscriptionDeactivated(subscriptionId); + } + + function updatePriceFeeds( + uint256 subscriptionId, + bytes[] calldata updateData, + bytes32[] calldata priceIds + ) external override onlyPusher { + SubscriptionStatus storage status = _state.subscriptionStatuses[ + subscriptionId + ]; + SubscriptionParams storage params = _state.subscriptionParams[ + subscriptionId + ]; + + if (!status.isActive) { + revert InactiveSubscription(); + } + + // Verify price IDs match subscription + if (priceIds.length != params.priceIds.length) { + revert InvalidPriceIdsLength(priceIds[0], params.priceIds[0]); + } + + // Keepers must provide priceIds in the exact same order as defined in the subscription + for (uint8 i = 0; i < priceIds.length; i++) { + if (priceIds[i] != params.priceIds[i]) { + revert InvalidPriceId(priceIds[i], params.priceIds[i]); + } + } + + // Get the Pyth contract and parse price updates + IPyth pyth = IPyth(_state.pyth); + uint256 pythFee = pyth.getUpdateFee(updateData); + + // Check if subscription has enough balance + if (status.balanceInWei < pythFee) { + revert InsufficientBalance(); + } + + // Parse price feed updates with an expected timestamp range of [-10s, now] + // We will validate the trigger conditions and timestamps ourselves + // using the returned PriceFeeds. + uint64 maxPublishTime = SafeCast.toUint64(block.timestamp); + uint64 minPublishTime = maxPublishTime - 10 seconds; + PythStructs.PriceFeed[] memory priceFeeds = pyth.parsePriceFeedUpdates{ + value: pythFee + }(updateData, priceIds, minPublishTime, maxPublishTime); + + // Verify all price feeds have the same timestamp + uint256 timestamp = priceFeeds[0].price.publishTime; + for (uint8 i = 1; i < priceFeeds.length; i++) { + if (priceFeeds[i].price.publishTime != timestamp) { + revert PriceTimestampMismatch(); + } + } + + // Verify that update conditions are met, and that the timestamp + // is more recent than latest stored update's. Reverts if not. + _validateShouldUpdatePrices(subscriptionId, params, status, priceFeeds); + + // Store the price updates, update status, and emit event + _storePriceUpdatesAndStatus( + subscriptionId, + status, + priceFeeds, + pythFee + ); + } + + /** + * @notice Stores the price updates, updates subscription status, and emits event. + */ + function _storePriceUpdatesAndStatus( + uint256 subscriptionId, + SubscriptionStatus storage status, + PythStructs.PriceFeed[] memory priceFeeds, + uint256 pythFee + ) internal { + // Store the price updates + for (uint8 i = 0; i < priceFeeds.length; i++) { + _state.priceUpdates[subscriptionId][priceFeeds[i].id] = priceFeeds[ + i + ]; + } + status.priceLastUpdatedAt = priceFeeds[0].price.publishTime; + status.balanceInWei -= pythFee; + status.totalUpdates += 1; + status.totalSpent += pythFee; + + emit PricesUpdated(subscriptionId, priceFeeds[0].price.publishTime); + } + + /** + * @notice Validates whether the update trigger criteria is met for a subscription. Reverts if not met. + * @dev This function assumes that all updates in priceFeeds have the same timestamp. The caller is expected to enforce this invariant. + * @param subscriptionId The ID of the subscription (needed for reading previous prices). + * @param params The subscription's parameters struct. + * @param status The subscription's status struct. + * @param priceFeeds The array of price feeds to validate. + */ + function _validateShouldUpdatePrices( + uint256 subscriptionId, + SubscriptionParams storage params, + SubscriptionStatus storage status, + PythStructs.PriceFeed[] memory priceFeeds + ) internal view returns (bool) { + // SECURITY NOTE: this check assumes that all updates in priceFeeds have the same timestamp. + // The caller is expected to enforce this invariant. + uint256 updateTimestamp = priceFeeds[0].price.publishTime; + + // Reject updates if they're older than the latest stored ones + if ( + status.priceLastUpdatedAt > 0 && + updateTimestamp <= status.priceLastUpdatedAt + ) { + revert TimestampOlderThanLastUpdate( + updateTimestamp, + status.priceLastUpdatedAt + ); + } + + // If updateOnHeartbeat is enabled and the heartbeat interval has passed, trigger update + if (params.updateCriteria.updateOnHeartbeat) { + uint256 lastUpdateTime = status.priceLastUpdatedAt; + + if ( + lastUpdateTime == 0 || + updateTimestamp >= + lastUpdateTime + params.updateCriteria.heartbeatSeconds + ) { + return true; + } + } + + // If updateOnDeviation is enabled, check if any price has deviated enough + if (params.updateCriteria.updateOnDeviation) { + for (uint8 i = 0; i < priceFeeds.length; i++) { + // Get the previous price feed for this price ID using subscriptionId + PythStructs.PriceFeed storage previousFeed = _state + .priceUpdates[subscriptionId][priceFeeds[i].id]; + + // If there's no previous price, this is the first update + if (previousFeed.id == bytes32(0)) { + return true; + } + + // Calculate the deviation percentage + int64 currentPrice = priceFeeds[i].price.price; + int64 previousPrice = previousFeed.price.price; + + // Skip if either price is zero to avoid division by zero + if (previousPrice == 0 || currentPrice == 0) { + continue; + } + + // Calculate absolute deviation basis points (scaled by 1e4) + uint256 numerator = SignedMath.abs( + currentPrice - previousPrice + ); + uint256 denominator = SignedMath.abs(previousPrice); + uint256 deviationBps = Math.mulDiv( + numerator, + 10_000, + denominator + ); + + // If deviation exceeds threshold, trigger update + if ( + deviationBps >= params.updateCriteria.deviationThresholdBps + ) { + return true; + } + } + } + + revert UpdateConditionsNotMet(); + } + + function getLatestPrices( + uint256 subscriptionId, + bytes32[] calldata priceIds + ) + external + view + override + onlyWhitelistedReader(subscriptionId) + returns (PythStructs.PriceFeed[] memory) + { + if (!_state.subscriptionStatuses[subscriptionId].isActive) { + revert InactiveSubscription(); + } + + SubscriptionParams storage params = _state.subscriptionParams[ + subscriptionId + ]; + + // If no price IDs provided, return all price feeds for the subscription + if (priceIds.length == 0) { + PythStructs.PriceFeed[] + memory allFeeds = new PythStructs.PriceFeed[]( + params.priceIds.length + ); + for (uint8 i = 0; i < params.priceIds.length; i++) { + allFeeds[i] = _state.priceUpdates[subscriptionId][ + params.priceIds[i] + ]; + } + return allFeeds; + } + + // Return only the requested price feeds + PythStructs.PriceFeed[] + memory requestedFeeds = new PythStructs.PriceFeed[]( + priceIds.length + ); + for (uint8 i = 0; i < priceIds.length; i++) { + // Verify the requested price ID is part of the subscription + bool validPriceId = false; + for (uint8 j = 0; j < params.priceIds.length; j++) { + if (priceIds[i] == params.priceIds[j]) { + validPriceId = true; + break; + } + } + + if (!validPriceId) { + revert InvalidPriceId(priceIds[i], params.priceIds[0]); + } + + requestedFeeds[i] = _state.priceUpdates[subscriptionId][ + priceIds[i] + ]; + } + + return requestedFeeds; + } + + function addFunds( + uint256 subscriptionId + ) external payable override onlyManager(subscriptionId) { + if (!_state.subscriptionStatuses[subscriptionId].isActive) { + revert InactiveSubscription(); + } + + _state.subscriptionStatuses[subscriptionId].balanceInWei += msg.value; + } + + function withdrawFunds( + uint256 subscriptionId, + uint256 amount + ) external override onlyManager(subscriptionId) { + SubscriptionStatus storage status = _state.subscriptionStatuses[ + subscriptionId + ]; + + if (status.balanceInWei < amount) { + revert InsufficientBalance(); + } + + status.balanceInWei -= amount; + + (bool sent, ) = msg.sender.call{value: amount}(""); + require(sent, "Failed to send funds"); + } + + // This function is intentionally public with no access control to allow keepers to discover active subscriptions + function getActiveSubscriptions() + external + view + override + returns ( + uint256[] memory subscriptionIds, + SubscriptionParams[] memory subscriptionParams + ) + { + // TODO: This is gonna be expensive because we're iterating through + // all subscriptions, including deactivated ones. But because its a view + // function maybe it's not bad? We can optimize this. + + // Count active subscriptions first to determine array size + uint256 activeCount = 0; + for (uint256 i = 1; i < _state.subscriptionNumber; i++) { + if (_state.subscriptionStatuses[i].isActive) { + activeCount++; + } + } + + // Create arrays for subscription IDs and parameters + subscriptionIds = new uint256[](activeCount); + subscriptionParams = new SubscriptionParams[](activeCount); + + // Populate arrays with active subscription data + uint256 index = 0; + for (uint256 i = 1; i < _state.subscriptionNumber; i++) { + if (_state.subscriptionStatuses[i].isActive) { + subscriptionIds[index] = i; + subscriptionParams[index] = _state.subscriptionParams[i]; + index++; + } + } + + return (subscriptionIds, subscriptionParams); + } + + // ACCESS CONTROL MODIFIERS + + modifier onlyPusher() { + // TODO: we may not make this permissioned. + _; + } + + modifier onlyManager(uint256 subscriptionId) { + if (_state.subscriptionManager[subscriptionId] != msg.sender) { + revert Unauthorized(); + } + _; + } + + modifier onlyWhitelistedReader(uint256 subscriptionId) { + // Manager is always allowed + if (_state.subscriptionManager[subscriptionId] == msg.sender) { + _; + return; + } + + // If whitelist is not used, allow any reader + if (!_state.subscriptionParams[subscriptionId].whitelistEnabled) { + _; + return; + } + + // Check if caller is in whitelist + address[] storage whitelist = _state + .subscriptionParams[subscriptionId] + .readerWhitelist; + bool isWhitelisted = false; + for (uint i = 0; i < whitelist.length; i++) { + if (whitelist[i] == msg.sender) { + isWhitelisted = true; + break; + } + } + + if (!isWhitelisted) { + revert Unauthorized(); + } + _; + } +} diff --git a/target_chains/ethereum/contracts/contracts/pulse/scheduler/SchedulerErrors.sol b/target_chains/ethereum/contracts/contracts/pulse/scheduler/SchedulerErrors.sol new file mode 100644 index 0000000000..8f63682bac --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/pulse/scheduler/SchedulerErrors.sol @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +error InactiveSubscription(); +error InsufficientBalance(); +error Unauthorized(); +error InvalidPriceId(bytes32 providedPriceId, bytes32 expectedPriceId); +error InvalidPriceIdsLength(bytes32 providedLength, bytes32 expectedLength); +error InvalidUpdateCriteria(); +error InvalidGasConfig(); +error PriceTimestampMismatch(); +error TooManyPriceIds(uint256 provided, uint256 maximum); +error UpdateConditionsNotMet(); +error TimestampOlderThanLastUpdate( + uint256 providedUpdateTimestamp, + uint256 lastUpdatedAt +); diff --git a/target_chains/ethereum/contracts/contracts/pulse/scheduler/SchedulerEvents.sol b/target_chains/ethereum/contracts/contracts/pulse/scheduler/SchedulerEvents.sol new file mode 100644 index 0000000000..f8acce50a4 --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/pulse/scheduler/SchedulerEvents.sol @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.0; + +import "./SchedulerState.sol"; + +interface SchedulerEvents { + event SubscriptionCreated( + uint256 indexed subscriptionId, + address indexed manager + ); + event SubscriptionUpdated(uint256 indexed subscriptionId); + event SubscriptionDeactivated(uint256 indexed subscriptionId); + event PricesUpdated(uint256 indexed subscriptionId, uint256 timestamp); +} diff --git a/target_chains/ethereum/contracts/contracts/pulse/scheduler/SchedulerState.sol b/target_chains/ethereum/contracts/contracts/pulse/scheduler/SchedulerState.sol new file mode 100644 index 0000000000..d638da6f90 --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/pulse/scheduler/SchedulerState.sol @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol"; + +contract SchedulerState { + // Maximum number of price feeds per subscription + uint8 public constant MAX_PRICE_IDS = 10; + + struct State { + // Monotonically increasing counter for subscription IDs + uint256 subscriptionNumber; + // Pyth contract for parsing updates and verifying sigs & timestamps + address pyth; + // Sub ID -> subscription parameters (which price feeds, when to update, etc) + mapping(uint256 => SubscriptionParams) subscriptionParams; + // Sub ID -> subscription status (metadata about their sub) + mapping(uint256 => SubscriptionStatus) subscriptionStatuses; + // Sub ID -> price ID -> latest parsed price update for the subscribed feed + mapping(uint256 => mapping(bytes32 => PythStructs.PriceFeed)) priceUpdates; + // Sub ID -> manager address + mapping(uint256 => address) subscriptionManager; + } + State internal _state; + + struct SubscriptionParams { + bytes32[] priceIds; + address[] readerWhitelist; + bool whitelistEnabled; + UpdateCriteria updateCriteria; + GasConfig gasConfig; + } + + struct SubscriptionStatus { + uint256 priceLastUpdatedAt; + uint256 balanceInWei; + uint256 totalUpdates; + uint256 totalSpent; + bool isActive; + } + + struct GasConfig { + // TODO: Figure out what controls to give users for gas strategy + + // Gas price limit to prevent runaway costs in high-gas environments + uint256 maxGasPrice; + // Gas limit for update operations + uint256 maxGasLimit; + } + + struct UpdateCriteria { + bool updateOnHeartbeat; + uint32 heartbeatSeconds; + bool updateOnDeviation; + uint32 deviationThresholdBps; + + // TODO: add updateOnConfidenceRatio? + + // TODO: add explicit "early update" support? i.e. update all feeds when at least one feed + // meets the triggering conditions, rather than waiting for all feeds + // to meet the conditions. Currently, "early update" is the only mode of operation. + } +} diff --git a/target_chains/ethereum/contracts/contracts/pulse/scheduler/SchedulerUpgradeable.sol b/target_chains/ethereum/contracts/contracts/pulse/scheduler/SchedulerUpgradeable.sol new file mode 100644 index 0000000000..45620cc9cb --- /dev/null +++ b/target_chains/ethereum/contracts/contracts/pulse/scheduler/SchedulerUpgradeable.sol @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +import "@openzeppelin/contracts-upgradeable/proxy/utils/Initializable.sol"; +import "@openzeppelin/contracts-upgradeable/proxy/utils/UUPSUpgradeable.sol"; +import "@openzeppelin/contracts-upgradeable/access/Ownable2StepUpgradeable.sol"; +import "./Scheduler.sol"; + +contract SchedulerUpgradeable is + Initializable, + Ownable2StepUpgradeable, + UUPSUpgradeable, + Scheduler +{ + event ContractUpgraded( + address oldImplementation, + address newImplementation + ); + + function initialize( + address owner, + address admin, + address pythAddress + ) external initializer { + require(owner != address(0), "owner is zero address"); + require(admin != address(0), "admin is zero address"); + + __Ownable_init(); + __UUPSUpgradeable_init(); + + Scheduler._initialize(admin, pythAddress); + + _transferOwnership(owner); + } + + /// @custom:oz-upgrades-unsafe-allow constructor + constructor() initializer {} + + function _authorizeUpgrade(address) internal override onlyOwner {} + + function upgradeTo(address newImplementation) external override onlyProxy { + address oldImplementation = _getImplementation(); + _authorizeUpgrade(newImplementation); + _upgradeToAndCallUUPS(newImplementation, new bytes(0), false); + + emit ContractUpgraded(oldImplementation, _getImplementation()); + } + + function upgradeToAndCall( + address newImplementation, + bytes memory data + ) external payable override onlyProxy { + address oldImplementation = _getImplementation(); + _authorizeUpgrade(newImplementation); + _upgradeToAndCallUUPS(newImplementation, data, true); + + emit ContractUpgraded(oldImplementation, _getImplementation()); + } + + function version() public pure returns (string memory) { + return "1.0.0"; + } +} diff --git a/target_chains/ethereum/contracts/forge-test/PulseScheduler.t.sol b/target_chains/ethereum/contracts/forge-test/PulseScheduler.t.sol new file mode 100644 index 0000000000..6a831d21e0 --- /dev/null +++ b/target_chains/ethereum/contracts/forge-test/PulseScheduler.t.sol @@ -0,0 +1,896 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.0; + +import "forge-std/Test.sol"; +import "@pythnetwork/pyth-sdk-solidity/IPyth.sol"; +import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol"; +import "./utils/PulseTestUtils.t.sol"; +import "../contracts/pulse/scheduler/SchedulerUpgradeable.sol"; +import "../contracts/pulse/scheduler/IScheduler.sol"; +import "../contracts/pulse/scheduler/SchedulerState.sol"; +import "../contracts/pulse/scheduler/SchedulerEvents.sol"; +import "../contracts/pulse/scheduler/SchedulerErrors.sol"; + +contract MockReader { + address private _scheduler; + + constructor(address scheduler) { + _scheduler = scheduler; + } + + function getLatestPrices( + uint256 subscriptionId, + bytes32[] memory priceIds + ) external view returns (PythStructs.PriceFeed[] memory) { + return IScheduler(_scheduler).getLatestPrices(subscriptionId, priceIds); + } + + function verifyPriceFeeds( + uint256 subscriptionId, + bytes32[] memory priceIds, + PythStructs.PriceFeed[] memory expectedFeeds + ) external view returns (bool) { + PythStructs.PriceFeed[] memory actualFeeds = IScheduler(_scheduler) + .getLatestPrices(subscriptionId, priceIds); + + if (actualFeeds.length != expectedFeeds.length) { + return false; + } + + for (uint i = 0; i < actualFeeds.length; i++) { + if ( + actualFeeds[i].id != expectedFeeds[i].id || + actualFeeds[i].price.price != expectedFeeds[i].price.price || + actualFeeds[i].price.conf != expectedFeeds[i].price.conf || + actualFeeds[i].price.publishTime != + expectedFeeds[i].price.publishTime + ) { + return false; + } + } + + return true; + } +} + +contract SchedulerTest is Test, SchedulerEvents, PulseTestUtils { + ERC1967Proxy public proxy; + SchedulerUpgradeable public scheduler; + MockReader public reader; + address public owner; + address public admin; + address public pyth; + address public pusher; + + // Constants + uint96 constant PYTH_FEE = 1 wei; + + function setUp() public { + owner = address(1); + admin = address(2); + pyth = address(3); + pusher = address(4); + + SchedulerUpgradeable _scheduler = new SchedulerUpgradeable(); + proxy = new ERC1967Proxy(address(_scheduler), ""); + scheduler = SchedulerUpgradeable(address(proxy)); + + scheduler.initialize(owner, admin, pyth); + + reader = new MockReader(address(proxy)); + + // Start tests at timestamp 100 to avoid underflow when we set + // `minPublishTime = timestamp - 10 seconds` in updatePriceFeeds + vm.warp(100); + } + + function testAddSubscription() public { + // Create subscription parameters + bytes32[] memory priceIds = createPriceIds(); + address[] memory readerWhitelist = new address[](1); + readerWhitelist[0] = address(reader); + + SchedulerState.UpdateCriteria memory updateCriteria = SchedulerState + .UpdateCriteria({ + updateOnHeartbeat: true, + heartbeatSeconds: 60, + updateOnDeviation: true, + deviationThresholdBps: 100 + }); + + SchedulerState.GasConfig memory gasConfig = SchedulerState.GasConfig({ + maxGasPrice: 100 gwei, + maxGasLimit: 1_000_000 + }); + + SchedulerState.SubscriptionParams memory params = SchedulerState + .SubscriptionParams({ + priceIds: priceIds, + readerWhitelist: readerWhitelist, + whitelistEnabled: true, + updateCriteria: updateCriteria, + gasConfig: gasConfig + }); + + // Add subscription + vm.expectEmit(); + emit SubscriptionCreated(1, address(this)); + + uint256 subscriptionId = scheduler.addSubscription(params); + assertEq(subscriptionId, 1, "Subscription ID should be 1"); + + // Verify subscription was added correctly + ( + SchedulerState.SubscriptionParams memory storedParams, + SchedulerState.SubscriptionStatus memory status + ) = scheduler.getSubscription(subscriptionId); + + assertEq( + storedParams.priceIds.length, + priceIds.length, + "Price IDs length mismatch" + ); + assertEq( + storedParams.readerWhitelist.length, + readerWhitelist.length, + "Whitelist length mismatch" + ); + assertEq( + storedParams.whitelistEnabled, + true, + "whitelistEnabled should be true" + ); + assertEq( + storedParams.updateCriteria.heartbeatSeconds, + 60, + "Heartbeat seconds mismatch" + ); + assertEq( + storedParams.updateCriteria.deviationThresholdBps, + 100, + "Deviation threshold mismatch" + ); + assertEq( + storedParams.gasConfig.maxGasPrice, + 100 gwei, + "Max gas price mismatch" + ); + + assertTrue(status.isActive, "Subscription should be active"); + assertEq(status.balanceInWei, 0, "Initial balance should be 0"); + } + + function testUpdateSubscription() public { + // First add a subscription + uint256 subscriptionId = addTestSubscription(); + + // Create updated parameters + bytes32[] memory newPriceIds = createPriceIds(3); // Add one more price ID + address[] memory newReaderWhitelist = new address[](2); + newReaderWhitelist[0] = address(reader); + newReaderWhitelist[1] = address(0x123); + + SchedulerState.UpdateCriteria memory newUpdateCriteria = SchedulerState + .UpdateCriteria({ + updateOnHeartbeat: true, + heartbeatSeconds: 120, // Changed from 60 + updateOnDeviation: true, + deviationThresholdBps: 200 // Changed from 100 + }); + + SchedulerState.GasConfig memory newGasConfig = SchedulerState + .GasConfig({ + maxGasPrice: 200 gwei, // Changed from 100 gwei + maxGasLimit: 2_000_000 // Changed from 1_000_000 + }); + + SchedulerState.SubscriptionParams memory newParams = SchedulerState + .SubscriptionParams({ + priceIds: newPriceIds, + readerWhitelist: newReaderWhitelist, + whitelistEnabled: false, // Changed from true + updateCriteria: newUpdateCriteria, + gasConfig: newGasConfig + }); + + // Update subscription + vm.expectEmit(); + emit SubscriptionUpdated(subscriptionId); + + scheduler.updateSubscription(subscriptionId, newParams); + + // Verify subscription was updated correctly + (SchedulerState.SubscriptionParams memory storedParams, ) = scheduler + .getSubscription(subscriptionId); + + assertEq( + storedParams.priceIds.length, + newPriceIds.length, + "Price IDs length mismatch" + ); + assertEq( + storedParams.readerWhitelist.length, + newReaderWhitelist.length, + "Whitelist length mismatch" + ); + assertEq( + storedParams.whitelistEnabled, + false, + "whitelistEnabled should be false" + ); + assertEq( + storedParams.updateCriteria.heartbeatSeconds, + 120, + "Heartbeat seconds mismatch" + ); + assertEq( + storedParams.updateCriteria.deviationThresholdBps, + 200, + "Deviation threshold mismatch" + ); + assertEq( + storedParams.gasConfig.maxGasPrice, + 200 gwei, + "Max gas price mismatch" + ); + } + + function testDeactivateSubscription() public { + // First add a subscription + uint256 subscriptionId = addTestSubscription(); + + // Deactivate subscription + vm.expectEmit(); + emit SubscriptionDeactivated(subscriptionId); + + scheduler.deactivateSubscription(subscriptionId); + + // Verify subscription was deactivated + (, SchedulerState.SubscriptionStatus memory status) = scheduler + .getSubscription(subscriptionId); + + assertFalse(status.isActive, "Subscription should be inactive"); + } + + function testAddFunds() public { + // First add a subscription + uint256 subscriptionId = addTestSubscription(); + + // Add funds + uint256 fundAmount = 1 ether; + scheduler.addFunds{value: fundAmount}(subscriptionId); + + // Verify funds were added + (, SchedulerState.SubscriptionStatus memory status) = scheduler + .getSubscription(subscriptionId); + + assertEq( + status.balanceInWei, + fundAmount, + "Balance should match added funds" + ); + } + + function testWithdrawFunds() public { + // First add a subscription and funds + uint256 subscriptionId = addTestSubscription(); + uint256 fundAmount = 1 ether; + scheduler.addFunds{value: fundAmount}(subscriptionId); + + // Get initial balance + uint256 initialBalance = address(this).balance; + + // Withdraw half the funds + uint256 withdrawAmount = fundAmount / 2; + scheduler.withdrawFunds(subscriptionId, withdrawAmount); + + // Verify funds were withdrawn + (, SchedulerState.SubscriptionStatus memory status) = scheduler + .getSubscription(subscriptionId); + + assertEq( + status.balanceInWei, + fundAmount - withdrawAmount, + "Remaining balance incorrect" + ); + assertEq( + address(this).balance, + initialBalance + withdrawAmount, + "Withdrawn amount not received" + ); + } + + function testUpdatePriceFeedsWorks() public { + // --- First Update --- + // Add a subscription and funds + uint256 subscriptionId = addTestSubscription(); // Uses heartbeat 60s, deviation 100bps + uint256 fundAmount = 2 ether; // Add enough for two updates + scheduler.addFunds{value: fundAmount}(subscriptionId); + + // Create price feeds and mock Pyth response for first update + bytes32[] memory priceIds = createPriceIds(); + uint64 publishTime1 = SafeCast.toUint64(block.timestamp); + PythStructs.PriceFeed[] memory priceFeeds1 = createMockPriceFeeds( + publishTime1 + ); + mockParsePriceFeedUpdates(pyth, priceFeeds1); + bytes[] memory updateData1 = createMockUpdateData(priceFeeds1); + + // Perform first update + vm.expectEmit(); + emit PricesUpdated(subscriptionId, publishTime1); + vm.prank(pusher); + + vm.breakpoint("a"); + scheduler.updatePriceFeeds(subscriptionId, updateData1, priceIds); + + // Verify first update + (, SchedulerState.SubscriptionStatus memory status1) = scheduler + .getSubscription(subscriptionId); + assertEq( + status1.priceLastUpdatedAt, + publishTime1, + "First update timestamp incorrect" + ); + assertEq( + status1.totalUpdates, + 1, + "Total updates should be 1 after first update" + ); + assertTrue( + status1.totalSpent > 0, + "Total spent should be > 0 after first update" + ); + uint256 spentAfterFirst = status1.totalSpent; // Store spent amount + + // --- Second Update --- + // Advance time beyond heartbeat interval (e.g., 100 seconds) + vm.warp(block.timestamp + 100); + + // Create price feeds for second update by cloning first update and modifying + uint64 publishTime2 = SafeCast.toUint64(block.timestamp); + PythStructs.PriceFeed[] + memory priceFeeds2 = new PythStructs.PriceFeed[]( + priceFeeds1.length + ); + for (uint i = 0; i < priceFeeds1.length; i++) { + priceFeeds2[i] = priceFeeds1[i]; // Clone the feed struct + priceFeeds2[i].price.publishTime = publishTime2; // Update timestamp + + // Apply a 100 bps price increase (satisfies update criteria) + int64 priceDiff = int64( + (uint64(priceFeeds1[i].price.price) * 100) / 10_000 + ); + priceFeeds2[i].price.price = priceFeeds1[i].price.price + priceDiff; + priceFeeds2[i].emaPrice.publishTime = publishTime2; + } + + mockParsePriceFeedUpdates(pyth, priceFeeds2); // Mock for the second call + bytes[] memory updateData2 = createMockUpdateData(priceFeeds2); + + // Perform second update + vm.expectEmit(); + emit PricesUpdated(subscriptionId, publishTime2); + vm.prank(pusher); + + vm.breakpoint("b"); + scheduler.updatePriceFeeds(subscriptionId, updateData2, priceIds); + + // Verify second update + (, SchedulerState.SubscriptionStatus memory status2) = scheduler + .getSubscription(subscriptionId); + assertEq( + status2.priceLastUpdatedAt, + publishTime2, + "Second update timestamp incorrect" + ); + assertEq( + status2.totalUpdates, + 2, + "Total updates should be 2 after second update" + ); + assertTrue( + status2.totalSpent > spentAfterFirst, + "Total spent should increase after second update" + ); + // Verify price feed data using the reader contract for the second update + assertTrue( + reader.verifyPriceFeeds( + subscriptionId, + new bytes32[](0), + priceFeeds2 + ), + "Price feeds verification failed after second update" + ); + } + + function testUpdatePriceFeedsRevertsOnUpdateConditionsNotMet_Heartbeat() + public + { + // Add a subscription with only heartbeat criteria (60 seconds) + uint32 heartbeat = 60; + SchedulerState.UpdateCriteria memory criteria = SchedulerState + .UpdateCriteria({ + updateOnHeartbeat: true, + heartbeatSeconds: heartbeat, + updateOnDeviation: false, + deviationThresholdBps: 0 + }); + uint256 subscriptionId = addTestSubscriptionWithUpdateCriteria( + criteria + ); + uint256 fundAmount = 1 ether; + scheduler.addFunds{value: fundAmount}(subscriptionId); + + // First update to set initial timestamp + bytes32[] memory priceIds = createPriceIds(); + uint64 publishTime1 = SafeCast.toUint64(block.timestamp); + PythStructs.PriceFeed[] memory priceFeeds1 = createMockPriceFeeds( + publishTime1 + ); + mockParsePriceFeedUpdates(pyth, priceFeeds1); + bytes[] memory updateData1 = createMockUpdateData(priceFeeds1); + vm.prank(pusher); + scheduler.updatePriceFeeds(subscriptionId, updateData1, priceIds); + + // Prepare second update within heartbeat interval + vm.warp(block.timestamp + 30); // Advance time by 30 seconds (less than 60) + uint64 publishTime2 = SafeCast.toUint64(block.timestamp); + PythStructs.PriceFeed[] memory priceFeeds2 = createMockPriceFeeds( + publishTime2 // Same prices, just new timestamp + ); + mockParsePriceFeedUpdates(pyth, priceFeeds2); // Mock the response for the second update + bytes[] memory updateData2 = createMockUpdateData(priceFeeds2); + + // Expect revert because heartbeat condition is not met + vm.expectRevert( + abi.encodeWithSelector(UpdateConditionsNotMet.selector) + ); + vm.prank(pusher); + scheduler.updatePriceFeeds(subscriptionId, updateData2, priceIds); + } + + function testUpdatePriceFeedsRevertsOnUpdateConditionsNotMet_Deviation() + public + { + // Add a subscription with only deviation criteria (100 bps / 1%) + uint16 deviationBps = 100; + SchedulerState.UpdateCriteria memory criteria = SchedulerState + .UpdateCriteria({ + updateOnHeartbeat: false, + heartbeatSeconds: 0, + updateOnDeviation: true, + deviationThresholdBps: deviationBps + }); + uint256 subscriptionId = addTestSubscriptionWithUpdateCriteria( + criteria + ); + uint256 fundAmount = 1 ether; + scheduler.addFunds{value: fundAmount}(subscriptionId); + + // First update to set initial price + bytes32[] memory priceIds = createPriceIds(); + uint64 publishTime1 = SafeCast.toUint64(block.timestamp); + PythStructs.PriceFeed[] memory priceFeeds1 = createMockPriceFeeds( + publishTime1 + ); + mockParsePriceFeedUpdates(pyth, priceFeeds1); + bytes[] memory updateData1 = createMockUpdateData(priceFeeds1); + vm.prank(pusher); + scheduler.updatePriceFeeds(subscriptionId, updateData1, priceIds); + + // Prepare second update with price deviation less than threshold (e.g., 50 bps) + vm.warp(block.timestamp + 1000); // Advance time significantly (doesn't matter for deviation) + uint64 publishTime2 = SafeCast.toUint64(block.timestamp); + + // Clone priceFeeds1 and apply a 50 bps deviation to its prices + PythStructs.PriceFeed[] + memory priceFeeds2 = new PythStructs.PriceFeed[]( + priceFeeds1.length + ); + for (uint i = 0; i < priceFeeds1.length; i++) { + priceFeeds2[i].id = priceFeeds1[i].id; + // Apply 50 bps deviation to the price + int64 priceDiff = int64( + (uint64(priceFeeds1[i].price.price) * 50) / 10_000 + ); + priceFeeds2[i].price.price = priceFeeds1[i].price.price + priceDiff; + priceFeeds2[i].price.conf = priceFeeds1[i].price.conf; + priceFeeds2[i].price.expo = priceFeeds1[i].price.expo; + priceFeeds2[i].price.publishTime = publishTime2; + } + + mockParsePriceFeedUpdates(pyth, priceFeeds2); + bytes[] memory updateData2 = createMockUpdateData(priceFeeds2); + + // Expect revert because deviation condition is not met + vm.expectRevert( + abi.encodeWithSelector(UpdateConditionsNotMet.selector) + ); + vm.prank(pusher); + scheduler.updatePriceFeeds(subscriptionId, updateData2, priceIds); + } + + function testUpdatePriceFeedsRevertsOnOlderTimestamp() public { + // Add a subscription and funds + uint256 subscriptionId = addTestSubscription(); + uint256 fundAmount = 1 ether; + scheduler.addFunds{value: fundAmount}(subscriptionId); + + // First update to establish last updated timestamp + bytes32[] memory priceIds = createPriceIds(); + uint64 publishTime1 = SafeCast.toUint64(block.timestamp); + PythStructs.PriceFeed[] memory priceFeeds1 = createMockPriceFeeds( + publishTime1 + ); + mockParsePriceFeedUpdates(pyth, priceFeeds1); + bytes[] memory updateData1 = createMockUpdateData(priceFeeds1); + + vm.prank(pusher); + scheduler.updatePriceFeeds(subscriptionId, updateData1, priceIds); + + // Prepare second update with an older timestamp + uint64 publishTime2 = publishTime1 - 10; // Timestamp older than the first update + PythStructs.PriceFeed[] memory priceFeeds2 = createMockPriceFeeds( + publishTime2 + ); + // Mock Pyth response to return feeds with the older timestamp + mockParsePriceFeedUpdates(pyth, priceFeeds2); + bytes[] memory updateData2 = createMockUpdateData(priceFeeds2); + + // Expect revert with TimestampOlderThanLastUpdate (checked in _validateShouldUpdatePrices) + vm.expectRevert( + abi.encodeWithSelector( + TimestampOlderThanLastUpdate.selector, + publishTime2, + publishTime1 + ) + ); + + // Attempt to update price feeds + vm.prank(pusher); + scheduler.updatePriceFeeds(subscriptionId, updateData2, priceIds); + } + + function testUpdatePriceFeedsRevertsOnMismatchedTimestamps() public { + // First add a subscription and funds + uint256 subscriptionId = addTestSubscription(); + uint256 fundAmount = 1 ether; + scheduler.addFunds{value: fundAmount}(subscriptionId); + + // Create two price feeds with mismatched timestamps + bytes32[] memory priceIds = createPriceIds(2); + uint64 time1 = SafeCast.toUint64(block.timestamp); + uint64 time2 = time1 + 10; + PythStructs.PriceFeed[] memory priceFeeds = new PythStructs.PriceFeed[]( + 2 + ); + priceFeeds[0] = createSingleMockPriceFeed(time1); + priceFeeds[1] = createSingleMockPriceFeed(time2); + + // Mock Pyth response to return these feeds + mockParsePriceFeedUpdates(pyth, priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); // Data needs to match expected length + + // Expect revert with PriceTimestampMismatch error + vm.expectRevert( + abi.encodeWithSelector(PriceTimestampMismatch.selector) + ); + + // Attempt to update price feeds + vm.prank(pusher); + scheduler.updatePriceFeeds(subscriptionId, updateData, priceIds); + } + + function testGetLatestPricesAllFeeds() public { + // First add a subscription, funds, and update price feeds + uint256 subscriptionId = addTestSubscription(); + uint256 fundAmount = 1 ether; + scheduler.addFunds{value: fundAmount}(subscriptionId); + + bytes32[] memory priceIds = createPriceIds(); + uint64 publishTime = SafeCast.toUint64(block.timestamp); + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockParsePriceFeedUpdates(pyth, priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + vm.prank(pusher); + scheduler.updatePriceFeeds(subscriptionId, updateData, priceIds); + + // Get all latest prices (empty priceIds array) + bytes32[] memory emptyPriceIds = new bytes32[](0); + PythStructs.PriceFeed[] memory latestPrices = scheduler.getLatestPrices( + subscriptionId, + emptyPriceIds + ); + + // Verify all price feeds were returned + assertEq( + latestPrices.length, + priceIds.length, + "Should return all price feeds" + ); + + // Verify price feed data using the reader contract + assertTrue( + reader.verifyPriceFeeds(subscriptionId, emptyPriceIds, priceFeeds), + "Price feeds verification failed" + ); + } + + function testGetLatestPricesSelectiveFeeds() public { + // First add a subscription with 3 price feeds, funds, and update price feeds + uint256 subscriptionId = addTestSubscriptionWithFeeds(3); + uint256 fundAmount = 1 ether; + scheduler.addFunds{value: fundAmount}(subscriptionId); + + bytes32[] memory priceIds = createPriceIds(3); + uint64 publishTime = SafeCast.toUint64(block.timestamp); + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime, + 3 + ); + mockParsePriceFeedUpdates(pyth, priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + vm.prank(pusher); + scheduler.updatePriceFeeds(subscriptionId, updateData, priceIds); + + // Get only the first price feed + bytes32[] memory selectedPriceIds = new bytes32[](1); + selectedPriceIds[0] = priceIds[0]; + + PythStructs.PriceFeed[] memory latestPrices = scheduler.getLatestPrices( + subscriptionId, + selectedPriceIds + ); + + // Verify only one price feed was returned + assertEq(latestPrices.length, 1, "Should return only one price feed"); + + // Create expected price feed array with just the first feed + PythStructs.PriceFeed[] + memory expectedFeeds = new PythStructs.PriceFeed[](1); + expectedFeeds[0] = priceFeeds[0]; + + // Verify price feed data using the reader contract + assertTrue( + reader.verifyPriceFeeds( + subscriptionId, + selectedPriceIds, + expectedFeeds + ), + "Price feeds verification failed" + ); + } + + function testOptionalWhitelist() public { + // Add a subscription with whitelistEnabled = false + bytes32[] memory priceIds = createPriceIds(); + address[] memory emptyWhitelist = new address[](0); + + SchedulerState.UpdateCriteria memory updateCriteria = SchedulerState + .UpdateCriteria({ + updateOnHeartbeat: true, + heartbeatSeconds: 60, + updateOnDeviation: true, + deviationThresholdBps: 100 + }); + + SchedulerState.GasConfig memory gasConfig = SchedulerState.GasConfig({ + maxGasPrice: 100 gwei, + maxGasLimit: 1_000_000 + }); + + SchedulerState.SubscriptionParams memory params = SchedulerState + .SubscriptionParams({ + priceIds: priceIds, + readerWhitelist: emptyWhitelist, + whitelistEnabled: false, // No whitelist + updateCriteria: updateCriteria, + gasConfig: gasConfig + }); + + uint256 subscriptionId = scheduler.addSubscription(params); + + // Update price feeds + uint256 fundAmount = 1 ether; + scheduler.addFunds{value: fundAmount}(subscriptionId); + + uint64 publishTime = SafeCast.toUint64(block.timestamp); + PythStructs.PriceFeed[] memory priceFeeds = createMockPriceFeeds( + publishTime + ); + mockParsePriceFeedUpdates(pyth, priceFeeds); + bytes[] memory updateData = createMockUpdateData(priceFeeds); + + vm.prank(pusher); + scheduler.updatePriceFeeds(subscriptionId, updateData, priceIds); + + // Try to access from a non-whitelisted address + address randomUser = address(0xdead); + vm.startPrank(randomUser); + bytes32[] memory emptyPriceIds = new bytes32[](0); + + // Should not revert since whitelist is disabled + // We'll just check that it doesn't revert + scheduler.getLatestPrices(subscriptionId, emptyPriceIds); + vm.stopPrank(); + + // Verify the data is correct + assertTrue( + reader.verifyPriceFeeds(subscriptionId, emptyPriceIds, priceFeeds), + "Price feeds verification failed" + ); + } + + function testGetActiveSubscriptions() public { + // Add multiple subscriptions with the test contract as manager + addTestSubscription(); + addTestSubscription(); + uint256 subscriptionId = addTestSubscription(); + + // Verify we can deactivate our own subscription + scheduler.deactivateSubscription(subscriptionId); + + // Create a subscription with pusher as manager + vm.startPrank(pusher); + bytes32[] memory priceIds = createPriceIds(); + address[] memory emptyWhitelist = new address[](0); + + SchedulerState.UpdateCriteria memory updateCriteria = SchedulerState + .UpdateCriteria({ + updateOnHeartbeat: true, + heartbeatSeconds: 60, + updateOnDeviation: true, + deviationThresholdBps: 100 + }); + + SchedulerState.GasConfig memory gasConfig = SchedulerState.GasConfig({ + maxGasPrice: 100 gwei, + maxGasLimit: 1_000_000 + }); + + SchedulerState.SubscriptionParams memory params = SchedulerState + .SubscriptionParams({ + priceIds: priceIds, + readerWhitelist: emptyWhitelist, + whitelistEnabled: false, + updateCriteria: updateCriteria, + gasConfig: gasConfig + }); + + scheduler.addSubscription(params); + vm.stopPrank(); + + // Get active subscriptions - use owner who has admin rights + vm.prank(owner); + ( + uint256[] memory activeIds, + SchedulerState.SubscriptionParams[] memory activeParams + ) = scheduler.getActiveSubscriptions(); + + // Verify active subscriptions + assertEq(activeIds.length, 3, "Should have 3 active subscriptions"); + assertEq( + activeParams.length, + 3, + "Should have 3 active subscription params" + ); + + // Verify subscription params + for (uint i = 0; i < activeIds.length; i++) { + ( + SchedulerState.SubscriptionParams memory storedParams, + + ) = scheduler.getSubscription(activeIds[i]); + + assertEq( + activeParams[i].priceIds.length, + storedParams.priceIds.length, + "Price IDs length mismatch" + ); + + assertEq( + activeParams[i].updateCriteria.heartbeatSeconds, + storedParams.updateCriteria.heartbeatSeconds, + "Heartbeat seconds mismatch" + ); + } + } + + // Helper function to add a test subscription + function addTestSubscription() internal returns (uint256) { + bytes32[] memory priceIds = createPriceIds(); + address[] memory readerWhitelist = new address[](1); + readerWhitelist[0] = address(reader); + + SchedulerState.UpdateCriteria memory updateCriteria = SchedulerState + .UpdateCriteria({ + updateOnHeartbeat: true, + heartbeatSeconds: 60, + updateOnDeviation: true, + deviationThresholdBps: 100 + }); + + SchedulerState.GasConfig memory gasConfig = SchedulerState.GasConfig({ + maxGasPrice: 100 gwei, + maxGasLimit: 1_000_000 + }); + + SchedulerState.SubscriptionParams memory params = SchedulerState + .SubscriptionParams({ + priceIds: priceIds, + readerWhitelist: readerWhitelist, + whitelistEnabled: true, + updateCriteria: updateCriteria, + gasConfig: gasConfig + }); + + return scheduler.addSubscription(params); + } + + // Helper function to add a test subscription with variable number of feeds + function addTestSubscriptionWithFeeds( + uint256 numFeeds + ) internal returns (uint256) { + bytes32[] memory priceIds = createPriceIds(numFeeds); + address[] memory readerWhitelist = new address[](1); + readerWhitelist[0] = address(reader); + + SchedulerState.UpdateCriteria memory updateCriteria = SchedulerState + .UpdateCriteria({ + updateOnHeartbeat: true, + heartbeatSeconds: 60, + updateOnDeviation: true, + deviationThresholdBps: 100 + }); + + SchedulerState.GasConfig memory gasConfig = SchedulerState.GasConfig({ + maxGasPrice: 100 gwei, + maxGasLimit: 1_000_000 + }); + + SchedulerState.SubscriptionParams memory params = SchedulerState + .SubscriptionParams({ + priceIds: priceIds, + readerWhitelist: readerWhitelist, + whitelistEnabled: true, + updateCriteria: updateCriteria, + gasConfig: gasConfig + }); + + return scheduler.addSubscription(params); + } + + // Helper function to add a test subscription with specific update criteria + function addTestSubscriptionWithUpdateCriteria( + SchedulerState.UpdateCriteria memory updateCriteria + ) internal returns (uint256) { + bytes32[] memory priceIds = createPriceIds(); + address[] memory readerWhitelist = new address[](1); + readerWhitelist[0] = address(reader); + + SchedulerState.GasConfig memory gasConfig = SchedulerState.GasConfig({ + maxGasPrice: 100 gwei, + maxGasLimit: 1_000_000 + }); + + SchedulerState.SubscriptionParams memory params = SchedulerState + .SubscriptionParams({ + priceIds: priceIds, + readerWhitelist: readerWhitelist, + whitelistEnabled: true, + updateCriteria: updateCriteria, // Use provided criteria + gasConfig: gasConfig + }); + + return scheduler.addSubscription(params); + } + + // Required to receive ETH when withdrawing funds + receive() external payable {} +} diff --git a/target_chains/ethereum/contracts/forge-test/utils/PulseTestUtils.t.sol b/target_chains/ethereum/contracts/forge-test/utils/PulseTestUtils.t.sol index eb400aa3e4..1b96e961a3 100644 --- a/target_chains/ethereum/contracts/forge-test/utils/PulseTestUtils.t.sol +++ b/target_chains/ethereum/contracts/forge-test/utils/PulseTestUtils.t.sol @@ -67,6 +67,13 @@ abstract contract PulseTestUtils is Test { return priceIds; } + // Helper function to create a single mock price feed + function createSingleMockPriceFeed( + uint256 publishTime + ) internal pure returns (PythStructs.PriceFeed memory) { + return createMockPriceFeeds(publishTime, 1)[0]; + } + // Helper function to create mock price feeds with default 2 feeds function createMockPriceFeeds( uint256 publishTime