Skip to content

Commit ba75383

Browse files
authored
v2.3.2
2 parents 15eb3e0 + cf096a6 commit ba75383

File tree

6 files changed

+33
-29
lines changed

6 files changed

+33
-29
lines changed

alphafold/model/data.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,9 @@
2121
# Internal import (7716).
2222

2323

24-
def get_model_haiku_params(model_name: str, data_dir: str, fuse: bool = True) -> hk.Params:
24+
def get_model_haiku_params(model_name: str,
25+
data_dir: str, fuse: bool = True, to_jnp: bool = True) -> hk.Params:
2526
"""Get the Haiku parameters from a model name."""
26-
2727
path = os.path.join(data_dir, 'params', f'params_{model_name}.npz')
28-
2928
params = np.load(path, allow_pickle=False)
30-
31-
return utils.flat_params_to_haiku(params, fuse=fuse)
29+
return utils.flat_params_to_haiku(params, fuse=fuse, to_jnp=to_jnp)

alphafold/model/model.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,14 +208,14 @@ def predict(self,
208208
while r < num_iters:
209209
if self.multimer_mode:
210210
sub_feat = feat
211-
sub_feat["iter"] = np.array(r)
212211
else:
213212
s = r * num_ensemble
214213
e = (r+1) * num_ensemble
215214
sub_feat = jax.tree_map(lambda x:x[s:e], feat)
216215

217216
sub_feat["prev"] = result["prev"]
218-
result = self.apply(self.params, key, sub_feat)
217+
key, sub_key = jax.random.split(key)
218+
result = self.apply(self.params, sub_key, sub_feat)
219219
seq_mask = feat["seq_mask"] if self.multimer_mode else feat["seq_mask"][0]
220220
confidences = get_confidence_metrics(result, mask=seq_mask, rank_by=self.config.model.rank_by)
221221

@@ -235,13 +235,16 @@ def predict(self,
235235
stop = True
236236
prev_pos = result["prev"]["prev_pos"][:,ca_idx]
237237

238+
result["pae"] = result.pop("predicted_aligned_error")
238239
result.update(confidences)
239-
if prediction_callback is not None: prediction_callback(result, r)
240+
241+
if prediction_callback is not None:
242+
prediction_callback(result, r)
240243

241244
if verbose:
242245
print_line = f"recycle={r} plddt={confidences['mean_plddt']:.3g}"
243246
for k in ["ptm","iptm","diff"]:
244-
if k in confidences: print_line += f" {k}:{confidences[k]:.3g}"
247+
if k in confidences: print_line += f" {k}={confidences[k]:.3g}"
245248
print(print_line)
246249
r += 1
247250
if stop: break

alphafold/model/modules.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,14 @@ def get_prev(ret):
183183
}
184184
return new_prev
185185

186-
prev = batch.pop("prev")
187-
batch = jax.tree_map(lambda x:x[0], batch)
186+
prev = batch.pop("prev",None)
187+
if batch["aatype"].ndim == 2:
188+
batch = jax.tree_map(lambda x:x[0], batch)
189+
if prev is None:
190+
L = batch["aatype"].shape[0]
191+
prev = {'prev_msa_first_row': jnp.zeros([L,256]),
192+
'prev_pair': jnp.zeros([L,L,128]),
193+
'prev_pos': jnp.zeros([L,37,3])}
188194
ret = impl(batch={**batch, **prev}, is_training=is_training)
189195
ret["prev"] = get_prev(ret)
190196
if not return_representations:
@@ -413,8 +419,15 @@ def slice_recycle_idx(x):
413419
compute_loss=compute_loss,
414420
ensemble_representations=ensemble_representations)
415421

416-
emb_config = self.config.embeddings_and_evoformer
417-
ret = do_call(prev=batch.pop("prev"), recycle_idx=0)
422+
emb_config = self.config.embeddings_and_evoformer
423+
prev = batch.pop("prev",None)
424+
if prev is None:
425+
L = num_residues
426+
prev = {'prev_msa_first_row': jnp.zeros([L,256]),
427+
'prev_pair': jnp.zeros([L,L,128]),
428+
'prev_pos': jnp.zeros([L,37,3])}
429+
430+
ret = do_call(prev=prev, recycle_idx=0)
418431
ret["prev"] = get_prev(ret)
419432

420433
if compute_loss:
@@ -2222,4 +2235,4 @@ def map_fn(batch):
22222235
# No gradients if no templates.
22232236
embedding *= (jnp.sum(template_mask) > 0.).astype(embedding.dtype)
22242237

2225-
return embedding
2238+
return embedding

alphafold/model/modules_multimer.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -442,17 +442,7 @@ def apply_network(prev, safe_key):
442442
batch=recycled_batch,
443443
is_training=is_training,
444444
safe_key=safe_key)
445-
446-
#########################################
447-
num_iter = c.num_recycle
448-
def key_body(i, k):
449-
k_ = jax.random.split(k[0])
450-
o = jax.lax.cond(i==num_iter, lambda _:k[0], lambda _:k_[1], None)
451-
return [k_[0],o]
452-
k = safe_key.get()
453-
safe_key = prng.SafeKey(jax.lax.fori_loop(0,batch.pop("iter")+1, key_body, [k,k])[1])
454-
##########################################
455-
445+
456446
ret = apply_network(prev=batch.pop("prev"), safe_key=safe_key)
457447
ret["prev"] = get_prev(ret)
458448

alphafold/model/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,14 @@ def mask_mean(mask, value, axis=None, drop_mask_channel=False, eps=1e-10):
9292
(jnp.sum(mask, axis=axis) * broadcast_factor + eps))
9393

9494

95-
def flat_params_to_haiku(params, fuse=True):
95+
def flat_params_to_haiku(params, fuse=True, to_jnp=True):
9696
"""Convert a dictionary of NumPy arrays to Haiku parameters."""
9797
P = {}
9898
for path, array in params.items():
9999
scope, name = path.split('//')
100100
if scope not in P:
101101
P[scope] = {}
102-
P[scope][name] = jnp.array(array)
102+
P[scope][name] = jnp.array(array) if to_jnp else array
103103
for a in ["evoformer_iteration",
104104
"extra_msa_stack",
105105
"template_embedding/single_template_embedding/template_embedding_iteration",
@@ -113,7 +113,7 @@ def flat_params_to_haiku(params, fuse=True):
113113
R = P.pop(f"{k}/right_{c}")
114114
P[f"{k}/{c}"] = {}
115115
for d in ["bias","weights"]:
116-
P[f"{k}/{c}"][d] = jnp.concatenate([L[d],R[d]],-1)
116+
P[f"{k}/{c}"][d] = jnp.concatenate([L[d],R[d]],-1) if to_jnp else np.concatenate([L[d],R[d]],-1)
117117
P[f"{k}/center_norm"] = P.pop(f"{k}/center_layer_norm")
118118
P[f"{k}/left_norm_input"] = P.pop(f"{k}/layer_norm_input")
119119

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
setup(
2020
name='alphafold-colabfold',
21-
version='2.3.1',
21+
version='2.3.2',
2222
long_description_content_type='text/markdown',
2323
description='An implementation of the inference pipeline of AlphaFold v2.0.'
2424
'This is a completely new model that was entered as AlphaFold2 in CASP14 '

0 commit comments

Comments
 (0)