Skip to content

Commit 44d3791

Browse files
committed
On my laptop takes 16.5s rather than 54.5s
1 parent 1aaffe1 commit 44d3791

File tree

1 file changed

+13
-18
lines changed

1 file changed

+13
-18
lines changed

pangolin/scripts/pangolearn.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import argparse
1414
import os
1515

16+
import json, sys
17+
1618
def parse_args():
1719
parser = argparse.ArgumentParser(description='pangoLEARN.')
1820
parser.add_argument("--header-file", action="store", type=str, dest="header_file")
@@ -111,8 +113,14 @@ def readInAndFormatData(sequencesFile, indiciesToKeep, blockSize=1000):
111113
# loading the list of headers the model needs.
112114
model_headers = joblib.load(args.header_file)
113115
indiciesToKeep = model_headers[1:]
116+
with open("/tmp/model_headers.csv", "w") as f:
117+
print("\n".join(map(str,indiciesToKeep)), file=f)
114118

115119
referenceSeq = findReferenceSeq()
120+
# possible nucleotide symbols
121+
categories = ['-','A', 'C', 'G', 'T']
122+
columns = [f"{i}_{c}" for i in indiciesToKeep for c in categories]
123+
refRow = {f"{i}_{c}": 1 for i,c in zip(indiciesToKeep, encodeSeq(referenceSeq, indiciesToKeep))}
116124

117125
print("loading model " + datetime.now().strftime("%m/%d/%Y, %H:%M:%S"))
118126
loaded_model = joblib.load(args.model_file)
@@ -125,30 +133,17 @@ def readInAndFormatData(sequencesFile, indiciesToKeep, blockSize=1000):
125133
len(seqList), datetime.now().strftime("%m/%d/%Y, %H:%M:%S")
126134
))
127135

136+
rows = [{f"{i}_{c}": 1 for i,c in zip(indiciesToKeep, row)} for row in seqList]
128137
# the reference seq must be added to everry block to make sure that the
129138
# spots in the reference have Ns are in the dataframe to guarentee that
130139
# the correct number of columns is created when get_dummies is called
140+
rows.append(refRow)
131141
idList.append(referenceId)
132-
seqList.append(encodeSeq(referenceSeq, indiciesToKeep))
133142

134143
# create a data from from the seqList
135-
df = pd.DataFrame(seqList, columns=indiciesToKeep)
136-
137-
# possible nucleotide symbols
138-
categories = ['A', 'C', 'G', 'T', '-']
139-
140-
# add extra rows to ensure all of the categories are represented, as otherwise
141-
# not enough columns will be created when we call get_dummies
142-
extra_rows = [[i] * len(indiciesToKeep) for i in categories]
143-
df = pd.concat([df, pd.DataFrame(extra_rows, columns = indiciesToKeep)], ignore_index=True)
144-
145-
# get one-hot encoding
146-
df = pd.get_dummies(df, columns=indiciesToKeep)
147-
148-
headers = list(df)
149-
150-
# get rid of the fake data we just added
151-
df.drop(df.tail(len(categories)).index,inplace=True)
144+
df = pd.DataFrame.from_records(rows, columns=columns)
145+
df.fillna(0, inplace=True)
146+
df = df.astype('uint8')
152147

153148
predictions = loaded_model.predict_proba(df)
154149

0 commit comments

Comments
 (0)