Skip to content

Commit fcfcad7

Browse files
author
David Pellow
committed
Rename
1 parent 9a437a3 commit fcfcad7

13 files changed

+269
-6
lines changed

Diff for: classify_fasta.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
# Provide a command line script to classify sequences in a fasta file
33
###
44

5-
from classification import classifier_utils as utils
6-
from classification import classifier
5+
from plasclass import plasclass_utils as utils
6+
from plasclass import plasclass
77

88
import argparse
99

@@ -38,7 +38,7 @@ def main(args):
3838
else: outfile = infile + '.probs.out'
3939
n_procs = args.num_processes
4040

41-
c = classifier.classifier(n_procs)
41+
c = plasclass.plasclass(n_procs)
4242
seq_names = []
4343
seqs = []
4444
print "Reading {} in batches of 100k sequences".format(infile)

Diff for: plasclass/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
name = "plasclass"

Diff for: plasclass/data/m1000

86.3 KB
Binary file not shown.

Diff for: plasclass/data/m10000

86.3 KB
Binary file not shown.

Diff for: plasclass/data/m100000

86.3 KB
Binary file not shown.

Diff for: plasclass/data/m500000

86.3 KB
Binary file not shown.

Diff for: plasclass/data/s1000

257 KB
Binary file not shown.

Diff for: plasclass/data/s10000

257 KB
Binary file not shown.

Diff for: plasclass/data/s100000

257 KB
Binary file not shown.

Diff for: plasclass/data/s500000

257 KB
Binary file not shown.

Diff for: plasclass/plasclass.py

+130
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
###
2+
# Define the classifier class and provide a set of functions to enable classification
3+
###
4+
5+
import numpy as np
6+
import os
7+
from sklearn.linear_model import LogisticRegression
8+
from sklearn.preprocessing import StandardScaler
9+
import itertools
10+
from joblib import load
11+
12+
import multiprocessing as mp
13+
from multiprocessing import Manager
14+
15+
import plasclass_utils as utils
16+
17+
class plasclass():
18+
def __init__(self,n_procs = 1):
19+
self._scales = [1000,10000,100000,500000]
20+
self._ks = [3,4,5,6,7]
21+
self._compute_kmer_inds()
22+
self._load_classifiers()
23+
self._n_procs = n_procs
24+
25+
26+
def classify(self,seq):
27+
'''Classify the sequence(s), return the probability of the sequence(s) being a plasmid.
28+
Assumes seq is either an individual string or a list of strings
29+
Returns either an individual plasmid probability for seq or a list of
30+
plasmid probabilities for each sequence in seq
31+
'''
32+
if isinstance(seq, basestring): # single sequence
33+
print "Counting k-mers for sequence of length {}".format(len(seq))
34+
kmer_freqs = [0]
35+
scale = self._get_scale(len(seq))
36+
utils.count_kmers([0, seq, self._ks, self._kmer_inds, self._kmer_count_lens, kmer_freqs])
37+
kmer_freqs = np.array(kmer_freqs)
38+
standardized_freqs = self._standardize(kmer_freqs, scale)
39+
print "Classifying"
40+
return self.classifiers[scale]['clf'].predict_proba(standardized_freqs)[0,1]
41+
42+
elif isinstance(seq, list): # list of sequences
43+
print "{} sequences to classify. Classifying in batches of 100k".format(len(seq))
44+
results = []
45+
seq_ind = 0
46+
pool = mp.Pool(self._n_procs)
47+
48+
while seq_ind < len(seq):
49+
print "Starting new batch"
50+
seq_batch = seq[seq_ind:seq_ind + 100000]
51+
print "Partitioning by length"
52+
scales = [self._get_scale(len(s)) for s in seq_batch]
53+
scale_partitions = {s: [seq_batch[i] for i,v in enumerate(scales) if v == s] for s in self._scales}
54+
55+
partitioned_classifications = {}
56+
for scale in self._scales: #scale_partitions:
57+
part_seqs = scale_partitions[scale]
58+
if len(part_seqs) <= 0: continue
59+
print "Getting kmer frequencies for partition length {}".format(scale)
60+
shared_list=Manager().list()
61+
for cur in np.arange(len(part_seqs)):
62+
shared_list.append(0)
63+
pool.map(utils.count_kmers, [[ind, s, self._ks, self._kmer_inds, self._kmer_count_lens, shared_list] for ind,s in enumerate(part_seqs)])
64+
kmer_freqs_mat = np.array(shared_list)
65+
standardized_freqs = self._standardize(kmer_freqs_mat, scale)
66+
print "Classifying sequences of length scale {}".format(scale)
67+
partitioned_classifications[scale] = self.classifiers[scale]['clf'].predict_proba(standardized_freqs)[:,1]
68+
69+
# recollate the results:
70+
scale_inds = {s:0 for s in self._scales}
71+
for s in scales:
72+
results.append(partitioned_classifications[s][scale_inds[s]])
73+
scale_inds[s] += 1
74+
75+
seq_ind += 100000
76+
77+
# pool.close() TODO: is this needed?
78+
return np.array(results)
79+
80+
else:
81+
raise TypeError('Can only classify strings or lists of strings')
82+
83+
84+
def _load_classifiers(self):
85+
''' Load the multi-scale classifiers and scalers
86+
'''
87+
curr_path = os.path.dirname(os.path.abspath(__file__))
88+
data_path = os.path.join(curr_path,'data')
89+
self.classifiers = {}
90+
for i in self._scales:
91+
print "Loading classifier " + str(i)
92+
self.classifiers[i] = {'clf': load(os.path.join(data_path,'m'+str(i))), 'scaler': load(os.path.join(data_path,'s'+str(i)))}
93+
94+
95+
def _get_scale(self, length):
96+
''' Choose which length scale to use for the sequence
97+
'''
98+
if length <= self._scales[0]: return self._scales[0]
99+
for i,l in enumerate(self._scales[:-1]):
100+
if length <= float(l + self._scales[i+1])/2.0:
101+
return l
102+
return self._scales[-1]
103+
104+
def _standardize(self, freqs, scale):
105+
''' Use sklearn's standard scaler to standardize
106+
Choose the appropriate scaler based on sequence length
107+
'''
108+
return self.classifiers[scale]['scaler'].transform(freqs)
109+
110+
def _compute_kmer_inds(self):
111+
''' Compute the indeces of each canonical kmer in the kmer count vectors
112+
'''
113+
114+
self._kmer_inds = {k: {} for k in self._ks}
115+
self._kmer_count_lens = {k: 0 for k in self._ks}
116+
117+
alphabet = 'ACGT'
118+
for k in self._ks:
119+
all_kmers = [''.join(kmer) for kmer in itertools.product(alphabet,repeat=k)]
120+
all_kmers.sort()
121+
ind = 0
122+
for kmer in all_kmers:
123+
bit_mer = utils.mer2bits(kmer)
124+
rc_bit_mer = utils.mer2bits(utils.get_rc(kmer))
125+
if rc_bit_mer in self._kmer_inds[k]:
126+
self._kmer_inds[k][bit_mer] = self._kmer_inds[k][rc_bit_mer]
127+
else:
128+
self._kmer_inds[k][bit_mer] = ind
129+
self._kmer_count_lens[k] += 1
130+
ind += 1

Diff for: plasclass/plasclass_utils.py

+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Utility functions for the classifier module
2+
3+
complements = {'A':'T', 'C':'G', 'G':'C', 'T':'A'}
4+
nt_bits = {'A':0,'C':1,'G':2,'T':3}
5+
6+
import numpy as np
7+
8+
9+
def readfq(fp): # this is a generator function
10+
''' Adapted from https://github.com/lh3/readfq
11+
'''
12+
last = None # this is a buffer keeping the last unprocessed line
13+
while True: # mimic closure; is it a bad idea?
14+
if not last: # the first record or a record following a fastq
15+
for l in fp: # search for the start of the next record
16+
if l[0] in '>@': # fasta/q header line
17+
last = l[:-1] # save this line
18+
break
19+
if not last: break
20+
name, seqs, last = last[1:].partition(" ")[0], [], None
21+
for l in fp: # read the sequence
22+
if l[0] in '@+>':
23+
last = l[:-1]
24+
break
25+
seqs.append(l[:-1])
26+
if not last or last[0] != '+': # this is a fasta record
27+
yield name, ''.join(seqs), None # yield a fasta record
28+
if not last: break
29+
else: # this is a fastq record
30+
seq, leng, seqs = ''.join(seqs), 0, []
31+
for l in fp: # read the quality
32+
seqs.append(l[:-1])
33+
leng += len(l) - 1
34+
if leng >= len(seq): # have read enough quality
35+
last = None
36+
yield name, seq, ''.join(seqs); # yield a fastq record
37+
break
38+
if last: # reach EOF before reading enough quality
39+
yield name, seq, None # yield a fasta record instead
40+
break
41+
42+
43+
def get_rc(seq):
44+
''' Return the reverse complement of seq
45+
'''
46+
rev = reversed(seq)
47+
return "".join([complements.get(i,i) for i in rev])
48+
49+
50+
def mer2bits(kmer):
51+
''' convert kmer to bit representation
52+
'''
53+
bit_mer=nt_bits[kmer[0]]
54+
for c in kmer[1:]:
55+
bit_mer = (bit_mer << 2) | nt_bits[c]
56+
return bit_mer
57+
58+
59+
def count_kmers(args_array):
60+
''' Count the k-mers in the sequence
61+
Return a dictionary of counts
62+
Assumes ks is sorted
63+
'''
64+
ret_ind, seq, ks, kmer_inds, vec_lens, shared_list = args_array
65+
66+
kmer_counts = {k:np.zeros(vec_lens[k]) for k in ks}
67+
68+
k_masks = [2**(2*k)-1 for k in ks]
69+
ind=0
70+
bit_mers = [0 for k in ks]
71+
72+
# get the first set of kmers
73+
while True:
74+
found = True
75+
for i,k in enumerate(ks):
76+
try:
77+
bit_mers[i] = mer2bits(seq[ind:ind+k])
78+
kmer_counts[k][kmer_inds[k][bit_mers[i]]] += 1.
79+
except:
80+
ind += 1
81+
found = False
82+
break
83+
if found == True:
84+
break
85+
86+
# count all other kmers
87+
while ind<len(seq)-ks[-1]: # iterate through sequence until last k-mer for largest k
88+
for i,k in enumerate(ks):
89+
try:
90+
c = nt_bits[seq[ind+k]]
91+
bit_mers[i] = ((bit_mers[i]<<2)|c)&k_masks[i]
92+
kmer_counts[k][kmer_inds[k][bit_mers[i]]] += 1.
93+
except: # out of alphabet
94+
ind += 2 # pass it and move on to the next
95+
# get the next set of legal kmers
96+
while ind<=len(seq)-ks[-1]:
97+
found = True
98+
for i2,k2 in enumerate(ks):
99+
try:
100+
bit_mers[i2] = mer2bits(seq[ind:ind+k2])
101+
kmer_counts[k2][kmer_inds[k2][bit_mers[i2]]] += 1.
102+
except:
103+
ind += 1
104+
found = False
105+
break
106+
if found == True:
107+
ind -= 1 # in next step increment ind
108+
break
109+
ind += 1 # move on to next letter in sequence
110+
111+
# count the last few kmers
112+
end = len(ks)-1
113+
for i in range(len(seq)-ks[-1]+1,len(seq)-ks[0]+1):
114+
for k in ks[:end]:
115+
kmer = seq[i:i+k]
116+
try:
117+
kmer_counts[k][kmer_inds[k][mer2bits(kmer)]] += 1.
118+
except:
119+
pass
120+
end -= 1
121+
122+
#normalise counts
123+
kmer_freqs = np.zeros(sum([vec_lens[k] for k in ks]))
124+
ind = 0
125+
for k in ks:
126+
counts_sum = np.sum(kmer_counts[k])
127+
if counts_sum != 0:
128+
kmer_counts[k] = kmer_counts[k]/float(counts_sum)
129+
kmer_freqs[ind:ind+vec_lens[k]] = kmer_counts[k]
130+
ind += vec_lens[k]
131+
132+
shared_list[ret_ind] = kmer_freqs

Diff for: setup.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
long_description = fh.read()
55

66
setuptools.setup(
7-
name="classification-dpellow",
7+
name="plasclass-dpellow",
88
version="0.1",
99
author="David Pellow",
1010
author_email="[email protected]",
1111
description="Classification of plasmid sequences",
1212
long_description=long_description,
1313
long_description_content_type="text/markdown",
14-
url="https://github.com/dpellow/classification",
15-
packages=['classification'],
14+
url="https://github.com/dpellow/plasclass",
15+
packages=['plasclass'],
1616
classifiers=[
1717
"Programming Language :: Python :: 2.7",
1818
"License :: OSI Approved :: MIT License",

0 commit comments

Comments
 (0)