Skip to content

Commit 7617439

Browse files
authored
remove SchedulingContext, flatten scheduler interfaces (#889)
Signed-off-by: Nir Rozenbaum <[email protected]>
1 parent fdea49d commit 7617439

21 files changed

+161
-195
lines changed

pkg/epp/scheduling/framework/plugins.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ limitations under the License.
1717
package framework
1818

1919
import (
20+
"context"
21+
2022
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
2123
)
2224

@@ -46,31 +48,31 @@ type ProfilePicker interface {
4648
// Filter defines the interface for filtering a list of pods based on context.
4749
type Filter interface {
4850
Plugin
49-
Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod
51+
Filter(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, pods []types.Pod) []types.Pod
5052
}
5153

5254
// Scorer defines the interface for scoring a list of pods based on context.
5355
// Scorers must score pods with a value within the range of [0,1] where 1 is the highest score.
5456
type Scorer interface {
5557
Plugin
56-
Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64
58+
Score(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, pods []types.Pod) map[types.Pod]float64
5759
}
5860

5961
// Picker picks the final pod(s) to send the request to.
6062
type Picker interface {
6163
Plugin
62-
Pick(ctx *types.SchedulingContext, scoredPods []*types.ScoredPod) *types.Result
64+
Pick(ctx context.Context, cycleState *types.CycleState, scoredPods []*types.ScoredPod) *types.Result
6365
}
6466

6567
// PostCycle is called by the scheduler after it selects a targetPod for the request in the SchedulerProfile cycle.
6668
type PostCycle interface {
6769
Plugin
68-
PostCycle(ctx *types.SchedulingContext, res *types.Result)
70+
PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.Result)
6971
}
7072

7173
// PostResponse is called by the scheduler after a successful response was sent.
7274
// The given pod argument is the pod that served the request.
7375
type PostResponse interface {
7476
Plugin
75-
PostResponse(ctx *types.SchedulingContext, pod types.Pod)
77+
PostResponse(ctx context.Context, response *types.LLMResponse, targetPod types.Pod)
7678
}

pkg/epp/scheduling/framework/plugins/filter/decision_tree_filter.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ limitations under the License.
1717
package filter
1818

1919
import (
20+
"context"
21+
22+
"sigs.k8s.io/controller-runtime/pkg/log"
2023
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
2124
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
2225
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
@@ -53,9 +56,9 @@ func (f *DecisionTreeFilter) Name() string {
5356
}
5457

5558
// Filter filters out pods that doesn't meet the filter criteria.
56-
func (f *DecisionTreeFilter) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
57-
loggerTrace := ctx.Logger.V(logutil.TRACE)
58-
filteredPod := f.Current.Filter(ctx, pods)
59+
func (f *DecisionTreeFilter) Filter(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, pods []types.Pod) []types.Pod {
60+
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
61+
filteredPod := f.Current.Filter(ctx, request, cycleState, pods)
5962

6063
next := f.NextOnSuccessOrFailure
6164
if len(filteredPod) > 0 {
@@ -68,7 +71,7 @@ func (f *DecisionTreeFilter) Filter(ctx *types.SchedulingContext, pods []types.P
6871
}
6972
loggerTrace.Info("Filter succeeded", "filter", f.Name(), "next", next.Name(), "filteredPodCount", len(filteredPod))
7073
// On success, pass the filtered result to the next filter.
71-
return next.Filter(ctx, filteredPod)
74+
return next.Filter(ctx, request, cycleState, filteredPod)
7275
} else {
7376
if f.NextOnFailure == nil && f.NextOnSuccessOrFailure == nil {
7477
// No succeeding filters to run, return.
@@ -79,6 +82,6 @@ func (f *DecisionTreeFilter) Filter(ctx *types.SchedulingContext, pods []types.P
7982
}
8083
loggerTrace.Info("Filter failed", "filter", f.Name(), "next", next.Name())
8184
// On failure, pass the initial set of pods to the next filter.
82-
return next.Filter(ctx, pods)
85+
return next.Filter(ctx, request, cycleState, pods)
8386
}
8487
}

pkg/epp/scheduling/framework/plugins/filter/filter_test.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ func (f *filterAll) Name() string {
3939
return "filter all"
4040
}
4141

42-
func (f *filterAll) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
42+
func (f *filterAll) Filter(_ context.Context, _ *types.LLMRequest, _ *types.CycleState, pods []types.Pod) []types.Pod {
4343
return []types.Pod{}
4444
}
4545

@@ -174,8 +174,7 @@ func TestFilter(t *testing.T) {
174174

175175
for _, test := range tests {
176176
t.Run(test.name, func(t *testing.T) {
177-
ctx := types.NewSchedulingContext(context.Background(), test.req, nil, test.input)
178-
got := test.filter.Filter(ctx, test.input)
177+
got := test.filter.Filter(context.Background(), test.req, types.NewCycleState(), test.input)
179178

180179
if diff := cmp.Diff(test.output, got); diff != "" {
181180
t.Errorf("Unexpected output (-want +got): %v", diff)
@@ -231,8 +230,6 @@ func TestLoRASoftAffinityDistribution(t *testing.T) {
231230
},
232231
},
233232
}
234-
ctx := types.NewSchedulingContext(context.Background(), req, nil, pods)
235-
236233
// Run the filter function multiple times and count the results
237234
affinityCount := 0
238235
availableCount := 0
@@ -245,7 +242,7 @@ func TestLoRASoftAffinityDistribution(t *testing.T) {
245242
LoraAffinityFilter := NewLoraAffinityFilter()
246243

247244
for i := 0; i < numIterations; i++ {
248-
result := LoraAffinityFilter.Filter(ctx, pods)
245+
result := LoraAffinityFilter.Filter(context.Background(), req, types.NewCycleState(), pods)
249246

250247
// Check which type of pod was returned
251248
if len(result) != 1 {

pkg/epp/scheduling/framework/plugins/filter/least_kvcache_filter.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
package filter
1818

1919
import (
20+
"context"
2021
"math"
2122

2223
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
@@ -44,7 +45,7 @@ func (f *LeastKVCacheFilter) Name() string {
4445
}
4546

4647
// Filter filters out pods that doesn't meet the filter criteria.
47-
func (f *LeastKVCacheFilter) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
48+
func (f *LeastKVCacheFilter) Filter(_ context.Context, _ *types.LLMRequest, _ *types.CycleState, pods []types.Pod) []types.Pod {
4849
filteredPods := []types.Pod{}
4950

5051
min := math.MaxFloat64

pkg/epp/scheduling/framework/plugins/filter/least_queue_filter.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
package filter
1818

1919
import (
20+
"context"
2021
"math"
2122

2223
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
@@ -44,7 +45,7 @@ func (f *LeastQueueFilter) Name() string {
4445
}
4546

4647
// Filter filters out pods that doesn't meet the filter criteria.
47-
func (f *LeastQueueFilter) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
48+
func (f *LeastQueueFilter) Filter(_ context.Context, _ *types.LLMRequest, _ *types.CycleState, pods []types.Pod) []types.Pod {
4849
filteredPods := []types.Pod{}
4950

5051
min := math.MaxInt

pkg/epp/scheduling/framework/plugins/filter/lora_affinity_filter.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
package filter
1818

1919
import (
20+
"context"
2021
"math/rand"
2122
"time"
2223

@@ -52,15 +53,15 @@ func (f *LoraAffinityFilter) Name() string {
5253
}
5354

5455
// Filter filters out pods that doesn't meet the filter criteria.
55-
func (f *LoraAffinityFilter) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
56+
func (f *LoraAffinityFilter) Filter(_ context.Context, request *types.LLMRequest, _ *types.CycleState, pods []types.Pod) []types.Pod {
5657
// Pre-allocate slices with estimated capacity
5758
filtered_affinity := make([]types.Pod, 0, len(pods))
5859
filtered_available := make([]types.Pod, 0, len(pods))
5960

6061
// Categorize pods based on affinity and availability
6162
for _, pod := range pods {
62-
_, active := pod.GetMetrics().ActiveModels[ctx.Req.TargetModel]
63-
_, waiting := pod.GetMetrics().WaitingModels[ctx.Req.TargetModel]
63+
_, active := pod.GetMetrics().ActiveModels[request.TargetModel]
64+
_, waiting := pod.GetMetrics().WaitingModels[request.TargetModel]
6465

6566
if active || waiting {
6667
filtered_affinity = append(filtered_affinity, pod)

pkg/epp/scheduling/framework/plugins/filter/low_queue_filter.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ limitations under the License.
1717
package filter
1818

1919
import (
20+
"context"
21+
2022
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/config"
2123
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
2224
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
@@ -43,7 +45,7 @@ func (f *LowQueueFilter) Name() string {
4345
}
4446

4547
// Filter filters out pods that doesn't meet the filter criteria.
46-
func (f *LowQueueFilter) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
48+
func (f *LowQueueFilter) Filter(_ context.Context, _ *types.LLMRequest, _ *types.CycleState, pods []types.Pod) []types.Pod {
4749
filteredPods := []types.Pod{}
4850

4951
for _, pod := range pods {

pkg/epp/scheduling/framework/plugins/filter/sheddable_capacity_filter.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ limitations under the License.
1717
package filter
1818

1919
import (
20+
"context"
21+
2022
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/config"
2123
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
2224
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
@@ -45,8 +47,8 @@ func (f *SheddableCapacityFilter) Name() string {
4547
}
4648

4749
// Filter filters out pods that doesn't meet the filter criteria.
48-
func (f *SheddableCapacityFilter) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod {
49-
if ctx.Req.Critical {
50+
func (f *SheddableCapacityFilter) Filter(_ context.Context, request *types.LLMRequest, _ *types.CycleState, pods []types.Pod) []types.Pod {
51+
if request.Critical {
5052
return pods // // Allow all pods to passthrough if the request is critical, even if all pods reach their capacity.
5153
}
5254

pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@ limitations under the License.
1717
package prefix
1818

1919
import (
20+
"context"
2021
"encoding/binary"
2122
"fmt"
2223

2324
"github.com/cespare/xxhash/v2"
2425
k8stypes "k8s.io/apimachinery/pkg/types"
26+
"sigs.k8s.io/controller-runtime/pkg/log"
2527
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
2628
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
2729
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
@@ -131,24 +133,11 @@ func (m *Plugin) Name() string {
131133
return "prefix-cache"
132134
}
133135

134-
// PostCycle records in the plugin cache the result of the scheduling selection.
135-
func (m *Plugin) PostCycle(ctx *types.SchedulingContext, res *types.Result) {
136-
targetPod := res.TargetPod.GetPod()
137-
state, err := m.getPrefixState(ctx.CycleState)
138-
if err != nil {
139-
ctx.Logger.Error(err, "failed to read prefix plugin cycle state")
140-
return
141-
}
142-
m.indexer.Add(state.PrefixHashes, ServerID(targetPod.NamespacedName))
143-
total := len(state.PrefixHashes)
144-
matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)]
145-
metrics.RecordPrefixCacheMatch(matchLen*m.HashBlockSize, total*m.HashBlockSize)
146-
}
147-
148136
// Score returns the scoring result for the given list of pods based on context.
149-
func (m *Plugin) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 {
137+
func (m *Plugin) Score(ctx context.Context, request *types.LLMRequest, cycleState *types.CycleState, pods []types.Pod) map[types.Pod]float64 {
138+
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
150139
// pre score step, hashing prompt and find longest prefix match.
151-
hashes := hashPrompt(ctx, m.HashBlockSize, m.MaxPrefixBlocksToMatch)
140+
hashes := hashPrompt(ctx, request, m.HashBlockSize, m.MaxPrefixBlocksToMatch)
152141
numServers := DefaultNumServersToMatch
153142
if numServers > len(pods) {
154143
numServers = len(pods)
@@ -157,8 +146,8 @@ func (m *Plugin) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types
157146
PrefixHashes: hashes,
158147
PrefixCacheServers: m.matchLongestPrefix(ctx, hashes, numServers),
159148
}
160-
ctx.CycleState.Write(types.StateKey(m.Name()), state)
161-
ctx.Logger.V(logutil.TRACE).Info(fmt.Sprintf("cached servers: %+v", state.PrefixCacheServers), "hashes", state.PrefixHashes)
149+
cycleState.Write(types.StateKey(m.Name()), state)
150+
loggerTrace.Info(fmt.Sprintf("cached servers: %+v", state.PrefixCacheServers), "hashes", state.PrefixHashes)
162151
// calculate the scores of pods
163152
scores := make(map[types.Pod]float64, len(pods))
164153

@@ -177,16 +166,31 @@ func (m *Plugin) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types
177166
return scores
178167
}
179168

169+
// PostCycle records in the plugin cache the result of the scheduling selection.
170+
func (m *Plugin) PostCycle(ctx context.Context, cycleState *types.CycleState, res *types.Result) {
171+
targetPod := res.TargetPod.GetPod()
172+
state, err := m.getPrefixState(cycleState)
173+
if err != nil {
174+
log.FromContext(ctx).Error(err, "failed to read prefix plugin cycle state")
175+
return
176+
}
177+
m.indexer.Add(state.PrefixHashes, ServerID(targetPod.NamespacedName))
178+
total := len(state.PrefixHashes)
179+
matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)]
180+
metrics.RecordPrefixCacheMatch(matchLen*m.HashBlockSize, total*m.HashBlockSize)
181+
}
182+
180183
// matchLongestPrefix returns a map of servers and length of prefix that each server caches.
181-
func (m *Plugin) matchLongestPrefix(ctx *types.SchedulingContext, hashes []BlockHash, numServers int) map[ServerID]int {
184+
func (m *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash, numServers int) map[ServerID]int {
185+
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
182186
res := make(map[ServerID]int)
183187
// Use a greedy strategy to search from the longest prefix.
184188
// NOTE: It's possible to further optimize this with a binary search.
185189
for i := len(hashes) - 1; i >= 0 && len(res) < numServers; i-- {
186190
hash := hashes[i]
187191
cachedServers := m.indexer.Get(hash)
188192
if len(cachedServers) > 0 {
189-
ctx.Logger.V(logutil.TRACE).Info("Found cached servers", "cachedServers", cachedServers, "total # blocks", len(hashes), "longest prefix", i)
193+
loggerTrace.Info("Found cached servers", "cachedServers", cachedServers, "total # blocks", len(hashes), "longest prefix", i)
190194
for server := range cachedServers {
191195
// Update servers with their longest prefix match.
192196
// If we already found this server with longer prefix match, don't update it.
@@ -218,21 +222,22 @@ func (m *Plugin) getPrefixState(cycleState *types.CycleState) (*schedulingContex
218222
// hashPrompt divides the prompt into blocks and calculate the prefix cache for each block.
219223
// hash(0) is the hash of the model name, since different models generally don't share prefix cache.
220224
// For block i, hash(i) = hash(block i content, hash(i-1)).
221-
func hashPrompt(ctx *types.SchedulingContext, cacheBlockSize int, maxPrefixBlocks int) []BlockHash {
222-
prompt := []byte(ctx.Req.Prompt)
225+
func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) []BlockHash {
226+
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
227+
prompt := []byte(request.Prompt)
223228
if len(prompt) < cacheBlockSize {
224-
ctx.Logger.V(logutil.DEBUG).Info("Request body too small for prefix cache", "size", len(prompt), "block size", cacheBlockSize)
229+
loggerDebug.Info("Request body too small for prefix cache", "size", len(prompt), "block size", cacheBlockSize)
225230
return nil
226231
}
227232
if len(prompt) > cacheBlockSize*maxPrefixBlocks {
228-
ctx.Logger.V(logutil.DEBUG).Info("Truncating input", "size", len(prompt), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize)
233+
loggerDebug.Info("Truncating input", "size", len(prompt), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize)
229234
prompt = prompt[:maxPrefixBlocks*cacheBlockSize]
230235
}
231236
// Split the body into blocks of size cacheBlockSize. The +1 is to account for the model.
232237
// If the last block is smaller than cacheBlockSize, it will be ignored.
233238
res := make([]BlockHash, 0, 1+len(prompt)/cacheBlockSize)
234239
// Add the model to the first block hash so that different models have different hashes even with the same body.
235-
res = append(res, BlockHash(xxhash.Sum64String(ctx.Req.TargetModel)))
240+
res = append(res, BlockHash(xxhash.Sum64String(request.TargetModel)))
236241
for i := 0; i+cacheBlockSize <= len(prompt); i += cacheBlockSize {
237242
block := prompt[i : i+cacheBlockSize]
238243
prevBlockHash := res[len(res)-1]

0 commit comments

Comments
 (0)