@@ -8,8 +8,7 @@ namespace Aikido.Zen.Core.Vulnerabilities
88{
99 public class AttackWaveDetector
1010 {
11- private readonly LRUCache < string , int > _suspiciousRequestsCounts ;
12- private readonly LRUCache < string , List < SuspiciousRequest > > _suspiciousRequestsSamples ;
11+ private readonly LRUCache < string , SuspiciousState > _suspiciousRequests ;
1312 private readonly LRUCache < string , long > _sentEventsMap ;
1413 private readonly int _attackWaveThreshold ;
1514 private readonly int _maxSamplesPerIp ;
@@ -29,8 +28,7 @@ public AttackWaveDetector(AttackWaveDetectorOptions options = null)
2928 _attackWaveThreshold = options . AttackWaveThreshold ?? 15 ;
3029 _maxSamplesPerIp = Math . Min ( options . MaxSamplesPerIP ?? 15 , _attackWaveThreshold ) ;
3130
32- _suspiciousRequestsCounts = new LRUCache < string , int > ( maxLruEntries , attackWaveTimeFrame ) ;
33- _suspiciousRequestsSamples = new LRUCache < string , List < SuspiciousRequest > > ( maxLruEntries , attackWaveTimeFrame ) ;
31+ _suspiciousRequests = new LRUCache < string , SuspiciousState > ( maxLruEntries , attackWaveTimeFrame ) ;
3432 _sentEventsMap = new LRUCache < string , long > ( maxLruEntries , minTimeBetweenEvents ) ;
3533 }
3634
@@ -70,8 +68,7 @@ public bool Check(Context context)
7068 }
7169
7270 // Update total counter and track unique sample
73- var suspiciousRequests = IncrementSuspiciousRequestCount ( ip ) ;
74- TrackUniqueSample ( ip , context ) ;
71+ var suspiciousRequests = TrackSuspiciousRequest ( ip , context ) ;
7572
7673 // Threshold not yet reached
7774 if ( suspiciousRequests < _attackWaveThreshold )
@@ -94,59 +91,42 @@ public IList<SuspiciousRequest> GetSamplesForIp(string ip)
9491
9592 lock ( _lock )
9693 {
97- if ( _suspiciousRequestsSamples . TryGetValue ( ip , out var samples ) && samples != null )
94+ if ( _suspiciousRequests . TryGetValue ( ip , out var state ) && state ? . Samples != null )
9895 {
99- return samples . ToList ( ) ;
96+ return state . Samples . ToList ( ) ;
10097 }
10198 }
10299
103100 return new List < SuspiciousRequest > ( ) ;
104101 }
105102
106- private int IncrementSuspiciousRequestCount ( string ip )
103+ private int TrackSuspiciousRequest ( string ip , Context context )
107104 {
108- if ( ! _suspiciousRequestsCounts . TryGetValue ( ip , out var count ) )
105+ if ( ! _suspiciousRequests . TryGetValue ( ip , out var state ) || state == null )
109106 {
110- count = 0 ;
107+ state = new SuspiciousState ( ) ;
111108 }
112109
113- count ++ ;
114- _suspiciousRequestsCounts . Set ( ip , count ) ;
115- return count ;
116- }
117-
118- private void TrackUniqueSample ( string ip , Context context )
119- {
120- if ( ! _suspiciousRequestsSamples . TryGetValue ( ip , out var samples ) || samples == null )
121- {
122- samples = new List < SuspiciousRequest > ( ) ;
123- }
124-
125- // Stop collecting once the per-IP cap is reached
126- if ( samples . Count >= _maxSamplesPerIp )
127- {
128- return ;
129- }
110+ state . Count ++ ;
130111
131- var requestSample = new SuspiciousRequest
112+ if ( state . Samples . Count < _maxSamplesPerIp )
132113 {
133- Method = context . Method ,
134- Url = BuildUrlWithQuery ( context )
135- } ;
114+ var requestSample = new SuspiciousRequest
115+ {
116+ Method = context . Method ,
117+ Url = BuildUrlWithQuery ( context )
118+ } ;
136119
137- // Only store unique method+URL combinations
138- if ( samples . Any ( s => string . Equals ( s . Method , requestSample . Method , StringComparison . OrdinalIgnoreCase )
139- && string . Equals ( s . Url , requestSample . Url , StringComparison . OrdinalIgnoreCase ) ) )
140- {
141- // Set is necessary to keep-alive the samples LRU even when no new samples are added, otherwise samples
142- // might expire before the total count LRU hits the threshold, leading to a report with no samples.
143- _suspiciousRequestsSamples . Set ( ip , samples ) ;
144- return ;
120+ // Only store unique samples
121+ if ( ! state . Samples . Any ( s => string . Equals ( s . Method , requestSample . Method , StringComparison . OrdinalIgnoreCase )
122+ && string . Equals ( s . Url , requestSample . Url , StringComparison . OrdinalIgnoreCase ) ) )
123+ {
124+ state . Samples . Add ( requestSample ) ;
125+ }
145126 }
146127
147- // Update sample list for this IP
148- samples . Add ( requestSample ) ;
149- _suspiciousRequestsSamples . Set ( ip , samples ) ;
128+ _suspiciousRequests . Set ( ip , state ) ;
129+ return state . Count ;
150130 }
151131
152132 private static string BuildUrlWithQuery ( Context context )
@@ -191,4 +171,10 @@ public class SuspiciousRequest
191171 public string Method { get ; set ; }
192172 public string Url { get ; set ; }
193173 }
174+
175+ internal class SuspiciousState
176+ {
177+ public int Count { get ; set ; }
178+ public List < SuspiciousRequest > Samples { get ; set ; } = new List < SuspiciousRequest > ( ) ;
179+ }
194180}
0 commit comments