Skip to content

Commit

Permalink
Merge pull request #133 from aunderwo/bt5_a_bit_faster
Browse files Browse the repository at this point in the history
Run quicker for small number of sequences
  • Loading branch information
aineniamh authored Feb 5, 2021
2 parents 1aaffe1 + e015893 commit 24649c5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 34 deletions.
27 changes: 9 additions & 18 deletions pangolin/scripts/pangolearn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn import metrics
Expand Down Expand Up @@ -113,6 +114,10 @@ def readInAndFormatData(sequencesFile, indiciesToKeep, blockSize=1000):
indiciesToKeep = model_headers[1:]

referenceSeq = findReferenceSeq()
# possible nucleotide symbols
categories = ['-','A', 'C', 'G', 'T']
columns = [f"{i}_{c}" for i in indiciesToKeep for c in categories]
refRow = [r==c for r in encodeSeq(referenceSeq, indiciesToKeep) for c in categories]

print("loading model " + datetime.now().strftime("%m/%d/%Y, %H:%M:%S"))
loaded_model = joblib.load(args.model_file)
Expand All @@ -125,30 +130,16 @@ def readInAndFormatData(sequencesFile, indiciesToKeep, blockSize=1000):
len(seqList), datetime.now().strftime("%m/%d/%Y, %H:%M:%S")
))

rows = [[r==c for r in row for c in categories] for row in seqList]
# the reference seq must be added to everry block to make sure that the
# spots in the reference have Ns are in the dataframe to guarentee that
# the correct number of columns is created when get_dummies is called
rows.append(refRow)
idList.append(referenceId)
seqList.append(encodeSeq(referenceSeq, indiciesToKeep))

# create a data from from the seqList
df = pd.DataFrame(seqList, columns=indiciesToKeep)

# possible nucleotide symbols
categories = ['A', 'C', 'G', 'T', '-']

# add extra rows to ensure all of the categories are represented, as otherwise
# not enough columns will be created when we call get_dummies
extra_rows = [[i] * len(indiciesToKeep) for i in categories]
df = pd.concat([df, pd.DataFrame(extra_rows, columns = indiciesToKeep)], ignore_index=True)

# get one-hot encoding
df = pd.get_dummies(df, columns=indiciesToKeep)

headers = list(df)

# get rid of the fake data we just added
df.drop(df.tail(len(categories)).index,inplace=True)
d = np.array(rows, np.uint8)
df = pd.DataFrame(d, columns=columns)

predictions = loaded_model.predict_proba(df)

Expand Down
20 changes: 4 additions & 16 deletions pangolin/scripts/report_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,10 @@ def make_objects(background_data, lineages_present):
lineages_to_taxa = defaultdict(list)
lin_obj_dict = {}

with open(background_data,newline="") as f:
reader = csv.DictReader(f)
for row in reader:
name = row["sequence_name"]
lin_string = row["lineage"]
date = row["sample_date"]
country = row["country"]

tax_name = f"{name}|{country}|{date}"

if lin_string in lineages_present:
new_taxon = classes.taxon(tax_name, lin_string)
taxa.append(new_taxon)

lineages_to_taxa[lin_string].append(new_taxon)

background_df = pd.read_csv(background_data).query("lineage in @lineages_present")
background_df['taxa'] = background_df.apply(lambda r: classes.taxon(f"{r['sequence_name']}|{r['country']}|{r['sample_date']}", r['lineage']), axis=1)
lineages_to_taxa = background_df.groupby("lineage")["taxa"].apply(list).to_dict()
taxa = list(background_df['taxa'])

for lin, taxa in lineages_to_taxa.items():
l_o = classes.lineage(lin, taxa)
Expand Down

0 comments on commit 24649c5

Please sign in to comment.