|
| 1 | +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""TEAMS pretraining task (Joint Masked LM, Replaced Token Detection and ).""" |
| 16 | + |
| 17 | +import dataclasses |
| 18 | +import tensorflow as tf |
| 19 | + |
| 20 | +from official.core import base_task |
| 21 | +from official.core import config_definitions as cfg |
| 22 | +from official.core import task_factory |
| 23 | +from official.modeling import tf_utils |
| 24 | +from official.nlp.data import pretrain_dataloader |
| 25 | +from official.nlp.modeling import layers |
| 26 | +from official.nlp.projects.teams import teams |
| 27 | +from official.nlp.projects.teams import teams_pretrainer |
| 28 | + |
| 29 | + |
| 30 | +@dataclasses.dataclass |
| 31 | +class TeamsPretrainTaskConfig(cfg.TaskConfig): |
| 32 | + """The model config.""" |
| 33 | + model: teams.TeamsPretrainerConfig = teams.TeamsPretrainerConfig() |
| 34 | + train_data: cfg.DataConfig = cfg.DataConfig() |
| 35 | + validation_data: cfg.DataConfig = cfg.DataConfig() |
| 36 | + |
| 37 | + |
| 38 | +def _get_generator_hidden_layers(discriminator_network, num_hidden_layers, |
| 39 | + num_shared_layers): |
| 40 | + if num_shared_layers <= 0: |
| 41 | + num_shared_layers = 0 |
| 42 | + hidden_layers = [] |
| 43 | + else: |
| 44 | + hidden_layers = discriminator_network.hidden_layers[:num_shared_layers] |
| 45 | + for _ in range(num_shared_layers, num_hidden_layers): |
| 46 | + hidden_layers.append(layers.Transformer) |
| 47 | + return hidden_layers |
| 48 | + |
| 49 | + |
| 50 | +def _build_pretrainer( |
| 51 | + config: teams.TeamsPretrainerConfig) -> teams_pretrainer.TeamsPretrainer: |
| 52 | + """Instantiates ElectraPretrainer from the config.""" |
| 53 | + generator_encoder_cfg = config.generator |
| 54 | + discriminator_encoder_cfg = config.discriminator |
| 55 | + discriminator_network = teams.get_encoder(discriminator_encoder_cfg) |
| 56 | + # Copy discriminator's embeddings to generator for easier model serialization. |
| 57 | + hidden_layers = _get_generator_hidden_layers( |
| 58 | + discriminator_network, generator_encoder_cfg.num_layers, |
| 59 | + config.num_shared_generator_hidden_layers) |
| 60 | + if config.tie_embeddings: |
| 61 | + generator_network = teams.get_encoder( |
| 62 | + generator_encoder_cfg, |
| 63 | + embedding_network=discriminator_network.embedding_network, |
| 64 | + hidden_layers=hidden_layers) |
| 65 | + else: |
| 66 | + generator_network = teams.get_encoder( |
| 67 | + generator_encoder_cfg, hidden_layers=hidden_layers) |
| 68 | + |
| 69 | + return teams_pretrainer.TeamsPretrainer( |
| 70 | + generator_network=generator_network, |
| 71 | + discriminator_mws_network=discriminator_network, |
| 72 | + num_discriminator_task_agnostic_layers=config |
| 73 | + .num_discriminator_task_agnostic_layers, |
| 74 | + vocab_size=generator_encoder_cfg.vocab_size, |
| 75 | + candidate_size=config.candidate_size, |
| 76 | + mlm_activation=tf_utils.get_activation( |
| 77 | + generator_encoder_cfg.hidden_activation), |
| 78 | + mlm_initializer=tf.keras.initializers.TruncatedNormal( |
| 79 | + stddev=generator_encoder_cfg.initializer_range)) |
| 80 | + |
| 81 | + |
| 82 | +@task_factory.register_task_cls(TeamsPretrainTaskConfig) |
| 83 | +class TeamsPretrainTask(base_task.Task): |
| 84 | + """TEAMS Pretrain Task (Masked LM + RTD + MWS).""" |
| 85 | + |
| 86 | + def build_model(self): |
| 87 | + return _build_pretrainer(self.task_config.model) |
| 88 | + |
| 89 | + def build_losses(self, |
| 90 | + labels, |
| 91 | + model_outputs, |
| 92 | + metrics, |
| 93 | + aux_losses=None) -> tf.Tensor: |
| 94 | + with tf.name_scope('TeamsPretrainTask/losses'): |
| 95 | + metrics = dict([(metric.name, metric) for metric in metrics]) |
| 96 | + |
| 97 | + # Generator MLM loss. |
| 98 | + lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy( |
| 99 | + labels['masked_lm_ids'], |
| 100 | + tf.cast(model_outputs['lm_outputs'], tf.float32), |
| 101 | + from_logits=True) |
| 102 | + lm_label_weights = labels['masked_lm_weights'] |
| 103 | + lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights) |
| 104 | + lm_denominator_loss = tf.reduce_sum(lm_label_weights) |
| 105 | + mlm_loss = tf.math.divide_no_nan(lm_numerator_loss, lm_denominator_loss) |
| 106 | + metrics['masked_lm_loss'].update_state(mlm_loss) |
| 107 | + weight = self.task_config.model.generator_loss_weight |
| 108 | + total_loss = weight * mlm_loss |
| 109 | + |
| 110 | + # Discriminator RTD loss. |
| 111 | + rtd_logits = model_outputs['disc_rtd_logits'] |
| 112 | + rtd_labels = tf.cast(model_outputs['disc_rtd_label'], tf.float32) |
| 113 | + input_mask = tf.cast(labels['input_mask'], tf.float32) |
| 114 | + rtd_ind_loss = tf.nn.sigmoid_cross_entropy_with_logits( |
| 115 | + logits=rtd_logits, labels=rtd_labels) |
| 116 | + rtd_numerator = tf.reduce_sum(input_mask * rtd_ind_loss) |
| 117 | + rtd_denominator = tf.reduce_sum(input_mask) |
| 118 | + rtd_loss = tf.math.divide_no_nan(rtd_numerator, rtd_denominator) |
| 119 | + metrics['replaced_token_detection_loss'].update_state(rtd_loss) |
| 120 | + weight = self.task_config.model.discriminator_rtd_loss_weight |
| 121 | + total_loss = total_loss + weight * rtd_loss |
| 122 | + |
| 123 | + # Discriminator MWS loss. |
| 124 | + mws_logits = model_outputs['disc_mws_logits'] |
| 125 | + mws_labels = model_outputs['disc_mws_label'] |
| 126 | + mws_loss = tf.keras.losses.sparse_categorical_crossentropy( |
| 127 | + mws_labels, mws_logits, from_logits=True) |
| 128 | + mws_numerator_loss = tf.reduce_sum(mws_loss * lm_label_weights) |
| 129 | + mws_denominator_loss = tf.reduce_sum(lm_label_weights) |
| 130 | + mws_loss = tf.math.divide_no_nan(mws_numerator_loss, mws_denominator_loss) |
| 131 | + metrics['multiword_selection_loss'].update_state(mws_loss) |
| 132 | + weight = self.task_config.model.discriminator_mws_loss_weight |
| 133 | + total_loss = total_loss + weight * mws_loss |
| 134 | + |
| 135 | + if aux_losses: |
| 136 | + total_loss += tf.add_n(aux_losses) |
| 137 | + |
| 138 | + metrics['total_loss'].update_state(total_loss) |
| 139 | + return total_loss |
| 140 | + |
| 141 | + def build_inputs(self, params, input_context=None): |
| 142 | + """Returns tf.data.Dataset for pretraining.""" |
| 143 | + if params.input_path == 'dummy': |
| 144 | + |
| 145 | + def dummy_data(_): |
| 146 | + dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32) |
| 147 | + dummy_lm = tf.zeros((1, params.max_predictions_per_seq), dtype=tf.int32) |
| 148 | + return dict( |
| 149 | + input_word_ids=dummy_ids, |
| 150 | + input_mask=dummy_ids, |
| 151 | + input_type_ids=dummy_ids, |
| 152 | + masked_lm_positions=dummy_lm, |
| 153 | + masked_lm_ids=dummy_lm, |
| 154 | + masked_lm_weights=tf.cast(dummy_lm, dtype=tf.float32)) |
| 155 | + |
| 156 | + dataset = tf.data.Dataset.range(1) |
| 157 | + dataset = dataset.repeat() |
| 158 | + dataset = dataset.map( |
| 159 | + dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE) |
| 160 | + return dataset |
| 161 | + |
| 162 | + return pretrain_dataloader.BertPretrainDataLoader(params).load( |
| 163 | + input_context) |
| 164 | + |
| 165 | + def build_metrics(self, training=None): |
| 166 | + del training |
| 167 | + metrics = [ |
| 168 | + tf.keras.metrics.SparseCategoricalAccuracy(name='masked_lm_accuracy'), |
| 169 | + tf.keras.metrics.Mean(name='masked_lm_loss'), |
| 170 | + tf.keras.metrics.SparseCategoricalAccuracy( |
| 171 | + name='replaced_token_detection_accuracy'), |
| 172 | + tf.keras.metrics.Mean(name='replaced_token_detection_loss'), |
| 173 | + tf.keras.metrics.SparseCategoricalAccuracy( |
| 174 | + name='multiword_selection_accuracy'), |
| 175 | + tf.keras.metrics.Mean(name='multiword_selection_loss'), |
| 176 | + tf.keras.metrics.Mean(name='total_loss'), |
| 177 | + ] |
| 178 | + return metrics |
| 179 | + |
| 180 | + def process_metrics(self, metrics, labels, model_outputs): |
| 181 | + with tf.name_scope('TeamsPretrainTask/process_metrics'): |
| 182 | + metrics = dict([(metric.name, metric) for metric in metrics]) |
| 183 | + if 'masked_lm_accuracy' in metrics: |
| 184 | + metrics['masked_lm_accuracy'].update_state(labels['masked_lm_ids'], |
| 185 | + model_outputs['lm_outputs'], |
| 186 | + labels['masked_lm_weights']) |
| 187 | + |
| 188 | + if 'replaced_token_detection_accuracy' in metrics: |
| 189 | + rtd_logits_expanded = tf.expand_dims(model_outputs['disc_rtd_logits'], |
| 190 | + -1) |
| 191 | + rtd_full_logits = tf.concat( |
| 192 | + [-1.0 * rtd_logits_expanded, rtd_logits_expanded], -1) |
| 193 | + metrics['replaced_token_detection_accuracy'].update_state( |
| 194 | + model_outputs['disc_rtd_label'], rtd_full_logits, |
| 195 | + labels['input_mask']) |
| 196 | + |
| 197 | + if 'multiword_selection_accuracy' in metrics: |
| 198 | + metrics['multiword_selection_accuracy'].update_state( |
| 199 | + model_outputs['disc_mws_label'], model_outputs['disc_mws_logits'], |
| 200 | + labels['masked_lm_weights']) |
| 201 | + |
| 202 | + def train_step(self, inputs, model: tf.keras.Model, |
| 203 | + optimizer: tf.keras.optimizers.Optimizer, metrics): |
| 204 | + """Does forward and backward. |
| 205 | +
|
| 206 | + Args: |
| 207 | + inputs: a dictionary of input tensors. |
| 208 | + model: the model, forward pass definition. |
| 209 | + optimizer: the optimizer for this training step. |
| 210 | + metrics: a nested structure of metrics objects. |
| 211 | +
|
| 212 | + Returns: |
| 213 | + A dictionary of logs. |
| 214 | + """ |
| 215 | + with tf.GradientTape() as tape: |
| 216 | + outputs = model(inputs, training=True) |
| 217 | + # Computes per-replica loss. |
| 218 | + loss = self.build_losses( |
| 219 | + labels=inputs, |
| 220 | + model_outputs=outputs, |
| 221 | + metrics=metrics, |
| 222 | + aux_losses=model.losses) |
| 223 | + # Scales loss as the default gradients allreduce performs sum inside the |
| 224 | + # optimizer. |
| 225 | + scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync |
| 226 | + tvars = model.trainable_variables |
| 227 | + grads = tape.gradient(scaled_loss, tvars) |
| 228 | + optimizer.apply_gradients(list(zip(grads, tvars))) |
| 229 | + self.process_metrics(metrics, inputs, outputs) |
| 230 | + return {self.loss: loss} |
| 231 | + |
| 232 | + def validation_step(self, inputs, model: tf.keras.Model, metrics): |
| 233 | + """Validatation step. |
| 234 | +
|
| 235 | + Args: |
| 236 | + inputs: a dictionary of input tensors. |
| 237 | + model: the keras.Model. |
| 238 | + metrics: a nested structure of metrics objects. |
| 239 | +
|
| 240 | + Returns: |
| 241 | + A dictionary of logs. |
| 242 | + """ |
| 243 | + outputs = model(inputs, training=False) |
| 244 | + loss = self.build_losses( |
| 245 | + labels=inputs, |
| 246 | + model_outputs=outputs, |
| 247 | + metrics=metrics, |
| 248 | + aux_losses=model.losses) |
| 249 | + self.process_metrics(metrics, inputs, outputs) |
| 250 | + return {self.loss: loss} |
0 commit comments