1
1
package platform
2
2
3
3
import (
4
+ "context"
4
5
"errors"
5
6
"os/exec"
6
7
"testing"
@@ -9,6 +10,7 @@ import (
9
10
"github.com/golang/mock/gomock"
10
11
"github.com/stretchr/testify/assert"
11
12
"github.com/stretchr/testify/require"
13
+ "golang.org/x/sys/windows/svc"
12
14
)
13
15
14
16
var errTestFailure = errors .New ("test failure" )
@@ -98,3 +100,205 @@ func TestExecuteCommandError(t *testing.T) {
98
100
assert .ErrorAs (t , err , & xErr )
99
101
assert .Equal (t , 1 , xErr .ExitCode ())
100
102
}
103
+
104
+ type mockManagedService struct {
105
+ queryFuncs []func () (svc.Status , error )
106
+ controlFunc func (svc.Cmd ) (svc.Status , error )
107
+ startFunc func (args ... string ) error
108
+ }
109
+
110
+ func (m * mockManagedService ) Query () (svc.Status , error ) {
111
+ queryFunc := m .queryFuncs [0 ]
112
+ m .queryFuncs = m .queryFuncs [1 :]
113
+ return queryFunc ()
114
+ }
115
+
116
+ func (m * mockManagedService ) Control (cmd svc.Cmd ) (svc.Status , error ) {
117
+ return m .controlFunc (cmd )
118
+ }
119
+
120
+ func (m * mockManagedService ) Start (args ... string ) error {
121
+ return m .startFunc (args ... )
122
+ }
123
+
124
+ func TestTryStopServiceFn (t * testing.T ) {
125
+ tests := []struct {
126
+ name string
127
+ queryFuncs []func () (svc.Status , error )
128
+ controlFunc func (svc.Cmd ) (svc.Status , error )
129
+ expectError bool
130
+ }{
131
+ {
132
+ name : "Service already stopped" ,
133
+ queryFuncs : []func () (svc.Status , error ){
134
+ func () (svc.Status , error ) {
135
+ return svc.Status {State : svc .Stopped }, nil
136
+ },
137
+ func () (svc.Status , error ) {
138
+ return svc.Status {State : svc .Stopped }, nil
139
+ },
140
+ },
141
+ controlFunc : nil ,
142
+ expectError : false ,
143
+ },
144
+ {
145
+ name : "Service running and stops successfully" ,
146
+ queryFuncs : []func () (svc.Status , error ){
147
+ func () (svc.Status , error ) {
148
+ return svc.Status {State : svc .Running }, nil
149
+ },
150
+ func () (svc.Status , error ) {
151
+ return svc.Status {State : svc .Stopped }, nil
152
+ },
153
+ },
154
+ controlFunc : func (svc.Cmd ) (svc.Status , error ) {
155
+ return svc.Status {State : svc .Stopped }, nil
156
+ },
157
+ expectError : false ,
158
+ },
159
+ {
160
+ name : "Service running and stops after multiple attempts" ,
161
+ queryFuncs : []func () (svc.Status , error ){
162
+ func () (svc.Status , error ) {
163
+ return svc.Status {State : svc .Running }, nil
164
+ },
165
+ func () (svc.Status , error ) {
166
+ return svc.Status {State : svc .Running }, nil
167
+ },
168
+ func () (svc.Status , error ) {
169
+ return svc.Status {State : svc .Running }, nil
170
+ },
171
+ func () (svc.Status , error ) {
172
+ return svc.Status {State : svc .Stopped }, nil
173
+ },
174
+ },
175
+ controlFunc : func (svc.Cmd ) (svc.Status , error ) {
176
+ return svc.Status {State : svc .Stopped }, nil
177
+ },
178
+ expectError : false ,
179
+ },
180
+ {
181
+ name : "Service running and fails to stop" ,
182
+ queryFuncs : []func () (svc.Status , error ){
183
+ func () (svc.Status , error ) {
184
+ return svc.Status {State : svc .Running }, nil
185
+ },
186
+ },
187
+ controlFunc : func (svc.Cmd ) (svc.Status , error ) {
188
+ return svc.Status {State : svc .Running }, errors .New ("failed to stop service" ) //nolint:err113 // test error
189
+ },
190
+ expectError : true ,
191
+ },
192
+ {
193
+ name : "Service query fails" ,
194
+ queryFuncs : []func () (svc.Status , error ){
195
+ func () (svc.Status , error ) {
196
+ return svc.Status {}, errors .New ("failed to query service status" ) //nolint:err113 // test error
197
+ },
198
+ },
199
+ controlFunc : nil ,
200
+ expectError : true ,
201
+ },
202
+ }
203
+ for _ , tt := range tests {
204
+ t .Run (tt .name , func (t * testing.T ) {
205
+ service := & mockManagedService {
206
+ queryFuncs : tt .queryFuncs ,
207
+ controlFunc : tt .controlFunc ,
208
+ }
209
+ err := tryStopServiceFn (context .Background (), service )()
210
+ if tt .expectError {
211
+ assert .Error (t , err )
212
+ return
213
+ }
214
+ assert .NoError (t , err )
215
+ })
216
+ }
217
+ }
218
+
219
+ func TestTryStartServiceFn (t * testing.T ) {
220
+ tests := []struct {
221
+ name string
222
+ queryFuncs []func () (svc.Status , error )
223
+ startFunc func (... string ) error
224
+ expectError bool
225
+ }{
226
+ {
227
+ name : "Service already running" ,
228
+ queryFuncs : []func () (svc.Status , error ){
229
+ func () (svc.Status , error ) {
230
+ return svc.Status {State : svc .Running }, nil
231
+ },
232
+ func () (svc.Status , error ) {
233
+ return svc.Status {State : svc .Running }, nil
234
+ },
235
+ },
236
+ startFunc : nil ,
237
+ expectError : false ,
238
+ },
239
+ {
240
+ name : "Service already starting" ,
241
+ queryFuncs : []func () (svc.Status , error ){
242
+ func () (svc.Status , error ) {
243
+ return svc.Status {State : svc .StartPending }, nil
244
+ },
245
+ func () (svc.Status , error ) {
246
+ return svc.Status {State : svc .Running }, nil
247
+ },
248
+ },
249
+ startFunc : nil ,
250
+ expectError : false ,
251
+ },
252
+ {
253
+ name : "Service starts successfully" ,
254
+ queryFuncs : []func () (svc.Status , error ){
255
+ func () (svc.Status , error ) {
256
+ return svc.Status {State : svc .Stopped }, nil
257
+ },
258
+ func () (svc.Status , error ) {
259
+ return svc.Status {State : svc .Running }, nil
260
+ },
261
+ },
262
+ startFunc : func (... string ) error {
263
+ return nil
264
+ },
265
+ expectError : false ,
266
+ },
267
+ {
268
+ name : "Service fails to start" ,
269
+ queryFuncs : []func () (svc.Status , error ){
270
+ func () (svc.Status , error ) {
271
+ return svc.Status {State : svc .Stopped }, nil
272
+ },
273
+ },
274
+ startFunc : func (... string ) error {
275
+ return errors .New ("failed to start service" ) //nolint:err113 // test error
276
+ },
277
+ expectError : true ,
278
+ },
279
+ {
280
+ name : "Service query fails" ,
281
+ queryFuncs : []func () (svc.Status , error ){
282
+ func () (svc.Status , error ) {
283
+ return svc.Status {}, errors .New ("failed to query service status" ) //nolint:err113 // test error
284
+ },
285
+ },
286
+ startFunc : nil ,
287
+ expectError : true ,
288
+ },
289
+ }
290
+ for _ , tt := range tests {
291
+ t .Run (tt .name , func (t * testing.T ) {
292
+ service := & mockManagedService {
293
+ queryFuncs : tt .queryFuncs ,
294
+ startFunc : tt .startFunc ,
295
+ }
296
+ err := tryStartServiceFn (context .Background (), service )()
297
+ if tt .expectError {
298
+ assert .Error (t , err )
299
+ return
300
+ }
301
+ assert .NoError (t , err )
302
+ })
303
+ }
304
+ }
0 commit comments