Skip to content

Commit 2c66bc7

Browse files
committed
Update key generation for KV cache.
1 parent 243f358 commit 2c66bc7

File tree

1 file changed

+17
-1
lines changed
  • Seq2SeqSharp/Applications

1 file changed

+17
-1
lines changed

Seq2SeqSharp/Applications/GPT.cs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
using Seq2SeqSharp.Tools;
2121
using Seq2SeqSharp.Utils;
2222
using TensorSharp;
23+
using ManagedCuda.BasicTypes;
2324

2425
namespace Seq2SeqSharp.Applications
2526
{
@@ -144,13 +145,28 @@ private bool CreateTrainableParameters(IModel model)
144145
m_posEmbedding?.GetNetworkOnDevice(deviceIdIdx));
145146
}
146147

148+
/// <summary>
149+
/// Generate key for kv cache. </s> will be removed from the key if it exist
150+
/// </summary>
151+
/// <param name="strs"></param>
152+
/// <returns></returns>
147153
private string GenerateCacheKey(List<List<string>> strs)
148154
{
149155
List<string> r = new List<string>();
150156

151157
foreach (var str in strs)
152158
{
153-
r.Add(string.Join(" ", str));
159+
List<string> normStr = new List<string>();
160+
foreach (string word in str)
161+
{
162+
if (word == "</s>")
163+
{
164+
continue;
165+
}
166+
normStr.Add(word);
167+
}
168+
169+
r.Add(string.Join(" ", normStr));
154170
}
155171

156172
return string.Join("\t", r);

0 commit comments

Comments
 (0)