Skip to content

Commit

Permalink
Shortcircuit TBV visitor early for types we don't want to visit
Browse files Browse the repository at this point in the history
  • Loading branch information
manuel-alvarez-alvarez committed Nov 6, 2024
1 parent 21c0b2d commit 4712a4b
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ public Flow<Void> apply(final RequestContext ctx, final Object o) {
}

static boolean visitProtobufArtifact(@Nonnull final Class<?> kls) {
if (CharSequence.class.isAssignableFrom(kls)) {
return true; // we want to visit all strings in the message
}
if (kls.getSuperclass().getName().startsWith(GENERATED_MESSAGE)) {
return true; // GRPC custom messages
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,16 @@ public void onSessionValue(@Nonnull String name, Object value) {
* types.
*/
private static boolean visitClass(final Class<?> clazz) {
if (CharSequence.class.isAssignableFrom(clazz)) {
return true; // we want to visit all strings added to the session
}
if (clazz.isArray()) {
return true; // we also want to visit arrays
}
if ((Iterable.class.isAssignableFrom(clazz) || Map.class.isAssignableFrom(clazz))) {
final String className = clazz.getName();
return ALLOWED_COLLECTION_PKGS.apply(className) > 0;
return ALLOWED_COLLECTION_PKGS.apply(className)
> 0; // ignore unknown collection types (e.g. lazy ones)
}
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ private State visit(final int depth, final String path, final Object value) {
if (depth > maxDepth) {
return CONTINUE;
}
if (!visited.add(value)) {
if (!classFilter.test(value.getClass()) || !visited.add(value)) {
return CONTINUE;
}
State state = CONTINUE;
Expand Down Expand Up @@ -107,9 +107,6 @@ private State visitArray(final int depth, final String path, final Object[] arra
}

private State visitMap(final int depth, final String path, final Map<?, ?> map) {
if (!classFilter.test(map.getClass())) {
return CONTINUE;
}
final int mapDepth = depth + 1;
for (final Map.Entry<?, ?> entry : map.entrySet()) {
final Object key = entry.getKey();
Expand All @@ -133,9 +130,6 @@ private State visitMap(final int depth, final String path, final Map<?, ?> map)
}

private State visitIterable(final int depth, final String path, final Iterable<?> iterable) {
if (!classFilter.test(iterable.getClass())) {
return CONTINUE;
}
final int iterableDepth = depth + 1;
int index = 0;
for (final Object item : iterable) {
Expand All @@ -153,7 +147,7 @@ private State visitIterable(final int depth, final String path, final Iterable<?
private State visitObject(final int depth, final String path, final Object value) {
final int childDepth = depth + 1;
State state = visitor.visit(path, value);
if (state != State.CONTINUE || !classFilter.test(value.getClass())) {
if (state != State.CONTINUE) {
return state;
}
Class<?> klass = value.getClass();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ class GrpcRequestMessageHandlerTest extends IastModuleImplTestBase {
when: 'the message is not a protobuf instance'
ObjectVisitor.visit(nonProtobufMessage, visitor, filter)

then: 'only the root object is visited'
1 * visitor.visit('root', nonProtobufMessage) >> CONTINUE
then: 'nothing is visited'
0 * visitor._

when: 'the message is a protobuf message'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,22 @@ import com.datadog.iast.model.VulnerabilityType
import com.datadog.iast.taint.Ranges
import datadog.trace.api.iast.InstrumentationBridge
import datadog.trace.api.iast.SourceTypes
import foo.bar.Pojo
import foo.bar.VisitableClass
import net.bytebuddy.ByteBuddy
import net.bytebuddy.implementation.MethodDelegation
import net.bytebuddy.implementation.bind.annotation.AllArguments
import net.bytebuddy.implementation.bind.annotation.RuntimeType
import net.bytebuddy.implementation.bind.annotation.SuperMethod
import net.bytebuddy.implementation.bind.annotation.This

import java.lang.reflect.Method
import java.util.concurrent.atomic.AtomicInteger

import static java.util.Arrays.asList
import static net.bytebuddy.matcher.ElementMatchers.isConstructor
import static net.bytebuddy.matcher.ElementMatchers.named
import static net.bytebuddy.matcher.ElementMatchers.not

class TrustBoundaryViolationModuleTest extends IastModuleImplTestBase {

Expand Down Expand Up @@ -146,6 +159,32 @@ class TrustBoundaryViolationModuleTest extends IastModuleImplTestBase {
0 * reporter.report(_, _ as Vulnerability)
}

void 'test that proxy objects are not visited in TBV'() {
setup:
final pojo = new ByteBuddy()
.subclass(Pojo)
.method(not(isConstructor().or(named('toString'))))
.intercept(MethodDelegation.to(PojoInterceptor))
.make()
.load(Pojo.classLoader)
.loaded
.newInstance()

when:
module.onSessionValue('test', pojo)

then:
PojoInterceptor.INVOCATIONS.get() == 0

when:
pojo.setId(23)
pojo.setName('test')
pojo.equals(new Pojo(id: 12, name: 'another'))

then:
PojoInterceptor.INVOCATIONS.get() == 3
}

private static void assertVulnerability(final Vulnerability vuln, String expectedValue) {
assert vuln != null
assert vuln.getType() == VulnerabilityType.TRUST_BOUNDARY_VIOLATION
Expand All @@ -170,3 +209,16 @@ class DynamicList<E> {
throw new UnsupportedOperationException('Do not touch me!')
}
}

class PojoInterceptor {

static final AtomicInteger INVOCATIONS = new AtomicInteger(0)

@RuntimeType
static Object intercept(@This Object self,
@AllArguments Object[] args,
@SuperMethod Method superMethod) throws Throwable {
INVOCATIONS.addAndGet(1)
return superMethod.invoke(self, args)
}
}
38 changes: 38 additions & 0 deletions dd-java-agent/agent-iast/src/test/java/foo/bar/Pojo.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package foo.bar;

import java.util.Objects;

public class Pojo {

private int id;
private String name;

public int getId() {
return id;
}

public void setId(int id) {
this.id = id;
}

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Pojo pojo = (Pojo) o;
return id == pojo.id;
}

@Override
public int hashCode() {
return Objects.hashCode(id);
}
}

0 comments on commit 4712a4b

Please sign in to comment.