|
| 1 | +// Copyright (c) Zhongkai Fu. All rights reserved. |
| 2 | +// https://github.com/zhongkaifu/Seq2SeqSharp |
| 3 | +// |
| 4 | +// This file is part of Seq2SeqSharp. |
| 5 | +// |
| 6 | +// Seq2SeqSharp is licensed under the BSD-3-Clause license found in the LICENSE file in the root directory of this source tree. |
| 7 | +// |
| 8 | +// Seq2SeqSharp is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of |
| 9 | +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD-3-Clause License for more details. |
| 10 | + |
| 11 | +using System; |
| 12 | +using System.Collections.Generic; |
| 13 | +using System.IO; |
| 14 | +using AdvUtils; |
| 15 | +using System.Runtime.Caching; |
| 16 | +using Seq2SeqSharp.Enums; |
| 17 | +using Seq2SeqSharp.Corpus; |
| 18 | +using Seq2SeqSharp.Layers; |
| 19 | +using Seq2SeqSharp.Models; |
| 20 | +using Seq2SeqSharp.Tools; |
| 21 | +using Seq2SeqSharp.Utils; |
| 22 | +using TensorSharp; |
| 23 | +using ManagedCuda.BasicTypes; |
| 24 | + |
| 25 | +namespace Seq2SeqSharp.Applications |
| 26 | +{ |
| 27 | + public class DPO : BaseSeq2SeqFramework<Seq2SeqModel> |
| 28 | + { |
| 29 | + // Trainable parameters including networks and tensors |
| 30 | + private MultiProcessorNetworkWrapper<IWeightTensor> m_tgtEmbedding = null; //The embeddings over devices for source |
| 31 | + private MultiProcessorNetworkWrapper<IDecoder> m_decoder = null; //The decoders over devices |
| 32 | + private MultiProcessorNetworkWrapper<IFeedForwardLayer> m_decoderFFLayer = null; //The feed forward layers over devices after all layers in decoder |
| 33 | + private MultiProcessorNetworkWrapper<IWeightTensor> m_segmentEmbedding = null; |
| 34 | + private MultiProcessorNetworkWrapper<IWeightTensor> m_posEmbedding = null; |
| 35 | + |
| 36 | + |
| 37 | + |
| 38 | + |
| 39 | + |
| 40 | + |
| 41 | + private MultiProcessorNetworkWrapper<IWeightTensor> ref_m_tgtEmbedding = null; //The embeddings over devices for source |
| 42 | + private MultiProcessorNetworkWrapper<IDecoder> ref_m_decoder = null; //The decoders over devices |
| 43 | + private MultiProcessorNetworkWrapper<IFeedForwardLayer> ref_m_decoderFFLayer = null; //The feed forward layers over devices after all layers in decoder |
| 44 | + private MultiProcessorNetworkWrapper<IWeightTensor> ref_m_segmentEmbedding = null; |
| 45 | + private MultiProcessorNetworkWrapper<IWeightTensor> ref_m_posEmbedding = null; |
| 46 | + |
| 47 | + |
| 48 | + |
| 49 | + |
| 50 | + private readonly PaddingEnums m_paddingType = PaddingEnums.AllowPadding; |
| 51 | + readonly Seq2SeqOptions m_options = null; |
| 52 | + |
| 53 | + public event EventHandler KVCacheRemoveWatcher; |
| 54 | + |
| 55 | + public DPO(Seq2SeqOptions options, Vocab tgtVocab = null) |
| 56 | + : base(deviceIds: options.DeviceIds, processorType: options.ProcessorType, modelFilePath: options.ModelFilePath, memoryUsageRatio: options.MemoryUsageRatio, |
| 57 | + compilerOptions: options.CompilerOptions, runValidEveryUpdates: options.RunValidEveryUpdates, updateFreq: options.UpdateFreq, |
| 58 | + startToRunValidAfterUpdates: options.StartValidAfterUpdates, maxDegressOfParallelism: options.TaskParallelism, mklInstructions: options.MKLInstructions, weightsUpdateCount: options.WeightsUpdateCount, |
| 59 | + enableTensorCore: options.EnableTensorCore, cudaMemoryAllocatorType: options.CudaMemoryAllocatorType, elementType: options.AMP ? DType.Float16 : DType.Float32, randomSeed: options.RandomSeed, |
| 60 | + saveModelEveryUpdats: options.SaveModelEveryUpdates, saveGPUMemoryLevel: options.SaveGPUMemoryLevel, initLossScaling: options.InitLossScaling, autoCheckTensorCorruption: options.CheckTensorCorrupted, |
| 61 | + attentionType: options.AttentionType) |
| 62 | + { |
| 63 | + m_paddingType = options.PaddingType; |
| 64 | + m_options = options; |
| 65 | + |
| 66 | + // Check if options are valided. |
| 67 | + m_options.ValidateOptions(); |
| 68 | + if (File.Exists(m_options.ModelFilePath)) |
| 69 | + { |
| 70 | + if (tgtVocab != null) |
| 71 | + { |
| 72 | + throw new ArgumentException($"Model '{m_options.ModelFilePath}' exists and it includes vocabulary, so input vocabulary must be null."); |
| 73 | + } |
| 74 | + |
| 75 | + // Model file exists, so we load it from file. |
| 76 | + m_modelMetaData = LoadModel(); |
| 77 | + } |
| 78 | + else |
| 79 | + { |
| 80 | + // Model doesn't exist, we create it and initlaize parameters |
| 81 | + m_modelMetaData = new Seq2SeqModel(options, null, tgtVocab); |
| 82 | + |
| 83 | + //Initializng weights in encoders and decoders |
| 84 | + CreateTrainableParameters(m_modelMetaData); |
| 85 | + } |
| 86 | + |
| 87 | + m_modelMetaData.EncoderType = EncoderTypeEnums.None; |
| 88 | + m_modelMetaData.DecoderType = DecoderTypeEnums.GPTDecoder; |
| 89 | + m_modelMetaData.ShowModelInfo(); |
| 90 | + } |
| 91 | + |
| 92 | + public void UpdateVocabs(Vocab tgtVocab) |
| 93 | + { |
| 94 | + if (tgtVocab != null) |
| 95 | + { |
| 96 | + m_modelMetaData.TgtVocab = tgtVocab; |
| 97 | + } |
| 98 | + |
| 99 | + SaveModel(createBackupPrevious: true, suffix: ".updatevocab"); |
| 100 | + } |
| 101 | + |
| 102 | + public void VQModel() |
| 103 | + { |
| 104 | + m_modelMetaData.VQType = m_options.VQType; |
| 105 | + SaveModel(createBackupPrevious: true, suffix: $".{m_modelMetaData.VQType.ToString()}"); |
| 106 | + |
| 107 | + } |
| 108 | + |
| 109 | + protected override Seq2SeqModel LoadModel(string suffix = "") => base.LoadModelRoutine<Model_4_ProtoBufSerializer>(CreateTrainableParameters, Seq2SeqModel.Create, suffix); |
| 110 | + |
| 111 | + private bool CreateTrainableParameters(IModel model) |
| 112 | + { |
| 113 | + CreateDPOModel(model); |
| 114 | + CreateRefModel(model); |
| 115 | + |
| 116 | + return true; |
| 117 | + } |
| 118 | + |
| 119 | + private bool CreateDPOModel(IModel model) |
| 120 | + { |
| 121 | + if (m_decoder != null) |
| 122 | + { |
| 123 | + m_decoder.Dispose(); |
| 124 | + } |
| 125 | + if (m_decoderFFLayer != null) |
| 126 | + { |
| 127 | + m_decoderFFLayer.Dispose(); |
| 128 | + } |
| 129 | + |
| 130 | + if (m_segmentEmbedding != null) |
| 131 | + { |
| 132 | + m_segmentEmbedding.Dispose(); |
| 133 | + } |
| 134 | + |
| 135 | + if (m_tgtEmbedding != null) |
| 136 | + { |
| 137 | + m_tgtEmbedding.Dispose(); |
| 138 | + } |
| 139 | + |
| 140 | + Logger.WriteLine(Logger.Level.debug, $"Creating decoders..."); |
| 141 | + |
| 142 | + var raDeviceIds = new RoundArray<int>(DeviceIds); |
| 143 | + |
| 144 | + DType elementType = m_options.AMP ? DType.Float16 : DType.Float32; |
| 145 | + |
| 146 | + m_decoder = Decoder.CreateDecoders(model, m_options, raDeviceIds, isTrainable: m_options.IsDecoderTrainable && (m_options.Task == ModeEnums.DPO), elementType: elementType); |
| 147 | + m_decoderFFLayer = new MultiProcessorNetworkWrapper<IFeedForwardLayer>(new FeedForwardLayer("FeedForward_Decoder_0", model.HiddenDim, model.TgtVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(), |
| 148 | + isTrainable: (m_options.Task == ModeEnums.DPO), learningRateFactor: m_options.DecoderStartLearningRateFactor, elementType), DeviceIds); |
| 149 | + |
| 150 | + (m_posEmbedding, m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, Math.Max(m_options.MaxTgtSentLength, m_options.MaxValidTgtSentLength), model, elementType, |
| 151 | + isTrainable: (m_options.Task == ModeEnums.DPO), createAPE: (model.PEType == PositionEmbeddingEnums.APE)); |
| 152 | + m_tgtEmbedding = CreateTgtEmbeddings(model, raDeviceIds, m_options.IsTgtEmbeddingTrainable && (m_options.Task == ModeEnums.DPO), m_options.DecoderStartLearningRateFactor, elementType); |
| 153 | + |
| 154 | + return (true); |
| 155 | + } |
| 156 | + |
| 157 | + |
| 158 | + private bool CreateRefModel(IModel model) |
| 159 | + { |
| 160 | + if (ref_m_decoder != null) |
| 161 | + { |
| 162 | + ref_m_decoder.Dispose(); |
| 163 | + } |
| 164 | + if (ref_m_decoderFFLayer != null) |
| 165 | + { |
| 166 | + ref_m_decoderFFLayer.Dispose(); |
| 167 | + } |
| 168 | + |
| 169 | + if (ref_m_segmentEmbedding != null) |
| 170 | + { |
| 171 | + ref_m_segmentEmbedding.Dispose(); |
| 172 | + } |
| 173 | + |
| 174 | + if (ref_m_tgtEmbedding != null) |
| 175 | + { |
| 176 | + ref_m_tgtEmbedding.Dispose(); |
| 177 | + } |
| 178 | + |
| 179 | + Logger.WriteLine(Logger.Level.debug, $"Creating decoders..."); |
| 180 | + |
| 181 | + var raDeviceIds = new RoundArray<int>(DeviceIds); |
| 182 | + |
| 183 | + DType elementType = m_options.AMP ? DType.Float16 : DType.Float32; |
| 184 | + |
| 185 | + ref_m_decoder = Decoder.CreateDecoders(model, m_options, raDeviceIds, isTrainable: false, elementType: elementType, isSavable: false); |
| 186 | + ref_m_decoderFFLayer = new MultiProcessorNetworkWrapper<IFeedForwardLayer>(new FeedForwardLayer("FeedForward_Decoder_0", model.HiddenDim, model.TgtVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(), |
| 187 | + isTrainable: false, learningRateFactor: m_options.DecoderStartLearningRateFactor, elementType), DeviceIds, savableWeights: false); |
| 188 | + |
| 189 | + (ref_m_posEmbedding, ref_m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, Math.Max(m_options.MaxTgtSentLength, m_options.MaxValidTgtSentLength), model, elementType, |
| 190 | + isTrainable: false, createAPE: (model.PEType == PositionEmbeddingEnums.APE)); |
| 191 | + ref_m_tgtEmbedding = CreateTgtEmbeddings(model, raDeviceIds, false, m_options.DecoderStartLearningRateFactor, elementType, isSavable: false); |
| 192 | + |
| 193 | + return (true); |
| 194 | + } |
| 195 | + /// <summary> |
| 196 | + /// Get networks on specific devices |
| 197 | + /// </summary> |
| 198 | + private (IDecoder, IFeedForwardLayer, IWeightTensor, IWeightTensor, IWeightTensor) GetNetworksOnDeviceAt(int deviceId) |
| 199 | + { |
| 200 | + var deviceIdIdx = TensorAllocator.GetDeviceIdIndex(deviceId); |
| 201 | + return (m_decoder.GetNetworkOnDevice(deviceIdIdx), |
| 202 | + m_decoderFFLayer.GetNetworkOnDevice(deviceIdIdx), |
| 203 | + m_tgtEmbedding.GetNetworkOnDevice(deviceIdIdx), |
| 204 | + m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx), |
| 205 | + m_posEmbedding?.GetNetworkOnDevice(deviceIdIdx)); |
| 206 | + } |
| 207 | + |
| 208 | + |
| 209 | + private (IDecoder, IFeedForwardLayer, IWeightTensor, IWeightTensor, IWeightTensor) GetRefNetworksOnDeviceAt(int deviceId) |
| 210 | + { |
| 211 | + var deviceIdIdx = TensorAllocator.GetDeviceIdIndex(deviceId); |
| 212 | + return (ref_m_decoder.GetNetworkOnDevice(deviceIdIdx), |
| 213 | + ref_m_decoderFFLayer.GetNetworkOnDevice(deviceIdIdx), |
| 214 | + ref_m_tgtEmbedding.GetNetworkOnDevice(deviceIdIdx), |
| 215 | + ref_m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx), |
| 216 | + ref_m_posEmbedding?.GetNetworkOnDevice(deviceIdIdx)); |
| 217 | + } |
| 218 | + |
| 219 | + /// <summary> |
| 220 | + /// Run forward part on given single device |
| 221 | + /// </summary> |
| 222 | + /// <param name="computeGraph">The computing graph for current device. It gets created and passed by the framework</param> |
| 223 | + /// <param name="tgtSnts">A batch of output tokenized sentences in target side</param> |
| 224 | + /// <param name="deviceIdIdx">The index of current device</param> |
| 225 | + /// <returns>The cost of forward part</returns> |
| 226 | + public override List<NetworkResult> RunForwardOnSingleDevice(IComputeGraph computeGraph, IPairBatch sntPairBatch, DecodingOptions decodingOptions, bool isTraining) |
| 227 | + { |
| 228 | + if (isTraining == false) |
| 229 | + { |
| 230 | + throw new ArgumentException("The DPO is only for training mode."); |
| 231 | + } |
| 232 | + |
| 233 | + (var decoder, var decoderFFLayer, var tgtEmbedding, var segmentEmbedding, var posEmbeddings) = GetNetworksOnDeviceAt(computeGraph.DeviceId); |
| 234 | + (var ref_decoder, var ref_decoderFFLayer, var ref_tgtEmbedding, var ref_segmentEmbedding, var ref_posEmbeddings) = GetRefNetworksOnDeviceAt(computeGraph.DeviceId); |
| 235 | + |
| 236 | + List<NetworkResult> nrs = new List<NetworkResult>(); |
| 237 | + int messageTokenId = m_modelMetaData.TgtVocab.GetWordIndex(m_options.DPOMaskedToken, logUnk: true); |
| 238 | + |
| 239 | + // Generate output decoder sentences |
| 240 | + var chosenSnts = sntPairBatch.GetSrcTokens(); |
| 241 | + int batchSize = chosenSnts.Count; |
| 242 | + var chosenTokensList = m_modelMetaData.TgtVocab.GetWordIndex(chosenSnts); |
| 243 | + var chosenMask = computeGraph.BuildMaskAfter(chosenTokensList, messageTokenId, tgtEmbedding.ElementType); |
| 244 | + |
| 245 | + |
| 246 | + var rejectedSnts = sntPairBatch.GetTgtTokens(); |
| 247 | + //int batchSize = preferredSnts.Count; |
| 248 | + var rejectedTokensList = m_modelMetaData.TgtVocab.GetWordIndex(rejectedSnts); |
| 249 | + var rejectedMask = computeGraph.BuildMaskAfter(rejectedTokensList, messageTokenId, tgtEmbedding.ElementType); |
| 250 | + |
| 251 | + NetworkResult nr = new NetworkResult(); |
| 252 | + nr.Status = NetworkResultStatus.SUCCEED; |
| 253 | + |
| 254 | + decoder.Reset(computeGraph.GetWeightFactory(), chosenSnts.Count); |
| 255 | + //decoder.Reset(computeGraph.GetWeightFactory(), nonPreferredSnts.Count); |
| 256 | + |
| 257 | + (var loss, var cr, var rr) = Decoder.DPODecoderTrainer(chosenTokensList, rejectedTokensList, computeGraph, decoder as GPTDecoder, ref_decoder as GPTDecoder, |
| 258 | + decoderFFLayer, ref_decoderFFLayer, |
| 259 | + tgtEmbedding, ref_tgtEmbedding, |
| 260 | + m_modelMetaData.TgtVocab, m_paddingType, m_options.DropoutRatio, |
| 261 | + segmentEmbedding, ref_segmentEmbedding, |
| 262 | + m_options.AMP, |
| 263 | + posEmbeddings, ref_posEmbeddings, |
| 264 | + LossScaling, m_options.PaddingAlignmentFactor, lossSmooth: m_options.LossSmooth, beta: m_options.DPOBeta, chosenMasks: chosenMask, rejectedMasks: rejectedMask); |
| 265 | + nr.Cost = loss; |
| 266 | + nr.ChosenRewards = cr; |
| 267 | + nr.RejectedRewards = rr; |
| 268 | + nr.Output = null; |
| 269 | + |
| 270 | + nrs.Add(nr); |
| 271 | + return nrs; |
| 272 | + } |
| 273 | + } |
| 274 | +} |
0 commit comments