Skip to content

Commit 73f69fb

Browse files
add tests for putting session on hold
1 parent b5d1048 commit 73f69fb

File tree

1 file changed

+310
-0
lines changed

1 file changed

+310
-0
lines changed

server/ai_session_test.go

Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
package server
2+
3+
import (
4+
"container/heap"
5+
"context"
6+
"math/rand"
7+
"strconv"
8+
"sync"
9+
"testing"
10+
"time"
11+
12+
"github.com/livepeer/go-livepeer/ai/worker"
13+
"github.com/livepeer/go-livepeer/core"
14+
"github.com/livepeer/go-livepeer/net"
15+
"github.com/livepeer/go-tools/drivers"
16+
"github.com/stretchr/testify/assert"
17+
)
18+
19+
func TestProcessAIRequest_RetryableError(t *testing.T) {
20+
n, _ := core.NewLivepeerNode(nil, "", nil)
21+
n.AIProcesssingRetryTimeout = 1 * time.Millisecond //allows processAIRequest to run one time
22+
s := LivepeerServer{LivepeerNode: n}
23+
penalty := 3
24+
cap := core.Capability_AudioToText
25+
modelID := "audio-model/1"
26+
27+
orchSess1 := StubBroadcastSession("http://local.host/1")
28+
orchSess2 := StubBroadcastSession("http://local.host/2")
29+
orchSess3 := StubBroadcastSession("http://local.host/3")
30+
orchSess4 := StubBroadcastSession("http://local.host/4")
31+
orchSessions := []*BroadcastSession{orchSess1, orchSess2, orchSess3, orchSess4}
32+
33+
n.OrchestratorPool = createOrchestratorPool(orchSessions)
34+
35+
//warmCaps := newAICapabilities(cap, modelID, true, "")
36+
//coldCaps := newAICapabilities(cap, modelID, true, "")
37+
warmSelector := newStubMinLSSelector()
38+
coldSelector := newStubMinLSSelector()
39+
suspender := newSuspender()
40+
41+
warmPool := AISessionPool{
42+
selector: warmSelector,
43+
sessMap: make(map[string]*BroadcastSession),
44+
sessionsOnHold: make(map[string][]*BroadcastSession),
45+
penalty: penalty,
46+
suspender: suspender,
47+
}
48+
49+
coldPool := AISessionPool{
50+
selector: coldSelector,
51+
sessMap: make(map[string]*BroadcastSession),
52+
sessionsOnHold: make(map[string][]*BroadcastSession),
53+
penalty: penalty,
54+
suspender: suspender,
55+
}
56+
//add sessions to pool
57+
warmPool.Add(orchSessions)
58+
59+
selector := &AISessionSelector{
60+
cap: cap,
61+
modelID: modelID,
62+
node: n,
63+
warmPool: &warmPool,
64+
coldPool: &coldPool,
65+
ttl: 10 * time.Second,
66+
lastRefreshTime: time.Now(),
67+
initialPoolSize: 4,
68+
suspender: suspender,
69+
penalty: penalty,
70+
os: &stubOSSession{},
71+
}
72+
73+
s.AISessionManager = NewAISessionManager(n, 10*time.Second)
74+
selectorKey := strconv.Itoa(int(cap)) + "_" + modelID
75+
s.AISessionManager.selectors[selectorKey] = selector
76+
77+
reqID := string(core.RandomManifestID())
78+
drivers.NodeStorage = drivers.NewMemoryDriver(nil)
79+
params := aiRequestParams{
80+
node: n,
81+
os: drivers.NodeStorage.NewSession(reqID),
82+
sessManager: s.AISessionManager,
83+
requestID: reqID,
84+
}
85+
86+
testReq := worker.GenAudioToTextMultipartRequestBody{
87+
ModelId: &modelID,
88+
}
89+
90+
// Mock isRetryableError to return true
91+
originalIsRetryableError := isRetryableError
92+
defer func() { isRetryableError = originalIsRetryableError }()
93+
isRetryableError = mockRetryableErrorTrue
94+
95+
//select one session
96+
_, err := processAIRequest(context.TODO(), params, testReq)
97+
if err != nil {
98+
t.Logf("Unexpected error: %v", err)
99+
}
100+
assert := assert.New(t)
101+
assert.Less(len(warmSelector.unknownSessions), 4)
102+
assert.Equal(len(warmPool.sessionsOnHold[reqID]), 1)
103+
var wg sync.WaitGroup
104+
wg.Add(1)
105+
go func() {
106+
defer wg.Done()
107+
time.Sleep(2 * time.Second)
108+
}()
109+
wg.Wait()
110+
_, sessionStillOnHold := warmPool.sessionsOnHold[reqID]
111+
assert.False(sessionStillOnHold)
112+
113+
// now test another error to confirm session is suspended
114+
isRetryableError = mockRetryableErrorTrue
115+
_, err = processAIRequest(context.TODO(), params, testReq)
116+
if err != nil {
117+
t.Logf("non retryable error: %v", err)
118+
}
119+
assert.Greater(len(warmPool.suspender.list), 0)
120+
assert.Less(len(warmSelector.unknownSessions), 4)
121+
assert.Equal(0, warmSelector.knownSessions.Len())
122+
}
123+
124+
func TestPutSessionOnHold(t *testing.T) {
125+
n, _ := core.NewLivepeerNode(nil, "", nil)
126+
n.AIProcesssingRetryTimeout = 1 * time.Millisecond //allows processAIRequest to run one time
127+
s := LivepeerServer{LivepeerNode: n}
128+
penalty := 3
129+
cap := core.Capability_AudioToText
130+
modelID := "audio-model/1"
131+
132+
orchSess1 := StubBroadcastSession("http://local.host/1")
133+
orchSessions := []*BroadcastSession{orchSess1}
134+
135+
n.OrchestratorPool = createOrchestratorPool(orchSessions)
136+
137+
warmSelector := newStubMinLSSelector()
138+
coldSelector := newStubMinLSSelector()
139+
suspender := newSuspender()
140+
141+
warmPool := AISessionPool{
142+
selector: warmSelector,
143+
sessMap: make(map[string]*BroadcastSession),
144+
sessionsOnHold: make(map[string][]*BroadcastSession),
145+
penalty: penalty,
146+
suspender: suspender,
147+
}
148+
149+
coldPool := AISessionPool{
150+
selector: coldSelector,
151+
sessMap: make(map[string]*BroadcastSession),
152+
sessionsOnHold: make(map[string][]*BroadcastSession),
153+
penalty: penalty,
154+
suspender: suspender,
155+
}
156+
//add sessions to pool
157+
warmPool.Add(orchSessions)
158+
159+
selector := &AISessionSelector{
160+
cap: cap,
161+
modelID: modelID,
162+
node: n,
163+
warmPool: &warmPool,
164+
coldPool: &coldPool,
165+
ttl: 10 * time.Second,
166+
lastRefreshTime: time.Now(),
167+
initialPoolSize: 4,
168+
suspender: suspender,
169+
penalty: penalty,
170+
os: &stubOSSession{},
171+
}
172+
173+
s.AISessionManager = NewAISessionManager(n, 10*time.Second)
174+
selectorKey := strconv.Itoa(int(cap)) + "_" + modelID
175+
s.AISessionManager.selectors[selectorKey] = selector
176+
177+
reqID := string(core.RandomManifestID())
178+
//select one session and put it on hold
179+
sess := selector.Select(context.TODO())
180+
s.AISessionManager.PutSessionOnHold(context.TODO(), reqID, sess)
181+
182+
assert := assert.New(t)
183+
assert.Equal(1, len(warmPool.sessionsOnHold[reqID]))
184+
assert.Equal(0, warmPool.selector.Size())
185+
186+
//sessions should release in 1 millisecond
187+
var wg sync.WaitGroup
188+
wg.Add(1)
189+
go func() {
190+
defer wg.Done()
191+
time.Sleep(10 * time.Millisecond)
192+
}()
193+
wg.Wait()
194+
assert.Equal(0, len(warmPool.sessionsOnHold))
195+
assert.Equal(1, warmPool.selector.Size())
196+
assert.Equal(sess.Transcoder(), warmSelector.unknownSessions[0].Transcoder())
197+
198+
//select 2 sessions and put on hold for same reqID
199+
orchSess2 := StubBroadcastSession("http://local.host/2")
200+
orchSessions = []*BroadcastSession{orchSess2}
201+
warmPool.Add(orchSessions)
202+
selector.lastRefreshTime = time.Now()
203+
204+
sess1 := selector.Select(context.TODO())
205+
sess2 := selector.Select(context.TODO())
206+
s.AISessionManager.PutSessionOnHold(context.TODO(), reqID, sess1)
207+
s.AISessionManager.PutSessionOnHold(context.TODO(), reqID, sess2)
208+
assert.Equal(2, len(warmPool.sessionsOnHold[reqID]))
209+
assert.Equal(1, len(warmPool.sessionsOnHold))
210+
assert.Equal(0, warmPool.selector.Size())
211+
212+
// sessions should release after 1 millisecond
213+
wg.Add(1)
214+
go func() {
215+
defer wg.Done()
216+
time.Sleep(10 * time.Millisecond)
217+
}()
218+
wg.Wait()
219+
assert.Equal(0, len(warmPool.sessionsOnHold))
220+
assert.Equal(2, warmPool.selector.Size())
221+
}
222+
223+
func createOrchestratorPool(sessions []*BroadcastSession) *stubDiscovery {
224+
sd := &stubDiscovery{}
225+
226+
// populate stub discovery
227+
for idx, sess := range sessions {
228+
authToken := &net.AuthToken{Token: stubAuthToken.Token, SessionId: string(core.RandomManifestID()), Expiration: stubAuthToken.Expiration}
229+
sess.OrchestratorInfo = &net.OrchestratorInfo{
230+
PriceInfo: &net.PriceInfo{PricePerUnit: int64(idx), PixelsPerUnit: 1},
231+
TicketParams: &net.TicketParams{},
232+
AuthToken: authToken,
233+
Transcoder: sess.Transcoder(),
234+
}
235+
236+
sd.infos = append(sd.infos, sess.OrchestratorInfo)
237+
}
238+
239+
return sd
240+
}
241+
242+
// selector to test behavior of AI session selection that is re-used and moves
243+
// sessions from unknownSessions to knownSessions at completion
244+
type stubMinLSSelector struct {
245+
knownSessions *sessHeap
246+
unknownSessions []*BroadcastSession
247+
}
248+
249+
func newStubMinLSSelector() *stubMinLSSelector {
250+
knownSessions := &sessHeap{}
251+
heap.Init(knownSessions)
252+
253+
return &stubMinLSSelector{
254+
knownSessions: knownSessions,
255+
}
256+
}
257+
258+
var tries int
259+
260+
func mockRetryableErrorTrue(err error) bool {
261+
if tries < 1 {
262+
tries++
263+
return true
264+
} else {
265+
return false
266+
}
267+
}
268+
269+
func mockRetryableErrorFalse(err error) bool {
270+
return false
271+
}
272+
273+
func (s *stubMinLSSelector) Add(sessions []*BroadcastSession) {
274+
s.unknownSessions = append(s.unknownSessions, sessions...)
275+
}
276+
277+
func (s *stubMinLSSelector) Complete(sess *BroadcastSession) {
278+
heap.Push(s.knownSessions, sess)
279+
}
280+
281+
func (s *stubMinLSSelector) Select(ctx context.Context) *BroadcastSession {
282+
sess := s.knownSessions.Peek()
283+
if sess == nil {
284+
randSelected := rand.Intn(len(s.unknownSessions))
285+
sess := s.unknownSessions[randSelected]
286+
s.removeUnknownSession(randSelected)
287+
return sess
288+
}
289+
290+
//return the known session selected
291+
return heap.Pop(s.knownSessions).(*BroadcastSession)
292+
}
293+
294+
// Size returns the number of sessions stored by the selector
295+
func (s *stubMinLSSelector) Size() int {
296+
return len(s.unknownSessions) + s.knownSessions.Len()
297+
}
298+
299+
// Clear resets the selector's state
300+
func (s *stubMinLSSelector) Clear() {
301+
s.unknownSessions = nil
302+
s.knownSessions = &sessHeap{}
303+
//s.stakeRdr = nil //not used in this test
304+
}
305+
306+
func (s *stubMinLSSelector) removeUnknownSession(i int) {
307+
n := len(s.unknownSessions)
308+
s.unknownSessions[n-1], s.unknownSessions[i] = s.unknownSessions[i], s.unknownSessions[n-1]
309+
s.unknownSessions = s.unknownSessions[:n-1]
310+
}

0 commit comments

Comments
 (0)