@@ -17,11 +17,13 @@ limitations under the License.
17
17
package prefix
18
18
19
19
import (
20
+ "context"
20
21
"encoding/binary"
21
22
"fmt"
22
23
23
24
"github.com/cespare/xxhash/v2"
24
25
k8stypes "k8s.io/apimachinery/pkg/types"
26
+ "sigs.k8s.io/controller-runtime/pkg/log"
25
27
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
26
28
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
27
29
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
@@ -131,24 +133,11 @@ func (m *Plugin) Name() string {
131
133
return "prefix-cache"
132
134
}
133
135
134
- // PostCycle records in the plugin cache the result of the scheduling selection.
135
- func (m * Plugin ) PostCycle (ctx * types.SchedulingContext , res * types.Result ) {
136
- targetPod := res .TargetPod .GetPod ()
137
- state , err := m .getPrefixState (ctx .CycleState )
138
- if err != nil {
139
- ctx .Logger .Error (err , "failed to read prefix plugin cycle state" )
140
- return
141
- }
142
- m .indexer .Add (state .PrefixHashes , ServerID (targetPod .NamespacedName ))
143
- total := len (state .PrefixHashes )
144
- matchLen := state .PrefixCacheServers [ServerID (targetPod .NamespacedName )]
145
- metrics .RecordPrefixCacheMatch (matchLen * m .HashBlockSize , total * m .HashBlockSize )
146
- }
147
-
148
136
// Score returns the scoring result for the given list of pods based on context.
149
- func (m * Plugin ) Score (ctx * types.SchedulingContext , pods []types.Pod ) map [types.Pod ]float64 {
137
+ func (m * Plugin ) Score (ctx context.Context , request * types.LLMRequest , cycleState * types.CycleState , pods []types.Pod ) map [types.Pod ]float64 {
138
+ loggerTrace := log .FromContext (ctx ).V (logutil .TRACE )
150
139
// pre score step, hashing prompt and find longest prefix match.
151
- hashes := hashPrompt (ctx , m .HashBlockSize , m .MaxPrefixBlocksToMatch )
140
+ hashes := hashPrompt (ctx , request , m .HashBlockSize , m .MaxPrefixBlocksToMatch )
152
141
numServers := DefaultNumServersToMatch
153
142
if numServers > len (pods ) {
154
143
numServers = len (pods )
@@ -157,8 +146,8 @@ func (m *Plugin) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types
157
146
PrefixHashes : hashes ,
158
147
PrefixCacheServers : m .matchLongestPrefix (ctx , hashes , numServers ),
159
148
}
160
- ctx . CycleState .Write (types .StateKey (m .Name ()), state )
161
- ctx . Logger . V ( logutil . TRACE ) .Info (fmt .Sprintf ("cached servers: %+v" , state .PrefixCacheServers ), "hashes" , state .PrefixHashes )
149
+ cycleState .Write (types .StateKey (m .Name ()), state )
150
+ loggerTrace .Info (fmt .Sprintf ("cached servers: %+v" , state .PrefixCacheServers ), "hashes" , state .PrefixHashes )
162
151
// calculate the scores of pods
163
152
scores := make (map [types.Pod ]float64 , len (pods ))
164
153
@@ -177,16 +166,31 @@ func (m *Plugin) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types
177
166
return scores
178
167
}
179
168
169
+ // PostCycle records in the plugin cache the result of the scheduling selection.
170
+ func (m * Plugin ) PostCycle (ctx context.Context , cycleState * types.CycleState , res * types.Result ) {
171
+ targetPod := res .TargetPod .GetPod ()
172
+ state , err := m .getPrefixState (cycleState )
173
+ if err != nil {
174
+ log .FromContext (ctx ).Error (err , "failed to read prefix plugin cycle state" )
175
+ return
176
+ }
177
+ m .indexer .Add (state .PrefixHashes , ServerID (targetPod .NamespacedName ))
178
+ total := len (state .PrefixHashes )
179
+ matchLen := state .PrefixCacheServers [ServerID (targetPod .NamespacedName )]
180
+ metrics .RecordPrefixCacheMatch (matchLen * m .HashBlockSize , total * m .HashBlockSize )
181
+ }
182
+
180
183
// matchLongestPrefix returns a map of servers and length of prefix that each server caches.
181
- func (m * Plugin ) matchLongestPrefix (ctx * types.SchedulingContext , hashes []BlockHash , numServers int ) map [ServerID ]int {
184
+ func (m * Plugin ) matchLongestPrefix (ctx context.Context , hashes []BlockHash , numServers int ) map [ServerID ]int {
185
+ loggerTrace := log .FromContext (ctx ).V (logutil .TRACE )
182
186
res := make (map [ServerID ]int )
183
187
// Use a greedy strategy to search from the longest prefix.
184
188
// NOTE: It's possible to further optimize this with a binary search.
185
189
for i := len (hashes ) - 1 ; i >= 0 && len (res ) < numServers ; i -- {
186
190
hash := hashes [i ]
187
191
cachedServers := m .indexer .Get (hash )
188
192
if len (cachedServers ) > 0 {
189
- ctx . Logger . V ( logutil . TRACE ) .Info ("Found cached servers" , "cachedServers" , cachedServers , "total # blocks" , len (hashes ), "longest prefix" , i )
193
+ loggerTrace .Info ("Found cached servers" , "cachedServers" , cachedServers , "total # blocks" , len (hashes ), "longest prefix" , i )
190
194
for server := range cachedServers {
191
195
// Update servers with their longest prefix match.
192
196
// If we already found this server with longer prefix match, don't update it.
@@ -218,21 +222,22 @@ func (m *Plugin) getPrefixState(cycleState *types.CycleState) (*schedulingContex
218
222
// hashPrompt divides the prompt into blocks and calculate the prefix cache for each block.
219
223
// hash(0) is the hash of the model name, since different models generally don't share prefix cache.
220
224
// For block i, hash(i) = hash(block i content, hash(i-1)).
221
- func hashPrompt (ctx * types.SchedulingContext , cacheBlockSize int , maxPrefixBlocks int ) []BlockHash {
222
- prompt := []byte (ctx .Req .Prompt )
225
+ func hashPrompt (ctx context.Context , request * types.LLMRequest , cacheBlockSize int , maxPrefixBlocks int ) []BlockHash {
226
+ loggerDebug := log .FromContext (ctx ).V (logutil .DEBUG )
227
+ prompt := []byte (request .Prompt )
223
228
if len (prompt ) < cacheBlockSize {
224
- ctx . Logger . V ( logutil . DEBUG ) .Info ("Request body too small for prefix cache" , "size" , len (prompt ), "block size" , cacheBlockSize )
229
+ loggerDebug .Info ("Request body too small for prefix cache" , "size" , len (prompt ), "block size" , cacheBlockSize )
225
230
return nil
226
231
}
227
232
if len (prompt ) > cacheBlockSize * maxPrefixBlocks {
228
- ctx . Logger . V ( logutil . DEBUG ) .Info ("Truncating input" , "size" , len (prompt ), "max prefix blocks" , maxPrefixBlocks , "block size" , cacheBlockSize )
233
+ loggerDebug .Info ("Truncating input" , "size" , len (prompt ), "max prefix blocks" , maxPrefixBlocks , "block size" , cacheBlockSize )
229
234
prompt = prompt [:maxPrefixBlocks * cacheBlockSize ]
230
235
}
231
236
// Split the body into blocks of size cacheBlockSize. The +1 is to account for the model.
232
237
// If the last block is smaller than cacheBlockSize, it will be ignored.
233
238
res := make ([]BlockHash , 0 , 1 + len (prompt )/ cacheBlockSize )
234
239
// Add the model to the first block hash so that different models have different hashes even with the same body.
235
- res = append (res , BlockHash (xxhash .Sum64String (ctx . Req .TargetModel )))
240
+ res = append (res , BlockHash (xxhash .Sum64String (request .TargetModel )))
236
241
for i := 0 ; i + cacheBlockSize <= len (prompt ); i += cacheBlockSize {
237
242
block := prompt [i : i + cacheBlockSize ]
238
243
prevBlockHash := res [len (res )- 1 ]
0 commit comments