forked from tarantool/luatest
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhooks.lua
340 lines (294 loc) · 11.1 KB
/
hooks.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
--- Provide extra methods for hooks.
--
-- Preloaded hooks extend base hooks.
-- They behave like the pytest fixture with the `autouse` parameter.
--
-- @usage
--
-- local hooks = require('luatest.hooks')
--
-- hooks.before_suite_preloaded(...)
-- hooks.after_suite_preloaded(...)
--
-- hooks.before_all_preloaded(...)
-- hooks.after_all_preloaded(...)
--
-- hooks.before_each_preloaded(...)
-- hooks.after_each_preloaded(...)
--
-- @module luatest.hooks
local log = require('luatest.log')
local utils = require('luatest.utils')
local comparator = require('luatest.comparator')
local export = {}
local preloaded_hooks = {
before_suite = {},
after_suite = {},
before_all = {},
after_all = {},
before_each = {},
after_each = {}
}
--- Register preloaded before hook in the `suite` scope.
-- It will be done before the classic before_suite() hook in the tests.
--
-- @func fn The function where you will be preparing for the test.
function export.before_suite_preloaded(fn)
table.insert(preloaded_hooks.before_suite, {fn, {}})
end
--- Register preloaded after hook in the `suite` scope.
-- It will be done after the classic after_suite() hook in the tests.
--
-- @func fn The function where you will be cleaning up for the test.
function export.after_suite_preloaded(fn)
table.insert(preloaded_hooks.after_suite, {fn, {}})
end
--- Register preloaded before hook in the `all` scope.
-- It will be done before the classic before_all() hook in the tests.
--
-- @func fn The function where you will be preparing for the test.
function export.before_all_preloaded(fn)
table.insert(preloaded_hooks.before_all, {fn, {}})
end
--- Register preloaded after hook in the `all` scope.
-- It will be done after the classic after_all() hook in the tests.
--
-- @func fn The function where you will be cleaning up for the test.
function export.after_all_preloaded(fn)
table.insert(preloaded_hooks.after_all, {fn, {}})
end
--- Register preloaded before hook in the `each` scope.
-- It will be done before the classic before_each() hook in the tests.
--
-- @func fn The function where you will be preparing for the test.
function export.before_each_preloaded(fn)
table.insert(preloaded_hooks.before_each, {fn, {}})
end
--- Register preloaded after hook in the `each` scope.
-- It will be done after the classic after_each() hook in the tests.
--
-- @func fn The function where you will be cleaning up for the test.
function export.after_each_preloaded(fn)
table.insert(preloaded_hooks.after_each, {fn, {}})
end
local function check_params(required, actual)
for param_name, param_val in pairs(required) do
if not comparator.equals(param_val, actual[param_name]) then
return false
end
end
return true
end
local function define_hooks(object, hooks_type, preloaded_hook)
local hooks = {}
object[hooks_type .. '_hooks'] = hooks
object[hooks_type] = function(...)
local params, fn = ...
if fn == nil then
fn = params
params = {}
end
assert(type(params) == 'table',
string.format('params should be table, got %s', type(params)))
assert(type(fn) == 'function',
string.format('hook should be function, got %s', type(fn)))
params = params or {}
table.insert(hooks, {fn, params})
end
object['_original_' .. hooks_type] = object[hooks_type] -- for leagacy hooks support
local function run_preloaded_hooks()
if preloaded_hook == nil then
return
end
-- before_* -- direct order
-- after_* -- reverse order
local from = 1
local to = #preloaded_hook
local step = 1
if hooks_type:startswith('after_') then
from, to = to, from
step = -step
end
for i = from, to, step do
local hook = preloaded_hook[i]
if check_params(hook[2], object.params) then
hook[1](object)
end
end
end
object['run_' .. hooks_type] = function()
-- before_* -- run before test hooks
if hooks_type:startswith('before_') then
run_preloaded_hooks()
end
local active_hooks = object[hooks_type .. '_hooks']
for _, hook in ipairs(active_hooks) do
if check_params(hook[2], object.params) then
hook[1](object)
end
end
-- after_* -- run after test hooks
if hooks_type:startswith('after_') then
run_preloaded_hooks()
end
end
end
local function define_named_hooks(object, hooks_type)
local hooks = {}
object[hooks_type .. '_hooks'] = hooks
object[hooks_type] = function(...)
local test_name, params, fn = ...
if fn == nil then
fn = params
params = {}
end
assert(type(test_name) == 'string',
string.format('test name should be string, got %s', type(test_name)))
assert(type(params) == 'table',
string.format('params should be table, got %s', type(params)))
assert(type(fn) == 'function',
string.format('hook should be function, got %s', type(fn)))
test_name = object.name .. '.' .. test_name
params = params or {}
if not hooks[test_name] then
hooks[test_name] = {}
end
table.insert(hooks[test_name], {fn, params})
end
object['run_' .. hooks_type] = function(test)
local active_hooks = object[hooks_type .. '_hooks']
local test_name = test.name
-- When parametrized groups are defined named hooks saved by
-- super group test name. When they are called test name is
-- specific to the parametrized group. So, it should be
-- converted back to the super one.
if object.super_group then
local test_name_parts, parts_amount = utils.split_test_name(test_name)
test_name = object.super_group.name .. '.' .. test_name_parts[parts_amount]
end
if not active_hooks[test_name] then
return
end
for _, hook in ipairs(active_hooks[test_name]) do
if check_params(hook[2], object.params) then
hook[1](object)
end
end
end
end
-- Define hooks on group.
function export._define_group_hooks(group)
define_hooks(group, 'before_each', preloaded_hooks.before_each)
define_hooks(group, 'after_each', preloaded_hooks.after_each)
define_hooks(group, 'before_all', preloaded_hooks.before_all)
define_hooks(group, 'after_all', preloaded_hooks.after_all)
define_named_hooks(group, 'before_test')
define_named_hooks(group, 'after_test')
return group
end
-- Define suite hooks on luatest.
function export._define_suite_hooks(luatest)
define_hooks(luatest, 'before_suite', preloaded_hooks.before_suite)
define_hooks(luatest, 'after_suite', preloaded_hooks.after_suite)
end
local function run_group_hooks(runner, group, hooks_type)
local result
local hook = group and group['run_' .. hooks_type]
-- If _original_%hook_name% is not equal to %hook_name%, it means
-- that this method was assigned by user (legacy API).
if hook and group[hooks_type] == group['_original_' .. hooks_type] then
result = runner:protected_call(group, hook, group.name .. '.run_before_all_hooks')
elseif group and group[hooks_type] then
result = runner:protected_call(group, group[hooks_type], group.name .. '.before_all')
end
if result and result.status ~= 'success' then
return result
end
end
local function run_test_hooks(self, test, hooks_type, legacy_name)
log.info('Run hook %s', hooks_type)
local group = test.group
local hook
-- Support for group.setup/teardown methods (legacy API)
hook = group[legacy_name]
if hook and type(hook) == 'function' then
self:update_status(test, self:protected_call(group, hook, group.name .. '.' .. legacy_name))
end
hook = group['run_' .. hooks_type]
if hook then
self:update_status(test, self:protected_call(group, hook))
end
end
local function run_named_test_hooks(self, test, hooks_type)
log.info('Run hook %s', hooks_type)
local group = test.group
local hook = group['run_' .. hooks_type]
if hook then
self:update_status(test, self:protected_call(test, hook))
end
end
function export._patch_runner(Runner)
-- Last run test to set error for when group.after_all hook fails.
local last_test = nil
-- Run test hooks.
-- If test's group hook failed with error, then test does not run and
-- hook's error is copied for the test.
utils.patch(Runner.mt, 'invoke_test_function', function(super) return function(self, test, ...)
last_test = test
if test.group._before_all_hook_error then
return self:update_status(test, test.group._before_all_hook_error)
end
for _ = 1, self.exe_repeat or 1 do
if not test:is('success') then
break
end
run_test_hooks(self, test, 'before_each', 'setup')
run_named_test_hooks(self, test, 'before_test')
if test:is('success') then
log.info('Start test %s', test.name)
super(self, test, ...)
log.info('End test %s', test.name)
end
run_named_test_hooks(self, test, 'after_test')
run_test_hooks(self, test, 'after_each', 'teardown')
end
end end)
-- Run group hook and save possible error to the group object.
utils.patch(Runner.mt, 'start_group', function(super) return function(self, group)
super(self, group)
-- Check while starting group that 'before_test' and 'after_test' hooks are defined only for existing tests.
for _, hooks_type in ipairs({'before_test', 'after_test'}) do
for full_test_name in pairs(group[hooks_type .. '_hooks']) do
local test_name_parts, parts_count = utils.split_test_name(full_test_name)
local test_name = test_name_parts[parts_count]
if not group[test_name] then
error(string.format("There is no test with name '%s' but hook '%s' is defined for it",
test_name, hooks_type))
end
end
end
group._before_all_hook_error = run_group_hooks(self, group, 'before_all')
end end)
-- Run group hook and save possible error to the `last_test`.
utils.patch(Runner.mt, 'end_group', function(super) return function(self, group)
local err = run_group_hooks(self, group, 'after_all')
if err then
err.message = 'Failure in after_all hook: ' .. tostring(err.message)
self:update_status(last_test, err)
end
super(self, group)
end end)
-- Run suite hooks
utils.patch(Runner.mt, 'run_tests', function(super) return function(self, tests)
if #tests == 0 then
return
end
return utils.reraise_and_ensure(function()
self.luatest.run_before_suite()
super(self, tests)
end, nil, function()
self.luatest.run_after_suite()
end)
end end)
end
return export