-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathperform_stats.py
executable file
·255 lines (221 loc) · 9.8 KB
/
perform_stats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
#!/software/python/3.3.3/bin/python3.3
############################################################################
import argparse
import sys
import re
#import Bio
#from Bio import SeqIO
#from Bio.SeqRecord import SeqRecord
#from Bio.Seq import Seq
#def read_repeat_file(rep_file):
# """ Takes the repeat file generated by repeatmasker and reads each line
# describing a repeat. Returns a list of tuples (start, end) each corresponding
# to a repeat R where start is the index of the chromosome sequence where R
# begins and end is the index of the chromosome sequence where R ends"""
# return indices
#def get_repeats(rep_file):
# """ Takes the repeat file generated by repeatmasker and the original
# sequence file. Gets a list of the start and end indices of each repeat from
# masker file and then uses these indices to create a list of the actual
# repeat sequences. Returns this list and the length of the original sequence."""
# rf = open(rep_file, 'r')
# rf.readline()
# rf.readline()
# rf.readline()
# info = []
# while True:
# line = rf.readline()
# if not line:
# break
# A = re.split("\s+", line.strip())
# start = int(A[5]) - 1
# end = int(A[6])
# info.append((start, end))
# return info
# Assumes that rep_file is properly formatted. Throw exception if not?
def get_sequence_length(rep_file):
""" Returns the total number of bases in sequence. """
rf = open(rep_file, 'r')
rf.readline()
rf.readline()
rf.readline()
line = rf.readline()
A = re.split("\s+", line.strip())
# sequence length = end (A[6]) + left (A[7])
return int(A[6])+int(A[7].strip("()"))
def repeat_bounds_generator(rep_file, exclusion_set=None):
""" Generator function. Returns the start and end index of the next repeat.
Keyword arguments:
rep_file -- the repeat file generated by repeatmasker
exclusion_set -- criteria for a repeat to be skipped
"""
rf = open(rep_file, 'r')
rf.readline()
rf.readline()
rf.readline()
while True:
line = rf.readline()
if not line:
break
# if exclusion_set specified, check that should not exclude line
if not exclusion_set or (exclusion_set and not any([s in line for s in exclusion_set])):
A = re.split("\s+", line.strip())
start = int(A[5]) - 1
end = int(A[6])
yield start, end
# Use this when only one of the files has run out of repeats
def add_remaining_mismatches(repeat_bounds_generator, mismatch_list):
""" Appends all remaining repeats to list, and returns the total number of
bases contained in appended repeats.
Keyword arguments:
repeat_bounds_generator -- generator function to get next repeat's bounds
mismatch_list -- list to which append all repeat bounds
"""
num_bases = 0
while True:
try:
start, end = next(repeat_bounds_generator)
mismatch_list.append((start, end))
num_bases += end - start
except StopIteration:
break
return num_bases
def get_stats(real_indices, gen_indices, gen_length):
""" Generates performance statistics regarding the tool's ability to
classify bases as part of repeats or not part of repeats.
Keyword arguments:
real_indices -- the generator function for real (expected) repeat bounds
gen_indices -- the generator function for repeat bounds found by tool
gen_length -- total number of bases in original sequence file
"""
tps, fns, fps = [], [], []
tp, fn, fp = 0, 0, 0
leftover_real, leftover_gen = False, False
s1, e1, s2, e2 = 0, 0, 0, 0
while True:
# Only get new indices if nothing leftover to evaluate
if not leftover_real:
try:
s1, e1 = next(real_indices)
except StopIteration:
if leftover_gen:
fps.append((s2, e2))
fp += e2 - s2
fp += add_remaining_mismatches(gen_indices, fps)
break
if not leftover_gen:
try:
s2, e2 = next(gen_indices)
except StopIteration:
if leftover_real:
fns.append((s1, e1))
fn += e1 - s1
fn += add_remaining_mismatches(real_indices, fns)
break
leftover_real = False
leftover_gen = False
# If they overlap, classify match bounds as TP
if s2 <= e1 and s1 <= e2:
start_match = s2 if s2 > s1 else s1
end_match = e2 if e2 < e1 else e1
tps.append((start_match, end_match))
tp += end_match - start_match
if s2 > s1:
# Missed the real repeat from s1...s2
fns.append((s1, s2))
fn += (s2 - s1)
elif s2 < s1:
# Falsely identified s2...s1 as a repeat
fps.append((s2, s1))
fp += s1 - s2
if e2 < e1:
# Next generated repeat might match e2...e1 of real repeat
leftover_real = True
s1, e1 = e2, e1
elif e2 > e1:
# e1...e2 of generated repeat might match different real repeat
leftover_gen = True
s2, e2 = e1, e2
# If generated repeat ends before real one starts, classify miss as FN
elif e1 < s2:
leftover_gen = True
fns.append((s1, e1))
fn += e1 - s1
# If real repeat ends before generated one starts, classify as FP
else:
leftover_real = True
fps.append((s2, e2))
fp += e2 - s2
tn = gen_length - fp - tp - fn
tpr, tnr, ppv, npv, fpr, fdr = stats(tp, fp, fn, tn)
return (tp, fp, fn, tn), (tpr, tnr, ppv, npv, fpr, fdr), (tps, fps, fns)
def stats(tp, fp, fn, tn):
""" Returns specificity and sensitivity statistics for tool.
Keyword arguments:
tp -- number of bases correctly classified as part of repeat
fp -- number of bases incorrectly classified as part of repeat
fn -- number of bases incorrectly classified as not part of repeat
tn -- number of bases correctly classified as not part of repeat
"""
tpr = tp/float(tp + fn) if tp + fn > 0 else -1
tnr = tn/float(fp + tn) if fp + tn > 0 else -1
ppv = tp/float(tp + fp) if tp + fp > 0 else -1
npv = tn/float(tn + fn) if tn + fn > 0 else -1
fpr = 1 - tnr if tnr > -1 else -1
fdr = 1 - ppv if ppv > -1 else -1
return tpr, tnr, ppv, npv, fpr, fdr
def perform_stats(real_repeats, tool_output, exclusions):
""" Calculates the performance statistics and prints them to output file.
Keyword arguments:
real_repeats -- repeatmasker file for known repeats
tool_output -- repeatmasker file for repeats found by tool
exclusions -- set of exclusion criteria for repeats to be skipped
output_file -- file to which print output
"""
l = get_sequence_length(real_repeats)
if exclusions:
exclusion_set = {s for line in open(exclusions, "r") for s in re.split("\s+", line.rstrip())}
else:
exclusion_set = None
real_bounds = repeat_bounds_generator(real_repeats, exclusion_set)
gen_bounds = repeat_bounds_generator(tool_output, exclusion_set)
return get_stats(real_bounds, gen_bounds, l)
def generate_output_orig(output_file, print_reps, counts, stats, lists):
""" Writes performance statistics to specified output file.
Keyword arguments:
output_file -- file to which print output
print_reps -- whether or not should print out classified repeat bounds
counts -- tuple of classification counts (TP, FP, TN, FN)
stats -- tuple of classification statistics (TPR, TNR, PPV, NPV, FPR, FDR)
lists -- tuple of lists of classified repeat bounds (TPs, FPs, FNs)
"""
f = sys.stdout if output_file == "-" else open(output_file, 'w')
f.write("TP: %d\nFP: %d\nTN: %d\nFN: %d\n" % (counts))
f.write("TPR: %f\nTNR: %f\nPPV: %f\nNPV: %f\nFPR: %f\nFDR: %f\n\n" % (stats))
if print_reps:
f.write("\nThese bases were correctly identified (true positives):\n")
f.write('\n'.join('\t(%s %s)' % x for x in lists[0]))
f.write("\nThese bases were incorrectly identified (false positives):\n")
f.write('\n'.join('\t(%s %s)' % x for x in lists[1]))
f.write("\nThese bases were missed (false negatives):\n")
f.write('\n'.join('\t(%s %s)' % x for x in lists[2]))
f.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description = "Generate Sensitivity and Specificity Stats")
parser.add_argument('-p', '--print_reps', action = "store_true", help = "Print out the repeats", default= False)
parser.add_argument('-e', '--exclusion_file', help = "File of repeats to be ignored during analysis", default = None)
parser.add_argument("repeat_file", help = "Repeats file")
parser.add_argument("masker_output", help = "Masker output using consensus sequence and sequence file")
parser.add_argument("output_file", help = "Statistics output file")
args = parser.parse_args()
generate_output(args.output_file, args.print_reps, *perform_stats(args.repeat_file, args.masker_output, args.exclusion_file))
#NO HEADERS
print_str = "".join("{:<14}"*4) + "".join("{:<14}"*6) + "\n"
with open(output_file, "w") as fp:
try:
fp.write(print_str.format("tp", "fp", "fn", "tn", "tpr", "tnr", "ppv", "npv", "fpr", "fnr"))
Counts, Stats, Sets = perform_stats(args.repeat_file, args.masker_output, args.exclusion_file) # args.family_file)
Stats = [round(x,5) for x in Stats]
fp.write(print_str.format(*(list(Counts) + list(Stats))))
except Exception as E:
fp.write("\t".join(["NA", "INCOMPLETE\n"]))