Skip to content

Commit b69714d

Browse files
committed
wip
1 parent 59db91a commit b69714d

File tree

11 files changed

+163
-106
lines changed

11 files changed

+163
-106
lines changed

dd-java-agent/appsec/src/main/java/com/datadog/appsec/AppSecSystem.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,14 @@
2323
import datadog.trace.api.telemetry.ProductChange;
2424
import datadog.trace.api.telemetry.ProductChangeCollector;
2525
import datadog.trace.bootstrap.ActiveSubsystems;
26+
import datadog.trace.bootstrap.instrumentation.api.SpanPostProcessor;
2627
import java.util.Collections;
2728
import java.util.HashMap;
2829
import java.util.List;
2930
import java.util.Map;
3031
import java.util.Set;
3132
import java.util.concurrent.atomic.AtomicBoolean;
3233
import java.util.stream.Collectors;
33-
34-
import datadog.trace.bootstrap.instrumentation.api.SpanPostProcessor;
3534
import org.slf4j.Logger;
3635
import org.slf4j.LoggerFactory;
3736

@@ -73,7 +72,8 @@ private static void doStart(SubscriptionService gw, SharedCommunicationObjects s
7372
ApiSecurityRequestSampler requestSampler;
7473
if (Config.get().isApiSecurityEnabled()) {
7574
requestSampler = new ApiSecurityRequestSampler();
76-
SpanPostProcessor.Holder.INSTANCE = new AppSecSpanPostProcessor(requestSampler, REPLACEABLE_EVENT_PRODUCER);
75+
SpanPostProcessor.Holder.INSTANCE =
76+
new AppSecSpanPostProcessor(requestSampler, REPLACEABLE_EVENT_PRODUCER);
7777
} else {
7878
requestSampler = new ApiSecurityRequestSampler.NoOp();
7979
}

dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/ApiSecurityRequestSampler.java

+71-56
Original file line numberDiff line numberDiff line change
@@ -4,43 +4,52 @@
44
import datadog.trace.api.time.SystemTimeSource;
55
import datadog.trace.api.time.TimeSource;
66
import datadog.trace.util.NonBlockingSemaphore;
7-
8-
import javax.annotation.Nonnull;
97
import java.util.Deque;
10-
import java.util.Map;
118
import java.util.concurrent.ConcurrentHashMap;
129
import java.util.concurrent.ConcurrentLinkedDeque;
10+
import javax.annotation.Nonnull;
1311

1412
public class ApiSecurityRequestSampler {
1513

1614
/**
17-
* A maximum number of request contexts we'll keep open past the end of request at any given time. This will avoid
18-
* excessive memory usage in case of a high number of concurrent requests, and should also prevent memory leaks in
19-
* case of a bug.
15+
* A maximum number of request contexts we'll keep open past the end of request at any given time.
16+
* This will avoid excessive memory usage in case of a high number of concurrent requests, and
17+
* should also prevent memory leaks.
2018
*/
2119
private static final int MAX_POST_PROCESSING_TASKS = 4;
20+
2221
private static final int INTERVAL_SECONDS = 30;
2322
private static final int MAX_SIZE = 4096;
24-
private final Map<Long, Long> apiAccessMap; // Map<hash, timestamp>
25-
private final Deque<Long> apiAccessQueue; // hashes ordered by access time
23+
/** Mapping from endpoint hash to last access timestamp in millis. */
24+
private final ConcurrentHashMap<Long, Long> accessMap;
25+
/** Deque of endpoint hashes ordered by access time. Oldest is always first. */
26+
private final Deque<Long> accessDeque;
27+
2628
private final long expirationTimeInMs;
2729
private final int capacity;
2830
private final TimeSource timeSource;
2931

30-
final NonBlockingSemaphore counter = NonBlockingSemaphore.withPermitCount(MAX_POST_PROCESSING_TASKS);
32+
final NonBlockingSemaphore counter =
33+
NonBlockingSemaphore.withPermitCount(MAX_POST_PROCESSING_TASKS);
3134

3235
public ApiSecurityRequestSampler() {
3336
this(MAX_SIZE, INTERVAL_SECONDS * 1000, SystemTimeSource.INSTANCE);
3437
}
3538

36-
public ApiSecurityRequestSampler(int capacity, long expirationTimeInMs, @Nonnull TimeSource timeSource) {
39+
public ApiSecurityRequestSampler(
40+
int capacity, long expirationTimeInMs, @Nonnull TimeSource timeSource) {
3741
this.capacity = capacity;
3842
this.expirationTimeInMs = expirationTimeInMs;
39-
this.apiAccessMap = new ConcurrentHashMap<>(MAX_SIZE);
40-
this.apiAccessQueue = new ConcurrentLinkedDeque<>();
43+
this.accessMap = new ConcurrentHashMap<>();
44+
this.accessDeque = new ConcurrentLinkedDeque<>();
4145
this.timeSource = timeSource;
4246
}
4347

48+
/**
49+
* Prepare a request context for later sampling decision. This method should be called at request
50+
* end, and is thread-safe. If a request can potentially be sampled, this method will call {@link
51+
* AppSecRequestContext#setKeepOpenForApiSecurityPostProcessing(boolean)}.
52+
*/
4453
public void preSampleRequest(final @Nonnull AppSecRequestContext ctx) {
4554
final String route = ctx.getRoute();
4655
if (route == null) {
@@ -64,69 +73,77 @@ public void preSampleRequest(final @Nonnull AppSecRequestContext ctx) {
6473
}
6574
}
6675

76+
/** Get the final sampling decision. This method is NOT thread-safe. */
6777
public boolean sampleRequest(AppSecRequestContext ctx) {
6878
if (ctx == null) {
6979
return false;
7080
}
7181
final Long hash = ctx.getApiSecurityEndpointHash();
7282
if (hash == null) {
83+
// This should never happen, it should have been short-circuited before.
7384
return false;
7485
}
7586
return updateApiAccessIfExpired(hash);
7687
}
7788

78-
/**
79-
* Updates the API access log with the given route, method, and status code. If the record already
80-
* exists and is outdated, it is updated by moving to the end of the list. If the record does not
81-
* exist, a new record is added. If the capacity limit is reached, the oldest record is removed.
82-
* This method should not be called concurrently by multiple threads, due absence of additional
83-
* synchronization for updating data structures is not required.
84-
*/
85-
public boolean updateApiAccessIfExpired(final long hash) {
89+
private boolean updateApiAccessIfExpired(final long hash) {
8690
final long currentTime = timeSource.getCurrentTimeMillis();
8791

88-
// New or updated record
89-
boolean isNewOrUpdated = false;
90-
if (!apiAccessMap.containsKey(hash)
91-
|| currentTime - apiAccessMap.get(hash) >= expirationTimeInMs) {
92+
Long lastAccess = accessMap.get(hash);
93+
if (lastAccess != null && currentTime - lastAccess < expirationTimeInMs) {
94+
return false;
95+
}
9296

97+
if (accessMap.put(hash, currentTime) == null) {
98+
accessDeque.addLast(hash);
99+
// If we added a new entry, we perform purging.
93100
cleanupExpiredEntries(currentTime);
94-
95-
apiAccessMap.put(hash, currentTime); // Update timestamp
96-
// move hash to the end of the queue
97-
apiAccessQueue.remove(hash);
98-
apiAccessQueue.addLast(hash);
99-
isNewOrUpdated = true;
100-
101-
// Remove the oldest hash if capacity is reached
102-
while (apiAccessMap.size() > this.capacity) {
103-
Long oldestHash = apiAccessQueue.pollFirst();
104-
if (oldestHash != null) {
105-
apiAccessMap.remove(oldestHash);
106-
}
107-
}
101+
} else {
102+
// This is now the most recently accessed entry.
103+
accessDeque.remove(hash);
104+
accessDeque.addLast(hash);
108105
}
109106

110-
return isNewOrUpdated;
107+
return true;
111108
}
112109

113-
public boolean isApiAccessExpired(final long hash) {
114-
long currentTime = timeSource.getCurrentTimeMillis();
115-
return !apiAccessMap.containsKey(hash)
116-
|| currentTime - apiAccessMap.get(hash) >= expirationTimeInMs;
110+
private boolean isApiAccessExpired(final long hash) {
111+
final long currentTime = timeSource.getCurrentTimeMillis();
112+
final Long lastAccess = accessMap.get(hash);
113+
return lastAccess == null || currentTime - lastAccess >= expirationTimeInMs;
117114
}
118115

119116
private void cleanupExpiredEntries(final long currentTime) {
120-
while (!apiAccessQueue.isEmpty()) {
121-
Long oldestHash = apiAccessQueue.peekFirst();
122-
if (oldestHash == null) break;
123-
124-
Long lastAccessTime = apiAccessMap.get(oldestHash);
125-
if (lastAccessTime == null || currentTime - lastAccessTime >= expirationTimeInMs) {
126-
apiAccessQueue.pollFirst(); // remove from head
127-
apiAccessMap.remove(oldestHash);
128-
} else {
129-
break; // is up-to-date
117+
// Purge all expired entries.
118+
while (!accessDeque.isEmpty()) {
119+
final Long oldestHash = accessDeque.peekFirst();
120+
if (oldestHash == null) {
121+
// Should never happen
122+
continue;
123+
}
124+
125+
final Long lastAccessTime = accessMap.get(oldestHash);
126+
if (lastAccessTime == null) {
127+
// Should never happen
128+
continue;
129+
}
130+
131+
if (currentTime - lastAccessTime < expirationTimeInMs) {
132+
// The oldest hash is up-to-date, so stop here.
133+
break;
134+
}
135+
136+
accessDeque.pollFirst();
137+
accessMap.remove(oldestHash);
138+
}
139+
140+
// If we went over capacity, remove the oldest entries until we are within the limit.
141+
// This should never be more than 1.
142+
final int toRemove = accessMap.size() - this.capacity;
143+
for (int i = 0; i < toRemove; i++) {
144+
Long oldestHash = accessDeque.pollFirst();
145+
if (oldestHash != null) {
146+
accessMap.remove(oldestHash);
130147
}
131148
}
132149
}
@@ -145,13 +162,11 @@ public NoOp() {
145162
}
146163

147164
@Override
148-
public void preSampleRequest(@Nonnull AppSecRequestContext ctx) {
149-
}
165+
public void preSampleRequest(@Nonnull AppSecRequestContext ctx) {}
150166

151167
@Override
152168
public boolean sampleRequest(AppSecRequestContext ctx) {
153169
return false;
154170
}
155171
}
156-
157172
}

dd-java-agent/appsec/src/main/java/com/datadog/appsec/api/security/AppSecSpanPostProcessor.java

+18-19
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,19 @@
1111
import datadog.trace.api.gateway.RequestContextSlot;
1212
import datadog.trace.bootstrap.instrumentation.api.AgentSpan;
1313
import datadog.trace.bootstrap.instrumentation.api.SpanPostProcessor;
14-
import org.slf4j.Logger;
15-
import org.slf4j.LoggerFactory;
16-
1714
import java.util.Collections;
1815
import java.util.function.BooleanSupplier;
16+
import org.slf4j.Logger;
17+
import org.slf4j.LoggerFactory;
1918

2019
public class AppSecSpanPostProcessor implements SpanPostProcessor {
2120

2221
private static final Logger log = LoggerFactory.getLogger(AppSecSpanPostProcessor.class);
2322
private final ApiSecurityRequestSampler sampler;
2423
private final EventProducerService producerService;
2524

26-
public AppSecSpanPostProcessor(ApiSecurityRequestSampler sampler, EventProducerService producerService) {
25+
public AppSecSpanPostProcessor(
26+
ApiSecurityRequestSampler sampler, EventProducerService producerService) {
2727
this.sampler = sampler;
2828
this.producerService = producerService;
2929
}
@@ -63,21 +63,20 @@ public void process(AgentSpan span, BooleanSupplier timeoutCheck) {
6363
}
6464

6565
private void maybeExtractSchemas(AppSecRequestContext ctx) {
66-
final EventProducerService.DataSubscriberInfo sub = producerService.getDataSubscribers(KnownAddresses.WAF_CONTEXT_PROCESSOR);
67-
if (sub == null|| sub.isEmpty()) {
68-
return;
69-
}
66+
final EventProducerService.DataSubscriberInfo sub =
67+
producerService.getDataSubscribers(KnownAddresses.WAF_CONTEXT_PROCESSOR);
68+
if (sub == null || sub.isEmpty()) {
69+
return;
70+
}
7071

71-
final DataBundle bundle =
72-
new SingletonDataBundle<>(
73-
KnownAddresses.WAF_CONTEXT_PROCESSOR,
74-
Collections.singletonMap("extract-schema", true));
75-
try {
76-
GatewayContext gwCtx = new GatewayContext(false);
77-
producerService.publishDataEvent(sub, ctx, bundle, gwCtx);
78-
} catch (ExpiredSubscriberInfoException e) {
79-
log.debug("Subscriber info expired", e);
80-
}
72+
final DataBundle bundle =
73+
new SingletonDataBundle<>(
74+
KnownAddresses.WAF_CONTEXT_PROCESSOR, Collections.singletonMap("extract-schema", true));
75+
try {
76+
GatewayContext gwCtx = new GatewayContext(false);
77+
producerService.publishDataEvent(sub, ctx, bundle, gwCtx);
78+
} catch (ExpiredSubscriberInfoException e) {
79+
log.debug("Subscriber info expired", e);
80+
}
8181
}
82-
8382
}

dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/AppSecRequestContext.java

+5-3
Original file line numberDiff line numberDiff line change
@@ -600,12 +600,14 @@ public String getSessionId() {
600600
}
601601

602602
/**
603-
* Close the context and release all resources. This method is idempotent and can be called multiple times.
604-
* For each root span, this method is always called from CoreTracer#onRootSpaPublished.
603+
* Close the context and release all resources. This method is idempotent and can be called
604+
* multiple times. For each root span, this method is always called from
605+
* CoreTracer#onRootSpaPublished.
605606
*/
606607
@Override
607608
public void close() {
608-
// For API Security, we sometimes keep contexts open for late processing. In that case, this flag needs to be
609+
// For API Security, we sometimes keep contexts open for late processing. In that case, this
610+
// flag needs to be
609611
// later reset by the API Security post-processor and close must be called again.
610612
if (!keepOpenForApiSecurityPostProcessing) {
611613
closeAdditive();

dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/GatewayBridge.java

-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
import datadog.trace.api.telemetry.WafMetricCollector;
4242
import datadog.trace.bootstrap.instrumentation.api.Tags;
4343
import datadog.trace.bootstrap.instrumentation.api.URIDataAdapter;
44-
import datadog.trace.util.NonBlockingSemaphore;
4544
import datadog.trace.util.stacktrace.StackTraceEvent;
4645
import datadog.trace.util.stacktrace.StackUtils;
4746
import java.net.URI;

dd-java-agent/appsec/src/test/groovy/com/datadog/appsec/api/security/ApiSecurityRequestSamplerTest.groovy

+55-4
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ import com.datadog.appsec.gateway.AppSecRequestContext
44
import datadog.trace.api.time.ControllableTimeSource
55
import datadog.trace.test.util.DDSpecification
66

7-
import java.time.Duration
8-
97
class ApiSecurityRequestSamplerTest extends DDSpecification {
108

119
void 'happy path with single request'() {
@@ -163,6 +161,20 @@ class ApiSecurityRequestSamplerTest extends DDSpecification {
163161
!sampleDecision
164162
}
165163

164+
void 'sampleRequest with null hash'() {
165+
// This case should never happen, as a request without a hash should have been short-circuited at multiple places
166+
// before reaching this point. But checking just in case.
167+
given:
168+
def sampler = new ApiSecurityRequestSampler()
169+
def ctx = createContext('route1', 'GET', 200)
170+
171+
when:
172+
def sampleDecision = sampler.sampleRequest(ctx)
173+
174+
then:
175+
!sampleDecision
176+
}
177+
166178
void 'sampleRequest honors expiration'() {
167179
given:
168180
def ctx = createContext('route1', 'GET', 200)
@@ -201,12 +213,51 @@ class ApiSecurityRequestSamplerTest extends DDSpecification {
201213
0 * _
202214
}
203215

204-
private AppSecRequestContext createContext(final String route, final String method, int statusCode) {
216+
void 'internal accessMap never goes beyond capacity'() {
217+
given:
218+
final timeSource = new ControllableTimeSource()
219+
timeSource.set(0)
220+
final long expirationTimeInMs = 10_000
221+
final int maxCapacity = 10
222+
ApiSecurityRequestSampler sampler = new ApiSecurityRequestSampler(10, expirationTimeInMs, timeSource)
223+
224+
expect:
225+
for (int i = 0; i < maxCapacity * 10; i++) {
226+
timeSource.advance(1_000_000)
227+
final ctx = createContext('route1', 'GET', 200 + 1)
228+
ctx.setApiSecurityEndpointHash(i as long)
229+
ctx.setKeepOpenForApiSecurityPostProcessing(true)
230+
assert sampler.sampleRequest(ctx)
231+
assert sampler.accessMap.size() <= maxCapacity
232+
}
233+
}
234+
235+
void 'expired entries are purged from internal accessMap'() {
236+
given:
237+
final timeSource = new ControllableTimeSource()
238+
timeSource.set(0)
239+
final long expirationTimeInMs = 10_000
240+
final int maxCapacity = 10
241+
ApiSecurityRequestSampler sampler = new ApiSecurityRequestSampler(10, expirationTimeInMs, timeSource)
242+
243+
expect:
244+
for (int i = 0; i < maxCapacity * 10; i++) {
245+
final ctx = createContext('route1', 'GET', 200 + 1)
246+
ctx.setApiSecurityEndpointHash(i as long)
247+
ctx.setKeepOpenForApiSecurityPostProcessing(true)
248+
assert sampler.sampleRequest(ctx)
249+
assert sampler.accessMap.size() <= 2
250+
if (i % 2) {
251+
timeSource.advance(expirationTimeInMs * 1_000_000)
252+
}
253+
}
254+
}
255+
256+
private static AppSecRequestContext createContext(final String route, final String method, int statusCode) {
205257
final AppSecRequestContext ctx = new AppSecRequestContext()
206258
ctx.setRoute(route)
207259
ctx.setMethod(method)
208260
ctx.setResponseStatus(statusCode)
209261
ctx
210262
}
211-
212263
}

0 commit comments

Comments
 (0)