Skip to content

Commit 24649c5

Browse files
authored
Merge pull request #133 from aunderwo/bt5_a_bit_faster
Run quicker for small number of sequences
2 parents 1aaffe1 + e015893 commit 24649c5

File tree

2 files changed

+13
-34
lines changed

2 files changed

+13
-34
lines changed

pangolin/scripts/pangolearn.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python3
22

33
import pandas as pd
4+
import numpy as np
45
from sklearn.model_selection import train_test_split
56
from sklearn.linear_model import LogisticRegression
67
from sklearn import metrics
@@ -113,6 +114,10 @@ def readInAndFormatData(sequencesFile, indiciesToKeep, blockSize=1000):
113114
indiciesToKeep = model_headers[1:]
114115

115116
referenceSeq = findReferenceSeq()
117+
# possible nucleotide symbols
118+
categories = ['-','A', 'C', 'G', 'T']
119+
columns = [f"{i}_{c}" for i in indiciesToKeep for c in categories]
120+
refRow = [r==c for r in encodeSeq(referenceSeq, indiciesToKeep) for c in categories]
116121

117122
print("loading model " + datetime.now().strftime("%m/%d/%Y, %H:%M:%S"))
118123
loaded_model = joblib.load(args.model_file)
@@ -125,30 +130,16 @@ def readInAndFormatData(sequencesFile, indiciesToKeep, blockSize=1000):
125130
len(seqList), datetime.now().strftime("%m/%d/%Y, %H:%M:%S")
126131
))
127132

133+
rows = [[r==c for r in row for c in categories] for row in seqList]
128134
# the reference seq must be added to everry block to make sure that the
129135
# spots in the reference have Ns are in the dataframe to guarentee that
130136
# the correct number of columns is created when get_dummies is called
137+
rows.append(refRow)
131138
idList.append(referenceId)
132-
seqList.append(encodeSeq(referenceSeq, indiciesToKeep))
133139

134140
# create a data from from the seqList
135-
df = pd.DataFrame(seqList, columns=indiciesToKeep)
136-
137-
# possible nucleotide symbols
138-
categories = ['A', 'C', 'G', 'T', '-']
139-
140-
# add extra rows to ensure all of the categories are represented, as otherwise
141-
# not enough columns will be created when we call get_dummies
142-
extra_rows = [[i] * len(indiciesToKeep) for i in categories]
143-
df = pd.concat([df, pd.DataFrame(extra_rows, columns = indiciesToKeep)], ignore_index=True)
144-
145-
# get one-hot encoding
146-
df = pd.get_dummies(df, columns=indiciesToKeep)
147-
148-
headers = list(df)
149-
150-
# get rid of the fake data we just added
151-
df.drop(df.tail(len(categories)).index,inplace=True)
141+
d = np.array(rows, np.uint8)
142+
df = pd.DataFrame(d, columns=columns)
152143

153144
predictions = loaded_model.predict_proba(df)
154145

pangolin/scripts/report_results.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,10 @@ def make_objects(background_data, lineages_present):
3535
lineages_to_taxa = defaultdict(list)
3636
lin_obj_dict = {}
3737

38-
with open(background_data,newline="") as f:
39-
reader = csv.DictReader(f)
40-
for row in reader:
41-
name = row["sequence_name"]
42-
lin_string = row["lineage"]
43-
date = row["sample_date"]
44-
country = row["country"]
45-
46-
tax_name = f"{name}|{country}|{date}"
47-
48-
if lin_string in lineages_present:
49-
new_taxon = classes.taxon(tax_name, lin_string)
50-
taxa.append(new_taxon)
51-
52-
lineages_to_taxa[lin_string].append(new_taxon)
53-
38+
background_df = pd.read_csv(background_data).query("lineage in @lineages_present")
39+
background_df['taxa'] = background_df.apply(lambda r: classes.taxon(f"{r['sequence_name']}|{r['country']}|{r['sample_date']}", r['lineage']), axis=1)
40+
lineages_to_taxa = background_df.groupby("lineage")["taxa"].apply(list).to_dict()
41+
taxa = list(background_df['taxa'])
5442

5543
for lin, taxa in lineages_to_taxa.items():
5644
l_o = classes.lineage(lin, taxa)

0 commit comments

Comments
 (0)