@@ -89,6 +89,8 @@ class FNet(tf.keras.layers.Layer):
89
89
layers. If set False, output of attention and intermediate dense layers is
90
90
normalized.
91
91
with_dense_inputs: Whether to accept dense embeddings as the input.
92
+ num_dense_tokens: Length of the token dimension of dense inputs if dense
93
+ inputs are used. This counts towards max_sequence_length.
92
94
"""
93
95
94
96
def __init__ (
@@ -113,6 +115,7 @@ def __init__(
113
115
embedding_layer : Optional [tf .keras .layers .Layer ] = None ,
114
116
norm_first : bool = False ,
115
117
with_dense_inputs : bool = False ,
118
+ num_dense_tokens : int = 0 ,
116
119
** kwargs ):
117
120
super ().__init__ (** kwargs )
118
121
@@ -142,6 +145,7 @@ def __init__(
142
145
'embedding_layer' : embedding_layer ,
143
146
'norm_first' : norm_first ,
144
147
'with_dense_inputs' : with_dense_inputs ,
148
+ 'num_dense_tokens' : num_dense_tokens ,
145
149
}
146
150
147
151
if embedding_layer is None :
@@ -220,20 +224,26 @@ def __init__(
220
224
name = 'pooler_transform' )
221
225
222
226
if with_dense_inputs :
227
+ if max_sequence_length - num_dense_tokens < 0 :
228
+ raise ValueError (
229
+ 'FNet: `max_sequence_length` should include dense tokens, but got '
230
+ '`max_sequence_length` - `num_dense_tokens` = {} - {} < 0.' .format (
231
+ max_sequence_length , num_dense_tokens ))
223
232
self .inputs = dict (
224
233
input_word_ids = tf .keras .Input (
225
- shape = (max_sequence_length ,), dtype = tf .int32 ),
234
+ shape = (max_sequence_length - num_dense_tokens ,), dtype = tf .int32 ),
226
235
input_mask = tf .keras .Input (
227
- shape = (max_sequence_length ,), dtype = tf .int32 ),
236
+ shape = (max_sequence_length - num_dense_tokens ,), dtype = tf .int32 ),
228
237
input_type_ids = tf .keras .Input (
229
- shape = (max_sequence_length ,), dtype = tf .int32 ),
238
+ shape = (max_sequence_length - num_dense_tokens ,), dtype = tf .int32 ),
230
239
dense_inputs = tf .keras .Input (
231
- shape = (max_sequence_length , embedding_width ), dtype = tf .float32 ),
240
+ shape = (num_dense_tokens , embedding_width ), dtype = tf .float32 ),
232
241
dense_mask = tf .keras .Input (
233
- shape = (max_sequence_length ,), dtype = tf .int32 ),
242
+ shape = (num_dense_tokens ,), dtype = tf .int32 ),
234
243
dense_type_ids = tf .keras .Input (
235
- shape = (max_sequence_length ,), dtype = tf .int32 ),
244
+ shape = (num_dense_tokens ,), dtype = tf .int32 ),
236
245
)
246
+
237
247
else :
238
248
self .inputs = dict (
239
249
input_word_ids = tf .keras .Input (
0 commit comments