Skip to content

Commit bf50d28

Browse files
Merge remote-tracking branch 'origin/master'
2 parents dc4cb7a + 9b3129d commit bf50d28

File tree

4 files changed

+72
-8
lines changed

4 files changed

+72
-8
lines changed

loaders/loader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from loaders import decode_raw_file
2-
from parsers import ParserFormats, ContactInformationFormats
2+
from parsers import ParserFormats, BinaryFormats
33
from utils.exceptions import InvalidFormat
44
from utils import compress_data
55

@@ -10,7 +10,7 @@ def Loader(raw_file, input_format):
1010

1111
if raw_file is not None:
1212
try:
13-
if input_format != ContactInformationFormats.trROSETTA_NPZ.name:
13+
if input_format not in BinaryFormats.__members__.keys():
1414
decoded = decode_raw_file(raw_file)
1515
data_raw = ParserFormats.__dict__[input_format](decoded, input_format)
1616
else:
@@ -20,4 +20,4 @@ def Loader(raw_file, input_format):
2020
data = None
2121
invalid = True
2222

23-
return data, invalid
23+
return data, invalid

parsers/__init__.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ def NpzParser(*args, **kwargs):
9898
return NpzParser(*args, **kwargs)
9999

100100

101+
def AlphafoldParser(*args, **kwargs):
102+
from parsers.alphafold import AlphafoldParser
103+
104+
return AlphafoldParser(*args, **kwargs)
105+
106+
101107
class ParserFormats(Enum):
102108
TOPCONS = TopconsParser
103109
CONSURF = ConsurfParser
@@ -123,8 +129,9 @@ class ParserFormats(Enum):
123129
COLSTATS = CCMpredParser
124130
PDB = PDBParser
125131
CASPRR_MODE_2 = CASPRR2Parser
126-
trROSETTA_NPZ = NpzParser
132+
ROSETTA_NPZ = NpzParser
127133
MAPPRED = MappredParser
134+
ALPHAFOLD = AlphafoldParser
128135
A3M = A3mParser
129136

130137

@@ -150,8 +157,9 @@ class ContactInformationFormats(Enum):
150157
MAPALIGN = 18
151158
ALEIGEN = 19
152159
PDB = 20
153-
trROSETTA_NPZ = 21
160+
ROSETTA_NPZ = 21
154161
MAPPRED = 22
162+
ALPHAFOLD = 23
155163

156164

157165
class ContactMapFormats(Enum):
@@ -182,9 +190,14 @@ class StructuralInformationFormats(Enum):
182190

183191
class DistanceInformationFormats(Enum):
184192
CASPRR_MODE_2 = 1
185-
trROSETTA_NPZ = 2
193+
ROSETTA_NPZ = 2
186194
MAPPRED = 3
195+
ALPHAFOLD = 4
196+
187197

198+
class BinaryFormats(Enum):
199+
ROSETTA_NPZ = 1
200+
ALPHAFOLD = 2
188201

189202
class MembraneStates(Enum):
190203
INSIDE = 1

parsers/alphafold.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import io
2+
import base64
3+
import numpy as np
4+
from scipy.special import softmax
5+
from utils.exceptions import InvalidFormat
6+
from parsers import get_unique_distances
7+
8+
9+
def parse_distogram(distogram):
10+
# Output are raw logits rather than probabilities, apply softmax
11+
probs = softmax(distogram['logits'], axis=-1)
12+
assert probs.shape[-1] == 64
13+
# Bin edges represent interval boundaries, which are half-open (open on the left) (hence n-1 bin edges)
14+
assert distogram['bin_edges'].shape[-1] == 63
15+
assert distogram['bin_edges'][1] - distogram['bin_edges'][0] == 0.3125
16+
contacts = np.sum(probs[:, :, :19], axis=-1)
17+
L = contacts.shape[0]
18+
BINS = [np.sum(probs[:, :, :6], axis=-1)]
19+
BINS += [np.sum(probs[:, :, int((x * 2 + 0.0125) / 0.3125):int((x * 2 + 2.0125) / 0.3125)], axis=-1) for x in
20+
range(1, 9)]
21+
BINS.append(np.sum(probs[:, :, 57:], axis=-1))
22+
array = np.dstack(BINS)
23+
dist_bins = np.nanargmax(array, axis=2)
24+
dist_prob = np.amax(array, axis=2)
25+
return [[i + 1, j + 1, float(contacts[i, j]), int(dist_bins[i, j]), float(dist_prob[i, j])]
26+
for i in range(L) for j in range(i + 5, L)]
27+
28+
29+
def AlphafoldParser(input, input_format=None):
30+
output = []
31+
content_type, content_string = input.split(',')
32+
try:
33+
decoded = base64.b64decode(content_string)
34+
results = np.load(io.BytesIO(decoded), allow_pickle=True)
35+
distogram = results['distogram']
36+
tmp_output = parse_distogram(distogram)
37+
except (OSError, KeyError, IndexError) as e:
38+
raise InvalidFormat('Unable to parse alphafold pkl file')
39+
40+
for contact in tmp_output:
41+
# contact = [res_1, res_2, raw_score, distance_bin, distance_score]
42+
contact[:2] = sorted(contact[:2], reverse=True)
43+
output.append((tuple(contact[:2]), *contact[2:]))
44+
45+
if not output:
46+
raise InvalidFormat('Unable to parse alphafold pkl file')
47+
else:
48+
unique_contacts = get_unique_distances(output)
49+
if any([p for p in unique_contacts[1:] if p[3] > 9 or p[4] > 1.01]):
50+
raise InvalidFormat('Unable to parse alphafold pkl file')
51+
return unique_contacts

utils/plot_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def create_ConPlot(session_id, cache, trigger, selected_tracks, cmap_selection,
4242
figure = create_figure(display_settings.axis_range)
4343

4444
verbose_labels, additional_traces = add_additional_tracks(session_id, session, display_settings, figure, cache)
45-
contact_traces = add_contact_trace(session, display_settings, figure, verbose_labels)
45+
contact_traces = add_contact_trace(session, display_settings, verbose_labels)
4646

4747
figure.add_traces(contact_traces)
4848
figure.add_traces(additional_traces)
@@ -93,7 +93,7 @@ def add_additional_tracks(session_id, session, display_settings, figure, cache):
9393
return None, traces
9494

9595

96-
def add_contact_trace(session, display_settings, figure, verbose_labels):
96+
def add_contact_trace(session, display_settings, verbose_labels):
9797
if display_settings.superimpose and display_settings.heatmap:
9898
heat, hover, colorscale = heatmap_utils.superimpose_heatmaps(session, display_settings, verbose_labels)
9999
return heatmap_utils.create_heatmap_trace(hovertext=hover, distances=heat, colorscale=colorscale)

0 commit comments

Comments
 (0)