Skip to content

Commit 38aa20c

Browse files
authored
Merge pull request #201 from AikidoSec/feat-attack-wave-detection
Use a single LRU cache for suspicious requests
2 parents c93b440 + 4d46650 commit 38aa20c

File tree

1 file changed

+29
-43
lines changed

1 file changed

+29
-43
lines changed

Aikido.Zen.Core/Vulnerabilities/AttackWave/AttackWaveDetector.cs

Lines changed: 29 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ namespace Aikido.Zen.Core.Vulnerabilities
88
{
99
public class AttackWaveDetector
1010
{
11-
private readonly LRUCache<string, int> _suspiciousRequestsCounts;
12-
private readonly LRUCache<string, List<SuspiciousRequest>> _suspiciousRequestsSamples;
11+
private readonly LRUCache<string, SuspiciousState> _suspiciousRequests;
1312
private readonly LRUCache<string, long> _sentEventsMap;
1413
private readonly int _attackWaveThreshold;
1514
private readonly int _maxSamplesPerIp;
@@ -29,8 +28,7 @@ public AttackWaveDetector(AttackWaveDetectorOptions options = null)
2928
_attackWaveThreshold = options.AttackWaveThreshold ?? 15;
3029
_maxSamplesPerIp = Math.Min(options.MaxSamplesPerIP ?? 15, _attackWaveThreshold);
3130

32-
_suspiciousRequestsCounts = new LRUCache<string, int>(maxLruEntries, attackWaveTimeFrame);
33-
_suspiciousRequestsSamples = new LRUCache<string, List<SuspiciousRequest>>(maxLruEntries, attackWaveTimeFrame);
31+
_suspiciousRequests = new LRUCache<string, SuspiciousState>(maxLruEntries, attackWaveTimeFrame);
3432
_sentEventsMap = new LRUCache<string, long>(maxLruEntries, minTimeBetweenEvents);
3533
}
3634

@@ -70,8 +68,7 @@ public bool Check(Context context)
7068
}
7169

7270
// Update total counter and track unique sample
73-
var suspiciousRequests = IncrementSuspiciousRequestCount(ip);
74-
TrackUniqueSample(ip, context);
71+
var suspiciousRequests = TrackSuspiciousRequest(ip, context);
7572

7673
// Threshold not yet reached
7774
if (suspiciousRequests < _attackWaveThreshold)
@@ -94,59 +91,42 @@ public IList<SuspiciousRequest> GetSamplesForIp(string ip)
9491

9592
lock (_lock)
9693
{
97-
if (_suspiciousRequestsSamples.TryGetValue(ip, out var samples) && samples != null)
94+
if (_suspiciousRequests.TryGetValue(ip, out var state) && state?.Samples != null)
9895
{
99-
return samples.ToList();
96+
return state.Samples.ToList();
10097
}
10198
}
10299

103100
return new List<SuspiciousRequest>();
104101
}
105102

106-
private int IncrementSuspiciousRequestCount(string ip)
103+
private int TrackSuspiciousRequest(string ip, Context context)
107104
{
108-
if (!_suspiciousRequestsCounts.TryGetValue(ip, out var count))
105+
if (!_suspiciousRequests.TryGetValue(ip, out var state) || state == null)
109106
{
110-
count = 0;
107+
state = new SuspiciousState();
111108
}
112109

113-
count++;
114-
_suspiciousRequestsCounts.Set(ip, count);
115-
return count;
116-
}
117-
118-
private void TrackUniqueSample(string ip, Context context)
119-
{
120-
if (!_suspiciousRequestsSamples.TryGetValue(ip, out var samples) || samples == null)
121-
{
122-
samples = new List<SuspiciousRequest>();
123-
}
124-
125-
// Stop collecting once the per-IP cap is reached
126-
if (samples.Count >= _maxSamplesPerIp)
127-
{
128-
return;
129-
}
110+
state.Count++;
130111

131-
var requestSample = new SuspiciousRequest
112+
if (state.Samples.Count < _maxSamplesPerIp)
132113
{
133-
Method = context.Method,
134-
Url = BuildUrlWithQuery(context)
135-
};
114+
var requestSample = new SuspiciousRequest
115+
{
116+
Method = context.Method,
117+
Url = BuildUrlWithQuery(context)
118+
};
136119

137-
// Only store unique method+URL combinations
138-
if (samples.Any(s => string.Equals(s.Method, requestSample.Method, StringComparison.OrdinalIgnoreCase)
139-
&& string.Equals(s.Url, requestSample.Url, StringComparison.OrdinalIgnoreCase)))
140-
{
141-
// Set is necessary to keep-alive the samples LRU even when no new samples are added, otherwise samples
142-
// might expire before the total count LRU hits the threshold, leading to a report with no samples.
143-
_suspiciousRequestsSamples.Set(ip, samples);
144-
return;
120+
// Only store unique samples
121+
if (!state.Samples.Any(s => string.Equals(s.Method, requestSample.Method, StringComparison.OrdinalIgnoreCase)
122+
&& string.Equals(s.Url, requestSample.Url, StringComparison.OrdinalIgnoreCase)))
123+
{
124+
state.Samples.Add(requestSample);
125+
}
145126
}
146127

147-
// Update sample list for this IP
148-
samples.Add(requestSample);
149-
_suspiciousRequestsSamples.Set(ip, samples);
128+
_suspiciousRequests.Set(ip, state);
129+
return state.Count;
150130
}
151131

152132
private static string BuildUrlWithQuery(Context context)
@@ -191,4 +171,10 @@ public class SuspiciousRequest
191171
public string Method { get; set; }
192172
public string Url { get; set; }
193173
}
174+
175+
internal class SuspiciousState
176+
{
177+
public int Count { get; set; }
178+
public List<SuspiciousRequest> Samples { get; set; } = new List<SuspiciousRequest>();
179+
}
194180
}

0 commit comments

Comments
 (0)