Skip to content

Commit

Permalink
fix(ai): fix orchestrator suspension for AI jobs (#3393)
Browse files Browse the repository at this point in the history
  • Loading branch information
ad-astra-video authored Feb 18, 2025
1 parent d84c0c6 commit 39db9b6
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 11 deletions.
1 change: 1 addition & 0 deletions server/ai_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -1519,6 +1519,7 @@ func processAIRequest(ctx context.Context, params aiRequestParams, req interface
break
}

// Suspend the session on other errors.
clog.Infof(ctx, "Error submitting request modelID=%v try=%v orch=%v err=%v", modelID, tries, sess.Transcoder(), err)
params.sessManager.Remove(ctx, sess) //TODO: Improve session selection logic for live-video-to-video

Expand Down
58 changes: 47 additions & 11 deletions server/ai_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,16 @@ type AISessionPool struct {
sessMap map[string]*BroadcastSession
inUseSess []*BroadcastSession
suspender *suspender
penalty int
mu sync.RWMutex
}

func NewAISessionPool(selector BroadcastSessionsSelector, suspender *suspender) *AISessionPool {
func NewAISessionPool(selector BroadcastSessionsSelector, suspender *suspender, penalty int) *AISessionPool {
return &AISessionPool{
selector: selector,
sessMap: make(map[string]*BroadcastSession),
suspender: suspender,
penalty: penalty,
mu: sync.RWMutex{},
}
}
Expand Down Expand Up @@ -101,10 +103,6 @@ func (pool *AISessionPool) Add(sessions []*BroadcastSession) {
pool.mu.Lock()
defer pool.mu.Unlock()

// If we try to add new sessions to the pool the suspender
// should treat this as a refresh
pool.suspender.signalRefresh()

var uniqueSessions []*BroadcastSession
for _, sess := range sessions {
if _, ok := pool.sessMap[sess.Transcoder()]; ok {
Expand All @@ -126,10 +124,14 @@ func (pool *AISessionPool) Remove(sess *BroadcastSession) {
delete(pool.sessMap, sess.Transcoder())
pool.inUseSess = removeSessionFromList(pool.inUseSess, sess)

// Magic number for now
penalty := 3
// If this method is called assume that the orch should be suspended
// as well
// as well. Since AISessionManager re-uses the pools the suspension
// penalty needs to consider the current suspender count to set the penalty
lastCount, ok := pool.suspender.list[sess.Transcoder()]
penalty := pool.suspender.count + pool.penalty
if ok {
penalty -= lastCount
}
pool.suspender.suspend(sess.Transcoder(), penalty)
}

Expand All @@ -156,12 +158,14 @@ type AISessionSelector struct {
// The time until the pools should be refreshed with orchs from discovery
ttl time.Duration
lastRefreshTime time.Time
initialPoolSize int

cap core.Capability
modelID string

node *core.LivepeerNode
suspender *suspender
penalty int
os drivers.OSSession
}

Expand All @@ -180,8 +184,10 @@ func NewAISessionSelector(ctx context.Context, cap core.Capability, modelID stri
// The latency score in this context is just the latency of the last completed request for a session
// The "good enough" latency score is set to 0.0 so the selector will always select unknown sessions first
minLS := 0.0
warmPool := NewAISessionPool(NewMinLSSelector(stakeRdr, minLS, node.SelectionAlgorithm, node.OrchPerfScore, warmCaps), suspender)
coldPool := NewAISessionPool(NewMinLSSelector(stakeRdr, minLS, node.SelectionAlgorithm, node.OrchPerfScore, coldCaps), suspender)
// Session pool suspender starts at 0. Suspension is 3 requests if there are errors from the orchestrator
penalty := 3
warmPool := NewAISessionPool(NewMinLSSelector(stakeRdr, minLS, node.SelectionAlgorithm, node.OrchPerfScore, warmCaps), suspender, penalty)
coldPool := NewAISessionPool(NewMinLSSelector(stakeRdr, minLS, node.SelectionAlgorithm, node.OrchPerfScore, coldCaps), suspender, penalty)
sel := &AISessionSelector{
warmPool: warmPool,
coldPool: coldPool,
Expand All @@ -190,6 +196,7 @@ func NewAISessionSelector(ctx context.Context, cap core.Capability, modelID stri
modelID: modelID,
node: node,
suspender: suspender,
penalty: penalty,
os: drivers.NodeStorage.NewSession(strconv.Itoa(int(cap)) + "_" + modelID),
}

Expand Down Expand Up @@ -218,11 +225,26 @@ func newAICapabilities(cap core.Capability, modelID string, warm bool, minVersio
return caps
}

// selectorIsEmpty returns true if no orchestrators are in the warm or cold pools.
func (sel *AISessionSelector) SelectorIsEmpty() bool {
return sel.warmPool.Size() == 0 && sel.coldPool.Size() == 0
}

func (sel *AISessionSelector) Select(ctx context.Context) *AISession {
shouldRefreshSelector := func() bool {
discoveryPoolSize := int(math.Min(float64(sel.node.OrchestratorPool.Size()), float64(sel.initialPoolSize)))

// If the selector is empty, release all orchestrators from suspension and
// try refresh.
if sel.SelectorIsEmpty() {
clog.Infof(ctx, "refreshing sessions, no orchestrators in pools")
for i := 0; i < sel.penalty; i++ {
sel.suspender.signalRefresh()
}
}

// Refresh if the # of sessions across warm and cold pools falls below the smaller of the maxRefreshSessionsThreshold and
// 1/2 the total # of orchs that can be queried during discovery
discoveryPoolSize := sel.node.OrchestratorPool.Size()
if sel.warmPool.Size()+sel.coldPool.Size() < int(math.Min(maxRefreshSessionsThreshold, math.Ceil(float64(discoveryPoolSize)/2.0))) {
return true
}
Expand Down Expand Up @@ -272,6 +294,10 @@ func (sel *AISessionSelector) Remove(sess *AISession) {
}

func (sel *AISessionSelector) Refresh(ctx context.Context) error {
// If we try to add new sessions to the pool the suspender
// should treat this as a refresh
sel.suspender.signalRefresh()

sessions, err := sel.getSessions(ctx)
if err != nil {
return err
Expand All @@ -286,6 +312,13 @@ func (sel *AISessionSelector) Refresh(ctx context.Context) error {
continue
}

// We request 100 orchestrators in getSessions above so all Orchestrators are returned with refreshed information
// This keeps the suspended Orchestrators out of the pool until the selector is empty or 30 minutes has passed (refresh happens every 10 minutes)
if sel.suspender.Suspended(sess.Transcoder()) > 0 {
clog.V(common.DEBUG).Infof(ctx, "skipping suspended orchestrator=%s", sess.Transcoder())
continue
}

// If the constraint for the modelID are missing skip this session
modelConstraint, ok := constraints.Models[sel.modelID]
if !ok {
Expand All @@ -301,6 +334,7 @@ func (sel *AISessionSelector) Refresh(ctx context.Context) error {

sel.warmPool.Add(warmSessions)
sel.coldPool.Add(coldSessions)
sel.initialPoolSize = len(warmSessions) + len(coldSessions) + len(sel.suspender.list)

sel.lastRefreshTime = time.Now()

Expand Down Expand Up @@ -371,6 +405,8 @@ func (c *AISessionManager) Select(ctx context.Context, cap core.Capability, mode
return nil, err
}

clog.V(common.DEBUG).Infof(ctx, "selected orchestrator=%s", sess.Transcoder())

return sess, nil
}

Expand Down
82 changes: 82 additions & 0 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,88 @@ func getAvailableTranscodingOptionsHandler() http.Handler {
})
}

// poolOrchestrator contains information about an orchestrator in a pool.
type poolOrchestrator struct {
Url string `json:"url"`
LatencyScore float64 `json:"latency_score"`
InFlight int `json:"in_flight"`
}

// aiPoolInfo contains information about an AI pool.
type aiPoolInfo struct {
Size int `json:"size"`
InUse int `json:"in_use"`
Orchestrators []poolOrchestrator `json:"orchestrators"`
}

// suspendedInfo contains information about suspended orchestrators.
type suspendedInfo struct {
List map[string]int `json:"list"`
CurrentCount int `json:"current_count"`
}

// aiOrchestratorPools contains information about all AI pools.
type aiOrchestratorPools struct {
Cold aiPoolInfo `json:"cold"`
Warm aiPoolInfo `json:"warm"`
LastRefresh time.Time `json:"last_refresh"`
Suspended suspendedInfo `json:"suspended"`
}

// getAIOrchestratorPoolsInfoHandler returns information about AI orchestrator pools.
func (s *LivepeerServer) getAIPoolsInfoHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
aiPoolsInfoResp := make(map[string]aiOrchestratorPools)

s.AISessionManager.mu.Lock()
defer s.AISessionManager.mu.Unlock()

// Return if no selectors are present.
if len(s.AISessionManager.selectors) == 0 {
glog.Warning("Orchestrator pools are not yet initialized")
respondJson(w, aiPoolsInfoResp)
return
}

// Loop through selectors and get pools info.
for cap, pool := range s.AISessionManager.selectors {
warmPool := aiPoolInfo{
Size: pool.warmPool.Size(),
InUse: len(pool.warmPool.inUseSess),
}
for _, sess := range pool.warmPool.sessMap {
poolOrchestrator := poolOrchestrator{
Url: sess.Transcoder(),
LatencyScore: sess.LatencyScore,
InFlight: len(sess.SegsInFlight),
}
warmPool.Orchestrators = append(warmPool.Orchestrators, poolOrchestrator)
}

coldPool := aiPoolInfo{
Size: pool.coldPool.Size(),
InUse: len(pool.coldPool.inUseSess),
}
for _, sess := range pool.coldPool.sessMap {
coldPool.Orchestrators = append(coldPool.Orchestrators, poolOrchestrator{
Url: sess.Transcoder(),
LatencyScore: sess.LatencyScore,
InFlight: len(sess.SegsInFlight),
})
}

aiPoolsInfoResp[cap] = aiOrchestratorPools{
Cold: coldPool,
Warm: warmPool,
LastRefresh: pool.lastRefreshTime,
Suspended: suspendedInfo{List: pool.suspender.list, CurrentCount: pool.suspender.count},
}
}

respondJson(w, aiPoolsInfoResp)
})
}

// Rounds
func currentRoundHandler(client eth.LivepeerEthClient) http.Handler {
return mustHaveClient(client, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down
1 change: 1 addition & 0 deletions server/webserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ func (s *LivepeerServer) cliWebServerHandlers(bindAddr string) *http.ServeMux {
mux.Handle("/getBroadcastConfig", getBroadcastConfigHandler())
mux.Handle("/getAvailableTranscodingOptions", getAvailableTranscodingOptionsHandler())
mux.Handle("/setMaxPriceForCapability", mustHaveFormParams(s.setMaxPriceForCapability(), "maxPricePerUnit", "pixelsPerUnit", "currency", "pipeline", "modelID"))
mux.Handle("/getAISessionPoolsInfo", s.getAIPoolsInfoHandler())

// Rounds
mux.Handle("/currentRound", currentRoundHandler(client))
Expand Down

0 comments on commit 39db9b6

Please sign in to comment.