Skip to content

Commit 75d66b9

Browse files
committed
wip
1 parent f050d3d commit 75d66b9

File tree

3 files changed

+142
-52
lines changed

3 files changed

+142
-52
lines changed

dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiSecurityRequestSampler.java

+28-30
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import com.datadog.appsec.gateway.AppSecRequestContext;
44
import datadog.trace.util.NonBlockingSemaphore;
55

6+
import javax.annotation.Nonnull;
67
import java.util.Deque;
78
import java.util.Map;
89
import java.util.concurrent.ConcurrentHashMap;
@@ -32,38 +33,42 @@ public ApiSecurityRequestSampler() {
3233
public ApiSecurityRequestSampler(int capacity, long expirationTimeInMs) {
3334
this.capacity = capacity;
3435
this.expirationTimeInMs = expirationTimeInMs;
35-
this.apiAccessMap = new ConcurrentHashMap<>();
36+
this.apiAccessMap = new ConcurrentHashMap<>(MAX_SIZE);
3637
this.apiAccessQueue = new ConcurrentLinkedDeque<>();
3738
}
3839

39-
public void preSampleRequest(final AppSecRequestContext ctx) {
40-
if (!isValid(ctx)) {
40+
public void preSampleRequest(final @Nonnull AppSecRequestContext ctx) {
41+
final String route = ctx.getRoute();
42+
if (route == null) {
4143
return;
4244
}
43-
44-
if (!isApiAccessExpired(ctx.getRoute(), ctx.getMethod(), ctx.getResponseStatus())) {
45+
final String method = ctx.getMethod();
46+
if (method == null) {
47+
return;
48+
}
49+
final int statusCode = ctx.getResponseStatus();
50+
if (statusCode == 0) {
51+
return;
52+
}
53+
long hash = computeApiHash(route, method, statusCode);
54+
ctx.setApiSecurityEndpointHash(hash);
55+
if (!isApiAccessExpired(hash)) {
4556
return;
4657
}
47-
4858
if (counter.acquire()) {
4959
ctx.setKeepOpenForApiSecurityPostProcessing(true);
5060
}
5161
}
5262

5363
public boolean sampleRequest(AppSecRequestContext ctx) {
54-
if (!isValid(ctx)) {
64+
if (ctx == null) {
5565
return false;
5666
}
57-
58-
return updateApiAccessIfExpired(
59-
ctx.getRoute(), ctx.getMethod(), ctx.getResponseStatus());
60-
}
61-
62-
private boolean isValid(AppSecRequestContext ctx) {
63-
return ctx != null
64-
&& ctx.getRoute() != null
65-
&& ctx.getMethod() != null
66-
&& ctx.getResponseStatus() != 0;
67+
final Long hash = ctx.getApiSecurityEndpointHash();
68+
if (hash == null) {
69+
return false;
70+
}
71+
return updateApiAccessIfExpired(hash);
6772
}
6873

6974
/**
@@ -72,15 +77,9 @@ private boolean isValid(AppSecRequestContext ctx) {
7277
* exist, a new record is added. If the capacity limit is reached, the oldest record is removed.
7378
* This method should not be called concurrently by multiple threads, due absence of additional
7479
* synchronization for updating data structures is not required.
75-
*
76-
* @param route The route of the API endpoint request
77-
* @param method The method of the API request
78-
* @param statusCode The HTTP response status code of the API request
79-
* @return return true if the record was updated or added, false otherwise
8080
*/
81-
public boolean updateApiAccessIfExpired(String route, String method, int statusCode) {
82-
long currentTime = System.currentTimeMillis();
83-
long hash = computeApiHash(route, method, statusCode);
81+
public boolean updateApiAccessIfExpired(final long hash) {
82+
final long currentTime = System.currentTimeMillis();
8483

8584
// New or updated record
8685
boolean isNewOrUpdated = false;
@@ -107,14 +106,13 @@ public boolean updateApiAccessIfExpired(String route, String method, int statusC
107106
return isNewOrUpdated;
108107
}
109108

110-
public boolean isApiAccessExpired(String route, String method, int statusCode) {
109+
public boolean isApiAccessExpired(final long hash) {
111110
long currentTime = System.currentTimeMillis();
112-
long hash = computeApiHash(route, method, statusCode);
113111
return !apiAccessMap.containsKey(hash)
114112
|| currentTime - apiAccessMap.get(hash) > expirationTimeInMs;
115113
}
116114

117-
private void cleanupExpiredEntries(long currentTime) {
115+
private void cleanupExpiredEntries(final long currentTime) {
118116
while (!apiAccessQueue.isEmpty()) {
119117
Long oldestHash = apiAccessQueue.peekFirst();
120118
if (oldestHash == null) break;
@@ -129,7 +127,7 @@ private void cleanupExpiredEntries(long currentTime) {
129127
}
130128
}
131129

132-
private long computeApiHash(String route, String method, int statusCode) {
130+
private long computeApiHash(final String route, final String method, final int statusCode) {
133131
long result = 17;
134132
result = 31 * result + route.hashCode();
135133
result = 31 * result + method.hashCode();
@@ -143,7 +141,7 @@ public NoOp() {
143141
}
144142

145143
@Override
146-
public void preSampleRequest(AppSecRequestContext ctx) {
144+
public void preSampleRequest(@Nonnull AppSecRequestContext ctx) {
147145
}
148146

149147
@Override

dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/AppSecRequestContext.java

+9-4
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ public class AppSecRequestContext implements DataBundle, Closeable {
143143
private volatile String sessionId;
144144

145145
private volatile boolean keepOpenForApiSecurityPostProcessing;
146+
private volatile Long apiSecurityEndpointHash;
146147

147148
private static final AtomicIntegerFieldUpdater<AppSecRequestContext> WAF_TIMEOUTS_UPDATER =
148149
AtomicIntegerFieldUpdater.newUpdater(AppSecRequestContext.class, "wafTimeouts");
@@ -388,10 +389,6 @@ public String getRoute() {
388389
}
389390

390391
public void setRoute(String route) {
391-
if (this.route != null && this.route.compareToIgnoreCase(route) != 0) {
392-
throw new IllegalStateException(
393-
"Forbidden attempt to set different route for given request context");
394-
}
395392
this.route = route;
396393
}
397394

@@ -403,6 +400,14 @@ public boolean isKeepOpenForApiSecurityPostProcessing() {
403400
return this.keepOpenForApiSecurityPostProcessing;
404401
}
405402

403+
public void setApiSecurityEndpointHash(long hash) {
404+
this.apiSecurityEndpointHash = hash;
405+
}
406+
407+
public Long getApiSecurityEndpointHash() {
408+
return this.apiSecurityEndpointHash;
409+
}
410+
406411
void addRequestHeader(String name, String value) {
407412
if (finishedRequestHeaders) {
408413
throw new IllegalStateException("Request headers were said to be finished before");

dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/ApiSecurityRequestSamplerTest.groovy

+105-18
Original file line numberDiff line numberDiff line change
@@ -7,55 +7,54 @@ class ApiSecurityRequestSamplerTest extends DDSpecification {
77

88
void 'happy path with single request'() {
99
given:
10-
def ctx = Mock(AppSecRequestContext)
10+
def ctx = Spy(createContext('route1', 'GET', 200))
1111
def sampler = new ApiSecurityRequestSampler()
1212

1313
when:
1414
sampler.preSampleRequest(ctx)
1515

1616
then:
17-
_ * ctx.getRoute() >> 'route1'
18-
_ * ctx.getMethod() >> 'GET'
19-
_ * ctx.getResponseStatus() >> 200
17+
1 * ctx.getRoute()
18+
1 * ctx.getMethod()
19+
1 * ctx.getResponseStatus()
2020
1 * ctx.setKeepOpenForApiSecurityPostProcessing(true)
21+
1 * ctx.setApiSecurityEndpointHash(_)
2122
0 * _
2223

2324
when:
2425
def sampleDecision = sampler.sampleRequest(ctx)
2526

2627
then:
2728
sampleDecision
28-
_ * ctx.getRoute() >> 'route1'
29-
_ * ctx.getMethod() >> 'GET'
30-
_ * ctx.getResponseStatus() >> 200
3129
_ * ctx.isKeepOpenForApiSecurityPostProcessing() >> true
30+
1 * ctx.getApiSecurityEndpointHash()
3231
0 * _
3332
}
3433

3534
void 'second request is not sampled for the same endpoint'() {
35+
Long hash
3636
given:
37-
AppSecRequestContext ctx1 = Mock(AppSecRequestContext)
38-
AppSecRequestContext ctx2 = Mock(AppSecRequestContext)
37+
AppSecRequestContext ctx1 = Spy(createContext('route1', 'GET', 200))
38+
AppSecRequestContext ctx2 = Spy(createContext('route1', 'GET', 200))
3939
def sampler = new ApiSecurityRequestSampler()
4040

4141
when:
4242
sampler.preSampleRequest(ctx1)
4343
def sampleDecision = sampler.sampleRequest(ctx1)
44+
sampler.counter.release()
4445

4546
then:
4647
sampleDecision
47-
_ * ctx1.getRoute() >> 'route1'
48-
_ * ctx1.getMethod() >> 'GET'
49-
_ * ctx1.getResponseStatus() >> 200
5048
_ * _
5149

5250
when:
5351
sampler.preSampleRequest(ctx2)
5452

5553
then:
56-
_ * ctx2.getRoute() >> 'route1'
57-
_ * ctx2.getMethod() >> 'GET'
58-
_ * ctx2.getResponseStatus() >> 200
54+
1 * ctx2.getRoute()
55+
1 * ctx2.getMethod()
56+
1 * ctx2.getResponseStatus()
57+
1 * ctx2.setApiSecurityEndpointHash(_)
5958
0 * ctx2.setKeepOpenForApiSecurityPostProcessing(_)
6059
0 * _
6160

@@ -64,10 +63,98 @@ class ApiSecurityRequestSamplerTest extends DDSpecification {
6463

6564
then:
6665
!sampleDecision
67-
_ * ctx2.getRoute() >> 'route1'
68-
_ * ctx2.getMethod() >> 'GET'
69-
_ * ctx2.getResponseStatus() >> 200
66+
1 * ctx2.getApiSecurityEndpointHash()
7067
0 * _
7168
}
7269

70+
void 'preSampleRequest with maximum concurrent contexts'() {
71+
given:
72+
final ctx1 = Spy(createContext('route2', 'GET', 200))
73+
final ctx2 = Spy(createContext('route3', 'GET', 200))
74+
final sampler = new ApiSecurityRequestSampler()
75+
assert sampler.MAX_POST_PROCESSING_TASKS > 0
76+
77+
when: 'exhaust the maximum number of concurrent contexts'
78+
for (int i = 0; i < sampler.MAX_POST_PROCESSING_TASKS; i++) {
79+
sampler.preSampleRequest(createContext('route1', 'GET', 200 + i))
80+
}
81+
82+
and: 'try to sample one more'
83+
sampler.preSampleRequest(ctx1)
84+
85+
then:
86+
1 * ctx1.getRoute()
87+
1 * ctx1.getMethod()
88+
1 * ctx1.getResponseStatus()
89+
1 * ctx1.setApiSecurityEndpointHash(_)
90+
0 * _
91+
92+
when: 'release one context'
93+
sampler.counter.release()
94+
95+
and: 'next can be sampled'
96+
sampler.preSampleRequest(ctx2)
97+
98+
then:
99+
1 * ctx2.getRoute()
100+
1 * ctx2.getMethod()
101+
1 * ctx2.getResponseStatus()
102+
1 * ctx2.setApiSecurityEndpointHash(_)
103+
1 * ctx2.setKeepOpenForApiSecurityPostProcessing(true)
104+
0 * _
105+
}
106+
107+
void 'preSampleRequest with null route'() {
108+
given:
109+
def ctx = Spy(createContext(null, 'GET', 200))
110+
def sampler = new ApiSecurityRequestSampler()
111+
112+
when:
113+
def sampleDecision = sampler.preSampleRequest(ctx)
114+
115+
then:
116+
!sampleDecision
117+
1 * ctx.getRoute()
118+
0 * _
119+
}
120+
121+
void 'preSampleRequest with null method'() {
122+
given:
123+
def ctx = Spy(createContext('route1', null, 200))
124+
def sampler = new ApiSecurityRequestSampler()
125+
126+
when:
127+
def sampleDecision = sampler.preSampleRequest(ctx)
128+
129+
then:
130+
!sampleDecision
131+
1 * ctx.getRoute()
132+
1 * ctx.getMethod()
133+
0 * _
134+
}
135+
136+
void 'preSampleRequest with 0 status code'() {
137+
given:
138+
def ctx = Spy(createContext('route1', 'GET', 0))
139+
def sampler = new ApiSecurityRequestSampler()
140+
141+
when:
142+
def sampleDecision = sampler.preSampleRequest(ctx)
143+
144+
then:
145+
!sampleDecision
146+
1 * ctx.getRoute()
147+
1 * ctx.getMethod()
148+
1 * ctx.getResponseStatus()
149+
0 * _
150+
}
151+
152+
private AppSecRequestContext createContext(final String route, final String method, int statusCode) {
153+
final AppSecRequestContext ctx = new AppSecRequestContext()
154+
ctx.setRoute(route)
155+
ctx.setMethod(method)
156+
ctx.setResponseStatus(statusCode)
157+
ctx
158+
}
159+
73160
}

0 commit comments

Comments
 (0)