Skip to content

Commit

Permalink
Add support for workflows using the callback server in unit tests.
Browse files Browse the repository at this point in the history
This change requires:
- Delaying the initialization of detectors to when the test runs rather than when it is being created;
- Change to the test interface: instead of providing the detector tested, each test is now only provided with the name of the detector. It is now the responsibility of each test to initialize the detectors depending on their environment;

This change might make unit tests slightly slower, especially as the list of plugin grows, but provide a lot more flexibility for testing hermetically.

PiperOrigin-RevId: 725549365
Change-Id: I58f2a6f1f8955615b8ecbd1abe1c97aa1d62499d
  • Loading branch information
tooryx authored and copybara-github committed Feb 11, 2025
1 parent 64c5541 commit 817c341
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 29 deletions.
16 changes: 15 additions & 1 deletion templated/templateddetector/proto/templated_plugin_tests.proto
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,29 @@ message TemplatedPluginTests {
}

message Test {
message MockCallbackConfig {
// Whether the callback server should be mocked.
// Disabled by default.
bool enabled = 1;

// The return value for the callback server. That is, whether the callback
// server reports the secret having been interracted with or not.
// Only used if enabled is true.
bool has_interaction = 2;
}

// The name of the test.
string name = 1;

// Whether this test ensure that the vulnerability is found or not.
bool expect_vulnerability = 2;

// Configuration for the mocking the callback server behavior.
MockCallbackConfig mock_callback_server = 3;

// The action being tested.
oneof anyAction {
HttpTestAction mock_http_server = 3;
HttpTestAction mock_http_server = 4;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import com.google.tsunami.common.time.testing.FakeUtcClock;
import com.google.tsunami.common.time.testing.FakeUtcClockModule;
import com.google.tsunami.plugin.payload.testing.FakePayloadGeneratorModule;
import com.google.tsunami.plugin.payload.testing.PayloadTestHelper;
import com.google.tsunami.proto.NetworkService;
import com.google.tsunami.proto.TargetInfo;
import com.google.tsunami.proto.TransportProtocol;
Expand Down Expand Up @@ -62,6 +63,7 @@ public final class TemplatedDetectorDynamicTest {

private Environment environment;
private MockWebServer mockWebServer;
private MockWebServer mockCallbackServer;

private static final FakeUtcClock fakeUtcClock =
FakeUtcClock.create().setNow(Instant.parse("2020-01-01T00:00:00.00Z"));
Expand All @@ -74,21 +76,33 @@ public void nextBytes(byte[] bytes) {
};

@Before
public void setupMockWebServer() throws IOException {
public void setupMockServers() throws IOException {
environment = new Environment(false);
mockWebServer = new MockWebServer();
mockCallbackServer = new MockWebServer();
var baseUrl = "http://" + mockWebServer.getHostName() + ":" + mockWebServer.getPort() + "/";
environment.set("T_NS_BASEURL", baseUrl);
}

@After
public void tearDown() throws IOException {
mockWebServer.shutdown();
mockCallbackServer.shutdown();
}

@Test
@TestParameters(valuesProvider = TestProvider.class)
public void runTest(TemplatedDetector detector, TemplatedPluginTests.Test testCase) {
public void runTest(String pluginName, TemplatedPluginTests.Test testCase) {
var detectors = getDetectorsForCase(testCase);

if (!detectors.containsKey(pluginName)) {
throw new IllegalArgumentException(
"Plugin '"
+ pluginName
+ "' not found (ensure the tested_plugin field is set correctly).");
}

var detector = detectors.get(pluginName);
switch (testCase.getAnyActionCase()) {
case MOCK_HTTP_SERVER:
forHttpAction(detector, testCase);
Expand All @@ -98,6 +112,38 @@ public void runTest(TemplatedDetector detector, TemplatedPluginTests.Test testCa
}
}

private final ImmutableMap<String, TemplatedDetector> getDetectorsForCase(
TemplatedPluginTests.Test testCase) {
// Inject the adequate callback server configuration.
FakePayloadGeneratorModule.Builder payloadGeneratorModuleBuilder =
FakePayloadGeneratorModule.builder().setSecureRng(testSecureRandom);

if (testCase.getMockCallbackServer().getEnabled()) {
payloadGeneratorModuleBuilder.setCallbackServer(mockCallbackServer);

try {
if (testCase.getMockCallbackServer().getHasInteraction()) {
mockCallbackServer.enqueue(PayloadTestHelper.generateMockSuccessfulCallbackResponse());
} else {
mockCallbackServer.enqueue(PayloadTestHelper.generateMockUnsuccessfulCallbackResponse());
}
} catch (IOException e) {
throw new RuntimeException(e);
}
}

// Inject dependencies and get the detectors.
var bootstrap = new TemplatedDetectorBootstrapModule();
bootstrap.setForceLoadDetectors(true);
Guice.createInjector(
new FakeUtcClockModule(fakeUtcClock),
new HttpClientModule.Builder().build(),
payloadGeneratorModuleBuilder.build(),
bootstrap);

return bootstrap.getDetectors();
}

private final void forHttpAction(TemplatedDetector detector, TemplatedPluginTests.Test testCase) {
ImmutableList<NetworkService> httpServices =
ImmutableList.of(
Expand Down Expand Up @@ -199,34 +245,21 @@ private final MockResponse createResponse(HttpTestAction.MockResponse testRespon
static final class TestProvider extends TestParametersValuesProvider {
@Override
public ImmutableList<TestParametersValues> provideValues(Context context) {
var detectors = getDetectors();
return getResourceNames().stream()
.map(TestProvider::loadPlugin)
.filter(plugin -> plugin != null)
.flatMap(plugin -> parametersForPlugin(plugin, detectors).stream())
.flatMap(plugin -> parametersForPlugin(plugin).stream())
.collect(toImmutableList());
}

private static final ImmutableMap<String, TemplatedDetector> getDetectors() {
var bootstrap = new TemplatedDetectorBootstrapModule();
bootstrap.setForceLoadDetectors(true);
Guice.createInjector(
new FakeUtcClockModule(fakeUtcClock),
new HttpClientModule.Builder().build(),
FakePayloadGeneratorModule.builder().setSecureRng(testSecureRandom).build(),
bootstrap);
return bootstrap.getDetectors();
}

private static ImmutableList<TestParametersValues> generateCommonTests(
TemplatedDetector detector) {
private static ImmutableList<TestParametersValues> generateCommonTests(String pluginName) {
// Echo server test: plugins should never return a vulnerability when the response just
// contains the request.
var testName = detector.getName() + ", autogenerated_whenEchoServer_returnsFalse";
var testName = pluginName + ", autogenerated_whenEchoServer_returnsFalse";
return ImmutableList.of(
TestParametersValues.builder()
.name(testName)
.addParameter("detector", detector)
.addParameter("pluginName", pluginName)
.addParameter(
"testCase",
TemplatedPluginTests.Test.newBuilder()
Expand All @@ -242,31 +275,25 @@ private static ImmutableList<TestParametersValues> generateCommonTests(
}

private static ImmutableList<TestParametersValues> parametersForPlugin(
TemplatedPluginTests pluginTests, ImmutableMap<String, TemplatedDetector> detectors) {
TemplatedPluginTests pluginTests) {
var pluginName = pluginTests.getConfig().getTestedPlugin();

if (pluginTests.getConfig().getDisabled()) {
logger.atWarning().log("Plugin '%s' tests are disabled.", pluginName);
return ImmutableList.of();
}

if (!detectors.containsKey(pluginName)) {
logger.atWarning().log("Plugin '%s' not found or disabled. Skipping test.", pluginName);
return ImmutableList.of();
}

var detector = detectors.get(pluginName);
var testsBuilder = ImmutableList.<TestParametersValues>builder();
// Inject tests that are common to all plugins.
testsBuilder.addAll(generateCommonTests(detector));
testsBuilder.addAll(generateCommonTests(pluginName));

// Tests defined in the plugin test file.
pluginTests.getTestsList().stream()
.map(
t ->
TestParametersValues.builder()
.name(pluginName + ", " + t.getName())
.addParameter("detector", detector)
.addParameter("pluginName", pluginName)
.addParameter("testCase", t)
.build())
.forEach(testsBuilder::add);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package com.google.tsunami.plugins.detectors.templateddetector.actions;

import static com.google.common.truth.Truth.assertThat;

import com.google.inject.Guice;
import com.google.tsunami.common.net.http.HttpClientModule;
import com.google.tsunami.plugin.TcsClient;
import com.google.tsunami.plugin.payload.testing.FakePayloadGeneratorModule;
import com.google.tsunami.plugin.payload.testing.PayloadTestHelper;
import com.google.tsunami.plugins.detectors.templateddetector.Environment;
import com.google.tsunami.proto.NetworkService;
import com.google.tsunami.templatedplugin.proto.CallbackServerAction;
import com.google.tsunami.templatedplugin.proto.PluginAction;
import java.io.IOException;
import java.security.SecureRandom;
import java.util.Arrays;
import javax.inject.Inject;
import okhttp3.mockwebserver.MockWebServer;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public final class CallbackServerActionRunnerTest {
private CallbackServerActionRunner runner;
private Environment environment;
private MockWebServer mockCallbackServer;
private NetworkService service;

@Inject private TcsClient tcsClient;

private static final SecureRandom testSecureRandom =
new SecureRandom() {
@Override
public void nextBytes(byte[] bytes) {
Arrays.fill(bytes, (byte) 0xFF);
}
};

@Before
public void setup() {
this.environment = new Environment(false);
this.environment.set("T_CBS_SECRET", "irrelevant");
this.service = NetworkService.getDefaultInstance();
}

@Before
public void setupMockHttp() {
this.mockCallbackServer = new MockWebServer();
}

@After
public void tearMockHttp() throws IOException {
this.mockCallbackServer.shutdown();
}

@Test
public void checkAction_whenCallbackServerDisabled_returnsFalse() throws IOException {
PluginAction action =
PluginAction.newBuilder()
.setName("action")
.setCallbackServer(
CallbackServerAction.newBuilder()
.setActionType(CallbackServerAction.ActionType.CHECK))
.build();

setupCallbackServer(false, false);

assertThat(runner.run(this.service, action, this.environment)).isFalse();
}

@Test
public void checkAction_whenCallbackServerReturnsFalse_returnsFalse() throws IOException {
PluginAction action =
PluginAction.newBuilder()
.setName("action")
.setCallbackServer(
CallbackServerAction.newBuilder()
.setActionType(CallbackServerAction.ActionType.CHECK))
.build();

setupCallbackServer(true, false);

assertThat(runner.run(this.service, action, this.environment)).isFalse();
assertThat(this.mockCallbackServer.getRequestCount()).isEqualTo(1);
}

@Test
public void checkAction_whenCallbackServerReturnsTrue_returnsTrue() throws IOException {
PluginAction action =
PluginAction.newBuilder()
.setName("action")
.setCallbackServer(
CallbackServerAction.newBuilder()
.setActionType(CallbackServerAction.ActionType.CHECK))
.build();

setupCallbackServer(true, true);

assertThat(runner.run(this.service, action, this.environment)).isTrue();
assertThat(this.mockCallbackServer.getRequestCount()).isEqualTo(1);
}

private final void setupCallbackServer(boolean enabled, boolean response) throws IOException {
FakePayloadGeneratorModule.Builder payloadGeneratorModuleBuilder =
FakePayloadGeneratorModule.builder().setSecureRng(testSecureRandom);

if (enabled) {
payloadGeneratorModuleBuilder.setCallbackServer(mockCallbackServer);

if (response) {
mockCallbackServer.enqueue(PayloadTestHelper.generateMockSuccessfulCallbackResponse());
} else {
mockCallbackServer.enqueue(PayloadTestHelper.generateMockUnsuccessfulCallbackResponse());
}
}

Guice.createInjector(
new HttpClientModule.Builder().build(), payloadGeneratorModuleBuilder.build())
.injectMembers(this);
this.runner = new CallbackServerActionRunner(tcsClient, false);
}
}

0 comments on commit 817c341

Please sign in to comment.