Skip to content

Commit 18d6279

Browse files
remenberltensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 395322666
1 parent 336c813 commit 18d6279

File tree

5 files changed

+409
-0
lines changed

5 files changed

+409
-0
lines changed

official/nlp/configs/experiment_configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@
1717
from official.nlp.configs import finetuning_experiments
1818
from official.nlp.configs import pretraining_experiments
1919
from official.nlp.configs import wmt_transformer_experiments
20+
from official.nlp.projects.teams import teams_experiments
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Lint as: python3
16+
# pylint: disable=g-doc-return-or-yield,line-too-long
17+
"""TEAMS experiments."""
18+
import dataclasses
19+
from official.core import config_definitions as cfg
20+
from official.core import exp_factory
21+
from official.modeling import optimization
22+
from official.nlp.data import pretrain_dataloader
23+
from official.nlp.projects.teams import teams_task
24+
25+
26+
AdamWeightDecay = optimization.AdamWeightDecayConfig
27+
PolynomialLr = optimization.PolynomialLrConfig
28+
PolynomialWarmupConfig = optimization.PolynomialWarmupConfig
29+
30+
31+
@dataclasses.dataclass
32+
class TeamsOptimizationConfig(optimization.OptimizationConfig):
33+
"""TEAMS optimization config."""
34+
optimizer: optimization.OptimizerConfig = optimization.OptimizerConfig(
35+
type="adamw",
36+
adamw=AdamWeightDecay(
37+
weight_decay_rate=0.01,
38+
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
39+
epsilon=1e-6))
40+
learning_rate: optimization.LrConfig = optimization.LrConfig(
41+
type="polynomial",
42+
polynomial=PolynomialLr(
43+
initial_learning_rate=1e-4,
44+
decay_steps=1000000,
45+
end_learning_rate=0.0))
46+
warmup: optimization.WarmupConfig = optimization.WarmupConfig(
47+
type="polynomial", polynomial=PolynomialWarmupConfig(warmup_steps=10000))
48+
49+
50+
@exp_factory.register_config_factory("teams/pretraining")
51+
def teams_pretrain() -> cfg.ExperimentConfig:
52+
"""TEAMS pretraining."""
53+
config = cfg.ExperimentConfig(
54+
task=teams_task.TeamsPretrainTaskConfig(
55+
train_data=pretrain_dataloader.BertPretrainDataConfig(),
56+
validation_data=pretrain_dataloader.BertPretrainDataConfig(
57+
is_training=False)),
58+
trainer=cfg.TrainerConfig(
59+
optimizer_config=TeamsOptimizationConfig(), train_steps=1000000),
60+
restrictions=[
61+
"task.train_data.is_training != None",
62+
"task.validation_data.is_training != None"
63+
])
64+
return config
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Lint as: python3
16+
"""Tests for teams_experiments."""
17+
18+
from absl.testing import parameterized
19+
import tensorflow as tf
20+
21+
# pylint: disable=unused-import
22+
from official.common import registry_imports
23+
# pylint: enable=unused-import
24+
from official.core import config_definitions as cfg
25+
from official.core import exp_factory
26+
27+
28+
class TeamsExperimentsTest(tf.test.TestCase, parameterized.TestCase):
29+
30+
@parameterized.parameters(('teams/pretraining',))
31+
def test_teams_experiments(self, config_name):
32+
config = exp_factory.get_exp_config(config_name)
33+
self.assertIsInstance(config, cfg.ExperimentConfig)
34+
self.assertIsInstance(config.task.train_data, cfg.DataConfig)
35+
36+
37+
if __name__ == '__main__':
38+
tf.test.main()
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""TEAMS pretraining task (Joint Masked LM, Replaced Token Detection and )."""
16+
17+
import dataclasses
18+
import tensorflow as tf
19+
20+
from official.core import base_task
21+
from official.core import config_definitions as cfg
22+
from official.core import task_factory
23+
from official.modeling import tf_utils
24+
from official.nlp.data import pretrain_dataloader
25+
from official.nlp.modeling import layers
26+
from official.nlp.projects.teams import teams
27+
from official.nlp.projects.teams import teams_pretrainer
28+
29+
30+
@dataclasses.dataclass
31+
class TeamsPretrainTaskConfig(cfg.TaskConfig):
32+
"""The model config."""
33+
model: teams.TeamsPretrainerConfig = teams.TeamsPretrainerConfig()
34+
train_data: cfg.DataConfig = cfg.DataConfig()
35+
validation_data: cfg.DataConfig = cfg.DataConfig()
36+
37+
38+
def _get_generator_hidden_layers(discriminator_network, num_hidden_layers,
39+
num_shared_layers):
40+
if num_shared_layers <= 0:
41+
num_shared_layers = 0
42+
hidden_layers = []
43+
else:
44+
hidden_layers = discriminator_network.hidden_layers[:num_shared_layers]
45+
for _ in range(num_shared_layers, num_hidden_layers):
46+
hidden_layers.append(layers.Transformer)
47+
return hidden_layers
48+
49+
50+
def _build_pretrainer(
51+
config: teams.TeamsPretrainerConfig) -> teams_pretrainer.TeamsPretrainer:
52+
"""Instantiates ElectraPretrainer from the config."""
53+
generator_encoder_cfg = config.generator
54+
discriminator_encoder_cfg = config.discriminator
55+
discriminator_network = teams.get_encoder(discriminator_encoder_cfg)
56+
# Copy discriminator's embeddings to generator for easier model serialization.
57+
hidden_layers = _get_generator_hidden_layers(
58+
discriminator_network, generator_encoder_cfg.num_layers,
59+
config.num_shared_generator_hidden_layers)
60+
if config.tie_embeddings:
61+
generator_network = teams.get_encoder(
62+
generator_encoder_cfg,
63+
embedding_network=discriminator_network.embedding_network,
64+
hidden_layers=hidden_layers)
65+
else:
66+
generator_network = teams.get_encoder(
67+
generator_encoder_cfg, hidden_layers=hidden_layers)
68+
69+
return teams_pretrainer.TeamsPretrainer(
70+
generator_network=generator_network,
71+
discriminator_mws_network=discriminator_network,
72+
num_discriminator_task_agnostic_layers=config
73+
.num_discriminator_task_agnostic_layers,
74+
vocab_size=generator_encoder_cfg.vocab_size,
75+
candidate_size=config.candidate_size,
76+
mlm_activation=tf_utils.get_activation(
77+
generator_encoder_cfg.hidden_activation),
78+
mlm_initializer=tf.keras.initializers.TruncatedNormal(
79+
stddev=generator_encoder_cfg.initializer_range))
80+
81+
82+
@task_factory.register_task_cls(TeamsPretrainTaskConfig)
83+
class TeamsPretrainTask(base_task.Task):
84+
"""TEAMS Pretrain Task (Masked LM + RTD + MWS)."""
85+
86+
def build_model(self):
87+
return _build_pretrainer(self.task_config.model)
88+
89+
def build_losses(self,
90+
labels,
91+
model_outputs,
92+
metrics,
93+
aux_losses=None) -> tf.Tensor:
94+
with tf.name_scope('TeamsPretrainTask/losses'):
95+
metrics = dict([(metric.name, metric) for metric in metrics])
96+
97+
# Generator MLM loss.
98+
lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy(
99+
labels['masked_lm_ids'],
100+
tf.cast(model_outputs['lm_outputs'], tf.float32),
101+
from_logits=True)
102+
lm_label_weights = labels['masked_lm_weights']
103+
lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights)
104+
lm_denominator_loss = tf.reduce_sum(lm_label_weights)
105+
mlm_loss = tf.math.divide_no_nan(lm_numerator_loss, lm_denominator_loss)
106+
metrics['masked_lm_loss'].update_state(mlm_loss)
107+
weight = self.task_config.model.generator_loss_weight
108+
total_loss = weight * mlm_loss
109+
110+
# Discriminator RTD loss.
111+
rtd_logits = model_outputs['disc_rtd_logits']
112+
rtd_labels = tf.cast(model_outputs['disc_rtd_label'], tf.float32)
113+
input_mask = tf.cast(labels['input_mask'], tf.float32)
114+
rtd_ind_loss = tf.nn.sigmoid_cross_entropy_with_logits(
115+
logits=rtd_logits, labels=rtd_labels)
116+
rtd_numerator = tf.reduce_sum(input_mask * rtd_ind_loss)
117+
rtd_denominator = tf.reduce_sum(input_mask)
118+
rtd_loss = tf.math.divide_no_nan(rtd_numerator, rtd_denominator)
119+
metrics['replaced_token_detection_loss'].update_state(rtd_loss)
120+
weight = self.task_config.model.discriminator_rtd_loss_weight
121+
total_loss = total_loss + weight * rtd_loss
122+
123+
# Discriminator MWS loss.
124+
mws_logits = model_outputs['disc_mws_logits']
125+
mws_labels = model_outputs['disc_mws_label']
126+
mws_loss = tf.keras.losses.sparse_categorical_crossentropy(
127+
mws_labels, mws_logits, from_logits=True)
128+
mws_numerator_loss = tf.reduce_sum(mws_loss * lm_label_weights)
129+
mws_denominator_loss = tf.reduce_sum(lm_label_weights)
130+
mws_loss = tf.math.divide_no_nan(mws_numerator_loss, mws_denominator_loss)
131+
metrics['multiword_selection_loss'].update_state(mws_loss)
132+
weight = self.task_config.model.discriminator_mws_loss_weight
133+
total_loss = total_loss + weight * mws_loss
134+
135+
if aux_losses:
136+
total_loss += tf.add_n(aux_losses)
137+
138+
metrics['total_loss'].update_state(total_loss)
139+
return total_loss
140+
141+
def build_inputs(self, params, input_context=None):
142+
"""Returns tf.data.Dataset for pretraining."""
143+
if params.input_path == 'dummy':
144+
145+
def dummy_data(_):
146+
dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
147+
dummy_lm = tf.zeros((1, params.max_predictions_per_seq), dtype=tf.int32)
148+
return dict(
149+
input_word_ids=dummy_ids,
150+
input_mask=dummy_ids,
151+
input_type_ids=dummy_ids,
152+
masked_lm_positions=dummy_lm,
153+
masked_lm_ids=dummy_lm,
154+
masked_lm_weights=tf.cast(dummy_lm, dtype=tf.float32))
155+
156+
dataset = tf.data.Dataset.range(1)
157+
dataset = dataset.repeat()
158+
dataset = dataset.map(
159+
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
160+
return dataset
161+
162+
return pretrain_dataloader.BertPretrainDataLoader(params).load(
163+
input_context)
164+
165+
def build_metrics(self, training=None):
166+
del training
167+
metrics = [
168+
tf.keras.metrics.SparseCategoricalAccuracy(name='masked_lm_accuracy'),
169+
tf.keras.metrics.Mean(name='masked_lm_loss'),
170+
tf.keras.metrics.SparseCategoricalAccuracy(
171+
name='replaced_token_detection_accuracy'),
172+
tf.keras.metrics.Mean(name='replaced_token_detection_loss'),
173+
tf.keras.metrics.SparseCategoricalAccuracy(
174+
name='multiword_selection_accuracy'),
175+
tf.keras.metrics.Mean(name='multiword_selection_loss'),
176+
tf.keras.metrics.Mean(name='total_loss'),
177+
]
178+
return metrics
179+
180+
def process_metrics(self, metrics, labels, model_outputs):
181+
with tf.name_scope('TeamsPretrainTask/process_metrics'):
182+
metrics = dict([(metric.name, metric) for metric in metrics])
183+
if 'masked_lm_accuracy' in metrics:
184+
metrics['masked_lm_accuracy'].update_state(labels['masked_lm_ids'],
185+
model_outputs['lm_outputs'],
186+
labels['masked_lm_weights'])
187+
188+
if 'replaced_token_detection_accuracy' in metrics:
189+
rtd_logits_expanded = tf.expand_dims(model_outputs['disc_rtd_logits'],
190+
-1)
191+
rtd_full_logits = tf.concat(
192+
[-1.0 * rtd_logits_expanded, rtd_logits_expanded], -1)
193+
metrics['replaced_token_detection_accuracy'].update_state(
194+
model_outputs['disc_rtd_label'], rtd_full_logits,
195+
labels['input_mask'])
196+
197+
if 'multiword_selection_accuracy' in metrics:
198+
metrics['multiword_selection_accuracy'].update_state(
199+
model_outputs['disc_mws_label'], model_outputs['disc_mws_logits'],
200+
labels['masked_lm_weights'])
201+
202+
def train_step(self, inputs, model: tf.keras.Model,
203+
optimizer: tf.keras.optimizers.Optimizer, metrics):
204+
"""Does forward and backward.
205+
206+
Args:
207+
inputs: a dictionary of input tensors.
208+
model: the model, forward pass definition.
209+
optimizer: the optimizer for this training step.
210+
metrics: a nested structure of metrics objects.
211+
212+
Returns:
213+
A dictionary of logs.
214+
"""
215+
with tf.GradientTape() as tape:
216+
outputs = model(inputs, training=True)
217+
# Computes per-replica loss.
218+
loss = self.build_losses(
219+
labels=inputs,
220+
model_outputs=outputs,
221+
metrics=metrics,
222+
aux_losses=model.losses)
223+
# Scales loss as the default gradients allreduce performs sum inside the
224+
# optimizer.
225+
scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
226+
tvars = model.trainable_variables
227+
grads = tape.gradient(scaled_loss, tvars)
228+
optimizer.apply_gradients(list(zip(grads, tvars)))
229+
self.process_metrics(metrics, inputs, outputs)
230+
return {self.loss: loss}
231+
232+
def validation_step(self, inputs, model: tf.keras.Model, metrics):
233+
"""Validatation step.
234+
235+
Args:
236+
inputs: a dictionary of input tensors.
237+
model: the keras.Model.
238+
metrics: a nested structure of metrics objects.
239+
240+
Returns:
241+
A dictionary of logs.
242+
"""
243+
outputs = model(inputs, training=False)
244+
loss = self.build_losses(
245+
labels=inputs,
246+
model_outputs=outputs,
247+
metrics=metrics,
248+
aux_losses=model.losses)
249+
self.process_metrics(metrics, inputs, outputs)
250+
return {self.loss: loss}

0 commit comments

Comments
 (0)