Skip to content

Commit 69bbdc1

Browse files
Internal change
PiperOrigin-RevId: 485690158
1 parent 66f7666 commit 69bbdc1

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

official/nlp/modeling/layers/mixing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import functools
3131
from typing import Callable, Tuple, Union
3232

33+
import gin
3334
import numpy as np
3435
from scipy import linalg
3536
import tensorflow as tf
@@ -41,6 +42,7 @@
4142
default_kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev=2e-2)
4243

4344

45+
@gin.constants_from_enum
4446
class MixingMechanism(enum.Enum):
4547
"""Determines the type of mixing layer.
4648

official/nlp/modeling/networks/fnet.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ class FNet(tf.keras.layers.Layer):
8989
layers. If set False, output of attention and intermediate dense layers is
9090
normalized.
9191
with_dense_inputs: Whether to accept dense embeddings as the input.
92+
num_dense_tokens: Length of the token dimension of dense inputs if dense
93+
inputs are used. This counts towards max_sequence_length.
9294
"""
9395

9496
def __init__(
@@ -113,6 +115,7 @@ def __init__(
113115
embedding_layer: Optional[tf.keras.layers.Layer] = None,
114116
norm_first: bool = False,
115117
with_dense_inputs: bool = False,
118+
num_dense_tokens: int = 0,
116119
**kwargs):
117120
super().__init__(**kwargs)
118121

@@ -142,6 +145,7 @@ def __init__(
142145
'embedding_layer': embedding_layer,
143146
'norm_first': norm_first,
144147
'with_dense_inputs': with_dense_inputs,
148+
'num_dense_tokens': num_dense_tokens,
145149
}
146150

147151
if embedding_layer is None:
@@ -220,20 +224,26 @@ def __init__(
220224
name='pooler_transform')
221225

222226
if with_dense_inputs:
227+
if max_sequence_length - num_dense_tokens < 0:
228+
raise ValueError(
229+
'FNet: `max_sequence_length` should include dense tokens, but got '
230+
'`max_sequence_length` - `num_dense_tokens` = {} - {} < 0.'.format(
231+
max_sequence_length, num_dense_tokens))
223232
self.inputs = dict(
224233
input_word_ids=tf.keras.Input(
225-
shape=(max_sequence_length,), dtype=tf.int32),
234+
shape=(max_sequence_length - num_dense_tokens,), dtype=tf.int32),
226235
input_mask=tf.keras.Input(
227-
shape=(max_sequence_length,), dtype=tf.int32),
236+
shape=(max_sequence_length - num_dense_tokens,), dtype=tf.int32),
228237
input_type_ids=tf.keras.Input(
229-
shape=(max_sequence_length,), dtype=tf.int32),
238+
shape=(max_sequence_length - num_dense_tokens,), dtype=tf.int32),
230239
dense_inputs=tf.keras.Input(
231-
shape=(max_sequence_length, embedding_width), dtype=tf.float32),
240+
shape=(num_dense_tokens, embedding_width), dtype=tf.float32),
232241
dense_mask=tf.keras.Input(
233-
shape=(max_sequence_length,), dtype=tf.int32),
242+
shape=(num_dense_tokens,), dtype=tf.int32),
234243
dense_type_ids=tf.keras.Input(
235-
shape=(max_sequence_length,), dtype=tf.int32),
244+
shape=(num_dense_tokens,), dtype=tf.int32),
236245
)
246+
237247
else:
238248
self.inputs = dict(
239249
input_word_ids=tf.keras.Input(

0 commit comments

Comments
 (0)