diff --git a/.gitignore b/.gitignore index 13c515a..201fd37 100644 --- a/.gitignore +++ b/.gitignore @@ -124,3 +124,4 @@ Temporary Items /Tools/SeqDictMatchConsole/bin/Debug/net8.0 /Tools/SeqMedical/obj /Tools/SeqMedical/bin/Debug/net9.0 +/Tools/SeqMedical/bin diff --git a/Seq2SeqSharp/Applications/DPO.cs b/Seq2SeqSharp/Applications/DPO.cs new file mode 100644 index 0000000..a927f19 --- /dev/null +++ b/Seq2SeqSharp/Applications/DPO.cs @@ -0,0 +1,274 @@ +// Copyright (c) Zhongkai Fu. All rights reserved. +// https://github.com/zhongkaifu/Seq2SeqSharp +// +// This file is part of Seq2SeqSharp. +// +// Seq2SeqSharp is licensed under the BSD-3-Clause license found in the LICENSE file in the root directory of this source tree. +// +// Seq2SeqSharp is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD-3-Clause License for more details. + +using System; +using System.Collections.Generic; +using System.IO; +using AdvUtils; +using System.Runtime.Caching; +using Seq2SeqSharp.Enums; +using Seq2SeqSharp.Corpus; +using Seq2SeqSharp.Layers; +using Seq2SeqSharp.Models; +using Seq2SeqSharp.Tools; +using Seq2SeqSharp.Utils; +using TensorSharp; +using ManagedCuda.BasicTypes; + +namespace Seq2SeqSharp.Applications +{ + public class DPO : BaseSeq2SeqFramework + { + // Trainable parameters including networks and tensors + private MultiProcessorNetworkWrapper m_tgtEmbedding = null; //The embeddings over devices for source + private MultiProcessorNetworkWrapper m_decoder = null; //The decoders over devices + private MultiProcessorNetworkWrapper m_decoderFFLayer = null; //The feed forward layers over devices after all layers in decoder + private MultiProcessorNetworkWrapper m_segmentEmbedding = null; + private MultiProcessorNetworkWrapper m_posEmbedding = null; + + + + + + + private MultiProcessorNetworkWrapper ref_m_tgtEmbedding = null; //The embeddings over devices for source + private MultiProcessorNetworkWrapper ref_m_decoder = null; //The decoders over devices + private MultiProcessorNetworkWrapper ref_m_decoderFFLayer = null; //The feed forward layers over devices after all layers in decoder + private MultiProcessorNetworkWrapper ref_m_segmentEmbedding = null; + private MultiProcessorNetworkWrapper ref_m_posEmbedding = null; + + + + + private readonly PaddingEnums m_paddingType = PaddingEnums.AllowPadding; + readonly Seq2SeqOptions m_options = null; + + public event EventHandler KVCacheRemoveWatcher; + + public DPO(Seq2SeqOptions options, Vocab tgtVocab = null) + : base(deviceIds: options.DeviceIds, processorType: options.ProcessorType, modelFilePath: options.ModelFilePath, memoryUsageRatio: options.MemoryUsageRatio, + compilerOptions: options.CompilerOptions, runValidEveryUpdates: options.RunValidEveryUpdates, updateFreq: options.UpdateFreq, + startToRunValidAfterUpdates: options.StartValidAfterUpdates, maxDegressOfParallelism: options.TaskParallelism, mklInstructions: options.MKLInstructions, weightsUpdateCount: options.WeightsUpdateCount, + enableTensorCore: options.EnableTensorCore, cudaMemoryAllocatorType: options.CudaMemoryAllocatorType, elementType: options.AMP ? DType.Float16 : DType.Float32, randomSeed: options.RandomSeed, + saveModelEveryUpdats: options.SaveModelEveryUpdates, saveGPUMemoryLevel: options.SaveGPUMemoryLevel, initLossScaling: options.InitLossScaling, autoCheckTensorCorruption: options.CheckTensorCorrupted, + attentionType: options.AttentionType) + { + m_paddingType = options.PaddingType; + m_options = options; + + // Check if options are valided. + m_options.ValidateOptions(); + if (File.Exists(m_options.ModelFilePath)) + { + if (tgtVocab != null) + { + throw new ArgumentException($"Model '{m_options.ModelFilePath}' exists and it includes vocabulary, so input vocabulary must be null."); + } + + // Model file exists, so we load it from file. + m_modelMetaData = LoadModel(); + } + else + { + // Model doesn't exist, we create it and initlaize parameters + m_modelMetaData = new Seq2SeqModel(options, null, tgtVocab); + + //Initializng weights in encoders and decoders + CreateTrainableParameters(m_modelMetaData); + } + + m_modelMetaData.EncoderType = EncoderTypeEnums.None; + m_modelMetaData.DecoderType = DecoderTypeEnums.GPTDecoder; + m_modelMetaData.ShowModelInfo(); + } + + public void UpdateVocabs(Vocab tgtVocab) + { + if (tgtVocab != null) + { + m_modelMetaData.TgtVocab = tgtVocab; + } + + SaveModel(createBackupPrevious: true, suffix: ".updatevocab"); + } + + public void VQModel() + { + m_modelMetaData.VQType = m_options.VQType; + SaveModel(createBackupPrevious: true, suffix: $".{m_modelMetaData.VQType.ToString()}"); + + } + + protected override Seq2SeqModel LoadModel(string suffix = "") => base.LoadModelRoutine(CreateTrainableParameters, Seq2SeqModel.Create, suffix); + + private bool CreateTrainableParameters(IModel model) + { + CreateDPOModel(model); + CreateRefModel(model); + + return true; + } + + private bool CreateDPOModel(IModel model) + { + if (m_decoder != null) + { + m_decoder.Dispose(); + } + if (m_decoderFFLayer != null) + { + m_decoderFFLayer.Dispose(); + } + + if (m_segmentEmbedding != null) + { + m_segmentEmbedding.Dispose(); + } + + if (m_tgtEmbedding != null) + { + m_tgtEmbedding.Dispose(); + } + + Logger.WriteLine(Logger.Level.debug, $"Creating decoders..."); + + var raDeviceIds = new RoundArray(DeviceIds); + + DType elementType = m_options.AMP ? DType.Float16 : DType.Float32; + + m_decoder = Decoder.CreateDecoders(model, m_options, raDeviceIds, isTrainable: m_options.IsDecoderTrainable && (m_options.Task == ModeEnums.DPO), elementType: elementType); + m_decoderFFLayer = new MultiProcessorNetworkWrapper(new FeedForwardLayer("FeedForward_Decoder_0", model.HiddenDim, model.TgtVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(), + isTrainable: (m_options.Task == ModeEnums.DPO), learningRateFactor: m_options.DecoderStartLearningRateFactor, elementType), DeviceIds); + + (m_posEmbedding, m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, Math.Max(m_options.MaxTgtSentLength, m_options.MaxValidTgtSentLength), model, elementType, + isTrainable: (m_options.Task == ModeEnums.DPO), createAPE: (model.PEType == PositionEmbeddingEnums.APE)); + m_tgtEmbedding = CreateTgtEmbeddings(model, raDeviceIds, m_options.IsTgtEmbeddingTrainable && (m_options.Task == ModeEnums.DPO), m_options.DecoderStartLearningRateFactor, elementType); + + return (true); + } + + + private bool CreateRefModel(IModel model) + { + if (ref_m_decoder != null) + { + ref_m_decoder.Dispose(); + } + if (ref_m_decoderFFLayer != null) + { + ref_m_decoderFFLayer.Dispose(); + } + + if (ref_m_segmentEmbedding != null) + { + ref_m_segmentEmbedding.Dispose(); + } + + if (ref_m_tgtEmbedding != null) + { + ref_m_tgtEmbedding.Dispose(); + } + + Logger.WriteLine(Logger.Level.debug, $"Creating decoders..."); + + var raDeviceIds = new RoundArray(DeviceIds); + + DType elementType = m_options.AMP ? DType.Float16 : DType.Float32; + + ref_m_decoder = Decoder.CreateDecoders(model, m_options, raDeviceIds, isTrainable: false, elementType: elementType, isSavable: false); + ref_m_decoderFFLayer = new MultiProcessorNetworkWrapper(new FeedForwardLayer("FeedForward_Decoder_0", model.HiddenDim, model.TgtVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(), + isTrainable: false, learningRateFactor: m_options.DecoderStartLearningRateFactor, elementType), DeviceIds, savableWeights: false); + + (ref_m_posEmbedding, ref_m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, Math.Max(m_options.MaxTgtSentLength, m_options.MaxValidTgtSentLength), model, elementType, + isTrainable: false, createAPE: (model.PEType == PositionEmbeddingEnums.APE)); + ref_m_tgtEmbedding = CreateTgtEmbeddings(model, raDeviceIds, false, m_options.DecoderStartLearningRateFactor, elementType, isSavable: false); + + return (true); + } + /// + /// Get networks on specific devices + /// + private (IDecoder, IFeedForwardLayer, IWeightTensor, IWeightTensor, IWeightTensor) GetNetworksOnDeviceAt(int deviceId) + { + var deviceIdIdx = TensorAllocator.GetDeviceIdIndex(deviceId); + return (m_decoder.GetNetworkOnDevice(deviceIdIdx), + m_decoderFFLayer.GetNetworkOnDevice(deviceIdIdx), + m_tgtEmbedding.GetNetworkOnDevice(deviceIdIdx), + m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx), + m_posEmbedding?.GetNetworkOnDevice(deviceIdIdx)); + } + + + private (IDecoder, IFeedForwardLayer, IWeightTensor, IWeightTensor, IWeightTensor) GetRefNetworksOnDeviceAt(int deviceId) + { + var deviceIdIdx = TensorAllocator.GetDeviceIdIndex(deviceId); + return (ref_m_decoder.GetNetworkOnDevice(deviceIdIdx), + ref_m_decoderFFLayer.GetNetworkOnDevice(deviceIdIdx), + ref_m_tgtEmbedding.GetNetworkOnDevice(deviceIdIdx), + ref_m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx), + ref_m_posEmbedding?.GetNetworkOnDevice(deviceIdIdx)); + } + + /// + /// Run forward part on given single device + /// + /// The computing graph for current device. It gets created and passed by the framework + /// A batch of output tokenized sentences in target side + /// The index of current device + /// The cost of forward part + public override List RunForwardOnSingleDevice(IComputeGraph computeGraph, IPairBatch sntPairBatch, DecodingOptions decodingOptions, bool isTraining) + { + if (isTraining == false) + { + throw new ArgumentException("The DPO is only for training mode."); + } + + (var decoder, var decoderFFLayer, var tgtEmbedding, var segmentEmbedding, var posEmbeddings) = GetNetworksOnDeviceAt(computeGraph.DeviceId); + (var ref_decoder, var ref_decoderFFLayer, var ref_tgtEmbedding, var ref_segmentEmbedding, var ref_posEmbeddings) = GetRefNetworksOnDeviceAt(computeGraph.DeviceId); + + List nrs = new List(); + int messageTokenId = m_modelMetaData.TgtVocab.GetWordIndex(m_options.DPOMaskedToken, logUnk: true); + + // Generate output decoder sentences + var chosenSnts = sntPairBatch.GetSrcTokens(); + int batchSize = chosenSnts.Count; + var chosenTokensList = m_modelMetaData.TgtVocab.GetWordIndex(chosenSnts); + var chosenMask = computeGraph.BuildMaskAfter(chosenTokensList, messageTokenId, tgtEmbedding.ElementType); + + + var rejectedSnts = sntPairBatch.GetTgtTokens(); + //int batchSize = preferredSnts.Count; + var rejectedTokensList = m_modelMetaData.TgtVocab.GetWordIndex(rejectedSnts); + var rejectedMask = computeGraph.BuildMaskAfter(rejectedTokensList, messageTokenId, tgtEmbedding.ElementType); + + NetworkResult nr = new NetworkResult(); + nr.Status = NetworkResultStatus.SUCCEED; + + decoder.Reset(computeGraph.GetWeightFactory(), chosenSnts.Count); + //decoder.Reset(computeGraph.GetWeightFactory(), nonPreferredSnts.Count); + + (var loss, var cr, var rr) = Decoder.DPODecoderTrainer(chosenTokensList, rejectedTokensList, computeGraph, decoder as GPTDecoder, ref_decoder as GPTDecoder, + decoderFFLayer, ref_decoderFFLayer, + tgtEmbedding, ref_tgtEmbedding, + m_modelMetaData.TgtVocab, m_paddingType, m_options.DropoutRatio, + segmentEmbedding, ref_segmentEmbedding, + m_options.AMP, + posEmbeddings, ref_posEmbeddings, + LossScaling, m_options.PaddingAlignmentFactor, lossSmooth: m_options.LossSmooth, beta: m_options.DPOBeta, chosenMasks: chosenMask, rejectedMasks: rejectedMask); + nr.Cost = loss; + nr.ChosenRewards = cr; + nr.RejectedRewards = rr; + nr.Output = null; + + nrs.Add(nr); + return nrs; + } + } +} diff --git a/Seq2SeqSharp/Applications/Decoder.cs b/Seq2SeqSharp/Applications/Decoder.cs index e00d74c..e07e11f 100644 --- a/Seq2SeqSharp/Applications/Decoder.cs +++ b/Seq2SeqSharp/Applications/Decoder.cs @@ -19,12 +19,13 @@ using Seq2SeqSharp.Enums; using ProtoBuf; using System.Xml.Linq; +using System.IO; namespace Seq2SeqSharp.Applications { public class Decoder { - public static MultiProcessorNetworkWrapper CreateDecoders(IModel model, Seq2SeqOptions options, RoundArray raDeviceIds, DType elementType = DType.Float32) + public static MultiProcessorNetworkWrapper CreateDecoders(IModel model, Seq2SeqOptions options, RoundArray raDeviceIds, bool isTrainable, bool isSavable = true, DType elementType = DType.Float32) { MultiProcessorNetworkWrapper decoder; if (model.DecoderType == DecoderTypeEnums.AttentionLSTM) @@ -32,22 +33,22 @@ public static MultiProcessorNetworkWrapper CreateDecoders(IModel model decoder = new MultiProcessorNetworkWrapper( new AttentionDecoder("AttnLSTMDecoder", model.HiddenDim, model.DecoderEmbeddingDim, model.HiddenDim, options.DropoutRatio, model.DecoderLayerDepth, raDeviceIds.GetNextItem(), model.EnableCoverageModel, - isTrainable: options.IsDecoderTrainable && (options.Task == ModeEnums.Train), elementType: elementType), raDeviceIds.ToArray()); + isTrainable: isTrainable, elementType: elementType), raDeviceIds.ToArray(), savableWeights: isSavable); } else if (model.DecoderType == DecoderTypeEnums.GPTDecoder) { decoder = new MultiProcessorNetworkWrapper( new GPTDecoder("GPTDecoder", model.MultiHeadNum, model.HiddenDim, model.IntermediateDim, model.DecoderEmbeddingDim, model.DecoderLayerDepth, options.DropoutRatio, raDeviceIds.GetNextItem(), - isTrainable: options.IsDecoderTrainable && (options.Task == ModeEnums.Train), learningRateFactor: options.DecoderStartLearningRateFactor, activateFunc: model.ActivateFunc, expertNum: model.ExpertNum, + isTrainable: isTrainable, learningRateFactor: options.DecoderStartLearningRateFactor, activateFunc: model.ActivateFunc, expertNum: model.ExpertNum, expertsPerTokenFactor: model.ExpertsPerTokenFactor, elementType: elementType, peType:model.PEType, normType: model.NormType, attentionType: options.AttentionType, multiHeadAttentionType: model.MultiHeadAttentionType, - KVGroupNum: model.KVGroupNum), raDeviceIds.ToArray()); + KVGroupNum: model.KVGroupNum), raDeviceIds.ToArray(), savableWeights: isSavable); } else { decoder = new MultiProcessorNetworkWrapper( new TransformerDecoder("TransformerDecoder", model.MultiHeadNum, model.HiddenDim, model.IntermediateDim, model.DecoderEmbeddingDim, model.DecoderLayerDepth, options.DropoutRatio, raDeviceIds.GetNextItem(), - isTrainable: options.IsDecoderTrainable && (options.Task == ModeEnums.Train), learningRateFactor: options.DecoderStartLearningRateFactor, activateFunc: model.ActivateFunc, expertNum: model.ExpertNum, - expertsPerTokenFactor: model.ExpertsPerTokenFactor, elementType: elementType, peType:model.PEType, normType: model.NormType, attentionType: options.AttentionType), raDeviceIds.ToArray()); + isTrainable: isTrainable, learningRateFactor: options.DecoderStartLearningRateFactor, activateFunc: model.ActivateFunc, expertNum: model.ExpertNum, + expertsPerTokenFactor: model.ExpertsPerTokenFactor, elementType: elementType, peType:model.PEType, normType: model.NormType, attentionType: options.AttentionType), raDeviceIds.ToArray(), savableWeights: isSavable); } return decoder; @@ -517,6 +518,133 @@ public static (float, List>) DecodeTransformer(List> chosenSeqs, List> rejectedSeqs, IComputeGraph g, GPTDecoder decoder, GPTDecoder refDecoder, IFeedForwardLayer decoderFFLayer, IFeedForwardLayer refDecoderFFLayer, + IWeightTensor tgtEmbedding, IWeightTensor refTgtEmbedding, Vocab tgtVocab, PaddingEnums paddingType, float dropoutRatio, IWeightTensor segmentEmbeddings = null, IWeightTensor refSegmentEmbeddings = null, bool amp = true, + IWeightTensor posEmbeddings = null, IWeightTensor refPosEmbeddings = null, float lossScaling = 1.0f, int paddingAligmentFactor = 0, float lossSmooth = 0.0f, float beta = 0.5f, IWeightTensor chosenMasks = null, IWeightTensor rejectedMasks = null) + { + var chosenProbs = DecoderLogist(chosenSeqs, g, decoder, decoderFFLayer, tgtEmbedding, tgtVocab, paddingType, dropoutRatio, segmentEmbeddings, amp, posEmbeddings, paddingAligmentFactor); + var rejectedProbs = DecoderLogist(rejectedSeqs, g, decoder, decoderFFLayer, tgtEmbedding, tgtVocab, paddingType, dropoutRatio, segmentEmbeddings, amp, posEmbeddings, paddingAligmentFactor); + + var refChosenProbs = DecoderLogist(chosenSeqs, g, refDecoder, refDecoderFFLayer, refTgtEmbedding, tgtVocab, paddingType, dropoutRatio, refSegmentEmbeddings, amp, refPosEmbeddings, paddingAligmentFactor); + var refRejectedProbs = DecoderLogist(rejectedSeqs, g, refDecoder, refDecoderFFLayer, refTgtEmbedding, tgtVocab, paddingType, dropoutRatio, refSegmentEmbeddings, amp, refPosEmbeddings, paddingAligmentFactor); + + + var batchSize = chosenSeqs.Count; + var seqLen = chosenSeqs[0].Count; + var dim = chosenProbs.Sizes[^1]; + chosenMasks = g.View(chosenMasks, new long[] { batchSize, seqLen, 1 }); + chosenMasks = g.AsContiguous(g.Expand(chosenMasks, new long[] { batchSize, seqLen, dim })); + chosenMasks = g.View(chosenMasks, chosenProbs.Sizes); + + if (lossSmooth > 0) + { + chosenProbs = g.Add(chosenProbs, lossSmooth); + } + chosenProbs = g.Log(chosenProbs); + chosenProbs = g.EltMul(chosenProbs, chosenMasks); + + chosenProbs = g.View(chosenProbs, dims: new long[] { batchSize, seqLen, -1 }); + chosenProbs = g.Sum(chosenProbs, dim: 1); + chosenProbs = g.View(chosenProbs, dims: new long[] { batchSize, -1 }); + + if (lossSmooth > 0) + { + refChosenProbs = g.Add(refChosenProbs, lossSmooth); + } + refChosenProbs = g.Log(refChosenProbs); + refChosenProbs = g.EltMul(refChosenProbs, chosenMasks); + + refChosenProbs = g.View(refChosenProbs, dims: new long[] { batchSize, seqLen, -1 }); + refChosenProbs = g.Sum(refChosenProbs, dim: 1); + refChosenProbs = g.View(refChosenProbs, dims: new long[] { batchSize, -1 }); + + + seqLen = rejectedSeqs[0].Count; + rejectedMasks = g.View(rejectedMasks, new long[] { batchSize, seqLen, 1 }); + rejectedMasks = g.AsContiguous(g.Expand(rejectedMasks, new long[] { batchSize, seqLen, dim })); + rejectedMasks = g.View(rejectedMasks, rejectedProbs.Sizes); + + if (lossSmooth > 0) + { + rejectedProbs = g.Add(rejectedProbs, lossSmooth); + } + rejectedProbs = g.Log(rejectedProbs); + rejectedProbs = g.EltMul(rejectedProbs, rejectedMasks); + + + rejectedProbs = g.View(rejectedProbs, dims: new long[] { batchSize, seqLen, -1 }); + rejectedProbs = g.Sum(rejectedProbs, dim: 1); + rejectedProbs = g.View(rejectedProbs, dims: new long[] { batchSize, -1 }); + + if (lossSmooth > 0) + { + refRejectedProbs = g.Add(refRejectedProbs, lossSmooth); + } + refRejectedProbs = g.Log(refRejectedProbs); + refRejectedProbs = g.EltMul(refRejectedProbs, rejectedMasks); + + refRejectedProbs = g.View(refRejectedProbs, dims: new long[] { batchSize, seqLen, -1 }); + refRejectedProbs = g.Sum(refRejectedProbs, dim: 1); + refRejectedProbs = g.View(refRejectedProbs, dims: new long[] { batchSize, -1 }); + + + (var lossValue, var chosen_rewards, var rejected_rewards) = g.DPOLoss(chosenProbs, rejectedProbs, refChosenProbs, refRejectedProbs, lossScaling, lossSmooth, beta); + + return (lossValue, chosen_rewards, rejected_rewards); + } + + + public static IWeightTensor DecoderLogist(List> tgtSeqs, IComputeGraph g, GPTDecoder decoder, IFeedForwardLayer decoderFFLayer, + IWeightTensor tgtEmbedding, Vocab tgtVocab, PaddingEnums paddingType, float dropoutRatio, IWeightTensor segmentEmbeddings = null, bool amp = true, + IWeightTensor posEmbeddings = null, int paddingAligmentFactor = 0) + { + int eosTokenId = tgtVocab.GetWordIndex(BuildInTokens.EOS, logUnk: true); + int batchSize = tgtSeqs.Count; + var tgtOriginalLengths = BuildInTokens.PadSentences(tgtSeqs, eosTokenId, alignmentFactor: paddingAligmentFactor); + int tgtSeqLen = tgtSeqs[0].Count; + IWeightTensor tgtSelfTriMask = null; + IWeightTensor inputEmbs = null; + + + if (decoder.AttentionType == AttentionTypeEnums.Classic) + { + if (paddingType == PaddingEnums.NoPadding || paddingType == PaddingEnums.NoPaddingInTgt || batchSize == 1) + { + tgtSelfTriMask = g.BuildTriMask(tgtSeqLen, batchSize, amp ? TensorSharp.DType.Float16 : TensorSharp.DType.Float32); + tgtSelfTriMask = g.View(tgtSelfTriMask, new long[] { 1, 1, tgtSeqLen, tgtSeqLen }); + } + else + { + tgtSelfTriMask = g.BuildSelfTriMask(tgtSeqLen, tgtOriginalLengths, amp ? TensorSharp.DType.Float16 : TensorSharp.DType.Float32); + tgtSelfTriMask = g.View(tgtSelfTriMask, new long[] { batchSize, 1, tgtSeqLen, tgtSeqLen }); + } + } + + inputEmbs = TensorUtils.CreateTokensEmbeddings(tgtSeqs, g, tgtEmbedding, segmentEmbeddings, tgtVocab, scaleFactor: (float)Math.Sqrt(tgtEmbedding.Columns), amp: amp); + if (posEmbeddings != null) + { + inputEmbs = PositionEmbedding.AddPositionEmbedding(g, posEmbeddings, batchSize, inputEmbs, dropoutRatio); //Output Shape: [batchSize * seqLen, hidden_dim] + } + + + IWeightTensor decOutput; + (decOutput, _) = decoder.Decode(inputEmbs, tgtSelfTriMask, batchSize, g); + IWeightTensor ffLayer = decoderFFLayer.Process(decOutput, batchSize, g); + + if (amp) + { + var tmp = ffLayer; + ffLayer = g.Half2Float(ffLayer); + tmp.ReleaseWeight(); + } + + IWeightTensor probs = g.Softmax(ffLayer, inPlace: true); + + return probs; + } + + + public static (float, List>) GPTDecode(List> tgtSeqs, IComputeGraph g, GPTDecoder decoder, IFeedForwardLayer decoderFFLayer, IWeightTensor tgtEmbedding, Vocab tgtVocab, PaddingEnums paddingType, float dropoutRatio, DecodingOptions decodingOptions, bool isTraining = true, bool outputSentScore = true, List previousBeamSearchResults = null, Dictionary contextTensors = null, diff --git a/Seq2SeqSharp/Applications/GPT.cs b/Seq2SeqSharp/Applications/GPT.cs index 9921908..0d9c279 100644 --- a/Seq2SeqSharp/Applications/GPT.cs +++ b/Seq2SeqSharp/Applications/GPT.cs @@ -121,7 +121,7 @@ private bool CreateTrainableParameters(IModel model) DType elementType = m_options.AMP ? DType.Float16 : DType.Float32; - m_decoder = Decoder.CreateDecoders(model, m_options, raDeviceIds, elementType); + m_decoder = Decoder.CreateDecoders(model, m_options, raDeviceIds, isTrainable: m_options.IsDecoderTrainable && (m_options.Task == ModeEnums.Train), elementType: elementType); m_decoderFFLayer = new MultiProcessorNetworkWrapper(new FeedForwardLayer("FeedForward_Decoder_0", model.HiddenDim, model.TgtVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(), isTrainable: (m_options.Task == ModeEnums.Train), learningRateFactor: m_options.DecoderStartLearningRateFactor, elementType), DeviceIds); diff --git a/Seq2SeqSharp/Applications/Options.cs b/Seq2SeqSharp/Applications/Options.cs index 7091689..46ef63d 100644 --- a/Seq2SeqSharp/Applications/Options.cs +++ b/Seq2SeqSharp/Applications/Options.cs @@ -207,7 +207,7 @@ public class Options public int PaddingAlignmentFactor = 0; [Arg("Task to execute. It supports Train, Valid, Test, DumpVocab, UpdateVocab and Help", nameof(Task))] - [RegularExpression("Train|Valid|Test|Alignment|DumpVocab|UpdateVocab|VQModel|Help")] + [RegularExpression("Train|Valid|Test|Alignment|DumpVocab|UpdateVocab|VQModel|DPO|Help")] public ModeEnums Task = ModeEnums.Help; [Arg("How to deal with too long sequence. It can be Ignore or Truncation", nameof(TooLongSequence))] @@ -345,8 +345,14 @@ public class Options public Logger.Level LogLevel = (Logger.Level.err | Logger.Level.warn | Logger.Level.info | Logger.Level.debug); [Arg("It indicates if checking tensor corrupted is enabled. Default is disabled.", nameof(CheckTensorCorrupted))] - public bool CheckTensorCorrupted = false; + public bool CheckTensorCorrupted = false; + [Arg("The beta value for DPO loss calulcation", nameof(DPOBeta))] + [Range(0.0f, 1.0f)] + public float DPOBeta = 0.5f; + + [Arg("The token should be masked and all content of it", nameof(DPOMaskedToken))] + public string DPOMaskedToken = "[message]"; public void ValidateOptions() { if (AMP == true && ProcessorType != ProcessorTypeEnums.GPU) diff --git a/Seq2SeqSharp/Applications/Seq2Seq.cs b/Seq2SeqSharp/Applications/Seq2Seq.cs index 16d3f83..af90e4b 100644 --- a/Seq2SeqSharp/Applications/Seq2Seq.cs +++ b/Seq2SeqSharp/Applications/Seq2Seq.cs @@ -103,7 +103,7 @@ private bool CreateTrainableParameters(IModel model) DType elementType = m_options.AMP ? DType.Float16 : DType.Float32; m_encoder = Encoder.CreateEncoders(model, m_options, raDeviceIds, elementType: elementType); - m_decoder = Decoder.CreateDecoders(model, m_options, raDeviceIds, elementType: elementType); + m_decoder = Decoder.CreateDecoders(model, m_options, raDeviceIds, isTrainable: m_options.IsDecoderTrainable && (m_options.Task == ModeEnums.Train), elementType: elementType); m_decoderFFLayer = new MultiProcessorNetworkWrapper(new FeedForwardLayer("FeedForward_Decoder_0", model.HiddenDim, model.TgtVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(), isTrainable: true, learningRateFactor: m_options.DecoderStartLearningRateFactor, elementType: elementType), DeviceIds); (m_posEmbedding, m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, Math.Max(Math.Max(m_options.MaxSrcSentLength, m_options.MaxValidSrcSentLength), Math.Max(m_options.MaxTgtSentLength, m_options.MaxValidTgtSentLength)), model, diff --git a/Seq2SeqSharp/Corpus/BuildInTokens.cs b/Seq2SeqSharp/Corpus/BuildInTokens.cs index 197a062..1983d47 100644 --- a/Seq2SeqSharp/Corpus/BuildInTokens.cs +++ b/Seq2SeqSharp/Corpus/BuildInTokens.cs @@ -99,23 +99,5 @@ public static float[] PadSentences(List> s, int tokenToPad, int maxLen return originalLengths; } - - public static List> LeftShiftSnts(List> input, string lastTokenToPad) - { - List> r = new List>(); - - foreach (var seq in input) - { - List rseq = new List(); - - rseq.AddRange(seq); - rseq.RemoveAt(0); - rseq.Add(lastTokenToPad); - - r.Add(rseq); - } - - return r; - } } } diff --git a/Seq2SeqSharp/Corpus/DPOCorpus.cs b/Seq2SeqSharp/Corpus/DPOCorpus.cs new file mode 100644 index 0000000..9d8e5f4 --- /dev/null +++ b/Seq2SeqSharp/Corpus/DPOCorpus.cs @@ -0,0 +1,49 @@ +// Copyright (c) Zhongkai Fu. All rights reserved. +// https://github.com/zhongkaifu/Seq2SeqSharp +// +// This file is part of Seq2SeqSharp. +// +// Seq2SeqSharp is licensed under the BSD-3-Clause license found in the LICENSE file in the root directory of this source tree. +// +// Seq2SeqSharp is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD-3-Clause License for more details. + +using Seq2SeqSharp.Tools; +using Seq2SeqSharp.Utils; +using System; + +namespace Seq2SeqSharp.Corpus +{ + public class DPOCorpus : DPOPairCorpus + { + + public DPOCorpus(string corpusFilePath, string srcLangName, string tgtLangName, int maxTokenSizePerBatch, int maxSrcSentLength = 32, int maxTgtSentLength = 32, PaddingEnums paddingEnums = PaddingEnums.AllowPadding, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = null, int startBatchId = 0, string dataPassword = "") + : base(corpusFilePath, srcLangName, tgtLangName, maxTokenSizePerBatch, maxSrcSentLength, maxTgtSentLength, paddingEnums: paddingEnums, tooLongSequence: tooLongSequence, indexedFilePath: indexedFilePath, startBatchId: startBatchId, dataPassword: dataPassword) + { + + } + + /// + /// Build vocabulary from training corpus + /// + /// + public (Vocab, Vocab) BuildVocabs(int srcVocabSize = 45000, int tgtVocabSize = 45000, bool sharedVocab = false, int minFreq = 1) + { + if (sharedVocab && (srcVocabSize != tgtVocabSize)) + { + throw new ArgumentException($"Vocab size must be equal if sharedVocab is true. Src Vocab Size = '{srcVocabSize}', Tgt Vocab Size = '{tgtVocabSize}'"); + } + + (CorpusBatch.s_ds, CorpusBatch.t_ds) = CountTokenFreqs(); + + CorpusBatch.ReduceSrcTokensToSingleGroup(); + if (sharedVocab) + { + CorpusBatch.MergeTokensCountSrcTgt(0, 0); + } + + (var srcVocabs, var tgtVocabs) = CorpusBatch.GenerateVocabs(srcVocabSize, tgtVocabSize, minFreq); + return (srcVocabs[0], tgtVocabs[0]); + } + } +} diff --git a/Seq2SeqSharp/Corpus/DPOPairCorpus.cs b/Seq2SeqSharp/Corpus/DPOPairCorpus.cs new file mode 100644 index 0000000..e11c551 --- /dev/null +++ b/Seq2SeqSharp/Corpus/DPOPairCorpus.cs @@ -0,0 +1,607 @@ +// Copyright (c) Zhongkai Fu. All rights reserved. +// https://github.com/zhongkaifu/Seq2SeqSharp +// +// This file is part of Seq2SeqSharp. +// +// Seq2SeqSharp is licensed under the BSD-3-Clause license found in the LICENSE file in the root directory of this source tree. +// +// Seq2SeqSharp is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD-3-Clause License for more details. + +using AdvUtils; +using Seq2SeqSharp.Corpus; +using Seq2SeqSharp.Utils; +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.IO.MemoryMappedFiles; +using System.Linq; +using System.Threading; + +namespace Seq2SeqSharp.Tools +{ + public class DPOPairCorpus : ICorpus where T : ISntPairBatch, new() + { + internal int m_maxSrcTokenSize = 32; + internal int m_maxTgtTokenSize = 32; + internal int m_maxTokenSizePerBatch = 1; + internal List m_srcFileList; + internal List m_tgtFileList; + internal PaddingEnums m_paddingEnums; + + private bool m_showTokenDist = true; + + private readonly Random rnd = new Random(DateTime.Now.Millisecond); + + public string CorpusName { get; set; } + + private TooLongSequence m_tooLongSequence = TooLongSequence.Ignore; + + private string m_sortedIndexedDataSetFilePath = ""; + private int m_batchNumInTotal = 0; + private int m_startBatchId = 0; + private string m_dataPassword = String.Empty; + + public (List>, List>) CountTokenFreqs() + { + List> sd = new List>(); + List> td = new List>(); + + for (int i = 0; i < m_srcFileList.Count; i++) + { + Logger.WriteLine(Logger.Level.debug, $"Start to count token frequency in '{m_srcFileList[i]}' and '{m_tgtFileList[i]}'."); + + StreamReader srSrc = new StreamReader(m_srcFileList[i]); + StreamReader srTgt = new StreamReader(m_tgtFileList[i]); + + while (true) + { + if (srSrc.EndOfStream && srTgt.EndOfStream) + { + break; + } + + string srcLine = srSrc.ReadLine(); + string tgtLine = srTgt.ReadLine(); + + if (srcLine.IsNullOrEmpty() && tgtLine.IsNullOrEmpty()) + { + break; + } + + string[] srcGroups = srcLine.Split('\t'); + string[] tgtGroups = tgtLine.Split('\t'); + + if (srcGroups.Length != tgtGroups.Length) + { + throw new InvalidDataException("Inconsistent group size between source side and target side."); + } + + if (sd.Count == 0) + { + for (int j = 0; j < srcGroups.Length; j++) + { + sd.Add(new Dictionary()); + td.Add(new Dictionary()); + } + } + + for (int j = 0; j < srcGroups.Length; j++) + { + string[] srcTokens = srcGroups[j].Split(' '); + string[] tgtTokens = tgtGroups[j].Split(' '); + + + foreach (var srcToken in srcTokens) + { + if (sd[j].ContainsKey(srcToken) == false) + { + sd[j].Add(srcToken, 0); + } + sd[j][srcToken]++; + } + + foreach (var tgtToken in tgtTokens) + { + if (td[j].ContainsKey(tgtToken) == false) + { + td[j].Add(tgtToken, 0); + } + td[j][tgtToken]++; + } + + } + } + } + +#if DEBUG + for (int j = 0; j < sd.Count; j++) + { + Logger.WriteLine(Logger.Level.debug, $"Original token size at group '{j}' source = '{sd[j].Count}' target = '{td[j].Count}'"); + } +#endif + return (sd, td); + } + + + private (Dictionary>, Dictionary, string) BuildIndex() + { + Logger.WriteLine(Logger.Level.debug, $"Start to build index for data set."); + + SortedDictionary dictSrcLenDist = new SortedDictionary(); + SortedDictionary dictTgtLenDist = new SortedDictionary(); + int corpusSize = 0; + int tooLongSrcSntCnt = 0; + int tooLongTgtSntCnt = 0; + string randomFileName = Path.GetRandomFileName(); + Logger.WriteLine($"Loading and shuffling corpus from '{m_srcFileList.Count}' files."); + + string binaryDataSetFilePath = randomFileName + ".tmp"; + BinaryWriter bw = new BinaryWriter(new FileStream(binaryDataSetFilePath, FileMode.Create)); + + Dictionary> len2offsets = new Dictionary>(); + Dictionary len2lengths = new Dictionary(); + + for (int i = 0; i < m_srcFileList.Count; i++) + { + StreamReader srSrc = new StreamReader(m_srcFileList[i]); + StreamReader srTgt = new StreamReader(m_tgtFileList[i]); + + while (true) + { + if (srSrc.EndOfStream && srTgt.EndOfStream) + { + break; + } + + RawSntPair rawSntPair = new RawSntPair(srSrc.ReadLine(), srTgt.ReadLine(), m_maxSrcTokenSize, m_maxTgtTokenSize, m_tooLongSequence == TooLongSequence.Truncation); + if (rawSntPair.IsEmptyPair()) + { + break; + } + + if (String.IsNullOrEmpty(rawSntPair.SrcSnt)) + { + throw new InvalidDataException($"Source Line is empty. The data set is corrupted. SourceLine = '{rawSntPair.SrcSnt}', TargetLine = '{rawSntPair.TgtSnt}'"); + } + + if (String.IsNullOrEmpty(rawSntPair.TgtSnt)) + { + throw new InvalidDataException($"Target Line is empty. The data set is corrupted. SourceLine = '{rawSntPair.SrcSnt}', TargetLine = '{rawSntPair.TgtSnt}'"); + } + + if (m_showTokenDist) + { + if (dictSrcLenDist.ContainsKey(rawSntPair.SrcTokenSize / 100) == false) + { + dictSrcLenDist.Add(rawSntPair.SrcTokenSize / 100, 0); + } + dictSrcLenDist[rawSntPair.SrcTokenSize / 100]++; + + if (dictTgtLenDist.ContainsKey(rawSntPair.TgtTokenSize / 100) == false) + { + dictTgtLenDist.Add(rawSntPair.TgtTokenSize / 100, 0); + } + dictTgtLenDist[rawSntPair.TgtTokenSize / 100]++; + } + + bool hasTooLongSent = false; + if (rawSntPair.SrcTokenSize > m_maxSrcTokenSize) + { + Interlocked.Increment(ref tooLongSrcSntCnt); + hasTooLongSent = true; + } + + if (rawSntPair.TgtTokenSize > m_maxTgtTokenSize) + { + Interlocked.Increment(ref tooLongTgtSntCnt); + hasTooLongSent = true; + } + + if (hasTooLongSent) + { + continue; + } + + long offset = bw.BaseStream.Position; + bw.Write(String.Join("\n", new string[] { rawSntPair.SrcSnt, rawSntPair.TgtSnt })); + + long length = 0; + if (m_paddingEnums == PaddingEnums.NoPaddingInSrc) + { + length = rawSntPair.SrcGroupLenId; + } + else if (m_paddingEnums == PaddingEnums.NoPadding) + { + length = rawSntPair.GroupLenId; + } + else if (m_paddingEnums == PaddingEnums.NoPaddingInTgt) + { + length = rawSntPair.TgtGroupLenId; + } + else + { + // Completely random shuffle + length = 0; + } + + if (len2offsets.ContainsKey(length) == false) + { + len2offsets.Add(length, new LinkedList()); + len2lengths.Add(length, 0); + } + len2offsets[length].AddLast(offset); + len2lengths[length]++; + + Interlocked.Increment(ref corpusSize); + } + + srSrc.Close(); + srTgt.Close(); + } + + bw.Close(); + + Logger.WriteLine(Logger.Level.debug, $"Shuffled '{corpusSize}' sentence pairs."); + + if (tooLongSrcSntCnt > 0) + { + Logger.WriteLine(Logger.Level.warn, ConsoleColor.Yellow, $"Found {tooLongSrcSntCnt} source sentences are longer than '{m_maxSrcTokenSize}' tokens, ignore them."); + } + + if (tooLongTgtSntCnt > 0) + { + Logger.WriteLine(Logger.Level.warn, ConsoleColor.Yellow, $"Found {tooLongTgtSntCnt} target sentences are longer than '{m_maxTgtTokenSize}' tokens, ignore them."); + } + + if (m_showTokenDist) + { + //TODO(Zho): executed even if nothing is printed + { + Logger.WriteLine(Logger.Level.debug, $"AggregateSrcLength = '{m_paddingEnums}'"); + Logger.WriteLine(Logger.Level.debug, $"Src token length distribution"); + } + + int srcTotalNum = 0; + foreach (var pair in dictSrcLenDist) + { + srcTotalNum += pair.Value; + } + + int srcAccNum = 0; + foreach (var pair in dictSrcLenDist) + { + srcAccNum += pair.Value; + + Logger.WriteLine(Logger.Level.debug, $"{pair.Key * 100} ~ {(pair.Key + 1) * 100}: {pair.Value} (acc: {100.0f * (float)srcAccNum / (float)srcTotalNum:F}%)"); + } + + Logger.WriteLine(Logger.Level.debug, $"Tgt token length distribution"); + + int tgtTotalNum = 0; + foreach (var pair in dictTgtLenDist) + { + tgtTotalNum += pair.Value; + } + + int tgtAccNum = 0; + + foreach (var pair in dictTgtLenDist) + { + tgtAccNum += pair.Value; + + Logger.WriteLine(Logger.Level.debug, $"{pair.Key * 100} ~ {(pair.Key + 1) * 100}: {pair.Value} (acc: {100.0f * (float)tgtAccNum / (float)tgtTotalNum:F}%)"); + } + + m_showTokenDist = false; + } + + Logger.WriteLine(Logger.Level.debug, $"Finished to build index for data set."); + + return (len2offsets, len2lengths, binaryDataSetFilePath); + } + + + public long GetNextLength(Dictionary len2counts, long totalRecordsNum) + { + long rndItems = rnd.NextInt64(totalRecordsNum); + long totalItems = 0; + foreach (var pair in len2counts) + { + long length = pair.Value; + if (totalItems <= rndItems && totalItems + length >= rndItems) + { + return pair.Key; + } + totalItems += length; + } + + return -1; + } + + public void PrepareDataSet() + { + try + { + m_batchNumInTotal = 0; + (var length2offsets, var length2counts, string tmpDataSetFilePath) = BuildIndex(); + + long totalRecordsNum = 0; + foreach (var pair in length2offsets) + { + totalRecordsNum += length2counts[pair.Key]; + } + + Logger.WriteLine(Logger.Level.debug, $"Start to sort and shuffle data set by length."); + + m_sortedIndexedDataSetFilePath = tmpDataSetFilePath + ".sorted"; + +#if DEBUG + string tmp_sortedIndexedDataSetFilePath = tmpDataSetFilePath + ".sorted.txt"; + using (StreamWriter bwt = new StreamWriter(new FileStream(tmp_sortedIndexedDataSetFilePath, FileMode.Create, FileAccess.Write, FileShare.None, 40960000))) +#endif + using (BinaryWriter bw = new BinaryWriter(new FileStream(m_sortedIndexedDataSetFilePath, FileMode.Create, FileAccess.Write, FileShare.None, 40960000))) + using (MemoryMappedFile mmf = MemoryMappedFile.CreateFromFile(tmpDataSetFilePath)) + using (MemoryMappedViewStream mms = mmf.CreateViewStream()) + { + using (BinaryReader br = new BinaryReader(mms)) + { + while (length2offsets.Count > 0) + { + long length = GetNextLength(length2counts, totalRecordsNum); + LinkedList offsets = length2offsets[length]; + + int totalSrcTokenSize = 0; + int totalTgtTokenSize = 0; + int sentSize = 0; + List srcLines = new List(); + List tgtLines = new List(); + while (totalSrcTokenSize + totalTgtTokenSize < m_maxTokenSizePerBatch && offsets.Any()) + { + long offset = offsets.First.Value; + offsets.RemoveFirst(); + length2counts[length]--; + totalRecordsNum--; + + br.BaseStream.Seek(offset, SeekOrigin.Begin); + + string[] srcTgtLine = br.ReadString().Split("\n"); + string srcLine = srcTgtLine[0]; + string tgtLine = srcTgtLine[1]; + + var srcTokens = srcLine.Split(' ').ToList(); + var tgtTokens = tgtLine.Split(' ').ToList(); + + //if (srcTokens.Count > tgtTokens.Count) + //{ + // srcTokens = srcTokens.GetRange(0, tgtTokens.Count); + // srcLine = String.Join(" ", srcTokens); + //} + //else + //{ + // tgtTokens = tgtTokens.GetRange(0, srcTokens.Count); + // tgtLine = String.Join(" ", tgtTokens); + //} + + totalSrcTokenSize += srcTokens.Count; + totalTgtTokenSize += tgtTokens.Count; + + srcLines.Add(srcLine); + tgtLines.Add(tgtLine); + + + sentSize++; + } + + bw.Write(sentSize); + bw.Write(String.Join("\n", srcLines)); + bw.Write(String.Join("\n", tgtLines)); + +#if DEBUG + bwt.WriteLine(sentSize); + bwt.WriteLine(String.Join("\n", srcLines)); + bwt.WriteLine(String.Join("\n", tgtLines)); +#endif + + m_batchNumInTotal++; + if (m_batchNumInTotal % 10000 == 0) + { + Logger.WriteLine($"Batch '{m_batchNumInTotal}' has been processed."); + } + + + if (offsets.Any() == false) + { + length2offsets.Remove(length); + length2counts.Remove(length); + } + } + + bw.Write(-1); + } + } + + File.Delete(tmpDataSetFilePath); + + Logger.WriteLine($"Finished to sort and shuffle data set by length. Total batch size = '{m_batchNumInTotal}'"); + } + catch (Exception err) + { + Logger.WriteLine(Logger.Level.err, $"Failed to prepare data set: '{err.Message}'."); + Logger.WriteLine(Logger.Level.debug, $"Call Stack = '{err.StackTrace}'"); + } + } + + public IEnumerator GetEnumerator() + { + if (String.IsNullOrEmpty(m_sortedIndexedDataSetFilePath) || File.Exists(m_sortedIndexedDataSetFilePath) == false) + { + PrepareDataSet(); + } + else + { + Logger.WriteLine(Logger.Level.debug, $"Use existing sorted indexed data set file '{m_sortedIndexedDataSetFilePath}'"); + } + + int batchIdx = 0; + int currentBatchPercent = 0; + MemoryMappedFile mmf = null; + MemoryMappedViewStream mms = null; + ZipDecompressor decompressor = null; + if (m_sortedIndexedDataSetFilePath.ToLower().EndsWith(".zip")) + { + Logger.WriteLine($"The data set is a zip archive."); + decompressor = new ZipDecompressor(m_sortedIndexedDataSetFilePath, m_dataPassword); + mms = decompressor.GetMemoryMappedViewStream(); + } + else + { + mmf = MemoryMappedFile.CreateFromFile(m_sortedIndexedDataSetFilePath); + mms = mmf.CreateViewStream(); + } + + using (BinaryReader br = new BinaryReader(mms)) + { + while (true) + { + int sizeInBatch = br.ReadInt32(); + if (sizeInBatch < 0) + { + break; + } + + List outputs = new List(); + + string[] srcLines = br.ReadString().Split("\n"); + string[] tgtLines = br.ReadString().Split("\n"); + batchIdx++; + + if (batchIdx < m_startBatchId) + { + continue; + } + + if (batchIdx % 10000 == 0) + { + Logger.WriteLine(Logger.Level.debug, $"Processing batch '{batchIdx}'"); + } + + T batch; + int currentTokenCountsInBatch = 0; + for (int i = 0; i < sizeInBatch; i++) + { + var srcLine = srcLines[i]; + var tgtLine = tgtLines[i]; + + if (m_batchNumInTotal > 0) + { + if ((100 * batchIdx / m_batchNumInTotal) > currentBatchPercent) + { + Logger.WriteLine($"Processing batch '{batchIdx}/{m_batchNumInTotal}'."); // The '{i}th' record in this batch is: Source = '{srcLine}' Target = '{tgtLine}'"); + currentBatchPercent++; + } + } + + IPair sntPair = new SntPair(srcLine, tgtLine); + currentTokenCountsInBatch += (sntPair.GetTgtTokenCount() + sntPair.GetSrcTokenCount()); + outputs.Add(sntPair); + + if (currentTokenCountsInBatch >= m_maxTokenSizePerBatch) + { + batch = new T(); + batch.CreateBatch(outputs); + yield return batch; + + outputs = new List(); + currentTokenCountsInBatch = 0; + } + } + + if (outputs.Count > 0) + { + batch = new T(); + batch.CreateBatch(outputs); + yield return batch; + } + } + } + + if (mms != null) + { + mms.Dispose(); + } + if (mmf != null) + { + mmf.Dispose(); + } + + if (decompressor != null) + { + decompressor.Dispose(); + } + + File.Delete(m_sortedIndexedDataSetFilePath); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + public DPOPairCorpus() + { + + } + + public DPOPairCorpus(string corpusFilePath, string srcLangName, string tgtLangName, int maxTokenSizePerBatch, int maxSrcSentLength = 32, int maxTgtSentLength = 32, PaddingEnums paddingEnums = PaddingEnums.AllowPadding, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = null, int startBatchId = 0, string dataPassword = "") + { + Logger.WriteLine($"Loading parallel corpus from '{corpusFilePath}' for source side '{srcLangName}' and target side '{tgtLangName}' MaxSrcSentLength = '{maxSrcSentLength}', MaxTgtSentLength = '{maxTgtSentLength}', Token Paading Type = '{paddingEnums}', TooLongSequence = '{tooLongSequence}', Encrypted data set = '{!String.IsNullOrEmpty(dataPassword)}'"); + m_maxTokenSizePerBatch = maxTokenSizePerBatch; + m_maxSrcTokenSize = maxSrcSentLength; + m_maxTgtTokenSize = maxTgtSentLength; + m_tooLongSequence = tooLongSequence; + m_paddingEnums = paddingEnums; + CorpusName = corpusFilePath; + m_sortedIndexedDataSetFilePath = indexedFilePath; + m_dataPassword = dataPassword; + + m_srcFileList = new List(); + m_tgtFileList = new List(); + string[] files = Directory.GetFiles(corpusFilePath, $"*.*", SearchOption.TopDirectoryOnly); + + Dictionary srcKey2FileName = new Dictionary(); + Dictionary tgtKey2FileName = new Dictionary(); + + string srcSuffix = $".{srcLangName}.snt"; + string tgtSuffix = $".{tgtLangName}.snt"; + + foreach (string file in files) + { + if (file.EndsWith(srcSuffix, StringComparison.InvariantCultureIgnoreCase)) + { + string srcKey = file.Substring(0, file.Length - srcSuffix.Length); + srcKey2FileName.Add(srcKey, file); + + Logger.WriteLine($"Add source file '{file}' to key '{srcKey}'"); + } + + if (file.EndsWith(tgtSuffix, StringComparison.InvariantCultureIgnoreCase)) + { + string tgtKey = file.Substring(0, file.Length - tgtSuffix.Length); + tgtKey2FileName.Add(tgtKey, file); + + + Logger.WriteLine($"Add target file '{file}' to key '{tgtKey}'"); + } + } + + foreach (var pair in srcKey2FileName) + { + m_srcFileList.Add(pair.Value); + m_tgtFileList.Add(tgtKey2FileName[pair.Key]); + } + m_startBatchId = startBatchId; + } + } +} diff --git a/Seq2SeqSharp/MultiProcessorNetworkWrapper.cs b/Seq2SeqSharp/MultiProcessorNetworkWrapper.cs index 92e07b9..07f762d 100644 --- a/Seq2SeqSharp/MultiProcessorNetworkWrapper.cs +++ b/Seq2SeqSharp/MultiProcessorNetworkWrapper.cs @@ -23,17 +23,19 @@ public class MultiProcessorNetworkWrapper : IMultiProcessorNetworkWrapper whe private readonly int m_defaultDeviceId; // private readonly T m_networkOnDefaultDevice; private readonly bool m_isStaticWeights; + private readonly bool m_savableWeights; private bool m_weightsSynced; private Dictionary m_weightName2DefaultDeviceId = new Dictionary(); private Dictionary m_deviceId2Network = new Dictionary(); - public MultiProcessorNetworkWrapper(T networkOnDefaultDevice, int[] deviceIds, bool isStaticWeights = false) + public MultiProcessorNetworkWrapper(T networkOnDefaultDevice, int[] deviceIds, bool isStaticWeights = false, bool savableWeights = true) { m_networks = new T[deviceIds.Length]; m_defaultDeviceId = networkOnDefaultDevice.GetDeviceId(); // m_networkOnDefaultDevice = networkOnDefaultDevice; m_isStaticWeights = isStaticWeights; + m_savableWeights = savableWeights; m_weightsSynced = false; for (int i = 0; i < deviceIds.Length; i++) @@ -226,7 +228,7 @@ public void ReleaseGradientsOnAllDevices() /// public void Save(IModel model) { - if (m_isStaticWeights == false) + if (m_isStaticWeights == false && m_savableWeights) { m_networks[0].Save(model); } diff --git a/Seq2SeqSharp/Seq2SeqSharp.csproj b/Seq2SeqSharp/Seq2SeqSharp.csproj index 397a9ce..8bda03f 100644 --- a/Seq2SeqSharp/Seq2SeqSharp.csproj +++ b/Seq2SeqSharp/Seq2SeqSharp.csproj @@ -15,7 +15,7 @@ AnyCPU false bin\ - 2.8.20 + 2.8.21 Seq2SeqSharp is a tensor based fast & flexible encoder-decoder deep neural network framework written by .NET (C#). It can be used for sequence-to-sequence task, sequence-labeling task and sequence-classification task and other NLP tasks. Seq2SeqSharp supports both CPUs (x86, x64 and ARM64) and GPUs. It's powered by .NET core, so Seq2SeqSharp can run on both Windows and Linux without any modification and recompilation. README.md Seq2SeqSharp diff --git a/Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs b/Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs index bb9cd93..565b196 100644 --- a/Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs +++ b/Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs @@ -43,6 +43,8 @@ public enum NetworkResultStatus public class NetworkResult { public float Cost; + public float ChosenRewards = 0.0f; + public float RejectedRewards = 0.0f; public List>> Output; // (beam_size, batch_size, seq_len) public List>> Alignments; // (beam_size, batch_size, seq_len) public List>> AlignmentScores; // (beam_size, batch_size, seq_len) @@ -380,12 +382,12 @@ protected T LoadModelRoutine(Func initializeParametersFunc, return (srcEmbeddings, tgtEmbeddings); } - internal MultiProcessorNetworkWrapper CreateTgtEmbeddings(IModel modelMetaData, RoundArray raDeviceIds, bool isTgtEmbeddingTrainable, float decoderStartLearningRateFactor, DType elementType = DType.Float32) + internal MultiProcessorNetworkWrapper CreateTgtEmbeddings(IModel modelMetaData, RoundArray raDeviceIds, bool isTgtEmbeddingTrainable, float decoderStartLearningRateFactor, DType elementType = DType.Float32, bool isSavable = true) { Logger.WriteLine(Logger.Level.debug, $"Creating embeddings for target side. Shape = '({modelMetaData.TgtVocab.Count} ,{modelMetaData.DecoderEmbeddingDim})'"); var tgtEmbeddings = new MultiProcessorNetworkWrapper(new WeightTensor(new long[2] { modelMetaData.TgtVocab.Count, modelMetaData.DecoderEmbeddingDim }, - raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "TgtEmbeddings", isTrainable: isTgtEmbeddingTrainable, learningRateFactor: decoderStartLearningRateFactor, dtype: elementType), DeviceIds); + raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "TgtEmbeddings", isTrainable: isTgtEmbeddingTrainable, learningRateFactor: decoderStartLearningRateFactor, dtype: elementType), DeviceIds, savableWeights: isSavable); return tgtEmbeddings; } @@ -450,9 +452,11 @@ internal void TrainOneEpoch(int ep, ICorpus trainCorpus, ICorpus trainCorpus, ICorpus trainCorpus, ICorpus trainCorpus, ICorpus sntPairBatchs, int batchSplit return batchSplitFactor; } - private (float, int, int, int) RunNetwork(Func> ForwardOnSingleDevice, List sntPairBatchs, int batchSplitFactor, DecodingOptions decodingOptions, bool isTraining) + private (float, float, float, int, int, int) RunNetwork(Func> ForwardOnSingleDevice, List sntPairBatchs, int batchSplitFactor, DecodingOptions decodingOptions, bool isTraining) { float cost = 0.0f; + float chosenRewards = 0.0f; + float rejecteRewards = 0.0f; int processedLine = 0; int srcWordCnts = 0; int tgtWordCnts = 0; @@ -745,6 +758,8 @@ private int TryToSplitBatchFactor(List sntPairBatchs, int batchSplit foreach (var nr in nrs) { cost += nr.Cost; + chosenRewards += nr.ChosenRewards; + rejecteRewards += nr.RejectedRewards; } srcWordCnts += sntPairBatch_i.SrcTokenCount; @@ -802,7 +817,7 @@ private int TryToSplitBatchFactor(List sntPairBatchs, int batchSplit } }); - return (cost / processedLine, srcWordCnts, tgtWordCnts, processedLine); + return (cost / processedLine, chosenRewards / processedLine, rejecteRewards / processedLine, srcWordCnts, tgtWordCnts, processedLine); } private void CreateCheckPoint(ICorpus[] validCorpusList, Dictionary> taskId2metrics, DecodingOptions decodingOptions, Func> forwardOnSingleDevice, double avgCostPerWordInTotal) diff --git a/Seq2SeqSharp/Tools/ComputeGraphTensor.cs b/Seq2SeqSharp/Tools/ComputeGraphTensor.cs index 5a4c78f..e268f1c 100644 --- a/Seq2SeqSharp/Tools/ComputeGraphTensor.cs +++ b/Seq2SeqSharp/Tools/ComputeGraphTensor.cs @@ -15,6 +15,7 @@ using Seq2SeqSharp.Utils; using TensorSharp.Cpu; using AdvUtils; +using static System.Reflection.Metadata.BlobBuilder; /// /// Tensor based computing graph written by Zhongkai Fu. @@ -783,6 +784,45 @@ void backward() return res; } + public IWeightTensor UpdateGradientByMask(IWeightTensor w1, IWeightTensor mask1) + { + WeightTensor w = w1 as WeightTensor; + WeightTensor mask = mask1 as WeightTensor; + var res = w.CopyWeightsRef("UpdateGradientByMask", w.NeedGradient, this); + + if (m_needsBackprop) + { + Tensor maskW = null; + if (w.NeedGradient) + { + maskW = mask.TWeight.CopyRef(); + } + + void backward() + { + res.ReleaseWeight(); + if (w.NeedGradient) + { + w.AddMulGradient(res.TGradient, maskW); + maskW.Dispose(); + if (m_autoCheckCorruption) + { + if (w.IsGradientCorrupted()) + { + throw new WeightsCorruptedException($"Gradient '{w.Name}' is corrupted."); + } + } + } + + res.Dispose(); + } + m_backprop.Add(backward); + } + + return res; + } + + public IWeightTensor EltMul(IWeightTensor w1, IWeightTensor w2) { WeightTensor m1 = w1 as WeightTensor; @@ -1032,10 +1072,10 @@ void backward() } - public IWeightTensor Log(IWeightTensor w) + public IWeightTensor Log(IWeightTensor w, bool needGradient = true) { WeightTensor m = w as WeightTensor; - WeightTensor res = m_weightTensorFactory.CreateWeightTensor(m.Sizes, m_deviceId, name: $"{GetHashString(m.Name)}.Log", graphToBind: this, needGradient: m.NeedGradient, dtype: m.ElementType); + WeightTensor res = m_weightTensorFactory.CreateWeightTensor(m.Sizes, m_deviceId, name: $"{GetHashString(m.Name)}.Log", graphToBind: this, needGradient: m.NeedGradient && needGradient, dtype: m.ElementType); Ops.Log(res.TWeight, m.TWeight); if (m_autoCheckCorruption) @@ -1046,7 +1086,7 @@ public IWeightTensor Log(IWeightTensor w) } } - if (m_needsBackprop) + if (m_needsBackprop && needGradient) { Tensor mTWeight = null; //if (m_saveGPUMemoryMode) @@ -1282,11 +1322,11 @@ void backward() } - public IWeightTensor Sub(IWeightTensor w0, IWeightTensor w1) + public IWeightTensor Sub(IWeightTensor w0, IWeightTensor w1, bool needGradient = true) { WeightTensor m0 = w0 as WeightTensor; WeightTensor m1 = w1 as WeightTensor; - WeightTensor res = m_weightTensorFactory.CreateWeightTensor(m1.Sizes, m_deviceId, name: $"{GetHashString(w0.Name)}_{GetHashString(w1.Name)}.SubTT", graphToBind: this, needGradient: m0.NeedGradient || m1.NeedGradient, dtype: m0.ElementType); + WeightTensor res = m_weightTensorFactory.CreateWeightTensor(m1.Sizes, m_deviceId, name: $"{GetHashString(w0.Name)}_{GetHashString(w1.Name)}.SubTT", graphToBind: this, needGradient: needGradient && (m0.NeedGradient || m1.NeedGradient), dtype: m0.ElementType); VisualizeNodes(new IWeightTensor[] { w1 }, res); @@ -1299,7 +1339,7 @@ public IWeightTensor Sub(IWeightTensor w0, IWeightTensor w1) } } - if (m_needsBackprop) + if (m_needsBackprop && needGradient) { void backward() { @@ -3948,22 +3988,23 @@ void backward() return res; } - public IWeightTensor BuildMaskUntil(List> paddedTokensList, int maskEndId, DType elementType = DType.Float32) + public IWeightTensor BuildMaskAfter(List> paddedTokensList, int maskTokenId, DType elementType = DType.Float32) { int batchSize = paddedTokensList.Count; int seqLength = paddedTokensList[0].Count; float[] buf = new float[batchSize * seqLength]; - Array.Fill(buf, 1.0f); + Array.Fill(buf, 0.0f); for (int batchIdx = 0; batchIdx < batchSize; batchIdx++) { for (int tokenIdx = 0; tokenIdx < seqLength; tokenIdx++) { int token = paddedTokensList[batchIdx][tokenIdx]; - if (token == maskEndId) + if (token == maskTokenId && seqLength - tokenIdx - 1 > 0) { - Array.Fill(buf, 0.0f, batchIdx * seqLength, tokenIdx); + Array.Fill(buf, 1.0f, batchIdx * seqLength + tokenIdx + 1, seqLength - tokenIdx - 1); + break; } } } @@ -3994,6 +4035,40 @@ void backward() } + public (float, float, float) DPOLoss(IWeightTensor policy_chosen_logps, IWeightTensor policy_rejected_logps, IWeightTensor reference_chosen_logps, IWeightTensor reference_rejected_logps, float graident = 1.0f, float smooth = 0.0f, float beta = 0.5f) + { + float num_classes = policy_chosen_logps.Sizes[1]; + float N = policy_chosen_logps.Sizes[0]; + + var pi_logratios = Sub(policy_chosen_logps, policy_rejected_logps); + var ref_logratios = Sub(reference_chosen_logps, reference_rejected_logps); + + var logits = Sub(pi_logratios, ref_logratios); + + IWeightTensor loss = null; + //DPO Loss + if (smooth > 0.0f) + { + loss = Mul(Log(Add(Sigmoid(Mul(logits, beta)), smooth)), -1.0f); + } + else + { + loss = Mul(Log(Sigmoid(Mul(logits, beta))), -1.0f); + } + + var lossValue = loss.ToWeightArray().Sum() / (N * num_classes); + + loss.FillGradient(graident); + + var chosen_rewards = Sub(policy_chosen_logps, reference_chosen_logps, needGradient: false); + float cr = chosen_rewards.ToWeightArray().Sum() / (N * num_classes); + + var rejected_rewards = Sub(policy_rejected_logps, reference_rejected_logps, needGradient: false); + float rr = rejected_rewards.ToWeightArray().Sum() / (N * num_classes); + + return (lossValue, cr, rr); + } + private (float, IWeightTensor) CalculateEntropyLoss(IWeightTensor probs, IWeightTensor truthTgtSeqs, float smooth, float labelSmooth) { diff --git a/Seq2SeqSharp/Tools/IComputeGraph.cs b/Seq2SeqSharp/Tools/IComputeGraph.cs index a96b477..e92f58d 100644 --- a/Seq2SeqSharp/Tools/IComputeGraph.cs +++ b/Seq2SeqSharp/Tools/IComputeGraph.cs @@ -84,7 +84,7 @@ public interface IComputeGraph : IDisposable IWeightTensor ScatterAdd(IWeightTensor source, IWeightTensor indices, int dim, params long[] shape); (IWeightTensor, IWeightTensor) TopK(IWeightTensor src, int k); - IWeightTensor Sub(IWeightTensor w0, IWeightTensor w1); + IWeightTensor Sub(IWeightTensor w0, IWeightTensor w1, bool needGradient = true); IWeightTensor Sub(float v, IWeightTensor w1); #region Operations for masking @@ -94,6 +94,10 @@ public interface IComputeGraph : IDisposable IWeightTensor BuildSelfTriMask(int paddedLength, float[] originalLengths, DType elementType = DType.Float32); + IWeightTensor BuildMaskAfter(List> paddedTokensList, int maskEndId, DType elementType = DType.Float32); + + IWeightTensor UpdateGradientByMask(IWeightTensor w1, IWeightTensor mask1); + #endregion IWeightTensor LeftShiftTokens(List> input, int lastTokenToPad); @@ -103,7 +107,7 @@ public interface IComputeGraph : IDisposable IWeightTensor Sum(IWeightTensor w, int dim); IWeightTensor Mean(IWeightTensor w, int dim); - IWeightTensor Log(IWeightTensor w); + IWeightTensor Log(IWeightTensor w, bool needGradient = true); IWeightTensor Rsqrt(IWeightTensor w); @@ -118,6 +122,8 @@ public interface IComputeGraph : IDisposable float CrossEntropyLoss(IWeightTensor probs, IWeightTensor truthTgtSeqs, IWeightTensor graident, float smooth = 0.0f, float labelSmooth = 0.0f); float NLLLoss(IWeightTensor probs, IWeightTensor truthTgtSeqs, float graident = 1.0f, float smooth = 0.0f); + (float, float, float) DPOLoss(IWeightTensor preferredLogist, IWeightTensor nonPreferredLogist, IWeightTensor refPreferredLogist, IWeightTensor refNonPreferredLogist, float graident = 1.0f, float smooth = 0.0f, float beta = 0.5f); + IWeightTensor CreateUniformRandomTensor(long[] sizes, float minVal, float maxVal, DType dtype); IWeightTensor LogSoftmax(IWeightTensor x); diff --git a/Seq2SeqSharp/Utils/CostEventArg.cs b/Seq2SeqSharp/Utils/CostEventArg.cs index f42072c..ee25c25 100644 --- a/Seq2SeqSharp/Utils/CostEventArg.cs +++ b/Seq2SeqSharp/Utils/CostEventArg.cs @@ -8,6 +8,9 @@ public class CostEventArg : EventArgs { public double AvgCostInTotal { get; set; } + public double AvgChosenRewardInTotal { get; set; } + public double AvgRejectedRewardInTotal { get; set; } + public int Epoch { get; set; } public int Update { get; set; } diff --git a/Seq2SeqSharp/Utils/Misc.cs b/Seq2SeqSharp/Utils/Misc.cs index 04e8527..abbf589 100644 --- a/Seq2SeqSharp/Utils/Misc.cs +++ b/Seq2SeqSharp/Utils/Misc.cs @@ -108,7 +108,27 @@ public static void Ss_StatusUpdateWatcher(object sender, EventArgs e) wordPerSec = ep.ProcessedWordsInTotal / ts.TotalSeconds; } - Logger.WriteLine($"Update = {ep.Update}, Epoch = {ep.Epoch}, LR = {ep.LearningRate.ToString("e4")}, AvgCost = {ep.AvgCostInTotal.ToString("e4")}, LossScaling = {ep.LossScaling:F}, Sent = {ep.ProcessedSentencesInTotal}, SentPerMin = {sentPerMin:F}, WordPerSec = {wordPerSec:F}"); + Logger.WriteLine($"Update = {ep.Update}, Epoch = {ep.Epoch}, LR = {ep.LearningRate.ToString("e4")}, Cost = {ep.AvgCostInTotal.ToString("e4")}, LossScaling = {ep.LossScaling:F}, Sent = {ep.ProcessedSentencesInTotal}, SentPerMin = {sentPerMin:F}, WordPerSec = {wordPerSec:F}"); + } + + public static void Ss_StatusUpdateWatcherDPO(object sender, EventArgs e) + { + CostEventArg ep = e as CostEventArg; + + TimeSpan ts = DateTime.Now - ep.StartDateTime; + double sentPerMin = 0; + double wordPerSec = 0; + if (ts.TotalMinutes > 0) + { + sentPerMin = ep.ProcessedSentencesInTotal / ts.TotalMinutes; + } + + if (ts.TotalSeconds > 0) + { + wordPerSec = ep.ProcessedWordsInTotal / ts.TotalSeconds; + } + + Logger.WriteLine($"Update = {ep.Update}, Epoch = {ep.Epoch}, LR = {ep.LearningRate.ToString("e4")}, Cost = {ep.AvgCostInTotal.ToString("e4")}, ChosenReward = {ep.AvgChosenRewardInTotal.ToString("e4")}, RejectedReward = {ep.AvgRejectedRewardInTotal.ToString("e4")}, Margin = {(ep.AvgChosenRewardInTotal - ep.AvgRejectedRewardInTotal).ToString("e4")}, LossScaling = {ep.LossScaling:F}, Snt = {ep.ProcessedSentencesInTotal}, SentPerMin = {sentPerMin:F}, WordPerSec = {wordPerSec:F}"); } public static IOptimizer CreateOptimizer(Options opts) diff --git a/Seq2SeqSharp/Utils/ModeEnums.cs b/Seq2SeqSharp/Utils/ModeEnums.cs index c726289..8f32871 100644 --- a/Seq2SeqSharp/Utils/ModeEnums.cs +++ b/Seq2SeqSharp/Utils/ModeEnums.cs @@ -19,6 +19,7 @@ public enum ModeEnums DumpVocab, UpdateVocab, VQModel, + DPO, Help } diff --git a/Tools/GPTConsole/Program.cs b/Tools/GPTConsole/Program.cs index 9dcf6ee..541d33e 100644 --- a/Tools/GPTConsole/Program.cs +++ b/Tools/GPTConsole/Program.cs @@ -130,6 +130,44 @@ private static void Main(string[] args) // Kick off training ss.Train(maxTrainingEpoch: opts.MaxEpochNum, trainCorpus: trainCorpus, validCorpusList: null, learningRate: learningRate, optimizer: optimizer, metrics: null, decodingOptions: decodingOptions); } + else if (opts.Task == ModeEnums.DPO) + { + Logger.WriteLine($"Starting to run DPO against model '{opts.ModelFilePath}'"); + + + if (opts.ModelFilePath.IsNullOrEmpty() || !File.Exists(opts.ModelFilePath)) + { + Logger.WriteLine(Logger.Level.err, $"Model '{opts.ModelFilePath}' doesn't exist."); + return; + } + // Load train corpus + var trainCorpus = new DPOCorpus(corpusFilePath: opts.TrainCorpusPath, srcLangName: opts.SrcLang, tgtLangName: opts.TgtLang, maxTokenSizePerBatch: opts.MaxTokenSizePerBatch, + maxSrcSentLength: opts.MaxSrcSentLength, maxTgtSentLength: opts.MaxTgtSentLength, paddingEnums: opts.PaddingType, tooLongSequence: opts.TooLongSequence, indexedFilePath: opts.IndexedCorpusPath, + startBatchId: opts.StartBatchId, dataPassword: opts.DataPassword); + + // Create learning rate + ILearningRate learningRate = null; + + if (opts.LearningRateType == LearningRateTypeEnums.CosineDecay) + { + learningRate = new CosineDecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.LearningRateDecaySteps, opts.WeightsUpdateCount); + } + else + { + learningRate = new DecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.WeightsUpdateCount, opts.LearningRateStepDownFactor, opts.UpdateNumToStepDownLearningRate); + } + + // Create optimizer + IOptimizer optimizer = Misc.CreateOptimizer(opts); + + DPO trainer = new DPO(opts); + + // Add event handler for monitoring + trainer.StatusUpdateWatcher += Misc.Ss_StatusUpdateWatcherDPO; + trainer.EvaluationWatcher += Ss_EvaluationWatcher; + + trainer.Train(maxTrainingEpoch: opts.MaxEpochNum, trainCorpus: trainCorpus, validCorpusList: null, learningRate: learningRate, optimizer: optimizer, metrics: null, decodingOptions: decodingOptions); + } else if (opts.Task == ModeEnums.Test) { if (File.Exists(opts.OutputFile))