Skip to content

Commit c236eaf

Browse files
authored
backport: retry restarting HNS (#3529, #3540) (#3563)
Signed-off-by: Evan Baker <[email protected]>
1 parent 94c524d commit c236eaf

File tree

2 files changed

+297
-17
lines changed

2 files changed

+297
-17
lines changed

platform/os_windows.go

Lines changed: 93 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"github.com/Azure/azure-container-networking/log"
1717
"github.com/Azure/azure-container-networking/platform/windows/adapter"
1818
"github.com/Azure/azure-container-networking/platform/windows/adapter/mellanox"
19+
"github.com/avast/retry-go/v4"
1920
"github.com/pkg/errors"
2021
"go.uber.org/zap"
2122
"golang.org/x/sys/windows"
@@ -232,32 +233,107 @@ func restartHNS(ctx context.Context) error {
232233
}
233234
defer service.Close()
234235
// Stop the service
235-
_, err = service.Control(svc.Stop)
236-
if err != nil {
237-
return errors.Wrap(err, "could not stop service")
236+
log.Printf("Stopping HNS service")
237+
_ = retry.Do(
238+
tryStopServiceFn(ctx, service),
239+
retry.UntilSucceeded(),
240+
retry.Context(ctx),
241+
retry.DelayType(retry.BackOffDelay),
242+
)
243+
// Start the service again
244+
log.Printf("Starting HNS service")
245+
_ = retry.Do(
246+
tryStartServiceFn(ctx, service),
247+
retry.UntilSucceeded(),
248+
retry.Context(ctx),
249+
retry.DelayType(retry.BackOffDelay),
250+
)
251+
log.Printf("HNS service started")
252+
return nil
253+
}
254+
255+
type managedService interface {
256+
Control(control svc.Cmd) (svc.Status, error)
257+
Query() (svc.Status, error)
258+
Start(args ...string) error
259+
}
260+
261+
func tryStartServiceFn(ctx context.Context, service managedService) func() error {
262+
shouldStart := func(state svc.State) bool {
263+
return !(state == svc.Running || state == svc.StartPending)
238264
}
239-
// Wait for the service to stop
240-
ticker := time.NewTicker(500 * time.Millisecond) //nolint:gomnd // 500ms
241-
defer ticker.Stop()
242-
for { // hacky cancellable do-while
265+
return func() error {
243266
status, err := service.Query()
244267
if err != nil {
245268
return errors.Wrap(err, "could not query service status")
246269
}
247-
if status.State == svc.Stopped {
248-
break
270+
if shouldStart(status.State) {
271+
err = service.Start()
272+
if err != nil {
273+
return errors.Wrap(err, "could not start service")
274+
}
249275
}
250-
select {
251-
case <-ctx.Done():
252-
return errors.New("context cancelled")
253-
case <-ticker.C:
276+
// Wait for the service to start
277+
deadline, cancel := context.WithTimeout(ctx, 90*time.Second)
278+
defer cancel()
279+
ticker := time.NewTicker(500 * time.Millisecond) //nolint:gomnd // 500ms
280+
defer ticker.Stop()
281+
for {
282+
status, err := service.Query()
283+
if err != nil {
284+
return errors.Wrap(err, "could not query service status")
285+
}
286+
if status.State == svc.Running {
287+
log.Printf("service started")
288+
break
289+
}
290+
select {
291+
case <-deadline.Done():
292+
return deadline.Err() //nolint:wrapcheck // error has sufficient context
293+
case <-ticker.C:
294+
}
254295
}
296+
return nil
255297
}
256-
// Start the service again
257-
if err := service.Start(); err != nil {
258-
return errors.Wrap(err, "could not start service")
298+
}
299+
300+
func tryStopServiceFn(ctx context.Context, service managedService) func() error {
301+
shouldStop := func(state svc.State) bool {
302+
return !(state == svc.Stopped || state == svc.StopPending)
303+
}
304+
return func() error {
305+
status, err := service.Query()
306+
if err != nil {
307+
return errors.Wrap(err, "could not query service status")
308+
}
309+
if shouldStop(status.State) {
310+
_, err = service.Control(svc.Stop)
311+
if err != nil {
312+
return errors.Wrap(err, "could not stop service")
313+
}
314+
}
315+
// Wait for the service to stop
316+
deadline, cancel := context.WithTimeout(ctx, 90*time.Second)
317+
defer cancel()
318+
ticker := time.NewTicker(500 * time.Millisecond) //nolint:gomnd // 500ms
319+
defer ticker.Stop()
320+
for {
321+
status, err := service.Query()
322+
if err != nil {
323+
return errors.Wrap(err, "could not query service status")
324+
}
325+
if status.State == svc.Stopped {
326+
log.Printf("service stopped")
327+
break
328+
}
329+
select {
330+
case <-deadline.Done():
331+
return deadline.Err() //nolint:wrapcheck // error has sufficient context
332+
case <-ticker.C:
333+
}
334+
}
335+
return nil
259336
}
260-
return nil
261337
}
262338

263339
func HasMellanoxAdapter() bool {

platform/os_windows_test.go

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package platform
22

33
import (
4+
"context"
45
"errors"
56
"os/exec"
67
"testing"
@@ -9,6 +10,7 @@ import (
910
"github.com/golang/mock/gomock"
1011
"github.com/stretchr/testify/assert"
1112
"github.com/stretchr/testify/require"
13+
"golang.org/x/sys/windows/svc"
1214
)
1315

1416
var errTestFailure = errors.New("test failure")
@@ -98,3 +100,205 @@ func TestExecuteCommandError(t *testing.T) {
98100
assert.ErrorAs(t, err, &xErr)
99101
assert.Equal(t, 1, xErr.ExitCode())
100102
}
103+
104+
type mockManagedService struct {
105+
queryFuncs []func() (svc.Status, error)
106+
controlFunc func(svc.Cmd) (svc.Status, error)
107+
startFunc func(args ...string) error
108+
}
109+
110+
func (m *mockManagedService) Query() (svc.Status, error) {
111+
queryFunc := m.queryFuncs[0]
112+
m.queryFuncs = m.queryFuncs[1:]
113+
return queryFunc()
114+
}
115+
116+
func (m *mockManagedService) Control(cmd svc.Cmd) (svc.Status, error) {
117+
return m.controlFunc(cmd)
118+
}
119+
120+
func (m *mockManagedService) Start(args ...string) error {
121+
return m.startFunc(args...)
122+
}
123+
124+
func TestTryStopServiceFn(t *testing.T) {
125+
tests := []struct {
126+
name string
127+
queryFuncs []func() (svc.Status, error)
128+
controlFunc func(svc.Cmd) (svc.Status, error)
129+
expectError bool
130+
}{
131+
{
132+
name: "Service already stopped",
133+
queryFuncs: []func() (svc.Status, error){
134+
func() (svc.Status, error) {
135+
return svc.Status{State: svc.Stopped}, nil
136+
},
137+
func() (svc.Status, error) {
138+
return svc.Status{State: svc.Stopped}, nil
139+
},
140+
},
141+
controlFunc: nil,
142+
expectError: false,
143+
},
144+
{
145+
name: "Service running and stops successfully",
146+
queryFuncs: []func() (svc.Status, error){
147+
func() (svc.Status, error) {
148+
return svc.Status{State: svc.Running}, nil
149+
},
150+
func() (svc.Status, error) {
151+
return svc.Status{State: svc.Stopped}, nil
152+
},
153+
},
154+
controlFunc: func(svc.Cmd) (svc.Status, error) {
155+
return svc.Status{State: svc.Stopped}, nil
156+
},
157+
expectError: false,
158+
},
159+
{
160+
name: "Service running and stops after multiple attempts",
161+
queryFuncs: []func() (svc.Status, error){
162+
func() (svc.Status, error) {
163+
return svc.Status{State: svc.Running}, nil
164+
},
165+
func() (svc.Status, error) {
166+
return svc.Status{State: svc.Running}, nil
167+
},
168+
func() (svc.Status, error) {
169+
return svc.Status{State: svc.Running}, nil
170+
},
171+
func() (svc.Status, error) {
172+
return svc.Status{State: svc.Stopped}, nil
173+
},
174+
},
175+
controlFunc: func(svc.Cmd) (svc.Status, error) {
176+
return svc.Status{State: svc.Stopped}, nil
177+
},
178+
expectError: false,
179+
},
180+
{
181+
name: "Service running and fails to stop",
182+
queryFuncs: []func() (svc.Status, error){
183+
func() (svc.Status, error) {
184+
return svc.Status{State: svc.Running}, nil
185+
},
186+
},
187+
controlFunc: func(svc.Cmd) (svc.Status, error) {
188+
return svc.Status{State: svc.Running}, errors.New("failed to stop service") //nolint:err113 // test error
189+
},
190+
expectError: true,
191+
},
192+
{
193+
name: "Service query fails",
194+
queryFuncs: []func() (svc.Status, error){
195+
func() (svc.Status, error) {
196+
return svc.Status{}, errors.New("failed to query service status") //nolint:err113 // test error
197+
},
198+
},
199+
controlFunc: nil,
200+
expectError: true,
201+
},
202+
}
203+
for _, tt := range tests {
204+
t.Run(tt.name, func(t *testing.T) {
205+
service := &mockManagedService{
206+
queryFuncs: tt.queryFuncs,
207+
controlFunc: tt.controlFunc,
208+
}
209+
err := tryStopServiceFn(context.Background(), service)()
210+
if tt.expectError {
211+
assert.Error(t, err)
212+
return
213+
}
214+
assert.NoError(t, err)
215+
})
216+
}
217+
}
218+
219+
func TestTryStartServiceFn(t *testing.T) {
220+
tests := []struct {
221+
name string
222+
queryFuncs []func() (svc.Status, error)
223+
startFunc func(...string) error
224+
expectError bool
225+
}{
226+
{
227+
name: "Service already running",
228+
queryFuncs: []func() (svc.Status, error){
229+
func() (svc.Status, error) {
230+
return svc.Status{State: svc.Running}, nil
231+
},
232+
func() (svc.Status, error) {
233+
return svc.Status{State: svc.Running}, nil
234+
},
235+
},
236+
startFunc: nil,
237+
expectError: false,
238+
},
239+
{
240+
name: "Service already starting",
241+
queryFuncs: []func() (svc.Status, error){
242+
func() (svc.Status, error) {
243+
return svc.Status{State: svc.StartPending}, nil
244+
},
245+
func() (svc.Status, error) {
246+
return svc.Status{State: svc.Running}, nil
247+
},
248+
},
249+
startFunc: nil,
250+
expectError: false,
251+
},
252+
{
253+
name: "Service starts successfully",
254+
queryFuncs: []func() (svc.Status, error){
255+
func() (svc.Status, error) {
256+
return svc.Status{State: svc.Stopped}, nil
257+
},
258+
func() (svc.Status, error) {
259+
return svc.Status{State: svc.Running}, nil
260+
},
261+
},
262+
startFunc: func(...string) error {
263+
return nil
264+
},
265+
expectError: false,
266+
},
267+
{
268+
name: "Service fails to start",
269+
queryFuncs: []func() (svc.Status, error){
270+
func() (svc.Status, error) {
271+
return svc.Status{State: svc.Stopped}, nil
272+
},
273+
},
274+
startFunc: func(...string) error {
275+
return errors.New("failed to start service") //nolint:err113 // test error
276+
},
277+
expectError: true,
278+
},
279+
{
280+
name: "Service query fails",
281+
queryFuncs: []func() (svc.Status, error){
282+
func() (svc.Status, error) {
283+
return svc.Status{}, errors.New("failed to query service status") //nolint:err113 // test error
284+
},
285+
},
286+
startFunc: nil,
287+
expectError: true,
288+
},
289+
}
290+
for _, tt := range tests {
291+
t.Run(tt.name, func(t *testing.T) {
292+
service := &mockManagedService{
293+
queryFuncs: tt.queryFuncs,
294+
startFunc: tt.startFunc,
295+
}
296+
err := tryStartServiceFn(context.Background(), service)()
297+
if tt.expectError {
298+
assert.Error(t, err)
299+
return
300+
}
301+
assert.NoError(t, err)
302+
})
303+
}
304+
}

0 commit comments

Comments
 (0)