Skip to content

Commit

Permalink
global expansion
Browse files Browse the repository at this point in the history
  • Loading branch information
Brad Miller committed Dec 13, 2024
1 parent cfe3c71 commit b111aaf
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 8 deletions.
4 changes: 2 additions & 2 deletions sourcecode/scoring/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@

# Scoring Groups
coreGroups: Set[int] = {1, 2, 3, 6, 8, 9, 10, 11, 13, 14, 19, 21, 25}
expansionGroups: Set[int] = {0, 4, 5, 7, 12, 16, 18, 20, 22, 23, 24, 26, 27, 28}
expansionPlusGroups: Set[int] = {15, 17, 29, 30}
expansionGroups: Set[int] = {0, 4, 5, 7, 12, 15, 16, 18, 20, 22, 23, 26, 27, 28, 29}
expansionPlusGroups: Set[int] = {17, 24, 30, 31, 32}

# TSV Values
notHelpfulValueTsv = "NOT_HELPFUL"
Expand Down
1 change: 1 addition & 0 deletions sourcecode/scoring/pandas_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,7 @@ def _inner(*args, **kwargs) -> Any:
clArgs = kwargs["args"]
else:
# Handle the following, which expect args as the second positional argument:
# birdwatch/scoring/src/main/python/run_post_selection_similarity.py
# birdwatch/scoring/src/main/python/run_prescoring.py
# birdwatch/scoring/src/main/python/run_final_scoring.py
# birdwatch/scoring/src/main/python/run_contributor_scoring.py
Expand Down
19 changes: 14 additions & 5 deletions sourcecode/scoring/run_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,11 +1039,21 @@ def _validate_contributor_scoring_output(helpfulnessScores: pd.DataFrame) -> pd.
return helpfulnessScores


def run_post_selection_similarity(notes: pd.DataFrame, ratings: pd.DataFrame) -> pd.DataFrame:
with c.time_block("Compute Post Selection Similarity"):
pss = PostSelectionSimilarity(notes, ratings)
postSelectionSimilarityValues = pss.get_post_selection_similarity_values()
del pss
gc.collect()
return postSelectionSimilarityValues


def run_prescoring(
notes: pd.DataFrame,
ratings: pd.DataFrame,
noteStatusHistory: pd.DataFrame,
userEnrollment: pd.DataFrame,
postSelectionSimilarityValues: pd.DataFrame,
seed: Optional[int] = None,
enabledScorers: Optional[Set[Scorers]] = None,
runParallel: bool = True,
Expand Down Expand Up @@ -1081,16 +1091,12 @@ def run_prescoring(
logger.info(
f"ratings summary before PSS: {get_df_fingerprint(ratings, [c.noteIdKey, c.raterParticipantIdKey])}"
)
with c.time_block("Compute Post Selection Similarity"):
pss = PostSelectionSimilarity(notes, ratings)
postSelectionSimilarityValues = pss.get_post_selection_similarity_values()
with c.time_block("Filter ratings by Post Selection Similarity"):
logger.info(f"Post Selection Similarity Prescoring: begin with {len(ratings)} ratings.")
ratings = filter_ratings_by_post_selection_similarity(
notes, ratings, postSelectionSimilarityValues
)
logger.info(f"Post Selection Similarity Prescoring: {len(ratings)} ratings remaining.")
del pss
gc.collect()
logger.info(
f"ratings summary after PSS: {get_df_fingerprint(ratings, [c.noteIdKey, c.raterParticipantIdKey])}"
)
Expand Down Expand Up @@ -1868,6 +1874,8 @@ def run_scoring(
filterPrescoringInputToSimulateDelayInHours,
)

postSelectionSimilarityValues = run_post_selection_similarity(notes=notes, ratings=ratings)

(
prescoringNoteModelOutput,
prescoringRaterModelOutput,
Expand All @@ -1880,6 +1888,7 @@ def run_scoring(
ratings=prescoringRatingsInput,
noteStatusHistory=noteStatusHistory,
userEnrollment=userEnrollment,
postSelectionSimilarityValues=postSelectionSimilarityValues,
seed=seed,
enabledScorers=enabledScorers,
runParallel=runParallel,
Expand Down
2 changes: 1 addition & 1 deletion sourcecode/scoring/scoring_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class RuleID(Enum):

# Rules used in _meta_score.
META_INITIAL_NMR = RuleAndVersion("MetaInitialNMR", "1.0", False)
EXPANSION_MODEL = RuleAndVersion("ExpansionModel", "1.1", False)
EXPANSION_MODEL = RuleAndVersion("ExpansionModel", "1.1", True)
EXPANSION_PLUS_MODEL = RuleAndVersion("ExpansionPlusModel", "1.1", False)
CORE_MODEL = RuleAndVersion("CoreModel", "1.1", True)
COVERAGE_MODEL = RuleAndVersion("CoverageModel", "1.1", False)
Expand Down

0 comments on commit b111aaf

Please sign in to comment.