From 817c341e43693146e7d93a4cfe6ea2fc28e350f5 Mon Sep 17 00:00:00 2001 From: Pierre Precourt Date: Tue, 11 Feb 2025 02:41:58 -0800 Subject: [PATCH] Add support for workflows using the callback server in unit tests. This change requires: - Delaying the initialization of detectors to when the test runs rather than when it is being created; - Change to the test interface: instead of providing the detector tested, each test is now only provided with the name of the detector. It is now the responsibility of each test to initialize the detectors depending on their environment; This change might make unit tests slightly slower, especially as the list of plugin grows, but provide a lot more flexibility for testing hermetically. PiperOrigin-RevId: 725549365 Change-Id: I58f2a6f1f8955615b8ecbd1abe1c97aa1d62499d --- .../proto/templated_plugin_tests.proto | 16 ++- .../TemplatedDetectorDynamicTest.java | 83 ++++++++---- .../CallbackServerActionRunnerTest.java | 125 ++++++++++++++++++ 3 files changed, 195 insertions(+), 29 deletions(-) create mode 100644 templated/templateddetector/src/test/java/com/google/tsunami/plugins/detectors/templateddetector/actions/CallbackServerActionRunnerTest.java diff --git a/templated/templateddetector/proto/templated_plugin_tests.proto b/templated/templateddetector/proto/templated_plugin_tests.proto index 345a0e862..c90ea63af 100644 --- a/templated/templateddetector/proto/templated_plugin_tests.proto +++ b/templated/templateddetector/proto/templated_plugin_tests.proto @@ -18,15 +18,29 @@ message TemplatedPluginTests { } message Test { + message MockCallbackConfig { + // Whether the callback server should be mocked. + // Disabled by default. + bool enabled = 1; + + // The return value for the callback server. That is, whether the callback + // server reports the secret having been interracted with or not. + // Only used if enabled is true. + bool has_interaction = 2; + } + // The name of the test. string name = 1; // Whether this test ensure that the vulnerability is found or not. bool expect_vulnerability = 2; + // Configuration for the mocking the callback server behavior. + MockCallbackConfig mock_callback_server = 3; + // The action being tested. oneof anyAction { - HttpTestAction mock_http_server = 3; + HttpTestAction mock_http_server = 4; } } diff --git a/templated/templateddetector/src/test/java/com/google/tsunami/plugins/detectors/templateddetector/TemplatedDetectorDynamicTest.java b/templated/templateddetector/src/test/java/com/google/tsunami/plugins/detectors/templateddetector/TemplatedDetectorDynamicTest.java index abb4c51dd..baae44cc8 100644 --- a/templated/templateddetector/src/test/java/com/google/tsunami/plugins/detectors/templateddetector/TemplatedDetectorDynamicTest.java +++ b/templated/templateddetector/src/test/java/com/google/tsunami/plugins/detectors/templateddetector/TemplatedDetectorDynamicTest.java @@ -35,6 +35,7 @@ import com.google.tsunami.common.time.testing.FakeUtcClock; import com.google.tsunami.common.time.testing.FakeUtcClockModule; import com.google.tsunami.plugin.payload.testing.FakePayloadGeneratorModule; +import com.google.tsunami.plugin.payload.testing.PayloadTestHelper; import com.google.tsunami.proto.NetworkService; import com.google.tsunami.proto.TargetInfo; import com.google.tsunami.proto.TransportProtocol; @@ -62,6 +63,7 @@ public final class TemplatedDetectorDynamicTest { private Environment environment; private MockWebServer mockWebServer; + private MockWebServer mockCallbackServer; private static final FakeUtcClock fakeUtcClock = FakeUtcClock.create().setNow(Instant.parse("2020-01-01T00:00:00.00Z")); @@ -74,9 +76,10 @@ public void nextBytes(byte[] bytes) { }; @Before - public void setupMockWebServer() throws IOException { + public void setupMockServers() throws IOException { environment = new Environment(false); mockWebServer = new MockWebServer(); + mockCallbackServer = new MockWebServer(); var baseUrl = "http://" + mockWebServer.getHostName() + ":" + mockWebServer.getPort() + "/"; environment.set("T_NS_BASEURL", baseUrl); } @@ -84,11 +87,22 @@ public void setupMockWebServer() throws IOException { @After public void tearDown() throws IOException { mockWebServer.shutdown(); + mockCallbackServer.shutdown(); } @Test @TestParameters(valuesProvider = TestProvider.class) - public void runTest(TemplatedDetector detector, TemplatedPluginTests.Test testCase) { + public void runTest(String pluginName, TemplatedPluginTests.Test testCase) { + var detectors = getDetectorsForCase(testCase); + + if (!detectors.containsKey(pluginName)) { + throw new IllegalArgumentException( + "Plugin '" + + pluginName + + "' not found (ensure the tested_plugin field is set correctly)."); + } + + var detector = detectors.get(pluginName); switch (testCase.getAnyActionCase()) { case MOCK_HTTP_SERVER: forHttpAction(detector, testCase); @@ -98,6 +112,38 @@ public void runTest(TemplatedDetector detector, TemplatedPluginTests.Test testCa } } + private final ImmutableMap getDetectorsForCase( + TemplatedPluginTests.Test testCase) { + // Inject the adequate callback server configuration. + FakePayloadGeneratorModule.Builder payloadGeneratorModuleBuilder = + FakePayloadGeneratorModule.builder().setSecureRng(testSecureRandom); + + if (testCase.getMockCallbackServer().getEnabled()) { + payloadGeneratorModuleBuilder.setCallbackServer(mockCallbackServer); + + try { + if (testCase.getMockCallbackServer().getHasInteraction()) { + mockCallbackServer.enqueue(PayloadTestHelper.generateMockSuccessfulCallbackResponse()); + } else { + mockCallbackServer.enqueue(PayloadTestHelper.generateMockUnsuccessfulCallbackResponse()); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + // Inject dependencies and get the detectors. + var bootstrap = new TemplatedDetectorBootstrapModule(); + bootstrap.setForceLoadDetectors(true); + Guice.createInjector( + new FakeUtcClockModule(fakeUtcClock), + new HttpClientModule.Builder().build(), + payloadGeneratorModuleBuilder.build(), + bootstrap); + + return bootstrap.getDetectors(); + } + private final void forHttpAction(TemplatedDetector detector, TemplatedPluginTests.Test testCase) { ImmutableList httpServices = ImmutableList.of( @@ -199,34 +245,21 @@ private final MockResponse createResponse(HttpTestAction.MockResponse testRespon static final class TestProvider extends TestParametersValuesProvider { @Override public ImmutableList provideValues(Context context) { - var detectors = getDetectors(); return getResourceNames().stream() .map(TestProvider::loadPlugin) .filter(plugin -> plugin != null) - .flatMap(plugin -> parametersForPlugin(plugin, detectors).stream()) + .flatMap(plugin -> parametersForPlugin(plugin).stream()) .collect(toImmutableList()); } - private static final ImmutableMap getDetectors() { - var bootstrap = new TemplatedDetectorBootstrapModule(); - bootstrap.setForceLoadDetectors(true); - Guice.createInjector( - new FakeUtcClockModule(fakeUtcClock), - new HttpClientModule.Builder().build(), - FakePayloadGeneratorModule.builder().setSecureRng(testSecureRandom).build(), - bootstrap); - return bootstrap.getDetectors(); - } - - private static ImmutableList generateCommonTests( - TemplatedDetector detector) { + private static ImmutableList generateCommonTests(String pluginName) { // Echo server test: plugins should never return a vulnerability when the response just // contains the request. - var testName = detector.getName() + ", autogenerated_whenEchoServer_returnsFalse"; + var testName = pluginName + ", autogenerated_whenEchoServer_returnsFalse"; return ImmutableList.of( TestParametersValues.builder() .name(testName) - .addParameter("detector", detector) + .addParameter("pluginName", pluginName) .addParameter( "testCase", TemplatedPluginTests.Test.newBuilder() @@ -242,7 +275,7 @@ private static ImmutableList generateCommonTests( } private static ImmutableList parametersForPlugin( - TemplatedPluginTests pluginTests, ImmutableMap detectors) { + TemplatedPluginTests pluginTests) { var pluginName = pluginTests.getConfig().getTestedPlugin(); if (pluginTests.getConfig().getDisabled()) { @@ -250,15 +283,9 @@ private static ImmutableList parametersForPlugin( return ImmutableList.of(); } - if (!detectors.containsKey(pluginName)) { - logger.atWarning().log("Plugin '%s' not found or disabled. Skipping test.", pluginName); - return ImmutableList.of(); - } - - var detector = detectors.get(pluginName); var testsBuilder = ImmutableList.builder(); // Inject tests that are common to all plugins. - testsBuilder.addAll(generateCommonTests(detector)); + testsBuilder.addAll(generateCommonTests(pluginName)); // Tests defined in the plugin test file. pluginTests.getTestsList().stream() @@ -266,7 +293,7 @@ private static ImmutableList parametersForPlugin( t -> TestParametersValues.builder() .name(pluginName + ", " + t.getName()) - .addParameter("detector", detector) + .addParameter("pluginName", pluginName) .addParameter("testCase", t) .build()) .forEach(testsBuilder::add); diff --git a/templated/templateddetector/src/test/java/com/google/tsunami/plugins/detectors/templateddetector/actions/CallbackServerActionRunnerTest.java b/templated/templateddetector/src/test/java/com/google/tsunami/plugins/detectors/templateddetector/actions/CallbackServerActionRunnerTest.java new file mode 100644 index 000000000..9dcf5efcb --- /dev/null +++ b/templated/templateddetector/src/test/java/com/google/tsunami/plugins/detectors/templateddetector/actions/CallbackServerActionRunnerTest.java @@ -0,0 +1,125 @@ +package com.google.tsunami.plugins.detectors.templateddetector.actions; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.inject.Guice; +import com.google.tsunami.common.net.http.HttpClientModule; +import com.google.tsunami.plugin.TcsClient; +import com.google.tsunami.plugin.payload.testing.FakePayloadGeneratorModule; +import com.google.tsunami.plugin.payload.testing.PayloadTestHelper; +import com.google.tsunami.plugins.detectors.templateddetector.Environment; +import com.google.tsunami.proto.NetworkService; +import com.google.tsunami.templatedplugin.proto.CallbackServerAction; +import com.google.tsunami.templatedplugin.proto.PluginAction; +import java.io.IOException; +import java.security.SecureRandom; +import java.util.Arrays; +import javax.inject.Inject; +import okhttp3.mockwebserver.MockWebServer; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class CallbackServerActionRunnerTest { + private CallbackServerActionRunner runner; + private Environment environment; + private MockWebServer mockCallbackServer; + private NetworkService service; + + @Inject private TcsClient tcsClient; + + private static final SecureRandom testSecureRandom = + new SecureRandom() { + @Override + public void nextBytes(byte[] bytes) { + Arrays.fill(bytes, (byte) 0xFF); + } + }; + + @Before + public void setup() { + this.environment = new Environment(false); + this.environment.set("T_CBS_SECRET", "irrelevant"); + this.service = NetworkService.getDefaultInstance(); + } + + @Before + public void setupMockHttp() { + this.mockCallbackServer = new MockWebServer(); + } + + @After + public void tearMockHttp() throws IOException { + this.mockCallbackServer.shutdown(); + } + + @Test + public void checkAction_whenCallbackServerDisabled_returnsFalse() throws IOException { + PluginAction action = + PluginAction.newBuilder() + .setName("action") + .setCallbackServer( + CallbackServerAction.newBuilder() + .setActionType(CallbackServerAction.ActionType.CHECK)) + .build(); + + setupCallbackServer(false, false); + + assertThat(runner.run(this.service, action, this.environment)).isFalse(); + } + + @Test + public void checkAction_whenCallbackServerReturnsFalse_returnsFalse() throws IOException { + PluginAction action = + PluginAction.newBuilder() + .setName("action") + .setCallbackServer( + CallbackServerAction.newBuilder() + .setActionType(CallbackServerAction.ActionType.CHECK)) + .build(); + + setupCallbackServer(true, false); + + assertThat(runner.run(this.service, action, this.environment)).isFalse(); + assertThat(this.mockCallbackServer.getRequestCount()).isEqualTo(1); + } + + @Test + public void checkAction_whenCallbackServerReturnsTrue_returnsTrue() throws IOException { + PluginAction action = + PluginAction.newBuilder() + .setName("action") + .setCallbackServer( + CallbackServerAction.newBuilder() + .setActionType(CallbackServerAction.ActionType.CHECK)) + .build(); + + setupCallbackServer(true, true); + + assertThat(runner.run(this.service, action, this.environment)).isTrue(); + assertThat(this.mockCallbackServer.getRequestCount()).isEqualTo(1); + } + + private final void setupCallbackServer(boolean enabled, boolean response) throws IOException { + FakePayloadGeneratorModule.Builder payloadGeneratorModuleBuilder = + FakePayloadGeneratorModule.builder().setSecureRng(testSecureRandom); + + if (enabled) { + payloadGeneratorModuleBuilder.setCallbackServer(mockCallbackServer); + + if (response) { + mockCallbackServer.enqueue(PayloadTestHelper.generateMockSuccessfulCallbackResponse()); + } else { + mockCallbackServer.enqueue(PayloadTestHelper.generateMockUnsuccessfulCallbackResponse()); + } + } + + Guice.createInjector( + new HttpClientModule.Builder().build(), payloadGeneratorModuleBuilder.build()) + .injectMembers(this); + this.runner = new CallbackServerActionRunner(tcsClient, false); + } +}