Skip to content

Commit 9ee120f

Browse files
authored
Merge pull request #176 from twitter/jbaxter/2023_11_29_cleanup
Cleanup after expansion-plus launch
2 parents 6f46290 + 336d3d3 commit 9ee120f

File tree

5 files changed

+14
-68
lines changed

5 files changed

+14
-68
lines changed

sourcecode/scoring/constants.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -410,21 +410,12 @@ def rater_factor_key(i):
410410
(timestampOfLastStateChange, np.int64),
411411
(timestampOfLastEarnOut, np.double), # double because nullable.
412412
(modelingPopulationKey, str),
413+
(modelingGroupKey, np.float64),
413414
]
414415
userEnrollmentTSVColumns = [col for (col, _) in userEnrollmentTSVColumnsAndTypes]
415416
userEnrollmentTSVTypes = [dtype for (_, dtype) in userEnrollmentTSVColumnsAndTypes]
416417
userEnrollmentTSVTypeMapping = {col: dtype for (col, dtype) in userEnrollmentTSVColumnsAndTypes}
417418

418-
# TODO: delete expanded user enrollment definition once modeling group is fully rolled out
419-
userEnrollmentExpandedTSVColumnsAndTypes = userEnrollmentTSVColumnsAndTypes + [
420-
(modelingGroupKey, np.float64)
421-
]
422-
userEnrollmentExpandedTSVColumns = [col for (col, _) in userEnrollmentExpandedTSVColumnsAndTypes]
423-
userEnrollmentExpandedTSVTypes = [dtype for (_, dtype) in userEnrollmentExpandedTSVColumnsAndTypes]
424-
userEnrollmentExpandedTSVTypeMapping = {
425-
col: dtype for (col, dtype) in userEnrollmentExpandedTSVColumnsAndTypes
426-
}
427-
428419
noteInterceptMaxKey = "internalNoteIntercept_max"
429420
noteInterceptMinKey = "internalNoteIntercept_min"
430421
noteParameterUncertaintyTSVMainColumnsAndTypes = [

sourcecode/scoring/mf_core_scorer.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,4 @@ def _filter_input(
173173
noteStatusHistory = noteStatusHistory[noteStatusHistory[c.noteIdKey].isin(coreNotes)]
174174
print(f" Core ratings: {len(ratings)}")
175175

176-
# Guarantee ordering of ratings and noteStatusHistory remains the same relative to the
177-
# original ordering. This code exists to stabilize system test results and can be removed
178-
# once we're confident the rest of the implementation is correct.
179-
ratings = ratings.sort_values([c.noteIdKey, c.raterParticipantIdKey])
180-
noteStatusHistory = noteStatusHistory.sort_values(c.noteIdKey)
181-
182176
return ratings, noteStatusHistory

sourcecode/scoring/mf_expansion_scorer.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -153,20 +153,6 @@ def _filter_input(
153153
)
154154
print(f" Ratings after EXPANSION_PLUS notes filter: {len(ratings)}")
155155

156-
# Guarantee ordering of ratings and noteStatusHistory remains the same relative to the
157-
# original ordering. This code exists to stabilize system test results and can be removed
158-
# once we're confident the rest of the implementation is correct.
159-
ratingOrder = ratingsOrig[[c.noteIdKey, c.raterParticipantIdKey]].reset_index(drop=False)
160-
numRatings = len(ratings)
161-
ratings = ratings.merge(ratingOrder, on=[c.noteIdKey, c.raterParticipantIdKey], how="inner")
162-
assert len(ratings) == numRatings, f"mismatch: {len(ratings)} != {numRatings}"
163-
ratings = ratings.sort_values("index").drop(columns="index")
164-
nshOrder = noteStatusHistoryOrig[[c.noteIdKey]].reset_index(drop=False)
165-
numNotes = len(noteStatusHistory)
166-
noteStatusHistory = noteStatusHistory.merge(nshOrder, on=c.noteIdKey, how="inner")
167-
assert len(noteStatusHistory) == numNotes, f"mismatch: {len(noteStatusHistory)} != {numNotes}"
168-
noteStatusHistory = noteStatusHistory.sort_values("index").drop(columns="index")
169-
170156
return ratings.drop(columns=_EXPANSION_PLUS_BOOL), noteStatusHistory.drop(
171157
columns=_EXPANSION_PLUS_BOOL
172158
)

sourcecode/scoring/process_data.py

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -71,34 +71,9 @@ def tsv_parser(
7171
raise ValueError(f"Invalid input: {e}")
7272

7373

74-
# TODO: remove this function once modelingGroup column is fully launched
75-
def user_enrollment_parser(rawTSV: str, header: bool) -> pd.DataFrame:
76-
"""Parse user enrollment TSV and optinoally tolerate the modelingGroup column.
77-
78-
Args:
79-
rawTSV: str contianing entire TSV input
80-
header: bool indicating whether the input will have a header
81-
82-
Returns:
83-
pd.DataFrame containing parsed data
84-
"""
85-
try:
86-
df = tsv_parser(rawTSV, c.userEnrollmentTSVTypeMapping, c.userEnrollmentTSVColumns, header)
87-
df[c.modelingGroupKey] = 0
88-
except ValueError:
89-
df = tsv_parser(
90-
rawTSV, c.userEnrollmentExpandedTSVTypeMapping, c.userEnrollmentExpandedTSVColumns, header
91-
)
92-
return df
93-
94-
95-
# TODO: remove support for specifying a custom parser once modelingGroup is fully rolled out
96-
def tsv_reader(path: str, mapping, columns, header=False, parser=tsv_parser):
74+
def tsv_reader(path: str, mapping, columns, header=False):
9775
with open(path, "r") as handle:
98-
if parser == tsv_parser:
99-
return parser(handle.read(), mapping, columns, header)
100-
else:
101-
return parser(handle.read(), header)
76+
return tsv_parser(handle.read(), mapping, columns, header)
10277

10378

10479
def read_from_tsv(
@@ -159,13 +134,13 @@ def read_from_tsv(
159134
userEnrollment = None
160135
else:
161136
userEnrollment = tsv_reader(
162-
userEnrollmentPath, None, None, header=headers, parser=user_enrollment_parser
137+
userEnrollmentPath, c.userEnrollmentTSVTypeMapping, c.userEnrollmentTSVColumns, header=headers
163138
)
164-
assert len(userEnrollment.columns.values) <= len(c.userEnrollmentExpandedTSVColumns) and (
165-
len(set(userEnrollment.columns) - set(c.userEnrollmentExpandedTSVColumns)) == 0
139+
assert len(userEnrollment.columns.values) == len(c.userEnrollmentTSVColumns) and all(
140+
userEnrollment.columns == c.userEnrollmentTSVColumns
166141
), (
167-
f"userEnrollment columns don't match: \n{[col for col in userEnrollment.columns if not col in c.userEnrollmentExpandedTSVColumns]} are extra columns, "
168-
+ f"\n{[col for col in c.userEnrollmentExpandedTSVColumns if not col in userEnrollment.columns]} are missing."
142+
f"userEnrollment columns don't match: \n{[col for col in userEnrollment.columns if not col in c.userEnrollmentTSVColumns]} are extra columns, "
143+
+ f"\n{[col for col in c.userEnrollmentTSVColumns if not col in userEnrollment.columns]} are missing."
169144
)
170145

171146
return notes, ratings, noteStatusHistory, userEnrollment

sourcecode/scoring/run_scoring.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,15 @@ def _get_scorers(
5050

5151
if enabledScorers is None or Scorers.MFCoreScorer in enabledScorers:
5252
scorers[Scorers.MFCoreScorer] = [
53-
MFCoreScorer(seed, pseudoraters, useStableInitialization=useStableInitialization, threads=16)
53+
MFCoreScorer(seed, pseudoraters, useStableInitialization=useStableInitialization, threads=12)
5454
]
5555
if enabledScorers is None or Scorers.MFExpansionScorer in enabledScorers:
5656
scorers[Scorers.MFExpansionScorer] = [
57-
MFExpansionScorer(seed, useStableInitialization=useStableInitialization, threads=16)
57+
MFExpansionScorer(seed, useStableInitialization=useStableInitialization, threads=12)
5858
]
5959
if enabledScorers is None or Scorers.MFExpansionPlusScorer in enabledScorers:
6060
scorers[Scorers.MFExpansionPlusScorer] = [
61-
MFExpansionPlusScorer(seed, useStableInitialization=useStableInitialization, threads=16)
61+
MFExpansionPlusScorer(seed, useStableInitialization=useStableInitialization, threads=12)
6262
]
6363
if enabledScorers is None or Scorers.MFGroupScorer in enabledScorers:
6464
# Note that index 0 is reserved, corresponding to no group assigned, so scoring group
@@ -651,10 +651,10 @@ def run_scoring(
651651
maxReruns,
652652
runParallel=runParallel,
653653
dataLoader=dataLoader,
654-
# Restrict parallelism to 4 processes. Memory usage scales linearly with the number of
655-
# processes and 4 is enough that the limiting factor continues to be the longest running
654+
# Restrict parallelism to 6 processes. Memory usage scales linearly with the number of
655+
# processes and 6 is enough that the limiting factor continues to be the longest running
656656
# scorer (i.e. we would not finish faster with >4 worker processes.)
657-
maxWorkers=4,
657+
maxWorkers=6,
658658
)
659659

660660
postScoringStartTime = time.time()

0 commit comments

Comments
 (0)