Skip to content

Commit 9eb2bd2

Browse files
committed
1. Support reference model for reinforcement learning
2. Support DPO (direct preference optimization) training
1 parent 2c66bc7 commit 9eb2bd2

18 files changed

+1263
-56
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,4 @@ Temporary Items
124124
/Tools/SeqDictMatchConsole/bin/Debug/net8.0
125125
/Tools/SeqMedical/obj
126126
/Tools/SeqMedical/bin/Debug/net9.0
127+
/Tools/SeqMedical/bin

Seq2SeqSharp/Applications/DPO.cs

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
// Copyright (c) Zhongkai Fu. All rights reserved.
2+
// https://github.com/zhongkaifu/Seq2SeqSharp
3+
//
4+
// This file is part of Seq2SeqSharp.
5+
//
6+
// Seq2SeqSharp is licensed under the BSD-3-Clause license found in the LICENSE file in the root directory of this source tree.
7+
//
8+
// Seq2SeqSharp is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
9+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD-3-Clause License for more details.
10+
11+
using System;
12+
using System.Collections.Generic;
13+
using System.IO;
14+
using AdvUtils;
15+
using System.Runtime.Caching;
16+
using Seq2SeqSharp.Enums;
17+
using Seq2SeqSharp.Corpus;
18+
using Seq2SeqSharp.Layers;
19+
using Seq2SeqSharp.Models;
20+
using Seq2SeqSharp.Tools;
21+
using Seq2SeqSharp.Utils;
22+
using TensorSharp;
23+
using ManagedCuda.BasicTypes;
24+
25+
namespace Seq2SeqSharp.Applications
26+
{
27+
public class DPO : BaseSeq2SeqFramework<Seq2SeqModel>
28+
{
29+
// Trainable parameters including networks and tensors
30+
private MultiProcessorNetworkWrapper<IWeightTensor> m_tgtEmbedding = null; //The embeddings over devices for source
31+
private MultiProcessorNetworkWrapper<IDecoder> m_decoder = null; //The decoders over devices
32+
private MultiProcessorNetworkWrapper<IFeedForwardLayer> m_decoderFFLayer = null; //The feed forward layers over devices after all layers in decoder
33+
private MultiProcessorNetworkWrapper<IWeightTensor> m_segmentEmbedding = null;
34+
private MultiProcessorNetworkWrapper<IWeightTensor> m_posEmbedding = null;
35+
36+
37+
38+
39+
40+
41+
private MultiProcessorNetworkWrapper<IWeightTensor> ref_m_tgtEmbedding = null; //The embeddings over devices for source
42+
private MultiProcessorNetworkWrapper<IDecoder> ref_m_decoder = null; //The decoders over devices
43+
private MultiProcessorNetworkWrapper<IFeedForwardLayer> ref_m_decoderFFLayer = null; //The feed forward layers over devices after all layers in decoder
44+
private MultiProcessorNetworkWrapper<IWeightTensor> ref_m_segmentEmbedding = null;
45+
private MultiProcessorNetworkWrapper<IWeightTensor> ref_m_posEmbedding = null;
46+
47+
48+
49+
50+
private readonly PaddingEnums m_paddingType = PaddingEnums.AllowPadding;
51+
readonly Seq2SeqOptions m_options = null;
52+
53+
public event EventHandler KVCacheRemoveWatcher;
54+
55+
public DPO(Seq2SeqOptions options, Vocab tgtVocab = null)
56+
: base(deviceIds: options.DeviceIds, processorType: options.ProcessorType, modelFilePath: options.ModelFilePath, memoryUsageRatio: options.MemoryUsageRatio,
57+
compilerOptions: options.CompilerOptions, runValidEveryUpdates: options.RunValidEveryUpdates, updateFreq: options.UpdateFreq,
58+
startToRunValidAfterUpdates: options.StartValidAfterUpdates, maxDegressOfParallelism: options.TaskParallelism, mklInstructions: options.MKLInstructions, weightsUpdateCount: options.WeightsUpdateCount,
59+
enableTensorCore: options.EnableTensorCore, cudaMemoryAllocatorType: options.CudaMemoryAllocatorType, elementType: options.AMP ? DType.Float16 : DType.Float32, randomSeed: options.RandomSeed,
60+
saveModelEveryUpdats: options.SaveModelEveryUpdates, saveGPUMemoryLevel: options.SaveGPUMemoryLevel, initLossScaling: options.InitLossScaling, autoCheckTensorCorruption: options.CheckTensorCorrupted,
61+
attentionType: options.AttentionType)
62+
{
63+
m_paddingType = options.PaddingType;
64+
m_options = options;
65+
66+
// Check if options are valided.
67+
m_options.ValidateOptions();
68+
if (File.Exists(m_options.ModelFilePath))
69+
{
70+
if (tgtVocab != null)
71+
{
72+
throw new ArgumentException($"Model '{m_options.ModelFilePath}' exists and it includes vocabulary, so input vocabulary must be null.");
73+
}
74+
75+
// Model file exists, so we load it from file.
76+
m_modelMetaData = LoadModel();
77+
}
78+
else
79+
{
80+
// Model doesn't exist, we create it and initlaize parameters
81+
m_modelMetaData = new Seq2SeqModel(options, null, tgtVocab);
82+
83+
//Initializng weights in encoders and decoders
84+
CreateTrainableParameters(m_modelMetaData);
85+
}
86+
87+
m_modelMetaData.EncoderType = EncoderTypeEnums.None;
88+
m_modelMetaData.DecoderType = DecoderTypeEnums.GPTDecoder;
89+
m_modelMetaData.ShowModelInfo();
90+
}
91+
92+
public void UpdateVocabs(Vocab tgtVocab)
93+
{
94+
if (tgtVocab != null)
95+
{
96+
m_modelMetaData.TgtVocab = tgtVocab;
97+
}
98+
99+
SaveModel(createBackupPrevious: true, suffix: ".updatevocab");
100+
}
101+
102+
public void VQModel()
103+
{
104+
m_modelMetaData.VQType = m_options.VQType;
105+
SaveModel(createBackupPrevious: true, suffix: $".{m_modelMetaData.VQType.ToString()}");
106+
107+
}
108+
109+
protected override Seq2SeqModel LoadModel(string suffix = "") => base.LoadModelRoutine<Model_4_ProtoBufSerializer>(CreateTrainableParameters, Seq2SeqModel.Create, suffix);
110+
111+
private bool CreateTrainableParameters(IModel model)
112+
{
113+
CreateDPOModel(model);
114+
CreateRefModel(model);
115+
116+
return true;
117+
}
118+
119+
private bool CreateDPOModel(IModel model)
120+
{
121+
if (m_decoder != null)
122+
{
123+
m_decoder.Dispose();
124+
}
125+
if (m_decoderFFLayer != null)
126+
{
127+
m_decoderFFLayer.Dispose();
128+
}
129+
130+
if (m_segmentEmbedding != null)
131+
{
132+
m_segmentEmbedding.Dispose();
133+
}
134+
135+
if (m_tgtEmbedding != null)
136+
{
137+
m_tgtEmbedding.Dispose();
138+
}
139+
140+
Logger.WriteLine(Logger.Level.debug, $"Creating decoders...");
141+
142+
var raDeviceIds = new RoundArray<int>(DeviceIds);
143+
144+
DType elementType = m_options.AMP ? DType.Float16 : DType.Float32;
145+
146+
m_decoder = Decoder.CreateDecoders(model, m_options, raDeviceIds, isTrainable: m_options.IsDecoderTrainable && (m_options.Task == ModeEnums.DPO), elementType: elementType);
147+
m_decoderFFLayer = new MultiProcessorNetworkWrapper<IFeedForwardLayer>(new FeedForwardLayer("FeedForward_Decoder_0", model.HiddenDim, model.TgtVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(),
148+
isTrainable: (m_options.Task == ModeEnums.DPO), learningRateFactor: m_options.DecoderStartLearningRateFactor, elementType), DeviceIds);
149+
150+
(m_posEmbedding, m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, Math.Max(m_options.MaxTgtSentLength, m_options.MaxValidTgtSentLength), model, elementType,
151+
isTrainable: (m_options.Task == ModeEnums.DPO), createAPE: (model.PEType == PositionEmbeddingEnums.APE));
152+
m_tgtEmbedding = CreateTgtEmbeddings(model, raDeviceIds, m_options.IsTgtEmbeddingTrainable && (m_options.Task == ModeEnums.DPO), m_options.DecoderStartLearningRateFactor, elementType);
153+
154+
return (true);
155+
}
156+
157+
158+
private bool CreateRefModel(IModel model)
159+
{
160+
if (ref_m_decoder != null)
161+
{
162+
ref_m_decoder.Dispose();
163+
}
164+
if (ref_m_decoderFFLayer != null)
165+
{
166+
ref_m_decoderFFLayer.Dispose();
167+
}
168+
169+
if (ref_m_segmentEmbedding != null)
170+
{
171+
ref_m_segmentEmbedding.Dispose();
172+
}
173+
174+
if (ref_m_tgtEmbedding != null)
175+
{
176+
ref_m_tgtEmbedding.Dispose();
177+
}
178+
179+
Logger.WriteLine(Logger.Level.debug, $"Creating decoders...");
180+
181+
var raDeviceIds = new RoundArray<int>(DeviceIds);
182+
183+
DType elementType = m_options.AMP ? DType.Float16 : DType.Float32;
184+
185+
ref_m_decoder = Decoder.CreateDecoders(model, m_options, raDeviceIds, isTrainable: false, elementType: elementType, isSavable: false);
186+
ref_m_decoderFFLayer = new MultiProcessorNetworkWrapper<IFeedForwardLayer>(new FeedForwardLayer("FeedForward_Decoder_0", model.HiddenDim, model.TgtVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(),
187+
isTrainable: false, learningRateFactor: m_options.DecoderStartLearningRateFactor, elementType), DeviceIds, savableWeights: false);
188+
189+
(ref_m_posEmbedding, ref_m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, Math.Max(m_options.MaxTgtSentLength, m_options.MaxValidTgtSentLength), model, elementType,
190+
isTrainable: false, createAPE: (model.PEType == PositionEmbeddingEnums.APE));
191+
ref_m_tgtEmbedding = CreateTgtEmbeddings(model, raDeviceIds, false, m_options.DecoderStartLearningRateFactor, elementType, isSavable: false);
192+
193+
return (true);
194+
}
195+
/// <summary>
196+
/// Get networks on specific devices
197+
/// </summary>
198+
private (IDecoder, IFeedForwardLayer, IWeightTensor, IWeightTensor, IWeightTensor) GetNetworksOnDeviceAt(int deviceId)
199+
{
200+
var deviceIdIdx = TensorAllocator.GetDeviceIdIndex(deviceId);
201+
return (m_decoder.GetNetworkOnDevice(deviceIdIdx),
202+
m_decoderFFLayer.GetNetworkOnDevice(deviceIdIdx),
203+
m_tgtEmbedding.GetNetworkOnDevice(deviceIdIdx),
204+
m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx),
205+
m_posEmbedding?.GetNetworkOnDevice(deviceIdIdx));
206+
}
207+
208+
209+
private (IDecoder, IFeedForwardLayer, IWeightTensor, IWeightTensor, IWeightTensor) GetRefNetworksOnDeviceAt(int deviceId)
210+
{
211+
var deviceIdIdx = TensorAllocator.GetDeviceIdIndex(deviceId);
212+
return (ref_m_decoder.GetNetworkOnDevice(deviceIdIdx),
213+
ref_m_decoderFFLayer.GetNetworkOnDevice(deviceIdIdx),
214+
ref_m_tgtEmbedding.GetNetworkOnDevice(deviceIdIdx),
215+
ref_m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx),
216+
ref_m_posEmbedding?.GetNetworkOnDevice(deviceIdIdx));
217+
}
218+
219+
/// <summary>
220+
/// Run forward part on given single device
221+
/// </summary>
222+
/// <param name="computeGraph">The computing graph for current device. It gets created and passed by the framework</param>
223+
/// <param name="tgtSnts">A batch of output tokenized sentences in target side</param>
224+
/// <param name="deviceIdIdx">The index of current device</param>
225+
/// <returns>The cost of forward part</returns>
226+
public override List<NetworkResult> RunForwardOnSingleDevice(IComputeGraph computeGraph, IPairBatch sntPairBatch, DecodingOptions decodingOptions, bool isTraining)
227+
{
228+
if (isTraining == false)
229+
{
230+
throw new ArgumentException("The DPO is only for training mode.");
231+
}
232+
233+
(var decoder, var decoderFFLayer, var tgtEmbedding, var segmentEmbedding, var posEmbeddings) = GetNetworksOnDeviceAt(computeGraph.DeviceId);
234+
(var ref_decoder, var ref_decoderFFLayer, var ref_tgtEmbedding, var ref_segmentEmbedding, var ref_posEmbeddings) = GetRefNetworksOnDeviceAt(computeGraph.DeviceId);
235+
236+
List<NetworkResult> nrs = new List<NetworkResult>();
237+
int messageTokenId = m_modelMetaData.TgtVocab.GetWordIndex(m_options.DPOMaskedToken, logUnk: true);
238+
239+
// Generate output decoder sentences
240+
var chosenSnts = sntPairBatch.GetSrcTokens();
241+
int batchSize = chosenSnts.Count;
242+
var chosenTokensList = m_modelMetaData.TgtVocab.GetWordIndex(chosenSnts);
243+
var chosenMask = computeGraph.BuildMaskAfter(chosenTokensList, messageTokenId, tgtEmbedding.ElementType);
244+
245+
246+
var rejectedSnts = sntPairBatch.GetTgtTokens();
247+
//int batchSize = preferredSnts.Count;
248+
var rejectedTokensList = m_modelMetaData.TgtVocab.GetWordIndex(rejectedSnts);
249+
var rejectedMask = computeGraph.BuildMaskAfter(rejectedTokensList, messageTokenId, tgtEmbedding.ElementType);
250+
251+
NetworkResult nr = new NetworkResult();
252+
nr.Status = NetworkResultStatus.SUCCEED;
253+
254+
decoder.Reset(computeGraph.GetWeightFactory(), chosenSnts.Count);
255+
//decoder.Reset(computeGraph.GetWeightFactory(), nonPreferredSnts.Count);
256+
257+
(var loss, var cr, var rr) = Decoder.DPODecoderTrainer(chosenTokensList, rejectedTokensList, computeGraph, decoder as GPTDecoder, ref_decoder as GPTDecoder,
258+
decoderFFLayer, ref_decoderFFLayer,
259+
tgtEmbedding, ref_tgtEmbedding,
260+
m_modelMetaData.TgtVocab, m_paddingType, m_options.DropoutRatio,
261+
segmentEmbedding, ref_segmentEmbedding,
262+
m_options.AMP,
263+
posEmbeddings, ref_posEmbeddings,
264+
LossScaling, m_options.PaddingAlignmentFactor, lossSmooth: m_options.LossSmooth, beta: m_options.DPOBeta, chosenMasks: chosenMask, rejectedMasks: rejectedMask);
265+
nr.Cost = loss;
266+
nr.ChosenRewards = cr;
267+
nr.RejectedRewards = rr;
268+
nr.Output = null;
269+
270+
nrs.Add(nr);
271+
return nrs;
272+
}
273+
}
274+
}

0 commit comments

Comments
 (0)