From a3929ba8e3a9ef53cbf863124fc33351ecb7f88b Mon Sep 17 00:00:00 2001 From: gabrielfior Date: Fri, 14 Feb 2025 11:58:11 -0300 Subject: [PATCH] Refactored Enum --- .../markets/seer/data_models.py | 18 +++++++++--------- .../markets/seer/test_seer_data_model.py | 6 +++--- .../markets/seer/test_seer_outcome.py | 2 +- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/prediction_market_agent_tooling/markets/seer/data_models.py b/prediction_market_agent_tooling/markets/seer/data_models.py index c6f82071..daf51f60 100644 --- a/prediction_market_agent_tooling/markets/seer/data_models.py +++ b/prediction_market_agent_tooling/markets/seer/data_models.py @@ -48,22 +48,22 @@ class CreateCategoricalMarketsParams(BaseModel): class SeerOutcomeEnum(str, Enum): - POSITIVE = "positive" - NEGATIVE = "negative" - NEUTRAL = "neutral" + YES = "yes" + NO = "no" + INVALID = "invalid" @classmethod def from_bool(cls, value: bool) -> "SeerOutcomeEnum": - return cls.POSITIVE if value else cls.NEGATIVE + return cls.YES if value else cls.NO @classmethod def from_string(cls, value: str) -> "SeerOutcomeEnum": """Convert a string (case-insensitive) to an Outcome enum.""" normalized = value.strip().lower() patterns = { - r"^yes$": cls.POSITIVE, - r"^no$": cls.NEGATIVE, - r"^(invalid|invalid result)$": cls.NEUTRAL, + r"^yes$": cls.YES, + r"^no$": cls.NO, + r"^(invalid|invalid result)$": cls.INVALID, } # Search through patterns and return the first match @@ -108,7 +108,7 @@ def has_valid_answer(self) -> bool: # 2. Invalid payoutNumerator is 1. try: - self.outcome_as_enums[SeerOutcomeEnum.NEUTRAL] + self.outcome_as_enums[SeerOutcomeEnum.INVALID] except KeyError: raise ValueError( f"Market {self.id.hex()} has no invalid outcome. {self.outcomes}" @@ -183,7 +183,7 @@ def current_p_yes(self) -> Probability: ) price_data[idx] = price - yes_idx = self.outcome_as_enums[SeerOutcomeEnum.POSITIVE] + yes_idx = self.outcome_as_enums[SeerOutcomeEnum.YES] price_yes = price_data[yes_idx] / sum(price_data.values()) return Probability(price_yes) diff --git a/tests_integration/markets/seer/test_seer_data_model.py b/tests_integration/markets/seer/test_seer_data_model.py index 9b6882f2..a00934c2 100644 --- a/tests_integration/markets/seer/test_seer_data_model.py +++ b/tests_integration/markets/seer/test_seer_data_model.py @@ -18,15 +18,15 @@ def test_current_p_yes( market = seer_subgraph_handler_test.get_binary_markets( limit=100, sort_by=SortBy.HIGHEST_LIQUIDITY, filter_by=FilterBy.OPEN )[0] - yes_idx = market.outcome_as_enums[SeerOutcomeEnum.POSITIVE] + yes_idx = market.outcome_as_enums[SeerOutcomeEnum.YES] yes_token = market.wrapped_tokens[yes_idx] yes_price = market._get_price_for_token(Web3.to_checksum_address(yes_token)) - no_idx = market.outcome_as_enums[SeerOutcomeEnum.NEGATIVE] + no_idx = market.outcome_as_enums[SeerOutcomeEnum.NO] no_token = market.wrapped_tokens[no_idx] no_price = market._get_price_for_token(Web3.to_checksum_address(no_token)) - invalid_idx = market.outcome_as_enums[SeerOutcomeEnum.NEUTRAL] + invalid_idx = market.outcome_as_enums[SeerOutcomeEnum.INVALID] invalid_token = market.wrapped_tokens[invalid_idx] invalid_price = market._get_price_for_token(Web3.to_checksum_address(invalid_token)) diff --git a/tests_integration/markets/seer/test_seer_outcome.py b/tests_integration/markets/seer/test_seer_outcome.py index b923dfd0..9c8111e6 100644 --- a/tests_integration/markets/seer/test_seer_outcome.py +++ b/tests_integration/markets/seer/test_seer_outcome.py @@ -14,4 +14,4 @@ def test_seer_outcome(outcome: str) -> None: def test_seer_outcome_invalid() -> None: - assert SeerOutcomeEnum.from_string("Invalid result") == SeerOutcomeEnum.NEUTRAL + assert SeerOutcomeEnum.from_string("Invalid result") == SeerOutcomeEnum.INVALID