diff --git a/sourcecode/scoring/constants.py b/sourcecode/scoring/constants.py index 3f525e52..acd071f7 100644 --- a/sourcecode/scoring/constants.py +++ b/sourcecode/scoring/constants.py @@ -1,5 +1,6 @@ from contextlib import contextmanager from dataclasses import dataclass +from enum import Enum import os import time from typing import Dict, Optional @@ -36,6 +37,8 @@ prescoringAllUnlockedNotesMaxCrhChurn = 0.04 finalUnlockedNotesWithNoNewRatingsMaxCrhChurn = 0.03 finalNotesWithNewRatingsMaxCrhChurn = 0.40 +finalNotesThatJustFlippedStatusMaxCrhChurn = 1e8 +finalNotesThatFlippedRecentlyMaxCrhChurn = 1e8 # Data Filenames scoredNotesOutputPath = "scoredNotes.tsv" @@ -58,6 +61,15 @@ numberOfTimesEarnedOutKey = "numberOfTimesEarnedOut" defaultIndexKey = "index" +# Scoring Groups +coreGroups = {1, 2, 3, 6, 8, 9, 10, 11, 13, 14, 19, 21, 25} +expansionGroups = ( + # Divide into 3 grouping aggregates to prepare for multi-group models, + # and a 4th group containing leftovers + {0, 15, 17, 24, 29, 30} | {4, 5, 7, 12, 26} | {27} | {16, 20, 22, 23, 28} +) +expansionPlusGroups = {18} + # TSV Values notHelpfulValueTsv = "NOT_HELPFUL" somewhatHelpfulValueTsv = "SOMEWHAT_HELPFUL" @@ -101,6 +113,7 @@ unlockedRatingStatusKey = "unlockedRatingStatus" metaScorerActiveRulesKey = "metaScorerActiveRules" decidedByKey = "decidedBy" +rescoringActiveRulesKey = "rescoringActiveRules" # Note Status Changes Columns noteFinalStatusChange = "finalStatusChange" @@ -243,8 +256,10 @@ def rater_factor_key(i): (1, "helpfulUnbiasedLanguage"), ] helpfulTagsTSVOrder = [tag for (tiebreakOrder, tag) in helpfulTagsAndTieBreakOrder] -helpfulTagsAndTypesTSVOrder = [(tag, np.int8) for tag in helpfulTagsTSVOrder] +helpfulTagBoolsAndTypesTSVOrder = [(tag, pd.Int8Dtype()) for tag in helpfulTagsTSVOrder] helpfulTagsTiebreakOrder = [tag for (tiebreakOrder, tag) in sorted(helpfulTagsAndTieBreakOrder)] +helpfulTagCountsAndTypesTSVOrder = [(tag, pd.Int64Dtype()) for tag in helpfulTagsTSVOrder] + # NOTE: Always add new tags to the end of this list, and *never* change the order of # elements which are already in the list to maintain compatibility with @@ -281,7 +296,8 @@ def rater_factor_key(i): (6, notHelpfulNoteNotNeededKey), ] notHelpfulTagsTSVOrder = [tag for (tiebreakOrder, tag) in notHelpfulTagsAndTieBreakOrder] -notHelpfulTagsAndTypesTSVOrder = [(tag, np.int8) for tag in notHelpfulTagsTSVOrder] +notHelpfulTagsAndTypesTSVOrder = [(tag, pd.Int8Dtype()) for tag in notHelpfulTagsTSVOrder] +notHelpfulTagCountsAndTypesTSVOrder = [(tag, pd.Int64Dtype()) for tag in notHelpfulTagsTSVOrder] notHelpfulTagsTiebreakOrder = [ tag for (tiebreakOrder, tag) in sorted(notHelpfulTagsAndTieBreakOrder) ] @@ -355,7 +371,7 @@ def rater_factor_key(i): "misleadingUnverifiedClaimAsFact", "misleadingSatire", ] -misleadingTagsAndTypes = [(tag, np.int64) for tag in misleadingTags] +misleadingTagsAndTypes = [(tag, pd.Int8Dtype()) for tag in misleadingTags] notMisleadingTags = [ "notMisleadingOther", @@ -364,7 +380,7 @@ def rater_factor_key(i): "notMisleadingClearlySatire", "notMisleadingPersonalOpinion", ] -notMisleadingTagsAndTypes = [(tag, np.int64) for tag in notMisleadingTags] +notMisleadingTagsAndTypes = [(tag, pd.Int8Dtype()) for tag in notMisleadingTags] noteTSVColumnsAndTypes = ( [ @@ -373,13 +389,13 @@ def rater_factor_key(i): (createdAtMillisKey, np.int64), (tweetIdKey, np.int64), (classificationKey, object), - ("believable", object), - ("harmful", object), - ("validationDifficulty", object), + ("believable", "category"), + ("harmful", "category"), + ("validationDifficulty", "category"), ] + misleadingTagsAndTypes + notMisleadingTagsAndTypes - + [("trustworthySources", np.int64), (summaryKey, object), ("isMediaNote", np.int64)] + + [("trustworthySources", pd.Int8Dtype()), (summaryKey, object), ("isMediaNote", pd.Int8Dtype())] ) noteTSVColumns = [col for (col, dtype) in noteTSVColumnsAndTypes] noteTSVTypes = [dtype for (col, dtype) in noteTSVColumnsAndTypes] @@ -394,14 +410,14 @@ def rater_factor_key(i): (noteIdKey, np.int64), (raterParticipantIdKey, object), (createdAtMillisKey, np.int64), - (versionKey, np.int64), - (agreeKey, np.int64), - (disagreeKey, np.int64), - (helpfulKey, np.int64), - (notHelpfulKey, np.int64), + (versionKey, pd.Int8Dtype()), + (agreeKey, pd.Int8Dtype()), + (disagreeKey, pd.Int8Dtype()), + (helpfulKey, pd.Int8Dtype()), + (notHelpfulKey, pd.Int8Dtype()), (helpfulnessLevelKey, "category"), ] - + helpfulTagsAndTypesTSVOrder + + helpfulTagBoolsAndTypesTSVOrder + notHelpfulTagsAndTypesTSVOrder + [(ratedOnTweetIdKey, np.int64)] ) @@ -431,18 +447,18 @@ def rater_factor_key(i): (noteAuthorParticipantIdKey, object), (createdAtMillisKey, np.int64), (timestampMillisOfNoteFirstNonNMRLabelKey, np.double), # double because nullable. - (firstNonNMRLabelKey, object), + (firstNonNMRLabelKey, "category"), (timestampMillisOfNoteCurrentLabelKey, np.double), # double because nullable. - (currentLabelKey, object), + (currentLabelKey, "category"), (timestampMillisOfNoteMostRecentNonNMRLabelKey, np.double), # double because nullable. - (mostRecentNonNMRLabelKey, object), + (mostRecentNonNMRLabelKey, "category"), (timestampMillisOfStatusLockKey, np.double), # double because nullable. - (lockedStatusKey, object), + (lockedStatusKey, "category"), (timestampMillisOfRetroLockKey, np.double), # double because nullable. - (currentCoreStatusKey, object), - (currentExpansionStatusKey, object), - (currentGroupStatusKey, object), - (currentDecidedByKey, object), + (currentCoreStatusKey, "category"), + (currentExpansionStatusKey, "category"), + (currentGroupStatusKey, "category"), + (currentDecidedByKey, "category"), (currentModelingGroupKey, np.double), # TODO: int (timestampMillisOfMostRecentStatusChangeKey, np.double), # double because nullable. ] @@ -552,8 +568,8 @@ def rater_factor_key(i): (currentlyRatedNotHelpfulBoolKey, np.int8), (unlockedRatingStatusKey, str), ] - + helpfulTagsAndTypesTSVOrder - + notHelpfulTagsAndTypesTSVOrder + + helpfulTagCountsAndTypesTSVOrder + + notHelpfulTagCountsAndTypesTSVOrder + notHelpfulTagsAdjustedTSVColumnsAndTypes + notHelpfulTagsAdjustedRatioTSVColumnsAndTypes + incorrectFilterColumnsAndTypes @@ -641,6 +657,7 @@ def rater_factor_key(i): (expansionPlusNumFinalRoundRatingsKey, np.double), # double because nullable. (groupNumFinalRoundRatingsKey, np.double), # double because nullable. (topicNumFinalRoundRatingsKey, np.double), # double because nullable. + (rescoringActiveRulesKey, str), ] noteModelOutputTSVColumns = [col for (col, dtype) in noteModelOutputTSVColumnsAndTypes] noteModelOutputTSVTypeMapping = {col: dtype for (col, dtype) in noteModelOutputTSVColumnsAndTypes} @@ -794,10 +811,7 @@ class PrescoringMetaOutput: @dataclass class SharedMemoryDataframeInfo: sharedMemoryName: str - columns: list - dataShape: tuple - dtypesDict: dict - npDtype: str + dataSize: int @dataclass @@ -861,8 +875,16 @@ class ModelResult: metaScores: Optional[PrescoringMetaScorerOutput] +class RescoringRuleID(Enum): + ALL_NOTES = 1 + NOTES_WITH_NEW_RATINGS = 2 + NOTES_FLIPPED_PREVIOUS_RUN = 3 + NEW_NOTES_NOT_RESCORED_RECENTLY_ENOUGH = 4 + RECENTLY_FLIPPED_NOTES_NOT_RESCORED_RECENTLY_ENOUGH = 5 + + @dataclass class NoteSubset: noteSet: Optional[set] maxCrhChurnRate: float - description: str + description: RescoringRuleID diff --git a/sourcecode/scoring/mf_base_scorer.py b/sourcecode/scoring/mf_base_scorer.py index 3deb7f51..2acdd4a0 100644 --- a/sourcecode/scoring/mf_base_scorer.py +++ b/sourcecode/scoring/mf_base_scorer.py @@ -1,4 +1,5 @@ -from typing import Dict, List, Optional, Tuple +import gc +from typing import Dict, List, Optional, Set, Tuple from . import ( constants as c, @@ -11,6 +12,7 @@ from .incorrect_filter import get_user_incorrect_ratio from .matrix_factorization.matrix_factorization import MatrixFactorization from .matrix_factorization.pseudo_raters import PseudoRatersRunner +from .pandas_utils import keep_columns from .reputation_matrix_factorization.diligence_model import ( fit_low_diligence_model_final, fit_low_diligence_model_prescoring, @@ -143,6 +145,10 @@ class MFBaseScorer(Scorer): def __init__( self, + includedTopics: Set[str] = set(), + includedGroups: Set[int] = set(), + includeUnassigned: bool = False, + captureThreshold: Optional[float] = None, seed: Optional[int] = None, pseudoraters: Optional[bool] = True, minNumRatingsPerRater: int = 10, @@ -182,6 +188,8 @@ def __init__( """Configure MatrixFactorizationScorer object. Args: + includedGroups: if set, filter ratings and results based on includedGroups + includedTopics: if set, filter ratings based on includedTopics seed: if not None, seed value to ensure deterministic execution pseudoraters: if True, compute optional pseudorater confidence intervals minNumRatingsPerRater: Minimum number of ratings which a rater must produce to be @@ -214,7 +222,14 @@ def __init__( maxFirstMFTrainError: maximum error allowed for the first MF training process maxFinalMFTrainError: maximum error allowed for the final MF training process """ - super().__init__(seed, threads) + super().__init__( + includedTopics=includedTopics, + includedGroups=includedGroups, + includeUnassigned=includeUnassigned, + captureThreshold=captureThreshold, + seed=seed, + threads=threads, + ) self._pseudoraters = pseudoraters self._minNumRatingsPerRater = minNumRatingsPerRater self._minNumRatersPerNote = minNumRatersPerNote @@ -492,7 +507,22 @@ def _prescore_notes_and_users( # Removes ratings where either (1) the note did not receive enough ratings, or # (2) the rater did not rate enough notes. with self.time_block("Prepare ratings"): - ratingsForTraining = self._prepare_data_for_scoring(ratings) + ratingsForTraining = self._prepare_data_for_scoring( + ratings[ + [ + c.noteIdKey, + c.raterParticipantIdKey, + c.helpfulNumKey, + c.createdAtMillisKey, + c.helpfulnessLevelKey, + c.notHelpfulIncorrectTagKey, + c.notHelpfulIrrelevantSourcesTagKey, + c.notHelpfulSourcesMissingOrUnreliableTagKey, + c.notHelpfulSpamHarassmentOrAbuseTagKey, + c.notHelpfulOtherTagKey, + ] + ] + ) if self._saveIntermediateState: self.ratingsForTraining = ratingsForTraining @@ -502,12 +532,17 @@ def _prescore_notes_and_users( noteParamsUnfiltered, raterParamsUnfiltered, globalBias, - ) = self._run_stable_matrix_factorization(ratingsForTraining, userEnrollmentRaw) + ) = self._run_stable_matrix_factorization( + ratingsForTraining[[c.noteIdKey, c.raterParticipantIdKey, c.helpfulNumKey]], + userEnrollmentRaw[[c.participantIdKey, c.modelingGroupKey]], + ) if self._saveIntermediateState: self.noteParamsUnfiltered = noteParamsUnfiltered self.raterParamsUnfiltered = raterParamsUnfiltered self.globalBias = globalBias - self.assert_train_error_is_below_threshold(ratingsForTraining, self._maxFirstMFTrainError) + self.assert_train_error_is_below_threshold( + ratingsForTraining[[c.noteIdKey]], self._maxFirstMFTrainError + ) # If reputation is disabled, generate final intercepts, factors and note status # based on the first round scoring results. Disabling reputation can be desirable @@ -535,10 +570,36 @@ def _prescore_notes_and_users( # Get a dataframe of scored notes based on the algorithm results above with self.time_block("Compute scored notes"): scoredNotes = note_ratings.compute_scored_notes( - ratings, - noteParamsUnfiltered, - raterParamsUnfiltered, - noteStatusHistory, + ratings[ + [c.noteIdKey, c.raterParticipantIdKey, c.helpfulnessLevelKey, c.createdAtMillisKey] + + c.notHelpfulTagsTSVOrder + + c.helpfulTagsTSVOrder + ], + keep_columns( + noteParamsUnfiltered, + [ + c.noteIdKey, + c.internalNoteInterceptKey, + c.internalNoteFactor1Key, + ] + + c.noteParameterUncertaintyTSVColumns, + ), + raterParamsUnfiltered[ + [ + c.raterParticipantIdKey, + c.internalRaterFactor1Key, + ] + ], + noteStatusHistory[ + [ + c.noteIdKey, + c.createdAtMillisKey, + c.noteAuthorParticipantIdKey, + c.classificationKey, + c.currentLabelKey, + c.lockedStatusKey, + ] + ], minRatingsNeeded=self._minRatingsNeeded, crhThreshold=self._crhThreshold, crnhThresholdIntercept=self._crnhThresholdIntercept, @@ -557,8 +618,10 @@ def _prescore_notes_and_users( # Determine "valid" ratings with self.time_block("Compute valid ratings"): validRatings = note_ratings.get_valid_ratings( - ratings, - noteStatusHistory, + ratings[[c.noteIdKey, c.raterParticipantIdKey, c.helpfulNumKey, c.createdAtMillisKey]], + noteStatusHistory[ + [c.noteIdKey, c.createdAtMillisKey, c.timestampMillisOfNoteMostRecentNonNMRLabelKey] + ], scoredNotes[ [ c.noteIdKey, @@ -584,11 +647,13 @@ def _prescore_notes_and_users( c.internalNoteInterceptKey, ] ], - validRatings, + validRatings[ + [c.raterParticipantIdKey, c.ratingAgreesWithNoteStatusKey, c.ratingCountKey] + ], self._minMeanNoteScore, self._minCRHVsCRNHRatio, self._minRaterAgreeRatio, - ratingsForTraining, + ratingsForTraining[[c.noteIdKey, c.raterParticipantIdKey, c.helpfulNumKey]], ) ) if self._saveIntermediateState: @@ -599,7 +664,17 @@ def _prescore_notes_and_users( with self.time_block("Filtering by helpfulness score"): ratingsHelpfulnessScoreFilteredPreHarassmentFilter = ( helpfulness_scores.filter_ratings_by_helpfulness_scores( - ratingsForTraining, helpfulnessScoresPreHarassmentFilter + ratingsForTraining[ + [ + c.noteIdKey, + c.raterParticipantIdKey, + c.notHelpfulSpamHarassmentOrAbuseTagKey, + c.createdAtMillisKey, + c.helpfulnessLevelKey, + c.notHelpfulOtherTagKey, + ] + ], + helpfulnessScoresPreHarassmentFilter, ) ) @@ -612,10 +687,15 @@ def _prescore_notes_and_users( harassmentAbuseNoteParams, _, _ = tag_consensus.train_tag_model( ratingsHelpfulnessScoreFilteredPreHarassmentFilter, c.notHelpfulSpamHarassmentOrAbuseTagKey, - noteParamsUnfiltered, - raterParamsUnfiltered, + noteParamsUnfiltered[[c.noteIdKey, c.internalNoteInterceptKey, c.internalNoteFactor1Key]], + raterParamsUnfiltered[ + [c.raterParticipantIdKey, c.internalRaterInterceptKey, c.internalRaterFactor1Key] + ], name="harassment", ) + if not self._saveIntermediateState: + del ratingsHelpfulnessScoreFilteredPreHarassmentFilter + gc.collect() # Assigns contributor (author & rater) helpfulness bit based on (1) performance # authoring and reviewing previous and current notes, and (2) including an extra @@ -630,16 +710,21 @@ def _prescore_notes_and_users( c.internalNoteInterceptKey, ] ], - validRatings, + validRatings[ + [c.raterParticipantIdKey, c.ratingAgreesWithNoteStatusKey, c.ratingCountKey] + ], self._minMeanNoteScore, self._minCRHVsCRNHRatio, self._minRaterAgreeRatio, - ratings=ratingsForTraining, + ratings=ratingsForTraining[[c.noteIdKey, c.raterParticipantIdKey, c.helpfulNumKey]], tagConsensusHarassmentAbuseNotes=harassmentAbuseNoteParams, tagConsensusHarassmentHelpfulRatingPenalty=self.tagConsensusHarassmentHelpfulRatingPenalty, multiplyPenaltyByHarassmentScore=self.multiplyPenaltyByHarassmentScore, minimumHarassmentScoreToPenalize=self.minimumHarassmentScoreToPenalize, ) + if not self._saveIntermediateState: + del validRatings + gc.collect() if self._saveIntermediateState: self.helpfulnessScores = helpfulnessScores @@ -647,12 +732,26 @@ def _prescore_notes_and_users( # Filter ratings based on prev helpfulness scores with c.time_block("Final round MF"): finalRoundRatings = helpfulness_scores.filter_ratings_by_helpfulness_scores( - ratingsForTraining, helpfulnessScores + ratingsForTraining[ + [ + c.noteIdKey, + c.raterParticipantIdKey, + c.helpfulNumKey, + c.notHelpfulIncorrectTagKey, + c.notHelpfulSourcesMissingOrUnreliableTagKey, + c.notHelpfulIrrelevantSourcesTagKey, + ] + ], + helpfulnessScores[[c.raterParticipantIdKey, c.aboveHelpfulnessThresholdKey]], ) noteParams, raterParams, globalBias = self._mfRanker.run_mf( - ratings=finalRoundRatings, - noteInit=noteParamsUnfiltered, - userInit=raterParamsUnfiltered, + ratings=finalRoundRatings[[c.noteIdKey, c.raterParticipantIdKey, c.helpfulNumKey]], + noteInit=noteParamsUnfiltered[ + [c.noteIdKey, c.internalNoteInterceptKey, c.internalNoteFactor1Key] + ], + userInit=raterParamsUnfiltered[ + [c.raterParticipantIdKey, c.internalRaterInterceptKey, c.internalRaterFactor1Key] + ], ) # Run Diligence MF Prescoring, based on the final MF @@ -669,17 +768,52 @@ def _prescore_notes_and_users( diligenceRaterParams, diligenceGlobalIntercept, ) = fit_low_diligence_model_prescoring( - finalRoundRatings, raterInitStateDiligence=raterParamsDiligenceInit + finalRoundRatings[ + [ + c.noteIdKey, + c.raterParticipantIdKey, + c.notHelpfulIncorrectTagKey, + c.notHelpfulSourcesMissingOrUnreliableTagKey, + c.notHelpfulIrrelevantSourcesTagKey, + ] + ], + raterInitStateDiligence=raterParamsDiligenceInit, ) noteParams = noteParams.merge(diligenceNoteParams, on=c.noteIdKey) raterParams = raterParams.merge(diligenceRaterParams, on=c.raterParticipantIdKey) # Compute scored notes -- currently not returned; only used for downstream computation. scoredNotes = note_ratings.compute_scored_notes( - ratings, - noteParams, - raterParams, - noteStatusHistory, + ratings[ + [c.noteIdKey, c.raterParticipantIdKey, c.helpfulnessLevelKey, c.createdAtMillisKey] + + c.notHelpfulTagsTSVOrder + + c.helpfulTagsTSVOrder + ], + keep_columns( + noteParamsUnfiltered, + [ + c.noteIdKey, + c.internalNoteInterceptKey, + c.internalNoteFactor1Key, + ] + + c.noteParameterUncertaintyTSVColumns, + ), + raterParamsUnfiltered[ + [ + c.raterParticipantIdKey, + c.internalRaterFactor1Key, + ] + ], + noteStatusHistory[ + [ + c.noteIdKey, + c.createdAtMillisKey, + c.noteAuthorParticipantIdKey, + c.classificationKey, + c.currentLabelKey, + c.lockedStatusKey, + ] + ], minRatingsNeeded=self._minRatingsNeeded, crhThreshold=self._crhThreshold, crnhThresholdIntercept=self._crnhThresholdIntercept, @@ -699,14 +833,32 @@ def _prescore_notes_and_users( globalIntercept=globalBias, lowDiligenceGlobalIntercept=diligenceGlobalIntercept, tagFilteringThresholds=self.compute_tag_thresholds_for_percentile( - scoredNotes=noteParams.merge(scoredNotes, on=c.noteIdKey, suffixes=("", "_dup")), - raterParams=raterParams, - ratings=ratings, + scoredNotes=noteParams[[c.noteIdKey, c.internalNoteFactor1Key]].merge( + scoredNotes[[c.noteIdKey, c.currentlyRatedHelpfulBoolKey]], + on=c.noteIdKey, + suffixes=("", "_dup"), + ), + raterParams=raterParams[[c.raterParticipantIdKey, c.internalRaterFactor1Key]], + ratings=ratings[ + [ + c.noteIdKey, + c.raterParticipantIdKey, + ] + + c.notHelpfulTagsTSVOrder + ], ), ) # Compute user incorrect tag aggregates - userIncorrectTagUsageDf = get_user_incorrect_ratio(ratings) + userIncorrectTagUsageDf = get_user_incorrect_ratio( + ratings[ + [ + c.noteIdKey, + c.raterParticipantIdKey, + ] + + c.notHelpfulTagsTSVOrder + ] + ) raterModelOutput = raterParams.merge( helpfulnessScores[ @@ -728,7 +880,12 @@ def _prescore_notes_and_users( ) noteModelOutput = noteParams - + # Returning should remove references to these, but manually trigger GC just to reclaim + # resources as soon as possible. + del ratings + del ratingsForTraining + del finalRoundRatings + gc.collect() return noteModelOutput, raterModelOutput, metaOutput def _score_notes_and_users( diff --git a/sourcecode/scoring/mf_core_scorer.py b/sourcecode/scoring/mf_core_scorer.py index 39365003..9ba71eed 100644 --- a/sourcecode/scoring/mf_core_scorer.py +++ b/sourcecode/scoring/mf_core_scorer.py @@ -1,85 +1,8 @@ -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional from . import constants as c from .mf_base_scorer import MFBaseScorer -import numpy as np -import pandas as pd - - -_CORE_BOOL = "coreBool" -_TOTAL = "total" -_RATIO = "ratio" - - -def filter_core_input( - ratingsOrig: pd.DataFrame, - noteStatusHistoryOrig: pd.DataFrame, - userEnrollment: pd.DataFrame, -) -> Tuple[pd.DataFrame, pd.DataFrame]: - """Prune the contents of ratings and noteStatusHistory to scope model behavior. - - Filter ratings dataframe to only include ratings from CORE users. - - Args: - ratings (pd.DataFrame): preprocessed ratings - noteStatusHistory (pd.DataFrame): one row per note; history of when note had each status - userEnrollment (pd.DataFrame): one row per user specifying enrollment properties - - Returns: - Tuple[pd.DataFrame, pd.DataFrame]: - ratings: ratings filtered to only contain rows of interest - noteStatusHistory: noteStatusHistory filtered to only contain rows of interest - """ - print("Identifying core notes and ratings") - # Prune ratings that aren't defined as core - ratings = ratingsOrig.merge( - userEnrollment[[c.participantIdKey, c.modelingPopulationKey]].rename( - columns={c.participantIdKey: c.raterParticipantIdKey} - ), - on=c.raterParticipantIdKey, - how="left", - ) - print( - f" Ratings from user without modelingPopulation: {pd.isna(ratings[c.modelingPopulationKey]).sum()}" - ) - ratings = ratings.fillna({c.modelingPopulationKey: c.core}) - ratings = ratings[ratings[c.modelingPopulationKey] == c.core] - print(f" Core ratings: {len(ratings)}") - return ratings.drop(columns=c.modelingPopulationKey), noteStatusHistoryOrig - - -def filter_core_output( - ratingsOrig: pd.DataFrame, - userEnrollment: pd.DataFrame, - noteScores: pd.DataFrame, - coreThreshold: float = 0.5, -) -> pd.DataFrame: - # Drop ExpansionPlus ratings before determining ratios - print("Filtering Core Output") - print(f" Original ratings length: {len(ratingsOrig)}") - # Separate CORE and EXPANSION notes. - userEnrollment[_CORE_BOOL] = userEnrollment[c.modelingPopulationKey] == c.core - userGroups = userEnrollment[[c.participantIdKey, _CORE_BOOL]].copy() - ratings = ratingsOrig.merge( - userGroups.rename(columns={c.participantIdKey: c.raterParticipantIdKey}), - on=c.raterParticipantIdKey, - how="left", - unsafeAllowed=_CORE_BOOL, - ) - print(f" Final ratings length: {len(ratings)}") - ratings = ratings.fillna({_CORE_BOOL: True}) - ratings[_CORE_BOOL] = ratings[_CORE_BOOL].astype(np.bool8) - ratios = ratings[[c.noteIdKey, _CORE_BOOL]].groupby(c.noteIdKey).mean().reset_index() - # Identify CORE notes. We define a CORE note to be any note which (1) has ratings, and - # (2) half or more of the ratings are from CORE users. This construction does mean that - # notes without ratings can avoid locking, but as soon as they get enough ratings to be - # captured and scored by CORE they will lock (if older than 2 weeks). - print(f" Original noteScores length: {len(noteScores)}") - noteScores = noteScores.merge(ratios[ratios[_CORE_BOOL] >= coreThreshold][[c.noteIdKey]]) - print(f" Final noteScores length: {len(noteScores)}") - return noteScores - class MFCoreScorer(MFBaseScorer): def __init__( @@ -98,8 +21,11 @@ def __init__( threads: number of threads to use for intra-op parallelism in pytorch """ super().__init__( - seed, - pseudoraters, + includedGroups=c.coreGroups, + includeUnassigned=True, + captureThreshold=0.5, + seed=seed, + pseudoraters=pseudoraters, useStableInitialization=useStableInitialization, saveIntermediateState=saveIntermediateState, threads=threads, @@ -153,23 +79,3 @@ def get_helpfulness_scores_cols(self) -> List[str]: c.raterAgreeRatioKey, c.aboveHelpfulnessThresholdKey, ] - - def _filter_input( - self, - noteTopics: pd.DataFrame, - ratingsOrig: pd.DataFrame, - noteStatusHistoryOrig: pd.DataFrame, - userEnrollment: pd.DataFrame, - ) -> Tuple[pd.DataFrame, pd.DataFrame]: - """Prune the contents of ratings and noteStatusHistory to scope model behavior.""" - return filter_core_input(ratingsOrig, noteStatusHistoryOrig, userEnrollment) - - def _postprocess_output( - self, - noteScores: pd.DataFrame, - userScores: pd.DataFrame, - ratings: pd.DataFrame, - noteStatusHistory: pd.DataFrame, - userEnrollment: pd.DataFrame, - ) -> Tuple[pd.DataFrame, pd.DataFrame]: - return filter_core_output(ratings, userEnrollment, noteScores), userScores diff --git a/sourcecode/scoring/mf_expansion_plus_scorer.py b/sourcecode/scoring/mf_expansion_plus_scorer.py index 203e4f5e..065f8177 100644 --- a/sourcecode/scoring/mf_expansion_plus_scorer.py +++ b/sourcecode/scoring/mf_expansion_plus_scorer.py @@ -19,7 +19,9 @@ def __init__( threads: number of threads to use for intra-op parallelism in pytorch """ super().__init__( - seed, + includedGroups=(c.coreGroups | c.expansionGroups | c.expansionPlusGroups), + includeUnassigned=True, + seed=seed, pseudoraters=False, useStableInitialization=useStableInitialization, saveIntermediateState=saveIntermediateState, diff --git a/sourcecode/scoring/mf_expansion_scorer.py b/sourcecode/scoring/mf_expansion_scorer.py index b5c04273..08251351 100644 --- a/sourcecode/scoring/mf_expansion_scorer.py +++ b/sourcecode/scoring/mf_expansion_scorer.py @@ -1,11 +1,8 @@ -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional from . import constants as c from .mf_base_scorer import MFBaseScorer -import numpy as np -import pandas as pd - _EXPANSION_BOOL = "expansionBool" @@ -25,13 +22,15 @@ def __init__( threads: number of threads to use for intra-op parallelism in pytorch """ super().__init__( - seed, + includedGroups=(c.coreGroups | c.expansionGroups), + includeUnassigned=True, + captureThreshold=0.5, + seed=seed, pseudoraters=False, useStableInitialization=useStableInitialization, saveIntermediateState=saveIntermediateState, threads=threads, ) - self._expansionThreshold = 0.5 def get_name(self): return "MFExpansionScorer" @@ -102,78 +101,3 @@ def _get_dropped_user_cols(self) -> List[str]: c.raterAgreeRatioKey, c.aboveHelpfulnessThresholdKey, ] - - def _filter_input( - self, - noteTopics: pd.DataFrame, - ratingsOrig: pd.DataFrame, - noteStatusHistoryOrig: pd.DataFrame, - userEnrollment: pd.DataFrame, - ) -> Tuple[pd.DataFrame, pd.DataFrame]: - """Prune the contents of ratings to scope model behavior. - - The MFExpansionScorer input is filtered to exclude notes and ratings from EXPANSION_PLUS - users. All other ratings are included. - - Args: - ratings (pd.DataFrame): preprocessed ratings - noteStatusHistory (pd.DataFrame): one row per note; history of when note had each status - userEnrollment (pd.DataFrame): one row per user specifying enrollment properties - - Returns: - Tuple[pd.DataFrame, pd.DataFrame]: - ratingsOrig: ratings filtered to only contain rows of interest - noteStatusHistoryOrig: noteStatusHistory filtered to only contain rows of interest - """ - print("Identifying expansion notes and ratings") - # Prune ratings to CORE and EXPANSION users. - print(f" Total ratings: {len(ratingsOrig)}") - ratings = ratingsOrig.merge( - userEnrollment[[c.participantIdKey, c.modelingPopulationKey]].rename( - columns={c.participantIdKey: c.raterParticipantIdKey} - ), - on=c.raterParticipantIdKey, - how="left", - ) - print( - f" Ratings from user without modelingPopulation: {pd.isna(ratings[c.modelingPopulationKey]).sum()}" - ) - ratings = ratings.fillna({c.modelingPopulationKey: c.expansion}) - ratings = ratings[ratings[c.modelingPopulationKey] != c.expansionPlus] - print(f" Ratings after EXPANSION_PLUS filter: {len(ratings)}") - - return ratings.drop(columns=c.modelingPopulationKey), noteStatusHistoryOrig - - def _postprocess_output( - self, - noteScores: pd.DataFrame, - userScores: pd.DataFrame, - ratings: pd.DataFrame, - noteStatusHistory: pd.DataFrame, - userEnrollment: pd.DataFrame, - ) -> Tuple[pd.DataFrame, pd.DataFrame]: - print("Filtering Expansion Output") - print(f" Original ratings length: {len(ratings)}") - # Separate CORE and EXPANSION notes from EXPANSION_PLUS. - userEnrollment[_EXPANSION_BOOL] = userEnrollment[c.modelingPopulationKey].isin( - {c.core, c.expansion} - ) - userGroups = userEnrollment[[c.participantIdKey, _EXPANSION_BOOL]].copy() - ratings = ratings.merge( - userGroups.rename(columns={c.participantIdKey: c.raterParticipantIdKey}), - on=c.raterParticipantIdKey, - how="left", - unsafeAllowed=_EXPANSION_BOOL, - ) - print(f" Final ratings length: {len(ratings)}") - ratings = ratings.fillna({_EXPANSION_BOOL: True}) - ratings[_EXPANSION_BOOL] = ratings[_EXPANSION_BOOL].astype(np.bool8) - ratios = ratings[[c.noteIdKey, _EXPANSION_BOOL]].groupby(c.noteIdKey).mean().reset_index() - # Identify EXPANSION notes. We define a EXPANSION note to be any note which (1) has ratings, and - # (2) half or more of the ratings are from EXPANSION/CORE users. - print(f" Original noteScores length: {len(noteScores)}") - noteScores = noteScores.merge( - ratios[ratios[_EXPANSION_BOOL] >= self._expansionThreshold][[c.noteIdKey]] - ) - print(f" Final noteScores length: {len(noteScores)}") - return noteScores, userScores diff --git a/sourcecode/scoring/mf_group_scorer.py b/sourcecode/scoring/mf_group_scorer.py index 62573ccf..69447f1e 100644 --- a/sourcecode/scoring/mf_group_scorer.py +++ b/sourcecode/scoring/mf_group_scorer.py @@ -104,8 +104,11 @@ def __init__( for the model to be active """ super().__init__( - seed, - pseudoraters, + includedGroups={groupNumber}, + includeUnassigned=False, + captureThreshold=groupThreshold, + seed=seed, + pseudoraters=pseudoraters, useStableInitialization=False, saveIntermediateState=saveIntermediateState, threads=_groupScorerParalleism.get(groupNumber, 4), @@ -135,7 +138,6 @@ def __init__( assert groupNumber > 0, "groupNumber must be positive. 0 is reserved for unassigned." assert groupNumber <= groupScorerCount, "groupNumber exceeds maximum expected groups." self._groupNumber = groupNumber - self._groupThreshold = groupThreshold self._groupNoteInterceptKey = f"{c.groupNoteInterceptKey}_{self._groupNumber}" self._groupNoteFactor1Key = f"{c.groupNoteFactor1Key}_{self._groupNumber}" self._groupRatingStatusKey = f"{c.groupRatingStatusKey}_{self._groupNumber}" @@ -219,43 +221,6 @@ def _get_dropped_user_cols(self) -> List[str]: c.aboveHelpfulnessThresholdKey, ] - def _filter_input( - self, - noteTopics: pd.DataFrame, - ratings: pd.DataFrame, - noteStatusHistory: pd.DataFrame, - userEnrollment: pd.DataFrame, - ) -> Tuple[pd.DataFrame, pd.DataFrame]: - """Prune the contents of ratings to only include ratings from users in the modeling group. - - This function identifies the subset of ratings to include in group model scoring. - To improve modeling within the group, we only include ratings from users in the modeling - group. However, we place no restriction on which notes to include in the model and instead - include ratings on any note. Including ratings on any note increases the amount of data - available during training about each user, in effect also increasing the number of users - and notes we are able to include in the model. - - Including notes by users outside of the modeling group means that the model will issue - scores for notes which do not meet group modeling criteria (i.e. >80% of ratings are - from users in the modeling group, and the author is also from the modeling group). We - enforce these criteria *after* scoring in _postprocess_output so that the maximum amount - of ratings are available during scoring. - - Args: - ratings (pd.DataFrame): preprocessed ratings - noteStatusHistory (pd.DataFrame): one row per note; history of when note had each status - userEnrollment (pd.DataFrame): one row per user specifying enrollment properties - - Returns: - Tuple[pd.DataFrame, pd.DataFrame]: - ratings: ratings filtered to only contain rows of interest - noteStatusHistory: noteStatusHistory filtered to only contain rows of interest - """ - userEnrollment = userEnrollment.rename(columns={c.participantIdKey: c.raterParticipantIdKey}) - userEnrollment = userEnrollment[userEnrollment[c.modelingGroupKey] == self._groupNumber] - ratings = ratings.merge(userEnrollment[[c.raterParticipantIdKey]].drop_duplicates()) - return ratings, noteStatusHistory - def _postprocess_output( self, noteScores: pd.DataFrame, @@ -283,17 +248,9 @@ def _postprocess_output( noteScores: filtered and updated note scoring output userScores: filtered and updated user scoring output """ - # Identify notes with enough ratings from within the modeling group. - ratings = ratings.merge( - userEnrollment[[c.participantIdKey, c.modelingGroupKey]].rename( - columns={c.participantIdKey: c.raterParticipantIdKey} - ), - how="left", + noteScores, userScores = super()._postprocess_output( + noteScores, userScores, ratings, noteStatusHistory, userEnrollment ) - ratings["inGroup"] = ratings[c.modelingGroupKey] == self._groupNumber - ratios = ratings[[c.noteIdKey, "inGroup"]].groupby(c.noteIdKey).mean().reset_index() - notesAboveThreshold = ratios[ratios["inGroup"] >= self._groupThreshold][[c.noteIdKey]] - noteScores = noteScores.merge(notesAboveThreshold) # Note that even though ratings were restricted to the modeling group, users outside of # the modeling group may still have authored a note which was rated and may consequently # appear in the userScores. Accordingly, we drop any user which was outside of the diff --git a/sourcecode/scoring/mf_topic_scorer.py b/sourcecode/scoring/mf_topic_scorer.py index cfb8b9c1..22c04f8d 100644 --- a/sourcecode/scoring/mf_topic_scorer.py +++ b/sourcecode/scoring/mf_topic_scorer.py @@ -77,8 +77,9 @@ def __init__( pseudoraters: if True, compute optional pseudorater confidence intervals """ super().__init__( - seed, - pseudoraters, + includedTopics={topicName}, + seed=seed, + pseudoraters=pseudoraters, useStableInitialization=False, saveIntermediateState=saveIntermediateState, threads=4, @@ -175,31 +176,6 @@ def _get_dropped_user_cols(self) -> List[str]: c.raterParticipantIdKey, ] - def _filter_input( - self, - noteTopics: pd.DataFrame, - ratings: pd.DataFrame, - noteStatusHistory: pd.DataFrame, - userEnrollment: pd.DataFrame, - ) -> Tuple[pd.DataFrame, pd.DataFrame]: - """Prune the contents of ratings to only include ratings from notes on this topic. - - Args: - noteTopics: DF pairing notes and topics - ratings (pd.DataFrame): preprocessed ratings - noteStatusHistory (pd.DataFrame): one row per note; history of when note had each status - userEnrollment (pd.DataFrame): one row per user specifying enrollment properties - - Returns: - Tuple[pd.DataFrame, pd.DataFrame]: - ratings: ratings filtered to only contain rows of interest - noteStatusHistory: noteStatusHistory filtered to only contain rows of interest - """ - notes = noteTopics[noteTopics[c.noteTopicKey] == self._topicName][[c.noteIdKey]] - ratings = ratings.merge(notes) - noteStatusHistory = noteStatusHistory.merge(notes) - return ratings, noteStatusHistory - def _postprocess_output( self, noteScores: pd.DataFrame, diff --git a/sourcecode/scoring/pandas_utils.py b/sourcecode/scoring/pandas_utils.py index c2e5a17d..98893420 100644 --- a/sourcecode/scoring/pandas_utils.py +++ b/sourcecode/scoring/pandas_utils.py @@ -22,6 +22,9 @@ """ from collections import Counter +from dataclasses import dataclass +from enum import Enum +import re import sys from threading import Lock import traceback @@ -31,6 +34,11 @@ import pandas as pd +def keep_columns(df: pd.DataFrame, cols: List[str]): + cols = [col for col in cols if col in df] + return df[cols] + + class TypeErrorCounter(object): def __init__(self): self._callCounts: Dict[Tuple[str, str], int] = dict() @@ -61,367 +69,510 @@ def get_summary(self): return "\n".join(lines) -def get_check(fail: Any, lines: List[str], unsafeAllowed: Set[str]) -> Callable: - """Return a function which will either assert a condition or conditionally log.""" - - def _check(columns: Any, condition: bool, msg: str): - if isinstance(columns, str): - failDisabled = columns in unsafeAllowed - elif isinstance(columns, List): - failDisabled = all(col in unsafeAllowed for col in columns) - else: - # Note there are multiple circumstances where the type of Columns may not be a str - # or List[str], including when we are concatenating a Series (column name will be - # set to None), when there are mulit-level column names (column name will be a tuple) - # or when Pandas has set column names to a RangeIndex. - failDisabled = False - if fail and not failDisabled: - assert condition, msg - elif not condition: - if failDisabled: - lines.append(f"{msg} (allowed)") - else: - lines.append(f"{msg} (UNALLOWED)") - - return _check - - -def _log_errors(method: str, callsite: str, lines: List[str], counter: TypeErrorCounter) -> None: - if not lines: - return - counter.log_errors(method, callsite, lines) - errorLines = "\n".join([f" PandasTypeError: {l}" for l in lines]) - msg = f"\n{method} ERROR AT: {callsite}" f"{method} ERRORS:\n{errorLines}\n" - print(msg, file=sys.stderr) +class LogLevel(Enum): + # Raise an error if the expecatation is violated + FATAL = 1 + # Log to stderr when the expectation is violated + ERROR = 2 + # Log to stderr any time the column is observed + INFO = 3 -def safe_concat(fail: bool, counter: TypeErrorCounter) -> Callable: - """Return a modified concat function that checks type stability. +@dataclass +class TypeExpectation: + dtype: type + logLevel: LogLevel - Args: - fail: If True, unexpected type conversions should trigger a failed assert. If False, - unexpected conversions should log to stderr. - counter: Tracker for summarizing problematic calls at the end of execution. - """ - original = pd.concat - def _safe_concat(*args, **kwargs): - """Wrapper around pd.concat +class PandasPatcher(object): + def __init__(self, fail: bool, typeOverrides: Dict[str, TypeExpectation] = dict()): + """Initialize a PandasPatcher with particular failure and type expectations. Args: - args: non-keyword arguments to pass through to merge. - kwargs: keyword arguments to pass through to merge. + fail: Whether to raise errors or log to stderr when expectations are violated. + expectations: Type expecatations for select columns. """ - lines = [] + self._fail = fail + self._counter = TypeErrorCounter() + self._origConcat = pd.concat + self._origJoin = pd.DataFrame.join + self._origMerge = pd.DataFrame.merge + self._origApply = pd.DataFrame.apply + self._origInit = pd.DataFrame.__init__ + self._origGetItem = pd.DataFrame.__getitem__ + self._origSetItem = pd.DataFrame.__setitem__ + self._origLocGetItem = pd.core.indexing._LocationIndexer.__getitem__ + self._origLocSetItem = pd.core.indexing._LocationIndexer.__setitem__ + self._expectations: Dict[str, TypeExpectation] = dict() + for column, expectation in typeOverrides.items(): + self._expectations[column] = expectation + + def get_summary(self) -> str: + return f"\nTYPE WARNING SUMMARY\n{self._counter.get_summary()}" + + def _log_errors(self, method: str, callsite: str, lines: List[str]) -> None: + if not lines: + return + self._counter.log_errors(method, callsite, lines) + errorLines = "\n".join([f" PandasTypeError: {l}" for l in lines]) + msg = f"\n{method} ERROR(S) AT: {callsite}\n{errorLines}\n" + print(msg, file=sys.stderr) + + def _get_check(self, lines: List[str], kwargs: Dict) -> Callable: + """Return a function which will either assert a condition or append to a list of errors. + + Note that this function does not actually log to stderr, but rather appends to a list so + that all + """ + unsafeAllowed = set() if "unsafeAllowed" in kwargs: - unsafeAllowed = kwargs["unsafeAllowed"] - if isinstance(unsafeAllowed, str): - unsafeAllowed = {unsafeAllowed} + unsafeAllowedArg = kwargs["unsafeAllowed"] + if isinstance(unsafeAllowedArg, str): + unsafeAllowed = {unsafeAllowedArg} + elif isinstance(unsafeAllowedArg, List): + unsafeAllowed = set(unsafeAllowedArg) + else: + assert isinstance(unsafeAllowedArg, Set) + unsafeAllowed = unsafeAllowedArg del kwargs["unsafeAllowed"] - else: - unsafeAllowed: Set[str] = set() - check = get_check(fail, lines, unsafeAllowed) - # Validate that all objects being concatenated are either Series or DataFrames - objs = args[0] - assert type(objs) == list, f"expected first argument to be a list: type={type(objs)}" - assert all(type(obj) == pd.Series for obj in objs) or all( - type(obj) == pd.DataFrame for obj in objs - ), f"Expected concat args to be either pd.Series or pd.DataFrame: {[type(obj) for obj in objs]}" - if type(objs[0]) == pd.Series: - if "axis" in kwargs and kwargs["axis"] == 1: - # Since the call is concatenating Series as columns in a DataFrame, validate that the sequence - # of Series dtypes matches the sequence of column dtypes in the dataframe. - result = original(*args, **kwargs) - objDtypes = [obj.dtype for obj in objs] - assert len(objDtypes) == len( - result.dtypes - ), f"dtype length mismatch: {len(objDtypes)} vs {len(result.dtypes)}" - for col, seriesType, colType in zip(result.columns, objDtypes, result.dtypes): - check(col, seriesType == colType, f"Series concat on {col}: {seriesType} vs {colType}") + + def _check(columns: Any, condition: bool, msg: str): + if isinstance(columns, str): + failDisabled = columns in unsafeAllowed + elif isinstance(columns, List): + failDisabled = all(col in unsafeAllowed for col in columns) else: - # If Series, validate that all series were same type and return - seriesTypes = set(obj.dtype for obj in objs) - check(None, len(seriesTypes) == 1, f"More than 1 unique Series type: {seriesTypes}") - result = original(*args, **kwargs) + # Note there are multiple circumstances where the type of Columns may not be a str + # or List[str], including when we are concatenating a Series (column name will be + # set to None), when there are mulit-level column names (column name will be a tuple) + # or when Pandas has set column names to a RangeIndex. + failDisabled = False + if self._fail and not failDisabled: + assert condition, msg + elif not condition: + if failDisabled: + lines.append(f"{msg} (allowed)") + else: + lines.append(f"{msg} (UNALLOWED)") + + return _check + + def _get_callsite(self) -> str: + """Return the file, function, line numer and pandas API call on a single line.""" + for line in traceback.format_stack()[::-1]: + path = line.split(",")[0] + if "/pandas_utils.py" in path: + continue + if "/pandas/" in path: + continue + break + # Handle paths resulting from bazel invocation + match = re.match(r'^ File ".*?/site-packages(/.*?)", (.*?), (.*?)\n (.*)\n$', line) + if match: + return f"{match.group(1)}, {match.group(3)}, at {match.group(2)}: {match.group(4)}" + # Handle paths fresulting from pytest invocation + match = re.match(r'^ File ".*?/src/(test|main)/python(/.*?)", (.*?), (.*?)\n (.*)\n$', line) + if match: + return f"{match.group(2)}, {match.group(4)}, at {match.group(3)}: {match.group(5)}" + # Handle other paths (e.g. notebook, public code) + match = re.match(r'^ File "(.*?)", (.*?), (.*?)\n (.*)\n$', line) + if match: + return f"{match.group(1)}, {match.group(3)}, at {match.group(2)}: {match.group(4)}" else: - # If DataFrame, validate that all input columns with matching names have the same type - # and build expectation for output column types - assert type(objs[0]) == pd.DataFrame - colTypes: Dict[str, List[type]] = dict() - for df in objs: - for col, dtype in df.reset_index(drop=False).dtypes.items(): - if col not in colTypes: - colTypes[col] = [] - colTypes[col].append(dtype) - # Perform concatenation and validate that there weren't any type changes - result = original(*args, **kwargs) - for col, outputType in result.reset_index(drop=False).dtypes.items(): - check( - col, - all(inputType == outputType for inputType in colTypes[col]), - f"DataFrame concat on {col}: output={outputType} inputs={colTypes[col]}", - ) - _log_errors("CONCAT", traceback.format_stack()[-2], lines, counter) - return result - - return _safe_concat + stack = "\n\n".join(traceback.format_stack()[::-1]) + print(f"parsing error:\n{stack}", file=sys.stderr) + return "parsing error. callsite unknown." + def _check_dtype(self, dtype: Any, expected: type) -> bool: + """Return True IFF dtype corresponds to expected. -def safe_merge(fail: bool, counter: TypeErrorCounter) -> Callable: - """Return a modified merge function that checks type stability. - - Args: - fail: If True, unexpected type conversions should trigger a failed assert. If False, - unexpected conversions should log to stderr. - counter: Tracker for summarizing problematic calls at the end of execution. - """ - original = pd.DataFrame.merge - - def _safe_merge(*args, **kwargs): - """Wrapper around pd.DataFrame.merge. - - Args: - args: non-keyword arguments to pass through to merge. - kwargs: keyword arguments to pass through to merge. + Note that for non-nullable columns, dtype may equal type (e.g. np.int64), but for nullable + columns the column type is actually an instance of a pandas dtype (e.g. pd.Int64Dtype) """ - lines = [] - if "unsafeAllowed" in kwargs: - unsafeAllowed = kwargs["unsafeAllowed"] - if isinstance(unsafeAllowed, str): - unsafeAllowed = {unsafeAllowed} - del kwargs["unsafeAllowed"] - else: - unsafeAllowed: Set[str] = set() - check = get_check(fail, lines, unsafeAllowed) - leftFrame = args[0] - rightFrame = args[1] - # Validate that argument types are as expected - assert type(leftFrame) is pd.DataFrame - assert type(rightFrame) is pd.DataFrame - # Store dtypes and validate that any common columns have the same type - leftDtypes = dict(leftFrame.reset_index(drop=False).dtypes) - rightDtypes = dict(rightFrame.reset_index(drop=False).dtypes) - for col in set(leftDtypes) & set(rightDtypes): - check( - col, - leftDtypes[col] == rightDtypes[col], - f"Input mismatch on {col}: left={leftDtypes[col]} vs right={rightDtypes[col]}", + assert expected != object, "expectation must be more specific than object" + return dtype == expected or isinstance(dtype, expected) + + def _check_name_and_type(self, name: str, dtype: Any) -> List[str]: + """Returns a list of type mismatches if any are found, or raises an error.""" + if name not in self._expectations: + return [] + typeExpectation = self._expectations[name] + msg = f"Type expectation mismatch on {name}: found={dtype} expected={typeExpectation.dtype.__name__}" + match = self._check_dtype(dtype, typeExpectation.dtype) + if typeExpectation.logLevel == LogLevel.INFO: + return ( + [msg] + if not match + else [ + f"Type expectation match on {name}: found={dtype} expected={typeExpectation.dtype.__name__}" + ] ) - # Identify the columns we are merging on, if left_on and right_on are unset - if "on" in kwargs and type(kwargs["on"]) == str: - onCols = set([kwargs["on"]]) - elif "on" in kwargs and type(kwargs["on"]) == list: - onCols = set(kwargs["on"]) - elif "left_on" in kwargs: - assert "on" not in kwargs, "not expecting both on and left_on" - assert "right_on" in kwargs, "expecting both left_on and right_on to be set" - onCols = set() + elif typeExpectation.logLevel == LogLevel.ERROR or not self._fail: + return [msg] if not match else [] else: - assert "on" not in kwargs, f"""unexpected type for on: {type(kwargs["on"])}""" - onCols = set(leftFrame.columns) & set(rightFrame.columns) - # Validate that merge columns have matching types - if "left_on" in kwargs: - assert "right_on" in kwargs - left_on = kwargs["left_on"] - right_on = kwargs["right_on"] - check( - [left_on, right_on], - leftDtypes[left_on] == rightDtypes[right_on], - f"Merge key mismatch on type({left_on})={leftDtypes[left_on]} vs type({right_on})={rightDtypes[right_on]}", - ) + assert typeExpectation.logLevel == LogLevel.FATAL + assert self._fail + assert match, msg + return [] + + def _validate_series(self, series: pd.Series) -> List[str]: + assert isinstance(series, pd.Series), f"unexpected type: {type(series)}" + return self._check_name_and_type(series.name, series.dtype) + + def _validate_dataframe(self, df: pd.DataFrame) -> List[str]: + """Returns a list of type mismatches if any are found, or raises an error.""" + assert isinstance(df, pd.DataFrame), f"unexpected type: {type(df)}" + lines = [] + # Check index types + if type(df.index) == pd.MultiIndex: + for name, dtype in df.index.dtypes.to_dict().items(): + lines.extend(self._check_name_and_type(name, dtype)) + elif type(df.index) == pd.RangeIndex or df.index.name is None: + # Index is uninteresting - none was specified by the caller. + pass else: - assert len(onCols), "expected onCols to be defined since left_on was not" - assert "right_on" not in kwargs, "did not expect onCols and right_on" - for col in onCols: + lines.extend(self._check_name_and_type(df.index.name, df.index.dtype)) + # Check column types + for name, dtype in df.dtypes.to_dict().items(): + lines.extend(self._check_name_and_type(name, dtype)) + return lines + + def safe_init(self) -> Callable: + """Return a modified __init__ function that checks type expectations.""" + + def _safe_init(*args, **kwargs): + """Wrapper around pd.concat + + Args: + args: non-keyword arguments to pass through to merge. + kwargs: keyword arguments to pass through to merge. + """ + df = args[0] + assert isinstance(df, pd.DataFrame), f"unexpected type: {type(df)}" + retVal = self._origInit(*args, **kwargs) + assert retVal is None + lines = self._validate_dataframe(df) + self._log_errors("INIT", self._get_callsite(), lines) + return retVal + + return _safe_init + + def safe_concat(self) -> Callable: + """Return a modified concat function that checks type stability.""" + + def _safe_concat(*args, **kwargs): + """Wrapper around pd.concat + + Args: + args: non-keyword arguments to pass through to merge. + kwargs: keyword arguments to pass through to merge. + """ + lines = [] + check = self._get_check(lines, kwargs) + # Validate that all objects being concatenated are either Series or DataFrames + objs = args[0] + assert type(objs) == list, f"expected first argument to be a list: type={type(objs)}" + assert ( + all(type(obj) == pd.Series for obj in objs) + or all(type(obj) == pd.DataFrame for obj in objs) + ), f"Expected concat args to be either pd.Series or pd.DataFrame: {[type(obj) for obj in objs]}" + if type(objs[0]) == pd.Series: + if "axis" in kwargs and kwargs["axis"] == 1: + # Since the call is concatenating Series as columns in a DataFrame, validate that the sequence + # of Series dtypes matches the sequence of column dtypes in the dataframe. + result = self._origConcat(*args, **kwargs) + objDtypes = [obj.dtype for obj in objs] + assert len(objDtypes) == len( + result.dtypes + ), f"dtype length mismatch: {len(objDtypes)} vs {len(result.dtypes)}" + for col, seriesType, colType in zip(result.columns, objDtypes, result.dtypes): + check( + col, + seriesType == colType, + f"Series concat on {col}: {seriesType} vs {colType}", + ) + else: + # If Series, validate that all series were same type and return + seriesTypes = set(obj.dtype for obj in objs) + check(None, len(seriesTypes) == 1, f"More than 1 unique Series type: {seriesTypes}") + result = self._origConcat(*args, **kwargs) + else: + # If DataFrame, validate that all input columns with matching names have the same type + # and build expectation for output column types + assert type(objs[0]) == pd.DataFrame + colTypes: Dict[str, List[type]] = dict() + for df in objs: + for col, dtype in df.reset_index(drop=False).dtypes.items(): + if col not in colTypes: + colTypes[col] = [] + colTypes[col].append(dtype) + # Perform concatenation and validate that there weren't any type changes + result = self._origConcat(*args, **kwargs) + for col, outputType in result.reset_index(drop=False).dtypes.items(): + check( + col, + all(inputType == outputType for inputType in colTypes[col]), + f"DataFrame concat on {col}: output={outputType} inputs={colTypes[col]}", + ) + if isinstance(result, pd.DataFrame): + lines.extend(self._validate_dataframe(result)) + elif isinstance(result, pd.Series): + lines.extend(self._validate_series(result)) + self._log_errors("CONCAT", self._get_callsite(), lines) + return result + + return _safe_concat + + def safe_apply(self) -> Callable: + """Return a modified apply function that checks type stability.""" + + def _safe_apply(*args, **kwargs): + """Wrapper around pd.DataFrame.apply + + Args: + args: non-keyword arguments to pass through to merge. + kwargs: keyword arguments to pass through to merge. + """ + # TODO: Flesh this out with additional expectatoins around input and output types + result = self._origApply(*args, **kwargs) + if isinstance(result, pd.DataFrame): + self._log_errors("APPLY", self._get_callsite(), self._validate_dataframe(result)) + elif isinstance(result, pd.Series): + self._log_errors("APPLY", self._get_callsite(), self._validate_series(result)) + return result + + return _safe_apply + + def safe_merge(self) -> Callable: + """Return a modified merge function that checks type stability.""" + + def _safe_merge(*args, **kwargs): + """Wrapper around pd.DataFrame.merge. + + Args: + args: non-keyword arguments to pass through to merge. + kwargs: keyword arguments to pass through to merge. + """ + lines = [] + check = self._get_check(lines, kwargs) + leftFrame = args[0] + rightFrame = args[1] + # Validate that argument types are as expected + assert type(leftFrame) is pd.DataFrame + assert type(rightFrame) is pd.DataFrame + # Store dtypes and validate that any common columns have the same type + leftDtypes = dict(leftFrame.reset_index(drop=False).dtypes) + rightDtypes = dict(rightFrame.reset_index(drop=False).dtypes) + for col in set(leftDtypes) & set(rightDtypes): check( col, leftDtypes[col] == rightDtypes[col], - f"Merge key mismatch on {col}: left={leftDtypes[col]} vs right={rightDtypes[col]}", + f"Input mismatch on {col}: left={leftDtypes[col]} vs right={rightDtypes[col]}", ) - # Compute expected column types - leftSuffix, rightSuffix = kwargs.get("suffixes", ("_x", "_y")) - commonCols = set(leftFrame.columns) & set(rightFrame.columns) - expectedColTypes = dict() - for col in set(leftFrame.columns) | set(rightFrame.columns): - if col in onCols: - # Note that we check above whether leftDtypes[col] == rightDtypes[col] and either raise an - # error or log as appropriate if there is a mismatch. - if leftDtypes[col] == rightDtypes[col]: + # Identify the columns we are merging on, if left_on and right_on are unset + if "on" in kwargs and type(kwargs["on"]) == str: + onCols = set([kwargs["on"]]) + elif "on" in kwargs and type(kwargs["on"]) == list: + onCols = set(kwargs["on"]) + elif "left_on" in kwargs: + assert "on" not in kwargs, "not expecting both on and left_on" + assert "right_on" in kwargs, "expecting both left_on and right_on to be set" + onCols = set() + else: + assert "on" not in kwargs, f"""unexpected type for on: {type(kwargs["on"])}""" + onCols = set(leftFrame.columns) & set(rightFrame.columns) + # Validate that merge columns have matching types + if "left_on" in kwargs: + assert "right_on" in kwargs + left_on = kwargs["left_on"] + right_on = kwargs["right_on"] + check( + [left_on, right_on], + leftDtypes[left_on] == rightDtypes[right_on], + f"Merge key mismatch on type({left_on})={leftDtypes[left_on]} vs type({right_on})={rightDtypes[right_on]}", + ) + else: + assert len(onCols), "expected onCols to be defined since left_on was not" + assert "right_on" not in kwargs, "did not expect onCols and right_on" + for col in onCols: + check( + col, + leftDtypes[col] == rightDtypes[col], + f"Merge key mismatch on {col}: left={leftDtypes[col]} vs right={rightDtypes[col]}", + ) + # Compute expected column types + leftSuffix, rightSuffix = kwargs.get("suffixes", ("_x", "_y")) + commonCols = set(leftFrame.columns) & set(rightFrame.columns) + expectedColTypes = dict() + for col in set(leftFrame.columns) | set(rightFrame.columns): + if col in onCols: + # Note that we check above whether leftDtypes[col] == rightDtypes[col] and either raise an + # error or log as appropriate if there is a mismatch. + if leftDtypes[col] == rightDtypes[col]: + expectedColTypes[col] = leftDtypes[col] + else: + # Set expectation to None since we don't know what will happen, but do want to log an + # error later + expectedColTypes[col] = None + elif col in commonCols: + expectedColTypes[f"{col}{leftSuffix}"] = leftDtypes[col] + expectedColTypes[f"{col}{rightSuffix}"] = rightDtypes[col] + elif col in leftDtypes: + assert col not in rightDtypes expectedColTypes[col] = leftDtypes[col] else: - # Set expectation to None since we don't know what will happen, but do want to log an - # error later - expectedColTypes[col] = None - elif col in commonCols: - expectedColTypes[f"{col}{leftSuffix}"] = leftDtypes[col] - expectedColTypes[f"{col}{rightSuffix}"] = rightDtypes[col] - elif col in leftDtypes: - assert col not in rightDtypes - expectedColTypes[col] = leftDtypes[col] + expectedColTypes[col] = rightDtypes[col] + # Perform merge and validate results + result = self._origMerge(*args, **kwargs) + resultDtypes = dict(result.dtypes) + for col in resultDtypes: + check( + col, + resultDtypes[col] == expectedColTypes[col], + f"Output mismatch on {col}: result={resultDtypes[col]} expected={expectedColTypes[col]}", + ) + lines.extend(self._validate_dataframe(result)) + self._log_errors("MERGE", self._get_callsite(), lines) + return result + + return _safe_merge + + def safe_join(self) -> Callable: + """Return a modified merge function that checks type stability.""" + + def _safe_join(*args, **kwargs): + """Wrapper around pd.DataFrame.merge. + + Args: + args: non-keyword arguments to pass through to merge. + kwargs: keyword arguments to pass through to merge. + """ + lines = [] + check = self._get_check(lines, kwargs) + leftFrame = args[0] + rightFrame = args[1] + # Validate arguments are as expected + assert type(leftFrame) is pd.DataFrame + assert type(rightFrame) is pd.DataFrame + assert len(set(kwargs) - {"lsuffix", "rsuffix", "how"}) == 0, f"unexpected kwargs: {kwargs}" + # Validate the assumption that columns used as the join key in the index have the same type. + # This is analogous to validating that onCols match and have the same types in _safe_merge. + if len(leftFrame.index.names) == 1 and len(rightFrame.index.names) == 1: + match = leftFrame.index.dtype == rightFrame.index.dtype + elif len(leftFrame.index.names) == 1 and len(rightFrame.index.names) > 1: + indexTypes = dict(rightFrame.index.dtypes) + name = leftFrame.index.names[0] + assert name in indexTypes, f"{name} not found in {indexTypes}" + match = indexTypes[name] == leftFrame.index.dtype + elif len(leftFrame.index.names) > 1 and len(rightFrame.index.names) == 1: + indexTypes = dict(leftFrame.index.dtypes) + name = rightFrame.index.names[0] + assert name in indexTypes, f"{name} not found in {indexTypes}" + match = indexTypes[name] == rightFrame.index.dtype else: - expectedColTypes[col] = rightDtypes[col] - # Perform merge and validate results - result = original(*args, **kwargs) - resultDtypes = dict(result.dtypes) - for col in resultDtypes: - check( - col, - resultDtypes[col] == expectedColTypes[col], - f"Output mismatch on {col}: result={resultDtypes[col]} expected={expectedColTypes[col]}", - ) - _log_errors("MERGE", traceback.format_stack()[-2], lines, counter) - return result - - return _safe_merge - - -def safe_join(fail: bool, counter: TypeErrorCounter) -> Callable: - """Return a modified merge function that checks type stability. - - Args: - fail: If True, unexpected type conversions should trigger a failed assert. If False, - unexpected conversions should log to stderr. - counter: Tracker for summarizing problematic calls at the end of execution. - """ - original = pd.DataFrame.join - - def _safe_join(*args, **kwargs): - """Wrapper around pd.DataFrame.merge. - - Args: - args: non-keyword arguments to pass through to merge. - kwargs: keyword arguments to pass through to merge. - """ - lines = [] - if "unsafeAllowed" in kwargs: - unsafeAllowed = kwargs["unsafeAllowed"] - if isinstance(unsafeAllowed, str): - unsafeAllowed = {unsafeAllowed} - del kwargs["unsafeAllowed"] - else: - unsafeAllowed: Set[str] = set() - check = get_check(fail, lines, unsafeAllowed) - leftFrame = args[0] - rightFrame = args[1] - # Validate arguments are as expected - assert type(leftFrame) is pd.DataFrame - assert type(rightFrame) is pd.DataFrame - assert len(set(kwargs) - {"lsuffix", "rsuffix", "how"}) == 0, f"unexpected kwargs: {kwargs}" - # Validate the assumption that columns used as the join key in the index have the same type. - # This is analogous to validating that onCols match and have the same types in _safe_merge. - if len(leftFrame.index.names) == 1 and len(rightFrame.index.names) == 1: - match = leftFrame.index.dtype == rightFrame.index.dtype - elif len(leftFrame.index.names) == 1 and len(rightFrame.index.names) > 1: - indexTypes = dict(rightFrame.index.dtypes) - name = leftFrame.index.names[0] - assert name in indexTypes, f"{name} not found in {indexTypes}" - match = indexTypes[name] == leftFrame.index.dtype - elif len(leftFrame.index.names) > 1 and len(rightFrame.index.names) == 1: - indexTypes = dict(leftFrame.index.dtypes) - name = rightFrame.index.names[0] - assert name in indexTypes, f"{name} not found in {indexTypes}" - match = indexTypes[name] == rightFrame.index.dtype - else: + assert ( + len(leftFrame.index.names) > 1 + ), f"unexpected left: {type(leftFrame.index)}, {leftFrame.index}" + assert ( + len(rightFrame.index.names) > 1 + ), f"unexpected right: {type(rightFrame.index)}, {rightFrame.index}" + leftIndexTypes = dict(leftFrame.index.dtypes) + rightIndexTypes = dict(rightFrame.index.dtypes) + match = True + for col in set(leftIndexTypes) & set(rightIndexTypes): + match = match & (leftIndexTypes[col] == rightIndexTypes[col]) + assert match, f"Join index mismatch:\n{leftFrame.index}\nvs\n{rightFrame.index}" + # Validate that input columns with the same name have the same types + leftDtypes = dict(leftFrame.dtypes) + rightDtypes = dict(rightFrame.dtypes) + for col in set(leftDtypes) & set(rightDtypes): + check( + col, + leftDtypes[col] == rightDtypes[col], + f"Input mismatch on {col}: left={leftDtypes[col]} vs right={rightDtypes[col]}", + ) + # Validate that none of the columns in an index have the same name as a non-index column + # in the opposite dataframe assert ( - len(leftFrame.index.names) > 1 - ), f"unexpected left: {type(leftFrame.index)}, {leftFrame.index}" + len(set(leftFrame.index.names) & set(rightFrame.columns)) == 0 + ), f"left index: {set(leftFrame.index.names)}; right columns {set(rightFrame.columns)}" assert ( - len(rightFrame.index.names) > 1 - ), f"unexpected right: {type(rightFrame.index)}, {rightFrame.index}" - leftIndexTypes = dict(leftFrame.index.dtypes) - rightIndexTypes = dict(rightFrame.index.dtypes) - match = True - for col in set(leftIndexTypes) & set(rightIndexTypes): - match = match & (leftIndexTypes[col] == rightIndexTypes[col]) - assert match, f"Join index mismatch:\n{leftFrame.index}\nvs\n{rightFrame.index}" - # Validate that input columns with the same name have the same types - leftDtypes = dict(leftFrame.dtypes) - rightDtypes = dict(rightFrame.dtypes) - for col in set(leftDtypes) & set(rightDtypes): - check( - col, - leftDtypes[col] == rightDtypes[col], - f"Input mismatch on {col}: left={leftDtypes[col]} vs right={rightDtypes[col]}", - ) - # Validate that none of the columns in an index have the same name as a non-index column - # in the opposite dataframe - assert ( - len(set(leftFrame.index.names) & set(rightFrame.columns)) == 0 - ), f"left index: {set(leftFrame.index.names)}; right columns {set(rightFrame.columns)}" - assert ( - len(set(rightFrame.index.names) & set(leftFrame.columns)) == 0 - ), f"right index: {set(rightFrame.index.names)}; left columns {set(leftFrame.columns)}" - # Compute expected types for output columns - commonCols = set(leftFrame.columns) & set(rightFrame.columns) - expectedColTypes = dict() - leftSuffix = kwargs.get("lsuffix", "") - rightSuffix = kwargs.get("rsuffix", "") - for col in set(leftFrame.columns) | set(rightFrame.columns): - if col in commonCols: - expectedColTypes[f"{col}{leftSuffix}"] = leftDtypes[col] - expectedColTypes[f"{col}{rightSuffix}"] = rightDtypes[col] - elif col in leftDtypes: - assert col not in rightDtypes - expectedColTypes[col] = leftDtypes[col] - else: - expectedColTypes[col] = rightDtypes[col] - # Compute expected types for index columns - leftIndexCols = set(leftFrame.index.names) - rightIndexCols = set(rightFrame.index.names) - if len(leftIndexCols) > 1: - leftDtypes = dict(leftFrame.index.dtypes) - else: - leftDtypes = {leftFrame.index.name: leftFrame.index.dtype} - if len(rightIndexCols) > 1: - rightDtypes = dict(rightFrame.index.dtypes) - else: - rightDtypes = {rightFrame.index.name: rightFrame.index.dtype} - for col in leftIndexCols & rightIndexCols: - # For columns in both indices, type should not change if input types agree. If input types - # disagree, then we have no expectation. - if leftDtypes[col] == rightDtypes[col]: - expectedColTypes[col] = leftDtypes[col] + len(set(rightFrame.index.names) & set(leftFrame.columns)) == 0 + ), f"right index: {set(rightFrame.index.names)}; left columns {set(leftFrame.columns)}" + # Compute expected types for output columns + commonCols = set(leftFrame.columns) & set(rightFrame.columns) + expectedColTypes = dict() + leftSuffix = kwargs.get("lsuffix", "") + rightSuffix = kwargs.get("rsuffix", "") + for col in set(leftFrame.columns) | set(rightFrame.columns): + if col in commonCols: + expectedColTypes[f"{col}{leftSuffix}"] = leftDtypes[col] + expectedColTypes[f"{col}{rightSuffix}"] = rightDtypes[col] + elif col in leftDtypes: + assert col not in rightDtypes + expectedColTypes[col] = leftDtypes[col] + else: + expectedColTypes[col] = rightDtypes[col] + # Compute expected types for index columns + leftIndexCols = set(leftFrame.index.names) + rightIndexCols = set(rightFrame.index.names) + if len(leftIndexCols) > 1: + leftDtypes = dict(leftFrame.index.dtypes) else: - expectedColTypes[col] = None - for col in (leftIndexCols | rightIndexCols) - (leftIndexCols & rightIndexCols): - # For columns in exactly one index, the expected output type should match the input column type - # and the column name should not change because we have validated that the column does not - # appear in the other dataframe - if col in leftDtypes: - assert col not in rightDtypes, f"unexpected column: {col}" - expectedColTypes[col] = leftDtypes[col] + leftDtypes = {leftFrame.index.name: rightFrame.index.dtype} + if len(rightIndexCols) > 1: + rightDtypes = dict(rightFrame.index.dtypes) else: - expectedColTypes[col] = rightDtypes[col] - # Perform join and validate results. Note that we already validated that the indices had the - # same columns and types, and that the "on" argument is unset, so now we only need to check - # the non-index columns. - result = original(*args, **kwargs) - # Note that we must reset index to force any NaNs in the index to emerge as float types. - # See example below. - # left = pd.DataFrame({"idx0": [1, 2], "idx1": [11, 12], "val1": [4, 5]}).set_index(["idx0", "idx1"]) - # right = pd.DataFrame({"idx0": [1, 2, 3], "idx2": [21, 22, 23], "val2": [7, 8, 9]}).set_index(["idx0", "idx2"]) - # print(dict(left.join(right, how="outer").index.dtypes)) - # print(dict(left.join(right, how="outer").reset_index(drop=False).dtypes)) - # $> {'idx0': dtype('int64'), 'idx1': dtype('int64'), 'idx2': dtype('int64')} - # $> {'idx0': dtype('int64'), 'idx1': dtype('float64'), 'idx2': dtype('int64'), 'val1': dtype('float64'), 'val2': dtype('int64')} - resultDtypes = dict(result.reset_index(drop=False).dtypes) - # Add default type for index - if "index" not in expectedColTypes: - expectedColTypes["index"] = np.int64 - for col, dtype in resultDtypes.items(): - if len(col) == 2 and col[1] == "": - col = col[0] - check( - col, - dtype == expectedColTypes[col], - f"Output mismatch on {col}: result={dtype} expected={expectedColTypes[col]}", - ) - _log_errors("JOIN", traceback.format_stack()[-2], lines, counter) - return result + rightDtypes = {rightFrame.index.name: rightFrame.index.dtype} + for col in leftIndexCols & rightIndexCols: + # For columns in both indices, type should not change if input types agree. If input types + # disagree, then we have no expectation. + if leftDtypes[col] == rightDtypes[col]: + expectedColTypes[col] = leftDtypes[col] + else: + expectedColTypes[col] = None + for col in (leftIndexCols | rightIndexCols) - (leftIndexCols & rightIndexCols): + # For columns in exactly one index, the expected output type should match the input column type + # and the column name should not change because we have validated that the column does not + # appear in the other dataframe + if col in leftDtypes: + assert col not in rightDtypes, f"unexpected column: {col}" + expectedColTypes[col] = leftDtypes[col] + else: + expectedColTypes[col] = rightDtypes[col] + # Perform join and validate results. Note that we already validated that the indices had the + # same columns and types, and that the "on" argument is unset, so now we only need to check + # the non-index columns. + result = self._origJoin(*args, **kwargs) + # Note that we must reset index to force any NaNs in the index to emerge as float types. + # See example below. + # left = pd.DataFrame({"idx0": [1, 2], "idx1": [11, 12], "val1": [4, 5]}).set_index(["idx0", "idx1"]) + # right = pd.DataFrame({"idx0": [1, 2, 3], "idx2": [21, 22, 23], "val2": [7, 8, 9]}).set_index(["idx0", "idx2"]) + # print(dict(left.join(right, how="outer").index.dtypes)) + # print(dict(left.join(right, how="outer").reset_index(drop=False).dtypes)) + # $> {'idx0': dtype('int64'), 'idx1': dtype('int64'), 'idx2': dtype('int64')} + # $> {'idx0': dtype('int64'), 'idx1': dtype('float64'), 'idx2': dtype('int64'), 'val1': dtype('float64'), 'val2': dtype('int64')} + resultDtypes = dict(result.reset_index(drop=False).dtypes) + # Add default type for index + if "index" not in expectedColTypes: + expectedColTypes["index"] = np.int64 + for col, dtype in resultDtypes.items(): + if len(col) == 2 and col[1] == "": + col = col[0] + check( + col, + dtype == expectedColTypes[col], + f"Output mismatch on {col}: result={dtype} expected={expectedColTypes[col]}", + ) + lines.extend(self._validate_dataframe(result)) + self._log_errors("JOIN", self._get_callsite(), lines) + return result - return _safe_join + return _safe_join +# TODO: restore original functionality before return +# TODO: make enforce_types an explicit arguemnt so this is less error prone def patch_pandas(main: Callable) -> Callable: """Return a decorator for wrapping main with pandas patching and logging @@ -447,17 +598,19 @@ def _inner(*args, **kwargs) -> Any: assert len(kwargs) == 0, f"expected kwargs to be empty, but found {len(kwargs)}" clArgs = args[0] # Apply patches, configured based on whether types should be enforced or logged - counter = TypeErrorCounter() - pd.concat = safe_concat(clArgs.enforce_types, counter) + patcher = PandasPatcher(clArgs.enforce_types) + pd.concat = patcher.safe_concat() # Note that this will work when calling df1.merge(df2) because the first argument # to "merge" is df1 (i.e. self). - pd.DataFrame.merge = safe_merge(clArgs.enforce_types, counter) - pd.DataFrame.join = safe_join(clArgs.enforce_types, counter) + pd.DataFrame.merge = patcher.safe_merge() + pd.DataFrame.join = patcher.safe_join() + pd.DataFrame.apply = patcher.safe_apply() + pd.DataFrame.__init__ = patcher.safe_init() # Run main retVal = main(*args, **kwargs) # Log type error summary if hasattr(clArgs, "parallel") and not clArgs.parallel: - print(f"\nTYPE WARNING SUMMARY\n{counter.get_summary()}", file=sys.stderr) + print(patcher.get_summary(), file=sys.stderr) else: # Don't show type summary because counters will be inaccurate due to scorers running # in their own process. diff --git a/sourcecode/scoring/post_selection_similarity.py b/sourcecode/scoring/post_selection_similarity.py index ea150a11..ca93d20f 100644 --- a/sourcecode/scoring/post_selection_similarity.py +++ b/sourcecode/scoring/post_selection_similarity.py @@ -1,3 +1,4 @@ +import gc from typing import Dict from . import constants as c @@ -188,6 +189,7 @@ def aggregate_into_cliques(graphDf): cliqueToUserMap[sourceDestClique].append(userId) userToCliqueMap[userId] = sourceDestClique del cliqueToUserMap[oldTargetCliqueToDel] + gc.collect() else: # source in map; target not. add target to source's clique diff --git a/sourcecode/scoring/process_data.py b/sourcecode/scoring/process_data.py index 00e82921..b1909377 100644 --- a/sourcecode/scoring/process_data.py +++ b/sourcecode/scoring/process_data.py @@ -45,6 +45,8 @@ def tsv_parser( mapping: Dict[str, type], columns: List[str], header: bool, + useCols: Optional[List[str]] = None, + chunkSize: Optional[int] = None, convertNAToNone: bool = True, ) -> pd.DataFrame: """Parse a TSV input and raise an Exception if the input is not formatted as expected. @@ -54,6 +56,8 @@ def tsv_parser( mapping: Dict mapping column names to types columns: List of column names header: bool indicating whether the input will have a header + useCols: Optional list of columns to return + chunkSize: Optional number of rows to read at a time when returning a subset of columns Returns: pd.DataFrame containing parsed data @@ -64,14 +68,28 @@ def tsv_parser( if num_fields != len(columns): raise ValueError(f"Expected {len(columns)} columns, but got {num_fields}") - data = pd.read_csv( - StringIO(rawTSV), - sep="\t", - names=columns, - dtype=mapping, - header=0 if header else None, - index_col=[], - ) + if useCols and chunkSize: + textParser = pd.read_csv( + StringIO(rawTSV), + sep="\t", + names=columns, + dtype=mapping, + header=0 if header else None, + index_col=[], + usecols=useCols, + chunksize=chunkSize, + ) + data = pd.concat(textParser, ignore_index=True) + else: + data = pd.read_csv( + StringIO(rawTSV), + sep="\t", + names=columns, + dtype=mapping, + header=0 if header else None, + index_col=[], + usecols=useCols, + ) if convertNAToNone: # float types will be nan if missing; newer nullable types like "StringDtype" or "Int64Dtype" will by default # be pandas._libs.missing.NAType if missing. Set those to None and change the dtype back to object. @@ -92,7 +110,7 @@ def tsv_reader_single( ): """Read a single TSV file.""" with open(path, "r", encoding="utf-8") as handle: - return tsv_parser(handle.read(), mapping, columns, header, convertNAToNone) + return tsv_parser(handle.read(), mapping, columns, header, convertNAToNone=convertNAToNone) def tsv_reader( @@ -102,14 +120,21 @@ def tsv_reader( if os.path.isdir(path): dfs = [ tsv_reader_single( - os.path.join(path, filename), mapping, columns, header, parser, convertNAToNone + os.path.join(path, filename), + mapping, + columns, + header, + parser, + convertNAToNone=convertNAToNone, ) for filename in os.listdir(path) if filename.endswith(".tsv") ] return pd.concat(dfs, ignore_index=True) else: - return tsv_reader_single(path, mapping, columns, header, parser, convertNAToNone) + return tsv_reader_single( + path, mapping, columns, header, parser, convertNAToNone=convertNAToNone + ) def read_from_tsv( @@ -321,6 +346,7 @@ def preprocess_data( noteStatusHistory: pd.DataFrame, shouldFilterNotMisleadingNotes: bool = True, logging: bool = True, + ratingsOnly: bool = False, ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """Populate helpfulNumKey, a unified column that merges the helpfulness answers from the V1 and V2 rating forms together, as described in @@ -334,6 +360,7 @@ def preprocess_data( noteStatusHistory (pd.DataFrame) shouldFilterNotMisleadingNotes (bool, optional): Defaults to True. logging (bool, optional): Defaults to True. + ratingsOnly (bool, optional): Defaults to False Returns: notes (pd.DataFrame) @@ -345,13 +372,13 @@ def preprocess_data( "Timestamp of latest rating in data: ", pd.to_datetime(ratings[c.createdAtMillisKey], unit="ms").max(), ) - print( - "Timestamp of latest note in data: ", - pd.to_datetime(notes[c.createdAtMillisKey], unit="ms").max(), - ) + if not ratingsOnly: + print( + "Timestamp of latest note in data: ", + pd.to_datetime(notes[c.createdAtMillisKey], unit="ms").max(), + ) ratings = remove_duplicate_ratings(ratings) - notes = remove_duplicate_notes(notes) ratings.loc[:, c.helpfulNumKey] = np.nan ratings.loc[ratings[c.helpfulKey] == 1, c.helpfulNumKey] = 1 @@ -361,6 +388,11 @@ def preprocess_data( ratings.loc[ratings[c.helpfulnessLevelKey] == c.helpfulValueTsv, c.helpfulNumKey] = 1 ratings = ratings.loc[~pd.isna(ratings[c.helpfulNumKey])] + if ratingsOnly: + return pd.DataFrame(), ratings, pd.DataFrame() + + notes = remove_duplicate_notes(notes) + notes[c.tweetIdKey] = notes[c.tweetIdKey].astype(str) noteStatusHistory = note_status_history.merge_note_info(noteStatusHistory, notes) diff --git a/sourcecode/scoring/reputation_scorer.py b/sourcecode/scoring/reputation_scorer.py index 2fd52dfb..a39497ec 100644 --- a/sourcecode/scoring/reputation_scorer.py +++ b/sourcecode/scoring/reputation_scorer.py @@ -3,13 +3,12 @@ from . import constants as c from .matrix_factorization.matrix_factorization import MatrixFactorization from .mf_base_scorer import get_ratings_for_stable_init -from .mf_core_scorer import filter_core_input, filter_core_output from .process_data import filter_ratings from .reputation_matrix_factorization.helpfulness_model import ( get_helpfulness_reputation_results_final, get_helpfulness_reputation_results_prescoring, ) -from .scorer import Scorer +from .scorer import EmptyRatingException, Scorer import pandas as pd import torch @@ -37,7 +36,14 @@ def __init__( in scoring. Notes with fewer ratings are removed. threads: number of threads to use for intra-op parallelism in pytorch """ - super().__init__(seed, threads) + super().__init__( + includedTopics=set(), + includedGroups=c.coreGroups, + includeUnassigned=True, + captureThreshold=0.5, + seed=seed, + threads=threads, + ) self._minNumRatingsPerRater = minNumRatingsPerRater self._minNumRatersPerNote = minNumRatersPerNote self._crhThreshold = crhThreshold @@ -95,23 +101,13 @@ def _get_user_col_mapping(self) -> Dict[str, str]: c.internalRaterReputationKey: c.raterHelpfulnessReputationKey, } - def _filter_input( - self, - noteTopics: pd.DataFrame, - ratings: pd.DataFrame, - noteStatusHistory: pd.DataFrame, - userEnrollment: pd.DataFrame, - ) -> Tuple[pd.DataFrame, pd.DataFrame]: - ratings, noteStatusHistory = filter_core_input(ratings, noteStatusHistory, userEnrollment) - ratings = filter_ratings(ratings, self._minNumRatingsPerRater, self._minNumRatersPerNote) - return ratings, noteStatusHistory - def _prescore_notes_and_users( self, ratings: pd.DataFrame, noteStatusHistory: pd.DataFrame, userEnrollmentRaw: pd.DataFrame ) -> Tuple[pd.DataFrame, pd.DataFrame, c.PrescoringMetaScorerOutput]: if self._seed is not None: print(f"seeding with {self._seed}") torch.manual_seed(self._seed) + ratings = filter_ratings(ratings, self._minNumRatingsPerRater, self._minNumRatersPerNote) # Calculate initialization factors if necessary noteParamsInit = None raterParamsInit = None @@ -156,6 +152,9 @@ def _score_notes_and_users( if self._seed is not None: print(f"seeding with {self._seed}") torch.manual_seed(self._seed) + ratings = filter_ratings(ratings, self._minNumRatingsPerRater, self._minNumRatersPerNote) + if len(ratings) == 0: + raise EmptyRatingException() # Apply model # Note: we use the low diligence global intercept here as a temporary hack, since the prod scorer's @@ -178,13 +177,3 @@ def _score_notes_and_users( noteStats = noteStats.merge(noteStatusHistory[[c.noteIdKey]].drop_duplicates(), how="outer") assert len(noteStats) == len(noteStatusHistory) return noteStats, raterStats - - def _postprocess_output( - self, - noteScores: pd.DataFrame, - userScores: pd.DataFrame, - ratings: pd.DataFrame, - noteStatusHistory: pd.DataFrame, - userEnrollment: pd.DataFrame, - ) -> Tuple[pd.DataFrame, pd.DataFrame]: - return filter_core_output(ratings, userEnrollment, noteScores), userScores diff --git a/sourcecode/scoring/run_scoring.py b/sourcecode/scoring/run_scoring.py index b39023ea..39e70f05 100644 --- a/sourcecode/scoring/run_scoring.py +++ b/sourcecode/scoring/run_scoring.py @@ -6,6 +6,8 @@ """ import concurrent.futures import copy +import gc +import io from itertools import chain import multiprocessing from multiprocessing import shared_memory # type: ignore @@ -158,6 +160,7 @@ def _merge_results( + [f"{c.modelingGroupKey}_{group}" for group in range(groupScorerCount, 0, -1)] + [f"{c.topicNoteConfidentKey}_{topic.name}" for topic in Topics] + [f"{c.groupNumFinalRoundRatingsKey}_{group}" for group in range(groupScorerCount, 0, -1)] + + [f"{c.topicNumFinalRoundRatingsKey}_{topic.name}" for topic in Topics] ) scoredNotes = scoredNotes.merge( modelScoredNotes, @@ -278,7 +281,7 @@ def _run_scorer_parallelizable( # Run scoring scorerStartTime = time.perf_counter() if type(scoringArgs) == PrescoringArgs: - scoringResults = scorer.prescore(scoringArgs) + scoringResults = scorer.prescore(scoringArgs, preserveRatings=not runParallel) elif type(scoringArgs) == FinalScoringArgs: scoringResults = scorer.score_final(scoringArgs) else: @@ -294,19 +297,15 @@ def save_df_to_shared_memory(df: pd.DataFrame, shms: List) -> c.SharedMemoryData and returns the info needed to access it, as well as appends it to the list of shared memory objects so it's not garbage collected and can be closed later. """ - cols = df.columns - data = df.to_numpy() - df_dtypes_dict = dict(list(zip(df.columns, df.dtypes))) - shm = shared_memory.SharedMemory(create=True, size=data.nbytes) - np_array = np.ndarray(data.shape, dtype=data.dtype, buffer=shm.buf) - np_array[:] = data[:] + with io.BytesIO() as buf: + df.to_parquet(buf, compression="gzip", engine="pyarrow") + size = len(buf.getvalue()) + shm = shared_memory.SharedMemory(create=True, size=size) + shm.buf[:size] = buf.getvalue() shms.append(shm) # save the shared memory object so we can close it later return c.SharedMemoryDataframeInfo( sharedMemoryName=shm.name, - columns=cols, - dataShape=data.shape, - dtypesDict=df_dtypes_dict, - npDtype=np_array.dtype, + dataSize=size, ) @@ -316,12 +315,9 @@ def get_df_from_shared_memory(sharedMemoryDfInfo: c.SharedMemoryDataframeInfo) - Read a dataframe from shared memory and return it. """ existing_shm = shared_memory.SharedMemory(name=sharedMemoryDfInfo.sharedMemoryName) - np_array = np.ndarray( - sharedMemoryDfInfo.dataShape, buffer=existing_shm.buf, dtype=sharedMemoryDfInfo.npDtype - ) - df = pd.DataFrame(np_array, columns=sharedMemoryDfInfo.columns) - df = df.astype(sharedMemoryDfInfo.dtypesDict) - return df + size = sharedMemoryDfInfo.dataSize + with io.BytesIO(existing_shm.buf[:size]) as buf: + return pd.read_parquet(buf) def _save_dfs_to_shared_memory( @@ -1030,6 +1026,7 @@ def run_prescoring( ) print(f"Post Selection Similarity Prescoring: {len(ratings)} ratings remaining.") del pss + gc.collect() scorers = _get_scorers( seed=seed, @@ -1067,13 +1064,19 @@ def run_prescoring( ) del prescoringModelResultsFromAllScorers del scorers + gc.collect() # Prescoring itself is now done. We will not run final_note_scoring to check note status flips. if checkFlips: - # Rescore all notes. TODO: in the future, consider only rescoring a subset, e.g. unlocked notes. - ratingsToRescore = ratings - notesToRescore = notes - noteStatusHistoryToRescore = noteStatusHistory + # Rescore a smaller set of notes, since we are only using these note statuses to check for flips. + # Rescore a only unlocked notes. (In the future, we could randomly sample a subset of these) + noteStatusHistoryToRescore = noteStatusHistory[ + noteStatusHistory[c.timestampMillisOfStatusLockKey].isna() + ] + + notesToRescoreSet = set(noteStatusHistoryToRescore[c.noteIdKey]) + ratingsToRescore = ratings[ratings["noteId"].isin(notesToRescoreSet)].copy() + notesToRescore = notes[notes["noteId"].isin(notesToRescoreSet)].copy() scoredNotes, _, _ = run_final_note_scoring( notes=notesToRescore, @@ -1138,15 +1141,31 @@ def determine_which_notes_to_rescore( noteStatusHistory: pd.DataFrame, previousRatingCutoffTimestampMillis: Optional[int] = None, scoreRecentNotesMinimumFrequencyMillis: Optional[int] = 1000 * 60 * 60 * 24, # 1 day - recentNotesAgeCutoffMillis: Optional[int] = 1000 * 60 * 60 * 24 * 14, # 14 days -) -> Tuple[Optional[List[c.NoteSubset]], set]: + recentNotesAgeCutoffMillis: Optional[int] = 1000 * 60 * 60 * 24 * 14, # 14 days, + scoreRecentlyFlippedNotesMinimumFrequencyMillis: Optional[int] = 1000 * 60 * 60 * 1, # 1 hour + recentlyFlippedNoteAgeCutoffMillis: Optional[int] = 1000 * 60 * 60 * 24, # 1 day +) -> Tuple[List[c.NoteSubset], set]: + notesToRescoreSet = set() + noteSubsets = [] + # 1. Rescore all notes with a new rating since last scoring run. if previousRatingCutoffTimestampMillis is not None: notesWithNewRatings = set( ratings.loc[ratings[c.createdAtMillisKey] > previousRatingCutoffTimestampMillis, c.noteIdKey] ) + print( + f"1. Num notes with new ratings since last scoring run (ts: {previousRatingCutoffTimestampMillis}): {len(notesWithNewRatings)}" + ) + notesToRescoreSet.update(notesWithNewRatings) else: notesWithNewRatings = set() + noteSubsets.append( + c.NoteSubset( + noteSet=notesWithNewRatings, + maxCrhChurnRate=c.finalNotesWithNewRatingsMaxCrhChurn, + description=c.RescoringRuleID.NOTES_WITH_NEW_RATINGS, + ) + ) currentMillis = int(time.time() * 1000) @@ -1162,38 +1181,82 @@ def determine_which_notes_to_rescore( newNotesNotRescoredRecentlyEnough = set( noteStatusHistory.loc[noteCreatedRecently & noteNotRescoredRecently, c.noteIdKey] ) + print("2. Rescore all recently created notes if not rescored at the minimum frequency.") + print("Num notes created recently:", noteCreatedRecently.sum()) # Remove notes with new ratings from this set. newNotesNotRescoredRecentlyEnough = newNotesNotRescoredRecentlyEnough.difference( notesWithNewRatings ) + notesToRescoreSet.update(newNotesNotRescoredRecentlyEnough) else: newNotesNotRescoredRecentlyEnough = set() - - # TODO: 3. Recently-flipped notes. - - noteSubsets = [ - c.NoteSubset( - noteSet=notesWithNewRatings, - maxCrhChurnRate=c.finalNotesWithNewRatingsMaxCrhChurn, - description="notesWithNewRatings", - ), + noteSubsets.append( c.NoteSubset( noteSet=newNotesNotRescoredRecentlyEnough, maxCrhChurnRate=c.finalUnlockedNotesWithNoNewRatingsMaxCrhChurn, - description="newNotesNotRescoredRecentlyEnough", - ), - ] + description=c.RescoringRuleID.NEW_NOTES_NOT_RESCORED_RECENTLY_ENOUGH, + ) + ) - notesToRescoreSet = set() - for noteSubset in noteSubsets: - if noteSubset.noteSet is not None: - notesToRescoreSet.update(noteSubset.noteSet) + # 3. Rescore all notes that flipped status in the previous scoring run. + justFlippedNotes = set( + noteStatusHistory.loc[ + ( + noteStatusHistory[c.timestampMillisOfMostRecentStatusChangeKey] + == noteStatusHistory[c.timestampMillisOfNoteCurrentLabelKey] + ), + c.noteIdKey, + ] + ).difference(notesWithNewRatings) + print( + "3. Rescore all notes that flipped status in the previous scoring run.", len(justFlippedNotes) + ) + notesToRescoreSet.update(justFlippedNotes) + noteSubsets.append( + c.NoteSubset( + noteSet=justFlippedNotes, + maxCrhChurnRate=c.finalNotesThatJustFlippedStatusMaxCrhChurn, + description=c.RescoringRuleID.NOTES_FLIPPED_PREVIOUS_RUN, + ) + ) + + # 4. Rescore all recently-flipped notes if not rescored at the minimum frequency. + if ( + recentlyFlippedNoteAgeCutoffMillis is not None + and scoreRecentlyFlippedNotesMinimumFrequencyMillis is not None + ): + noteFlippedRecently = ( + noteStatusHistory[c.timestampMillisOfMostRecentStatusChangeKey] + > currentMillis - recentlyFlippedNoteAgeCutoffMillis + ) + noteNotRescoredRecently = ( + noteStatusHistory[c.timestampMillisOfNoteCurrentLabelKey] + < currentMillis - scoreRecentlyFlippedNotesMinimumFrequencyMillis + ) + print("4. Rescore all recently-flipped notes if not rescored at the minimum frequency.") + print("Num notes flipped recently:", noteFlippedRecently.sum()) + print("Num notes not rescored recently enough:", noteNotRescoredRecently.sum()) + recentlyFlippedNotesNotRescoredRecentlyEnough = set( + noteStatusHistory.loc[noteFlippedRecently & noteNotRescoredRecently, c.noteIdKey] + ) + notesToRescoreSet.update(recentlyFlippedNotesNotRescoredRecentlyEnough) + else: + recentlyFlippedNotesNotRescoredRecentlyEnough = set() + noteSubsets.append( + c.NoteSubset( + noteSet=recentlyFlippedNotesNotRescoredRecentlyEnough, + maxCrhChurnRate=c.finalNotesThatFlippedRecentlyMaxCrhChurn, + description=c.RescoringRuleID.RECENTLY_FLIPPED_NOTES_NOT_RESCORED_RECENTLY_ENOUGH, + ) + ) print( - f"""Notes to rescore: - {len(notesWithNewRatings)} notes with new ratings since last scoring run. - {len(newNotesNotRescoredRecentlyEnough)} notes created recently and not rescored recently enough. - Total: {len(notesToRescoreSet)} notes to rescore, out of {len(notes)} total.""" + f"""----\nNotes to rescore: + * {len(notesWithNewRatings)} notes with new ratings since last scoring run. + * {len(newNotesNotRescoredRecentlyEnough)} notes created recently and not rescored recently enough. + * {len(justFlippedNotes)} notes that flipped status in the previous scoring run. + * {len(recentlyFlippedNotesNotRescoredRecentlyEnough)} notes that flipped status recently and not rescored recently enough. + Overall: {len(notesToRescoreSet)} notes to rescore, out of {len(notes)} total.\n----""" ) return noteSubsets, notesToRescoreSet @@ -1235,11 +1298,11 @@ def run_final_note_scoring( print("No previous scored notes passed; scoring all notes.") notesToRescoreSet: Set[int] = set() scoredNotesPassthrough = None - noteSubsets: Optional[List[c.NoteSubset]] = [ + noteSubsets: List[c.NoteSubset] = [ c.NoteSubset( noteSet=None, maxCrhChurnRate=c.prescoringAllUnlockedNotesMaxCrhChurn, - description="allNotes", + description=c.RescoringRuleID.ALL_NOTES, ) ] else: @@ -1251,9 +1314,6 @@ def run_final_note_scoring( notes, ratings, noteStatusHistory, previousRatingCutoffTimestampMillis ) - for item in notesToRescoreSet: - print(f"notesToRescoreSet: {item}, {type(item)}, len: {len(notesToRescoreSet)}") - break scoredNotesPassthrough = previousScoredNotes[ ~previousScoredNotes[c.noteIdKey].isin(notesToRescoreSet) ] @@ -1321,18 +1381,16 @@ def run_final_note_scoring( scoredNotes, auxiliaryNoteInfo = combine_final_scorer_results(modelResults, noteStatusHistory) - if not checkFlips: - noteSubsets = None - scoredNotes, newNoteStatusHistory, auxiliaryNoteInfo = post_note_scoring( scorers, scoredNotes, auxiliaryNoteInfo, ratings, noteStatusHistory, + noteSubsets, enabledScorers, strictColumns, - noteSubsets, + checkFlips, ) # Concat final scoring results for newly-scored notes with the results for old notes not scores. @@ -1345,8 +1403,10 @@ def run_final_note_scoring( continue if scoredNotes[column].dtype != targetDtype: scoredNotes[column] = scoredNotes[column].astype(targetDtype) + scoredNotesPassthrough[c.rescoringActiveRulesKey] = "" scoredNotes = pd.concat( [scoredNotes, scoredNotesPassthrough], + unsafeAllowed=[c.topicNoteConfidentKey], # concat 'O' with BooleanDtype ) # Convert auxiliaryNoteInfo dtypes to match auxiliaryNoteInfoPassthrough @@ -1369,9 +1429,10 @@ def post_note_scoring( auxiliaryNoteInfo: pd.DataFrame, ratings: pd.DataFrame, noteStatusHistory: pd.DataFrame, + noteSubsetsAndMaxFlipRates: List[c.NoteSubset], enabledScorers: Optional[Set[Scorers]] = None, strictColumns: bool = True, - noteSubsetsAndMaxFlipRates: Optional[List[c.NoteSubset]] = None, + checkFlips: bool = True, ): """ Apply individual scoring models and obtained merged result. @@ -1409,14 +1470,27 @@ def post_note_scoring( noteStatusHistory ), "noteStatusHistory should be complete, and all notes should be scored." - # Merge scoring results into noteStatusHistory. + # Merge scoring results into noteStatusHistory, check flip rates, and set rescoringActiveRules. with c.time_block("Post-scorers: Update note status history"): mergedNoteStatuses = note_status_history.merge_old_and_new_note_statuses( noteStatusHistory, scoredNotes ) - if noteSubsetsAndMaxFlipRates is not None: - for noteSubset in noteSubsetsAndMaxFlipRates: + + scoredNotes[c.rescoringActiveRulesKey] = "" + for noteSubset in noteSubsetsAndMaxFlipRates: + if checkFlips: note_status_history.check_flips(mergedNoteStatuses, noteSubset=noteSubset) + if noteSubset.noteSet is not None: + noteInSetMask = scoredNotes[c.noteIdKey].isin(noteSubset.noteSet) + else: + noteInSetMask = scoredNotes[c.noteIdKey].notnull() # All notes by default. + scoredNotes.loc[noteInSetMask, c.rescoringActiveRulesKey] = scoredNotes.loc[ + noteInSetMask, c.rescoringActiveRulesKey + ].apply( + lambda rescoringActiveRules: rescoringActiveRules + noteSubset.description.name + if len(rescoringActiveRules) == 0 + else f"{rescoringActiveRules},{noteSubset.description.name}" + ) newNoteStatusHistory = note_status_history.update_note_status_history(mergedNoteStatuses) assert len(newNoteStatusHistory) == len( diff --git a/sourcecode/scoring/scorer.py b/sourcecode/scoring/scorer.py index a9becd53..4c9e505b 100644 --- a/sourcecode/scoring/scorer.py +++ b/sourcecode/scoring/scorer.py @@ -1,16 +1,25 @@ from abc import ABC, abstractmethod from contextlib import contextmanager +import gc import time -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Set, Tuple from . import constants as c from .constants import FinalScoringArgs, ModelResult, PrescoringArgs +from .pandas_utils import keep_columns import numpy as np import pandas as pd import torch +_IN_GROUP = "inGroup" + + +class EmptyRatingException(Exception): + """Exception rasied when no ratings are available""" + + class Scorer(ABC): """Base class which all other scorers must extend. @@ -20,12 +29,24 @@ class Scorer(ABC): exactly which columns are output and which are dropped. """ - def __init__(self, seed: Optional[int] = None, threads: int = c.defaultNumThreads) -> None: + def __init__( + self, + includedTopics: Set[str] = set(), + includedGroups: Set[int] = set(), + includeUnassigned: bool = False, + captureThreshold: Optional[float] = None, + seed: Optional[int] = None, + threads: int = c.defaultNumThreads, + ) -> None: """Configure a new Scorer object. Args: seed (int, optional): if not None, seed value to ensure deterministic execution """ + self._includedTopics = includedTopics + self._includedGroups = includedGroups + self._includeUnassigned = includeUnassigned + self._captureThreshold = captureThreshold self._seed = seed self._threads = threads @@ -90,6 +111,30 @@ def _filter_input( ratings: ratings filtered to only contain rows of interest noteStatusHistory: noteStatusHistory filtered to only contain rows of interest """ + if (not self._includedGroups) and (not self._includedTopics): + return ratings, noteStatusHistory + print(f"Filtering ratings for {self.get_name()}. Original rating length: {len(ratings)}") + # Apply topic filter + if self._includedTopics: + notes = noteTopics[noteTopics[c.noteTopicKey].isin(self._includedTopics)][[c.noteIdKey]] + ratings = ratings.merge(notes) + noteStatusHistory = noteStatusHistory.merge(notes) + print(f" Ratings after topic filter: {len(ratings)}") + # Apply group filter + if self._includedGroups: + userEnrollment = userEnrollment[[c.participantIdKey, c.modelingGroupKey]].rename( + columns={c.participantIdKey: c.raterParticipantIdKey} + ) + userEnrollment.loc[:, _IN_GROUP] = ( + userEnrollment[c.modelingGroupKey].isin(self._includedGroups).astype(pd.BooleanDtype()) + ) + ratings = ratings.merge( + userEnrollment[[c.raterParticipantIdKey, _IN_GROUP]], on=c.raterParticipantIdKey, how="left" + ) + print(f" Ratings without assigned group: {ratings[_IN_GROUP].isna().sum()}") + ratings = ratings.fillna({_IN_GROUP: self._includeUnassigned}) + ratings = ratings[ratings[_IN_GROUP]].drop(columns=[_IN_GROUP]) + print(f" Ratings after group filter: {len(ratings)}") return ratings, noteStatusHistory def _postprocess_output( @@ -119,6 +164,27 @@ def _postprocess_output( noteScores: note scoring output from _score_notes_and_users userScores: user scoring output from _score_notes_and_users """ + if self._captureThreshold is None: + return noteScores, userScores + # Identify notes with enough ratings from within the modeling group. + print(f"Postprocessing output for {self.get_name()}") + assert self._includedGroups, "includedGroups must be set" + userEnrollment = userEnrollment[[c.participantIdKey, c.modelingGroupKey]].rename( + columns={c.participantIdKey: c.raterParticipantIdKey} + ) + userEnrollment.loc[:, _IN_GROUP] = ( + userEnrollment[c.modelingGroupKey].isin(self._includedGroups).astype(pd.BooleanDtype()) + ) + ratings = ratings.merge( + userEnrollment[[c.raterParticipantIdKey, _IN_GROUP]], on=c.raterParticipantIdKey, how="left" + ) + ratings = ratings.fillna({_IN_GROUP: self._includeUnassigned}) + ratios = ratings[[c.noteIdKey, _IN_GROUP]].groupby(c.noteIdKey).mean().reset_index() + print(f" Original noteScores length: {len(noteScores)}") + noteScores = noteScores.merge( + ratios[ratios[_IN_GROUP] >= self._captureThreshold][[c.noteIdKey]] + ) + print(f" Final noteScores length: {len(noteScores)}") return noteScores, userScores def _get_note_col_mapping(self) -> Dict[str, str]: @@ -175,7 +241,7 @@ def _score_notes_and_users( userScores pd.DataFrame: one row per user containing a column for each helpfulness score. """ - def prescore(self, scoringArgs: PrescoringArgs) -> ModelResult: + def prescore(self, scoringArgs: PrescoringArgs, preserveRatings: bool = True) -> ModelResult: """ Runs initial rounds of the matrix factorization scoring algorithm and returns intermediate output that can be used to initialize and reduce the runtime of final scoring. @@ -189,10 +255,27 @@ def prescore(self, scoringArgs: PrescoringArgs) -> ModelResult: with self.time_block("Filter input"): ratings, noteStatusHistory = self._filter_input( scoringArgs.noteTopics, - scoringArgs.ratings, + keep_columns( + scoringArgs.ratings, + [ + c.noteIdKey, + c.raterParticipantIdKey, + c.helpfulNumKey, + c.helpfulnessLevelKey, + c.createdAtMillisKey, + ] + + c.notHelpfulTagsTSVOrder + + c.helpfulTagsTSVOrder, + ), scoringArgs.noteStatusHistory, scoringArgs.userEnrollment, ) + if not preserveRatings: + # Only remove ratings if we're running in parallel, since otherwise later scorers will + # need the ratings. + del scoringArgs.ratings + gc.collect() + # If there are no ratings left after filtering, then return empty dataframes. if len(ratings) == 0: return ModelResult( @@ -215,6 +298,10 @@ def prescore(self, scoringArgs: PrescoringArgs) -> ModelResult: ratings, noteStatusHistory, scoringArgs.userEnrollment ) + # Returning should remove references to ratings, but manually trigger GC just to reclaim + # resources as soon as possible. + del ratings + gc.collect() # Return dataframes with specified columns in specified order # Reindex fills required columns with NaN if they aren't present in the original df. return ModelResult( @@ -293,13 +380,16 @@ def score_final(self, scoringArgs: FinalScoringArgs) -> ModelResult: if len(ratings) == 0: return self._return_empty_final_scores() - noteScores, userScores = self._score_notes_and_users( - ratings=ratings, - noteStatusHistory=noteStatusHistory, - prescoringNoteModelOutput=prescoringNoteModelOutput, - prescoringRaterModelOutput=prescoringRaterModelOutput, - prescoringMetaScorerOutput=prescoringMetaScorerOutput, - ) + try: + noteScores, userScores = self._score_notes_and_users( + ratings=ratings, + noteStatusHistory=noteStatusHistory, + prescoringNoteModelOutput=prescoringNoteModelOutput, + prescoringRaterModelOutput=prescoringRaterModelOutput, + prescoringMetaScorerOutput=prescoringMetaScorerOutput, + ) + except EmptyRatingException: + return self._return_empty_final_scores() with self.time_block("Postprocess output"): # Only some subclasses do any postprocessing. diff --git a/sourcecode/scoring/scoring_rules.py b/sourcecode/scoring/scoring_rules.py index e4b3d49c..bc63de76 100644 --- a/sourcecode/scoring/scoring_rules.py +++ b/sourcecode/scoring/scoring_rules.py @@ -241,8 +241,11 @@ def score_notes( print(f"CRH notes prior to tag filtering: {len(crhStats)}") # Identify impacted notes. - impactedNotes = pd.DataFrame.from_dict({c.noteIdKey: [], c.activeFilterTagsKey: []}).astype( - {c.noteIdKey: np.int64} + impactedNotes = pd.DataFrame.from_dict( + { + c.noteIdKey: pd.Series([], dtype=np.int64), + c.activeFilterTagsKey: pd.Series([], dtype=object), + } ) print("Checking note tags:") for tag in c.notHelpfulTagsTSVOrder: @@ -534,7 +537,6 @@ def score_notes( ) -> Tuple[pd.DataFrame, pd.DataFrame]: """Sets Top Tags inplace on noteStats, returns notes on track for CRH / CRNH with insufficient to receive NMR status.""" - noteStats[c.noteIdKey] = noteStats[c.noteIdKey].astype(np.int64) noteStats[c.firstTagKey] = noteStats[c.firstTagKey].astype(object) noteStats[c.secondTagKey] = noteStats[c.secondTagKey].astype(object) @@ -578,14 +580,6 @@ def score_notes( noteStats.loc[:, c.firstTagKey] = topTags[c.firstTagKey] noteStats.loc[:, c.secondTagKey] = topTags[c.secondTagKey] - # The "apply" above converts the noteId column to a float. This cast - # guarantees that the type of the noteId column remains int64. Note that the cast will fail - # if the noteId column includes nan values. - # - # See links below for more context: - # https://stackoverflow.com/questions/40251948/stop-pandas-from-converting-int-to-float-due-to-an-insertion-in-another-column - # https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.Series.convert_dtypes.html - noteStats[c.noteIdKey] = noteStats[c.noteIdKey].astype(np.int64) noteStats[c.firstTagKey] = noteStats[c.firstTagKey].astype(object) noteStats[c.secondTagKey] = noteStats[c.secondTagKey].astype(object) @@ -813,13 +807,13 @@ def apply_scoring_rules( """ # Initialize empty dataframes to store labels for each note and which rules impacted # scoring for each note. - noteLabels = pd.DataFrame.from_dict({c.noteIdKey: [], statusColumn: []}).astype( - {c.noteIdKey: np.int64} + noteLabels = pd.DataFrame.from_dict( + {c.noteIdKey: pd.Series([], dtype=np.int64), statusColumn: pd.Series([], dtype=object)} ) - noteRules = pd.DataFrame.from_dict({c.noteIdKey: [], ruleColumn: []}).astype( - {c.noteIdKey: np.int64} + noteRules = pd.DataFrame.from_dict( + {c.noteIdKey: pd.Series([], dtype=np.int64), ruleColumn: pd.Series([], dtype=object)} ) - noteColumns = pd.DataFrame.from_dict({c.noteIdKey: []}).astype({c.noteIdKey: np.int64}) + noteColumns = pd.DataFrame.from_dict({c.noteIdKey: pd.Series([], dtype=np.int64)}) # Establish state to enforce rule dependencies. ruleIDs: Set[RuleID] = set()