44import datadog .trace .api .time .SystemTimeSource ;
55import datadog .trace .api .time .TimeSource ;
66import datadog .trace .util .NonBlockingSemaphore ;
7-
8- import javax .annotation .Nonnull ;
97import java .util .Deque ;
10- import java .util .Map ;
118import java .util .concurrent .ConcurrentHashMap ;
129import java .util .concurrent .ConcurrentLinkedDeque ;
10+ import javax .annotation .Nonnull ;
1311
1412public class ApiSecurityRequestSampler {
1513
1614 /**
17- * A maximum number of request contexts we'll keep open past the end of request at any given time. This will avoid
18- * excessive memory usage in case of a high number of concurrent requests, and should also prevent memory leaks in
19- * case of a bug .
15+ * A maximum number of request contexts we'll keep open past the end of request at any given time.
16+ * This will avoid excessive memory usage in case of a high number of concurrent requests, and
17+ * should also prevent memory leaks .
2018 */
2119 private static final int MAX_POST_PROCESSING_TASKS = 4 ;
20+
2221 private static final int INTERVAL_SECONDS = 30 ;
2322 private static final int MAX_SIZE = 4096 ;
24- private final Map <Long , Long > apiAccessMap ; // Map<hash, timestamp>
25- private final Deque <Long > apiAccessQueue ; // hashes ordered by access time
23+ /** Mapping from endpoint hash to last access timestamp in millis. */
24+ private final ConcurrentHashMap <Long , Long > accessMap ;
25+ /** Deque of endpoint hashes ordered by access time. Oldest is always first. */
26+ private final Deque <Long > accessDeque ;
27+
2628 private final long expirationTimeInMs ;
2729 private final int capacity ;
2830 private final TimeSource timeSource ;
2931
30- final NonBlockingSemaphore counter = NonBlockingSemaphore .withPermitCount (MAX_POST_PROCESSING_TASKS );
32+ final NonBlockingSemaphore counter =
33+ NonBlockingSemaphore .withPermitCount (MAX_POST_PROCESSING_TASKS );
3134
3235 public ApiSecurityRequestSampler () {
3336 this (MAX_SIZE , INTERVAL_SECONDS * 1000 , SystemTimeSource .INSTANCE );
3437 }
3538
36- public ApiSecurityRequestSampler (int capacity , long expirationTimeInMs , @ Nonnull TimeSource timeSource ) {
39+ public ApiSecurityRequestSampler (
40+ int capacity , long expirationTimeInMs , @ Nonnull TimeSource timeSource ) {
3741 this .capacity = capacity ;
3842 this .expirationTimeInMs = expirationTimeInMs ;
39- this .apiAccessMap = new ConcurrentHashMap <>(MAX_SIZE );
40- this .apiAccessQueue = new ConcurrentLinkedDeque <>();
43+ this .accessMap = new ConcurrentHashMap <>();
44+ this .accessDeque = new ConcurrentLinkedDeque <>();
4145 this .timeSource = timeSource ;
4246 }
4347
48+ /**
49+ * Prepare a request context for later sampling decision. This method should be called at request
50+ * end, and is thread-safe. If a request can potentially be sampled, this method will call {@link
51+ * AppSecRequestContext#setKeepOpenForApiSecurityPostProcessing(boolean)}.
52+ */
4453 public void preSampleRequest (final @ Nonnull AppSecRequestContext ctx ) {
4554 final String route = ctx .getRoute ();
4655 if (route == null ) {
@@ -64,69 +73,77 @@ public void preSampleRequest(final @Nonnull AppSecRequestContext ctx) {
6473 }
6574 }
6675
76+ /** Get the final sampling decision. This method is NOT thread-safe. */
6777 public boolean sampleRequest (AppSecRequestContext ctx ) {
6878 if (ctx == null ) {
6979 return false ;
7080 }
7181 final Long hash = ctx .getApiSecurityEndpointHash ();
7282 if (hash == null ) {
83+ // This should never happen, it should have been short-circuited before.
7384 return false ;
7485 }
7586 return updateApiAccessIfExpired (hash );
7687 }
7788
78- /**
79- * Updates the API access log with the given route, method, and status code. If the record already
80- * exists and is outdated, it is updated by moving to the end of the list. If the record does not
81- * exist, a new record is added. If the capacity limit is reached, the oldest record is removed.
82- * This method should not be called concurrently by multiple threads, due absence of additional
83- * synchronization for updating data structures is not required.
84- */
85- public boolean updateApiAccessIfExpired (final long hash ) {
89+ private boolean updateApiAccessIfExpired (final long hash ) {
8690 final long currentTime = timeSource .getCurrentTimeMillis ();
8791
88- // New or updated record
89- boolean isNewOrUpdated = false ;
90- if (! apiAccessMap . containsKey ( hash )
91- || currentTime - apiAccessMap . get ( hash ) >= expirationTimeInMs ) {
92+ Long lastAccess = accessMap . get ( hash );
93+ if ( lastAccess != null && currentTime - lastAccess < expirationTimeInMs ) {
94+ return false ;
95+ }
9296
97+ if (accessMap .put (hash , currentTime ) == null ) {
98+ accessDeque .addLast (hash );
99+ // If we added a new entry, we perform purging.
93100 cleanupExpiredEntries (currentTime );
94-
95- apiAccessMap .put (hash , currentTime ); // Update timestamp
96- // move hash to the end of the queue
97- apiAccessQueue .remove (hash );
98- apiAccessQueue .addLast (hash );
99- isNewOrUpdated = true ;
100-
101- // Remove the oldest hash if capacity is reached
102- while (apiAccessMap .size () > this .capacity ) {
103- Long oldestHash = apiAccessQueue .pollFirst ();
104- if (oldestHash != null ) {
105- apiAccessMap .remove (oldestHash );
106- }
107- }
101+ } else {
102+ // This is now the most recently accessed entry.
103+ accessDeque .remove (hash );
104+ accessDeque .addLast (hash );
108105 }
109106
110- return isNewOrUpdated ;
107+ return true ;
111108 }
112109
113- public boolean isApiAccessExpired (final long hash ) {
114- long currentTime = timeSource .getCurrentTimeMillis ();
115- return ! apiAccessMap . containsKey (hash )
116- || currentTime - apiAccessMap . get ( hash ) >= expirationTimeInMs ;
110+ private boolean isApiAccessExpired (final long hash ) {
111+ final long currentTime = timeSource .getCurrentTimeMillis ();
112+ final Long lastAccess = accessMap . get (hash );
113+ return lastAccess == null || currentTime - lastAccess >= expirationTimeInMs ;
117114 }
118115
119116 private void cleanupExpiredEntries (final long currentTime ) {
120- while (!apiAccessQueue .isEmpty ()) {
121- Long oldestHash = apiAccessQueue .peekFirst ();
122- if (oldestHash == null ) break ;
123-
124- Long lastAccessTime = apiAccessMap .get (oldestHash );
125- if (lastAccessTime == null || currentTime - lastAccessTime >= expirationTimeInMs ) {
126- apiAccessQueue .pollFirst (); // remove from head
127- apiAccessMap .remove (oldestHash );
128- } else {
129- break ; // is up-to-date
117+ // Purge all expired entries.
118+ while (!accessDeque .isEmpty ()) {
119+ final Long oldestHash = accessDeque .peekFirst ();
120+ if (oldestHash == null ) {
121+ // Should never happen
122+ continue ;
123+ }
124+
125+ final Long lastAccessTime = accessMap .get (oldestHash );
126+ if (lastAccessTime == null ) {
127+ // Should never happen
128+ continue ;
129+ }
130+
131+ if (currentTime - lastAccessTime < expirationTimeInMs ) {
132+ // The oldest hash is up-to-date, so stop here.
133+ break ;
134+ }
135+
136+ accessDeque .pollFirst ();
137+ accessMap .remove (oldestHash );
138+ }
139+
140+ // If we went over capacity, remove the oldest entries until we are within the limit.
141+ // This should never be more than 1.
142+ final int toRemove = accessMap .size () - this .capacity ;
143+ for (int i = 0 ; i < toRemove ; i ++) {
144+ Long oldestHash = accessDeque .pollFirst ();
145+ if (oldestHash != null ) {
146+ accessMap .remove (oldestHash );
130147 }
131148 }
132149 }
@@ -145,13 +162,11 @@ public NoOp() {
145162 }
146163
147164 @ Override
148- public void preSampleRequest (@ Nonnull AppSecRequestContext ctx ) {
149- }
165+ public void preSampleRequest (@ Nonnull AppSecRequestContext ctx ) {}
150166
151167 @ Override
152168 public boolean sampleRequest (AppSecRequestContext ctx ) {
153169 return false ;
154170 }
155171 }
156-
157172}
0 commit comments