Skip to content

Commit

Permalink
Add param_dtype to AddPositionEmbs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 615347924
  • Loading branch information
jpuigcerver authored and copybara-github committed Mar 13, 2024
1 parent 3d8e814 commit 1ff0176
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion vit_jax/models_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class AddPositionEmbs(nn.Module):
"""

posemb_init: Callable[[PRNGKey, Shape, Dtype], Array]
param_dtype: Dtype = jnp.float32

@nn.compact
def __call__(self, inputs):
Expand All @@ -57,7 +58,8 @@ def __call__(self, inputs):
assert inputs.ndim == 3, ('Number of dimensions should be 3,'
' but it is: %d' % inputs.ndim)
pos_emb_shape = (1, inputs.shape[1], inputs.shape[2])
pe = self.param('pos_embedding', self.posemb_init, pos_emb_shape)
pe = self.param(
'pos_embedding', self.posemb_init, pos_emb_shape, self.param_dtype)
return inputs + pe


Expand Down

0 comments on commit 1ff0176

Please sign in to comment.