Skip to content

Commit

Permalink
refactor: distribution and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vgorkavenko committed Feb 13, 2025
1 parent 1aa7228 commit e091d12
Show file tree
Hide file tree
Showing 7 changed files with 426 additions and 336 deletions.
173 changes: 102 additions & 71 deletions src/modules/csm/csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
)
from src.metrics.prometheus.duration_meter import duration_meter
from src.modules.csm.checkpoint import FrameCheckpointProcessor, FrameCheckpointsIterator, MinStepIsNotReached
from src.modules.csm.log import FramePerfLog
from src.modules.csm.state import State, Frame
from src.modules.csm.log import FramePerfLog, OperatorFrameSummary
from src.modules.csm.state import State, Frame, AttestationsAccumulator
from src.modules.csm.tree import Tree
from src.modules.csm.types import ReportData, Shares
from src.modules.submodules.consensus import ConsensusModule
Expand All @@ -29,13 +29,12 @@
SlotNumber,
StakingModuleAddress,
StakingModuleId,
ValidatorIndex,
)
from src.utils.blockstamp import build_blockstamp
from src.utils.cache import global_lru_cache as lru_cache
from src.utils.slot import get_next_non_missed_slot, get_reference_blockstamp
from src.utils.web3converter import Web3Converter
from src.web3py.extensions.lido_validators import NodeOperatorId, StakingModule, ValidatorsByNodeOperator
from src.web3py.extensions.lido_validators import NodeOperatorId, StakingModule, ValidatorsByNodeOperator, LidoValidator
from src.web3py.types import Web3

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -102,15 +101,15 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
if (prev_cid is None) != (prev_root == ZERO_HASH):
raise InconsistentData(f"Got inconsistent previous tree data: {prev_root=} {prev_cid=}")

distributed, shares, logs = self.calculate_distribution(blockstamp)
total_distributed, total_rewards, logs = self.calculate_distribution(blockstamp)

if distributed != sum(shares.values()):
raise InconsistentData(f"Invalid distribution: {sum(shares.values())=} != {distributed=}")
if total_distributed != sum(total_rewards.values()):
raise InconsistentData(f"Invalid distribution: {sum(total_rewards.values())=} != {total_distributed=}")

log_cid = self.publish_log(logs)

if not distributed and not shares:
logger.info({"msg": "No shares distributed in the current frame"})
if not total_distributed and not total_rewards:
logger.info({"msg": "No rewards distributed in the current frame"})
return ReportData(
self.get_consensus_version(blockstamp),
blockstamp.ref_slot,
Expand All @@ -123,11 +122,11 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
if prev_cid and prev_root != ZERO_HASH:
# Update cumulative amount of shares for all operators.
for no_id, acc_shares in self.get_accumulated_shares(prev_cid, prev_root):
shares[no_id] += acc_shares
total_rewards[no_id] += acc_shares
else:
logger.info({"msg": "No previous distribution. Nothing to accumulate"})

tree = self.make_tree(shares)
tree = self.make_tree(total_rewards)
tree_cid = self.publish_tree(tree)

return ReportData(
Expand All @@ -136,7 +135,7 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
tree_root=tree.root,
tree_cid=tree_cid,
log_cid=log_cid,
distributed=distributed,
distributed=total_distributed,
).as_tuple()

def is_main_data_submitted(self, blockstamp: BlockStamp) -> bool:
Expand Down Expand Up @@ -232,26 +231,36 @@ def calculate_distribution(
"""Computes distribution of fee shares at the given timestamp"""
operators_to_validators = self.module_validators_by_node_operators(blockstamp)

distributed = 0
# Calculate share of each CSM node operator.
shares = defaultdict[NodeOperatorId, int](int)
total_distributed = 0
total_rewards = defaultdict[NodeOperatorId, int](int)
logs: list[FramePerfLog] = []

for frame in self.state.data:
for frame in self.state.frames:
from_epoch, to_epoch = frame
logger.info({"msg": f"Calculating distribution for frame [{from_epoch};{to_epoch}]"})

frame_blockstamp = blockstamp
if to_epoch != blockstamp.ref_epoch:
frame_blockstamp = self._get_ref_blockstamp_for_frame(blockstamp, to_epoch)
distributed_in_frame, shares_in_frame, log = self._calculate_distribution_in_frame(
frame_blockstamp, operators_to_validators, frame, distributed

total_rewards_to_distribute = self.w3.csm.fee_distributor.shares_to_distribute(frame_blockstamp.block_hash)
rewards_to_distribute_in_frame = total_rewards_to_distribute - total_distributed

rewards_in_frame, log = self._calculate_distribution_in_frame(
frame, frame_blockstamp, rewards_to_distribute_in_frame, operators_to_validators
)
distributed += distributed_in_frame
for no_id, share in shares_in_frame.items():
shares[no_id] += share
distributed_in_frame = sum(rewards_in_frame.values())

total_distributed += distributed_in_frame
if total_distributed > total_rewards_to_distribute:
raise CSMError(f"Invalid distribution: {total_distributed=} > {total_rewards_to_distribute=}")

for no_id, rewards in rewards_in_frame.items():
total_rewards[no_id] += rewards

logs.append(log)

return distributed, shares, logs
return total_distributed, total_rewards, logs

def _get_ref_blockstamp_for_frame(
self, blockstamp: ReferenceBlockStamp, frame_ref_epoch: EpochNumber
Expand All @@ -266,63 +275,85 @@ def _get_ref_blockstamp_for_frame(

def _calculate_distribution_in_frame(
self,
blockstamp: ReferenceBlockStamp,
operators_to_validators: ValidatorsByNodeOperator,
frame: Frame,
distributed: int,
blockstamp: ReferenceBlockStamp,
rewards_to_distribute: int,
operators_to_validators: ValidatorsByNodeOperator
):
network_perf = self.state.get_network_aggr(frame).perf
threshold = network_perf - self.w3.csm.oracle.perf_leeway_bp(blockstamp.block_hash) / TOTAL_BASIS_POINTS

# Build the map of the current distribution operators.
distribution: dict[NodeOperatorId, int] = defaultdict(int)
stuck_operators = self.stuck_operators(blockstamp)
threshold = self._get_performance_threshold(frame, blockstamp)
log = FramePerfLog(blockstamp, frame, threshold)

participation_shares: defaultdict[NodeOperatorId, int] = defaultdict(int)

stuck_operators = self.stuck_operators(blockstamp)
for (_, no_id), validators in operators_to_validators.items():
log_operator = log.operators[no_id]
if no_id in stuck_operators:
log.operators[no_id].stuck = True
log_operator.stuck = True
continue
for validator in validators:
duty = self.state.data[frame].get(validator.index)
self.process_validator_duty(validator, duty, threshold, participation_shares, log_operator)

rewards_distribution = self.calc_rewards_distribution_in_frame(participation_shares, rewards_to_distribute)

for no_id, no_rewards in rewards_distribution.items():
log.operators[no_id].distributed = no_rewards

log.distributable = rewards_to_distribute

return rewards_distribution, log

def _get_performance_threshold(self, frame: Frame, blockstamp: ReferenceBlockStamp) -> float:
network_perf = self.state.get_network_aggr(frame).perf
perf_leeway = self.w3.csm.oracle.perf_leeway_bp(blockstamp.block_hash) / TOTAL_BASIS_POINTS
threshold = network_perf - perf_leeway
return threshold

@staticmethod
def process_validator_duty(
validator: LidoValidator,
attestation_duty: AttestationsAccumulator | None,
threshold: float,
participation_shares: defaultdict[NodeOperatorId, int],
log_operator: OperatorFrameSummary
):
if attestation_duty is None:
# It's possible that the validator is not assigned to any duty, hence it's performance
# is not presented in the aggregates (e.g. exited, pending for activation etc).
# TODO: check `sync_aggr` to strike (in case of bad sync performance) after validator exit
return

log_validator = log_operator.validators[validator.index]

if validator.validator.slashed is True:
# It means that validator was active during the frame and got slashed and didn't meet the exit
# epoch, so we should not count such validator for operator's share.
log_validator.slashed = True
return

if attestation_duty.perf > threshold:
# Count of assigned attestations used as a metrics of time
# the validator was active in the current frame.
participation_shares[validator.lido_id.operatorIndex] += attestation_duty.assigned

log_validator.attestation_duty = attestation_duty

@staticmethod
def calc_rewards_distribution_in_frame(
participation_shares: dict[NodeOperatorId, int],
rewards_to_distribute: int,
) -> dict[NodeOperatorId, int]:
rewards_distribution: dict[NodeOperatorId, int] = defaultdict(int)
total_participation = sum(participation_shares.values())

for no_id, no_participation_share in participation_shares.items():
if no_participation_share == 0:
# Skip operators with zero participation
continue
rewards_distribution[no_id] = rewards_to_distribute * no_participation_share // total_participation

for v in validators:
aggr = self.state.data[frame].get(ValidatorIndex(int(v.index)))

if aggr is None:
# It's possible that the validator is not assigned to any duty, hence it's performance
# is not presented in the aggregates (e.g. exited, pending for activation etc).
continue

if v.validator.slashed is True:
# It means that validator was active during the frame and got slashed and didn't meet the exit
# epoch, so we should not count such validator for operator's share.
log.operators[no_id].validators[v.index].slashed = True
continue

if aggr.perf > threshold:
# Count of assigned attestations used as a metrics of time
# the validator was active in the current frame.
distribution[no_id] += aggr.assigned

log.operators[no_id].validators[v.index].perf = aggr

# Calculate share of each CSM node operator.
shares = defaultdict[NodeOperatorId, int](int)
total = sum(p for p in distribution.values())
to_distribute = self.w3.csm.fee_distributor.shares_to_distribute(blockstamp.block_hash) - distributed
log.distributable = to_distribute

if not total:
return 0, shares, log

for no_id, no_share in distribution.items():
if no_share:
shares[no_id] = to_distribute * no_share // total
log.operators[no_id].distributed = shares[no_id]

distributed = sum(s for s in shares.values())
if distributed > to_distribute:
raise CSMError(f"Invalid distribution: {distributed=} > {to_distribute=}")
return distributed, shares, log
return rewards_distribution

def get_accumulated_shares(self, cid: CID, root: HexBytes) -> Iterator[tuple[NodeOperatorId, Shares]]:
logger.info({"msg": "Fetching tree by CID from IPFS", "cid": repr(cid)})
Expand Down
3 changes: 1 addition & 2 deletions src/modules/csm/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ class LogJSONEncoder(json.JSONEncoder): ...

@dataclass
class ValidatorFrameSummary:
# TODO: Should be renamed. Perf means different things in different contexts
perf: AttestationsAccumulator = field(default_factory=AttestationsAccumulator)
attestation_duty: AttestationsAccumulator = field(default_factory=AttestationsAccumulator)
slashed: bool = False


Expand Down
11 changes: 7 additions & 4 deletions src/modules/csm/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ def unprocessed_epochs(self) -> set[EpochNumber]:
def is_fulfilled(self) -> bool:
return not self.unprocessed_epochs

@property
def frames(self):
return self.calculate_frames(self._epochs_to_process, self._epochs_per_frame)

@staticmethod
def calculate_frames(epochs_to_process: tuple[EpochNumber, ...], epochs_per_frame: int) -> list[Frame]:
"""Split epochs to process into frames of `epochs_per_frame` length"""
Expand All @@ -127,11 +131,10 @@ def clear(self) -> None:
assert self.is_empty

def find_frame(self, epoch: EpochNumber) -> Frame:
frames = self.data.keys()
for epoch_range in frames:
for epoch_range in self.frames:
if epoch_range[0] <= epoch <= epoch_range[1]:
return epoch_range
raise ValueError(f"Epoch {epoch} is out of frames range: {frames}")
raise ValueError(f"Epoch {epoch} is out of frames range: {self.frames}")

def increment_duty(self, frame: Frame, val_index: ValidatorIndex, included: bool) -> None:
if frame not in self.data:
Expand Down Expand Up @@ -160,7 +163,7 @@ def init_or_migrate(
frames_data: StateData = {frame: defaultdict(AttestationsAccumulator) for frame in frames}

if not self.is_empty:
cached_frames = self.calculate_frames(self._epochs_to_process, self._epochs_per_frame)
cached_frames = self.frames
if cached_frames == frames:
logger.info({"msg": "No need to migrate duties data cache"})
return
Expand Down
4 changes: 2 additions & 2 deletions src/providers/execution/contracts/cs_fee_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from eth_typing import ChecksumAddress
from hexbytes import HexBytes
from web3 import Web3
from web3.types import BlockIdentifier
from web3.types import BlockIdentifier, Wei

from ..base_interface import ContractInterface

Expand All @@ -26,7 +26,7 @@ def oracle(self, block_identifier: BlockIdentifier = "latest") -> ChecksumAddres
)
return Web3.to_checksum_address(resp)

def shares_to_distribute(self, block_identifier: BlockIdentifier = "latest") -> int:
def shares_to_distribute(self, block_identifier: BlockIdentifier = "latest") -> Wei:
"""Returns the amount of shares that are pending to be distributed"""

resp = self.functions.pendingSharesToDistribute().call(block_identifier=block_identifier)
Expand Down
Loading

0 comments on commit e091d12

Please sign in to comment.