1
1
"""Warning checker."""
2
- from typing import Dict , List , Sequence
2
+ from collections import defaultdict
3
+ from itertools import groupby
4
+ from typing import Dict , List , NamedTuple , Sequence
3
5
from warnings import warn
4
6
5
7
from .spy_events import (
8
10
SpyEvent ,
9
11
VerifyRehearsal ,
10
12
WhenRehearsal ,
13
+ SpyRehearsal ,
11
14
match_event ,
12
15
)
13
16
from .warnings import DecoyWarning , MiscalledStubWarning , RedundantVerifyWarning
@@ -23,44 +26,62 @@ def check(all_calls: Sequence[AnySpyEvent]) -> None:
23
26
_check_no_redundant_verify (all_calls )
24
27
25
28
29
+ class _Call (NamedTuple ):
30
+ event : SpyEvent
31
+ all_rehearsals : List [SpyRehearsal ]
32
+ matching_rehearsals : List [SpyRehearsal ]
33
+
34
+
26
35
def _check_no_miscalled_stubs (all_events : Sequence [AnySpyEvent ]) -> None :
27
36
"""Ensure every call matches a rehearsal, if the spy has rehearsals."""
28
- all_calls_by_id : Dict [int , List [AnySpyEvent ]] = {}
37
+ all_events_by_id : Dict [int , List [AnySpyEvent ]] = defaultdict (list )
38
+ all_calls_by_id : Dict [int , List [_Call ]] = defaultdict (list )
29
39
30
40
for event in all_events :
31
- if isinstance (event .payload , SpyCall ):
32
- spy_id = event .spy .id
33
- spy_calls = all_calls_by_id .get (spy_id , [])
34
- all_calls_by_id [spy_id ] = [* spy_calls , event ]
41
+ all_events_by_id [event .spy .id ].append (event )
42
+
43
+ for events in all_events_by_id .values ():
44
+ for index , event in enumerate (events ):
45
+ if isinstance (event , SpyEvent ) and isinstance (event .payload , SpyCall ):
46
+ when_rehearsals = [
47
+ rehearsal
48
+ for rehearsal in events [0 :index ]
49
+ if isinstance (rehearsal , WhenRehearsal )
50
+ and isinstance (rehearsal .payload , SpyCall )
51
+ ]
52
+ verify_rehearsals = [
53
+ rehearsal
54
+ for rehearsal in events [index + 1 :]
55
+ if isinstance (rehearsal , VerifyRehearsal )
56
+ and isinstance (rehearsal .payload , SpyCall )
57
+ ]
58
+
59
+ all_rehearsals : List [SpyRehearsal ] = [
60
+ * when_rehearsals ,
61
+ * verify_rehearsals ,
62
+ ]
63
+ matching_rehearsals = [
64
+ rehearsal
65
+ for rehearsal in all_rehearsals
66
+ if match_event (event , rehearsal )
67
+ ]
68
+
69
+ all_calls_by_id [event .spy .id ].append (
70
+ _Call (event , all_rehearsals , matching_rehearsals )
71
+ )
35
72
36
73
for spy_calls in all_calls_by_id .values ():
37
- unmatched : List [SpyEvent ] = []
38
-
39
- for index , call in enumerate (spy_calls ):
40
- past_stubs = [
41
- wr for wr in spy_calls [0 :index ] if isinstance (wr , WhenRehearsal )
42
- ]
43
-
44
- matched_past_stubs = [wr for wr in past_stubs if match_event (call , wr )]
45
-
46
- matched_future_verifies = [
47
- vr
48
- for vr in spy_calls [index + 1 :]
49
- if isinstance (vr , VerifyRehearsal ) and match_event (call , vr )
50
- ]
51
-
52
- if (
53
- isinstance (call , SpyEvent )
54
- and len (past_stubs ) > 0
55
- and len (matched_past_stubs ) == 0
56
- and len (matched_future_verifies ) == 0
57
- ):
58
- unmatched = [* unmatched , call ]
59
- if index == len (spy_calls ) - 1 :
60
- _warn (MiscalledStubWarning (calls = unmatched , rehearsals = past_stubs ))
61
- elif isinstance (call , WhenRehearsal ) and len (unmatched ) > 0 :
62
- _warn (MiscalledStubWarning (calls = unmatched , rehearsals = past_stubs ))
63
- unmatched = []
74
+ for rehearsals , grouped_calls in groupby (spy_calls , lambda c : c .all_rehearsals ):
75
+ calls = list (grouped_calls )
76
+ is_stubbed = any (isinstance (r , WhenRehearsal ) for r in rehearsals )
77
+
78
+ if is_stubbed and all (len (c .matching_rehearsals ) == 0 for c in calls ):
79
+ _warn (
80
+ MiscalledStubWarning (
81
+ calls = [c .event for c in calls ],
82
+ rehearsals = rehearsals ,
83
+ )
84
+ )
64
85
65
86
66
87
def _check_no_redundant_verify (all_calls : Sequence [AnySpyEvent ]) -> None :
0 commit comments