4
4
import datadog .trace .api .time .SystemTimeSource ;
5
5
import datadog .trace .api .time .TimeSource ;
6
6
import datadog .trace .util .NonBlockingSemaphore ;
7
-
8
- import javax .annotation .Nonnull ;
9
7
import java .util .Deque ;
10
- import java .util .Map ;
11
8
import java .util .concurrent .ConcurrentHashMap ;
12
9
import java .util .concurrent .ConcurrentLinkedDeque ;
10
+ import javax .annotation .Nonnull ;
13
11
14
12
public class ApiSecurityRequestSampler {
15
13
16
14
/**
17
- * A maximum number of request contexts we'll keep open past the end of request at any given time. This will avoid
18
- * excessive memory usage in case of a high number of concurrent requests, and should also prevent memory leaks in
19
- * case of a bug .
15
+ * A maximum number of request contexts we'll keep open past the end of request at any given time.
16
+ * This will avoid excessive memory usage in case of a high number of concurrent requests, and
17
+ * should also prevent memory leaks .
20
18
*/
21
19
private static final int MAX_POST_PROCESSING_TASKS = 4 ;
20
+
22
21
private static final int INTERVAL_SECONDS = 30 ;
23
22
private static final int MAX_SIZE = 4096 ;
24
- private final Map <Long , Long > apiAccessMap ; // Map<hash, timestamp>
25
- private final Deque <Long > apiAccessQueue ; // hashes ordered by access time
23
+ /** Mapping from endpoint hash to last access timestamp in millis. */
24
+ private final ConcurrentHashMap <Long , Long > accessMap ;
25
+ /** Deque of endpoint hashes ordered by access time. Oldest is always first. */
26
+ private final Deque <Long > accessDeque ;
27
+
26
28
private final long expirationTimeInMs ;
27
29
private final int capacity ;
28
30
private final TimeSource timeSource ;
29
31
30
- final NonBlockingSemaphore counter = NonBlockingSemaphore .withPermitCount (MAX_POST_PROCESSING_TASKS );
32
+ final NonBlockingSemaphore counter =
33
+ NonBlockingSemaphore .withPermitCount (MAX_POST_PROCESSING_TASKS );
31
34
32
35
public ApiSecurityRequestSampler () {
33
36
this (MAX_SIZE , INTERVAL_SECONDS * 1000 , SystemTimeSource .INSTANCE );
34
37
}
35
38
36
- public ApiSecurityRequestSampler (int capacity , long expirationTimeInMs , @ Nonnull TimeSource timeSource ) {
39
+ public ApiSecurityRequestSampler (
40
+ int capacity , long expirationTimeInMs , @ Nonnull TimeSource timeSource ) {
37
41
this .capacity = capacity ;
38
42
this .expirationTimeInMs = expirationTimeInMs ;
39
- this .apiAccessMap = new ConcurrentHashMap <>(MAX_SIZE );
40
- this .apiAccessQueue = new ConcurrentLinkedDeque <>();
43
+ this .accessMap = new ConcurrentHashMap <>();
44
+ this .accessDeque = new ConcurrentLinkedDeque <>();
41
45
this .timeSource = timeSource ;
42
46
}
43
47
48
+ /**
49
+ * Prepare a request context for later sampling decision. This method should be called at request
50
+ * end, and is thread-safe. If a request can potentially be sampled, this method will call {@link
51
+ * AppSecRequestContext#setKeepOpenForApiSecurityPostProcessing(boolean)}.
52
+ */
44
53
public void preSampleRequest (final @ Nonnull AppSecRequestContext ctx ) {
45
54
final String route = ctx .getRoute ();
46
55
if (route == null ) {
@@ -64,69 +73,77 @@ public void preSampleRequest(final @Nonnull AppSecRequestContext ctx) {
64
73
}
65
74
}
66
75
76
+ /** Get the final sampling decision. This method is NOT thread-safe. */
67
77
public boolean sampleRequest (AppSecRequestContext ctx ) {
68
78
if (ctx == null ) {
69
79
return false ;
70
80
}
71
81
final Long hash = ctx .getApiSecurityEndpointHash ();
72
82
if (hash == null ) {
83
+ // This should never happen, it should have been short-circuited before.
73
84
return false ;
74
85
}
75
86
return updateApiAccessIfExpired (hash );
76
87
}
77
88
78
- /**
79
- * Updates the API access log with the given route, method, and status code. If the record already
80
- * exists and is outdated, it is updated by moving to the end of the list. If the record does not
81
- * exist, a new record is added. If the capacity limit is reached, the oldest record is removed.
82
- * This method should not be called concurrently by multiple threads, due absence of additional
83
- * synchronization for updating data structures is not required.
84
- */
85
- public boolean updateApiAccessIfExpired (final long hash ) {
89
+ private boolean updateApiAccessIfExpired (final long hash ) {
86
90
final long currentTime = timeSource .getCurrentTimeMillis ();
87
91
88
- // New or updated record
89
- boolean isNewOrUpdated = false ;
90
- if (! apiAccessMap . containsKey ( hash )
91
- || currentTime - apiAccessMap . get ( hash ) >= expirationTimeInMs ) {
92
+ Long lastAccess = accessMap . get ( hash );
93
+ if ( lastAccess != null && currentTime - lastAccess < expirationTimeInMs ) {
94
+ return false ;
95
+ }
92
96
97
+ if (accessMap .put (hash , currentTime ) == null ) {
98
+ accessDeque .addLast (hash );
99
+ // If we added a new entry, we perform purging.
93
100
cleanupExpiredEntries (currentTime );
94
-
95
- apiAccessMap .put (hash , currentTime ); // Update timestamp
96
- // move hash to the end of the queue
97
- apiAccessQueue .remove (hash );
98
- apiAccessQueue .addLast (hash );
99
- isNewOrUpdated = true ;
100
-
101
- // Remove the oldest hash if capacity is reached
102
- while (apiAccessMap .size () > this .capacity ) {
103
- Long oldestHash = apiAccessQueue .pollFirst ();
104
- if (oldestHash != null ) {
105
- apiAccessMap .remove (oldestHash );
106
- }
107
- }
101
+ } else {
102
+ // This is now the most recently accessed entry.
103
+ accessDeque .remove (hash );
104
+ accessDeque .addLast (hash );
108
105
}
109
106
110
- return isNewOrUpdated ;
107
+ return true ;
111
108
}
112
109
113
- public boolean isApiAccessExpired (final long hash ) {
114
- long currentTime = timeSource .getCurrentTimeMillis ();
115
- return ! apiAccessMap . containsKey (hash )
116
- || currentTime - apiAccessMap . get ( hash ) >= expirationTimeInMs ;
110
+ private boolean isApiAccessExpired (final long hash ) {
111
+ final long currentTime = timeSource .getCurrentTimeMillis ();
112
+ final Long lastAccess = accessMap . get (hash );
113
+ return lastAccess == null || currentTime - lastAccess >= expirationTimeInMs ;
117
114
}
118
115
119
116
private void cleanupExpiredEntries (final long currentTime ) {
120
- while (!apiAccessQueue .isEmpty ()) {
121
- Long oldestHash = apiAccessQueue .peekFirst ();
122
- if (oldestHash == null ) break ;
123
-
124
- Long lastAccessTime = apiAccessMap .get (oldestHash );
125
- if (lastAccessTime == null || currentTime - lastAccessTime >= expirationTimeInMs ) {
126
- apiAccessQueue .pollFirst (); // remove from head
127
- apiAccessMap .remove (oldestHash );
128
- } else {
129
- break ; // is up-to-date
117
+ // Purge all expired entries.
118
+ while (!accessDeque .isEmpty ()) {
119
+ final Long oldestHash = accessDeque .peekFirst ();
120
+ if (oldestHash == null ) {
121
+ // Should never happen
122
+ continue ;
123
+ }
124
+
125
+ final Long lastAccessTime = accessMap .get (oldestHash );
126
+ if (lastAccessTime == null ) {
127
+ // Should never happen
128
+ continue ;
129
+ }
130
+
131
+ if (currentTime - lastAccessTime < expirationTimeInMs ) {
132
+ // The oldest hash is up-to-date, so stop here.
133
+ break ;
134
+ }
135
+
136
+ accessDeque .pollFirst ();
137
+ accessMap .remove (oldestHash );
138
+ }
139
+
140
+ // If we went over capacity, remove the oldest entries until we are within the limit.
141
+ // This should never be more than 1.
142
+ final int toRemove = accessMap .size () - this .capacity ;
143
+ for (int i = 0 ; i < toRemove ; i ++) {
144
+ Long oldestHash = accessDeque .pollFirst ();
145
+ if (oldestHash != null ) {
146
+ accessMap .remove (oldestHash );
130
147
}
131
148
}
132
149
}
@@ -145,13 +162,11 @@ public NoOp() {
145
162
}
146
163
147
164
@ Override
148
- public void preSampleRequest (@ Nonnull AppSecRequestContext ctx ) {
149
- }
165
+ public void preSampleRequest (@ Nonnull AppSecRequestContext ctx ) {}
150
166
151
167
@ Override
152
168
public boolean sampleRequest (AppSecRequestContext ctx ) {
153
169
return false ;
154
170
}
155
171
}
156
-
157
172
}
0 commit comments