diff --git a/README.md b/README.md index 85313ca98..27ec54292 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ Firewall autonomously protects your Node.js applications against: * 🛡️ [Command injection attacks](https://owasp.org/www-community/attacks/Command_Injection) * 🛡️ [Prototype pollution](./docs/prototype-pollution.md) * 🛡️ [Path traversal attacks](https://owasp.org/www-community/attacks/Path_Traversal) +* 🛡️ [Server-side request forgery (SSRF)](./docs/ssrf.md) * 🚀 More to come (see the [public roadmap](https://github.com/orgs/AikidoSec/projects/2/views/1))! Firewall operates autonomously on the same server as your Node.js app to: diff --git a/docs/ssrf.md b/docs/ssrf.md new file mode 100644 index 000000000..29dc0b510 --- /dev/null +++ b/docs/ssrf.md @@ -0,0 +1,31 @@ +# Server-side request forgery (SSRF) + +Aikido firewall for Node.js 16+ secures your app against server-side request forgery (SSRF) attacks. SSRF vulnerabilities allow attackers to send crafted requests to internal services, bypassing firewalls and security controls. Runtime blocks SSRF attacks by intercepting and validating requests to internal services. + +## Example + +``` +GET https://your-app.com/files?url=http://localhost:3000/private +``` + +```js +const response = http.request(req.query.url); +``` + +In this example, an attacker sends a request to `localhost:3000/private` from your server. Firewall can intercept the request and block it, preventing the attacker from accessing internal services. + +``` +GET https://your-app.com/files?url=http://localtest.me:3000/private +``` + +In this example, the attacker sends a request to `localtest.me:3000/private`, which resolves to `127.0.0.1`. Firewall can intercept the request and block it, preventing the attacker from accessing internal services. + +We don't protect against stored SSRF attacks, where an attacker injects a malicious URL into your app's database. To prevent stored SSRF attacks, validate and sanitize user input before storing it in your database. + +## Which built-in modules are protected? + +Firewall protects against SSRF attacks in the following built-in modules: +* `http` +* `https` +* `undici` +* `globalThis.fetch` (Node.js 18+) diff --git a/library/agent/Agent.ts b/library/agent/Agent.ts index af3821e80..ec5044c13 100644 --- a/library/agent/Agent.ts +++ b/library/agent/Agent.ts @@ -425,6 +425,10 @@ export class Agent { return this.routes; } + log(message: string) { + this.logger.log(message); + } + async flushStats(timeoutInMS: number) { this.statistics.forceCompress(); await this.sendHeartbeat(timeoutInMS); diff --git a/library/agent/Attack.ts b/library/agent/Attack.ts index e076b2572..6d1bbe473 100644 --- a/library/agent/Attack.ts +++ b/library/agent/Attack.ts @@ -2,7 +2,8 @@ export type Kind = | "nosql_injection" | "sql_injection" | "shell_injection" - | "path_traversal"; + | "path_traversal" + | "ssrf"; export function attackKindHumanName(kind: Kind) { switch (kind) { @@ -14,5 +15,7 @@ export function attackKindHumanName(kind: Kind) { return "a shell injection"; case "path_traversal": return "a path traversal attack"; + case "ssrf": + return "a server-side request forgery"; } } diff --git a/library/agent/applyHooks.test.ts b/library/agent/applyHooks.test.ts index ddc4a26ab..4c4202f6c 100644 --- a/library/agent/applyHooks.test.ts +++ b/library/agent/applyHooks.test.ts @@ -94,9 +94,9 @@ t.test("it tries to wrap method that does not exist", async (t) => { }); t.same(logger.getMessages(), [ - "Failed to wrap method does_not_exist in module shell-quote", - "Failed to wrap method another_method_that_does_not_exist in module shell-quote", "Failed to wrap method another_second_method_that_does_not_exist in module shell-quote", + "Failed to wrap method another_method_that_does_not_exist in module shell-quote", + "Failed to wrap method does_not_exist in module shell-quote", ]); }); diff --git a/library/agent/applyHooks.ts b/library/agent/applyHooks.ts index 1b7fbf435..bd75787ed 100644 --- a/library/agent/applyHooks.ts +++ b/library/agent/applyHooks.ts @@ -92,17 +92,15 @@ export function applyHooks(hooks: Hooks, agent: Agent) { return; } - const interceptor = g.getMethodInterceptor(); - - if (!interceptor) { - return; - } - - if (interceptor instanceof ModifyingArgumentsMethodInterceptor) { - wrapWithArgumentModification(global, interceptor, "global", agent); - } else { - wrapWithoutArgumentModification(global, interceptor, "global", agent); - } + g.getMethodInterceptors() + .reverse() // Reverse to make sure we wrap in the order they were added + .forEach((interceptor) => { + if (interceptor instanceof ModifyingArgumentsMethodInterceptor) { + wrapWithArgumentModification(global, interceptor, name, agent); + } else { + wrapWithoutArgumentModification(global, interceptor, name, agent); + } + }); }); return wrapped; @@ -364,15 +362,18 @@ function wrapSubject( return; } - subject.getMethodInterceptors().forEach((method) => { - if (method instanceof ModifyingArgumentsMethodInterceptor) { - wrapWithArgumentModification(theSubject, method, module, agent); - } else if (method instanceof MethodInterceptor) { - wrapWithoutArgumentModification(theSubject, method, module, agent); - } else if (method instanceof MethodResultInterceptor) { - wrapWithResult(theSubject, method, module, agent); - } else { - wrapNewInstance(theSubject, method, module, agent); - } - }); + subject + .getMethodInterceptors() + .reverse() // Reverse to make sure we wrap in the order they were added + .forEach((method) => { + if (method instanceof ModifyingArgumentsMethodInterceptor) { + wrapWithArgumentModification(theSubject, method, module, agent); + } else if (method instanceof MethodInterceptor) { + wrapWithoutArgumentModification(theSubject, method, module, agent); + } else if (method instanceof MethodResultInterceptor) { + wrapWithResult(theSubject, method, module, agent); + } else { + wrapNewInstance(theSubject, method, module, agent); + } + }); } diff --git a/library/agent/hooks/Global.ts b/library/agent/hooks/Global.ts index 465555cb6..2f9dc00cf 100644 --- a/library/agent/hooks/Global.ts +++ b/library/agent/hooks/Global.ts @@ -5,10 +5,8 @@ import { } from "./ModifyingArgumentsInterceptor"; export class Global { - private method: - | MethodInterceptor - | ModifyingArgumentsMethodInterceptor - | undefined = undefined; + private methods: (MethodInterceptor | ModifyingArgumentsMethodInterceptor)[] = + []; constructor(private readonly name: string) { if (!this.name) { @@ -22,7 +20,8 @@ export class Global { * This is the preferred way to use when wrapping methods */ inspect(interceptor: Interceptor) { - this.method = new MethodInterceptor(this.name, interceptor); + const method = new MethodInterceptor(this.name, interceptor); + this.methods.push(method); return this; } @@ -35,10 +34,11 @@ export class Global { * Don't use this unless you have to, it's better to use inspect */ modifyArguments(interceptor: ModifyingArgumentsInterceptor) { - this.method = new ModifyingArgumentsMethodInterceptor( + const method = new ModifyingArgumentsMethodInterceptor( this.name, interceptor ); + this.methods.push(method); return this; } @@ -47,7 +47,7 @@ export class Global { return this.name; } - getMethodInterceptor() { - return this.method; + getMethodInterceptors() { + return this.methods; } } diff --git a/library/agent/logger/LoggerForTesting.ts b/library/agent/logger/LoggerForTesting.ts index ca56de31f..f559130e5 100644 --- a/library/agent/logger/LoggerForTesting.ts +++ b/library/agent/logger/LoggerForTesting.ts @@ -1,7 +1,7 @@ import { Logger } from "./Logger"; export class LoggerForTesting implements Logger { - private readonly messages: string[] = []; + private messages: string[] = []; log(message: string) { this.messages.push(message); @@ -10,4 +10,8 @@ export class LoggerForTesting implements Logger { getMessages() { return this.messages; } + + clear() { + this.messages = []; + } } diff --git a/library/helpers/tryParseURL.ts b/library/helpers/tryParseURL.ts index 181e93b53..b3a25161c 100644 --- a/library/helpers/tryParseURL.ts +++ b/library/helpers/tryParseURL.ts @@ -1,4 +1,4 @@ -export function tryParseURL(url: string) { +export function tryParseURL(url: string): URL | undefined { try { return new URL(url); } catch { diff --git a/library/sinks/Fetch.test.ts b/library/sinks/Fetch.test.ts index 9aca380a0..f4ccd3238 100644 --- a/library/sinks/Fetch.test.ts +++ b/library/sinks/Fetch.test.ts @@ -1,9 +1,49 @@ +/* eslint-disable prefer-rest-params */ import * as t from "tap"; import { Agent } from "../agent/Agent"; import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; -import { Token } from "../agent/api/Token"; +import { Context, runWithContext } from "../agent/Context"; import { LoggerNoop } from "../agent/logger/LoggerNoop"; +import { wrap } from "../helpers/wrap"; import { Fetch } from "./Fetch"; +import * as dns from "dns"; + +const calls: Record = {}; +wrap(dns, "lookup", function lookup(original) { + return function lookup() { + const hostname = arguments[0]; + + if (!calls[hostname]) { + calls[hostname] = 0; + } + + calls[hostname]++; + + if (hostname === "thisdomainpointstointernalip.com") { + return original.apply(this, [ + "localhost", + ...Array.from(arguments).slice(1), + ]); + } + + original.apply(this, arguments); + }; +}); + +const context: Context = { + remoteAddress: "::1", + method: "POST", + url: "http://localhost:4000", + query: {}, + headers: {}, + body: { + image: "http://localhost:4000/api/internal", + }, + cookies: {}, + routeParams: {}, + source: "express", + route: "/posts/:id", +}; t.test( "it works", @@ -13,7 +53,7 @@ t.test( true, new LoggerNoop(), new ReportingAPIForTesting(), - new Token("123"), + undefined, undefined ); agent.start([new Fetch()]); @@ -39,5 +79,51 @@ t.test( t.same(agent.getHostnames().asArray(), []); agent.getHostnames().clear(); + + await runWithContext(context, async () => { + await fetch("https://google.com"); + const error = await t.rejects(() => + fetch("http://localhost:4000/api/internal") + ); + if (error instanceof Error) { + t.same( + error.message, + "Aikido firewall has blocked a server-side request forgery: fetch(...) originating from body.image" + ); + } + + const error2 = await t.rejects(() => + fetch(new URL("http://localhost:4000/api/internal")) + ); + if (error2 instanceof Error) { + t.same( + error2.message, + "Aikido firewall has blocked a server-side request forgery: fetch(...) originating from body.image" + ); + } + }); + + await runWithContext( + { + ...context, + ...{ body: { image: "http://thisdomainpointstointernalip.com" } }, + }, + async () => { + const error = await t.rejects(() => + fetch("http://thisdomainpointstointernalip.com") + ); + if (error instanceof Error) { + t.same( + // @ts-expect-error Type is not defined + error.cause.message, + "Aikido firewall has blocked a server-side request forgery: fetch(...) originating from body.image" + ); + } + + // Ensure the lookup is only called once per hostname + // Otherwise, it could be vulnerable to TOCTOU + t.same(calls["thisdomainpointstointernalip.com"], 1); + } + ); } ); diff --git a/library/sinks/Fetch.ts b/library/sinks/Fetch.ts index e5aeb13e4..539281799 100644 --- a/library/sinks/Fetch.ts +++ b/library/sinks/Fetch.ts @@ -1,31 +1,126 @@ +import { lookup } from "dns"; import { Agent } from "../agent/Agent"; +import { getContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; +import { InterceptorResult } from "../agent/hooks/MethodInterceptor"; import { Wrapper } from "../agent/Wrapper"; import { getPortFromURL } from "../helpers/getPortFromURL"; +import { tryParseURL } from "../helpers/tryParseURL"; +import { checkContextForSSRF } from "../vulnerabilities/ssrf/checkContextForSSRF"; +import { inspectDNSLookupCalls } from "../vulnerabilities/ssrf/inspectDNSLookupCalls"; export class Fetch implements Wrapper { - inspectFetch(args: unknown[], agent: Agent) { + private patchedGlobalDispatcher = false; + + private inspectHostname( + agent: Agent, + hostname: string, + port: number | undefined + ): InterceptorResult { + // Let the agent know that we are connecting to this hostname + // This is to build a list of all hostnames that the application is connecting to + agent.onConnectHostname(hostname, port); + const context = getContext(); + + if (!context) { + return undefined; + } + + return checkContextForSSRF({ + hostname: hostname, + operation: "fetch", + context: context, + }); + } + + inspectFetch(args: unknown[], agent: Agent): InterceptorResult { if (args.length > 0) { if (typeof args[0] === "string" && args[0].length > 0) { - try { - const url = new URL(args[0]); - if (url.hostname.length > 0) { - agent.onConnectHostname(url.hostname, getPortFromURL(url)); + const url = tryParseURL(args[0]); + if (url) { + const attack = this.inspectHostname( + agent, + url.hostname, + getPortFromURL(url) + ); + if (attack) { + return attack; } - } catch (e) { - // Ignore } } if (args[0] instanceof URL && args[0].hostname.length > 0) { - agent.onConnectHostname(args[0].hostname, getPortFromURL(args[0])); + const attack = this.inspectHostname( + agent, + args[0].hostname, + getPortFromURL(args[0]) + ); + if (attack) { + return attack; + } } } + + return undefined; + } + + // We'll set a global dispatcher that will allow us to inspect the resolved IPs (and thus preventing TOCTOU attacks) + private patchGlobalDispatcher(agent: Agent) { + const undiciGlobalDispatcherSymbol = Symbol.for( + "undici.globalDispatcher.1" + ); + + // @ts-expect-error Type is not defined + const dispatcher = globalThis[undiciGlobalDispatcherSymbol]; + + if (!dispatcher) { + agent.log( + `global dispatcher not found for fetch, we can't provide protection!` + ); + return; + } + + if (dispatcher.constructor.name !== "Agent") { + agent.log( + `Expected Agent as global dispatcher for fetch but found ${dispatcher.constructor.name}, we can't provide protection!` + ); + return; + } + + try { + // @ts-expect-error Type is not defined + globalThis[undiciGlobalDispatcherSymbol] = new dispatcher.constructor({ + connect: { + lookup: inspectDNSLookupCalls(lookup, agent, "fetch", "fetch"), + }, + }); + } catch (error) { + agent.log( + `Failed to patch global dispatcher for fetch, we can't provide protection!` + ); + } } wrap(hooks: Hooks) { + if (typeof globalThis.fetch === "function") { + // Fetch is lazy loaded in Node.js + // By calling fetch() we ensure that the global dispatcher is available + // @ts-expect-error Type is not defined + globalThis.fetch().catch(() => {}); + } + hooks .addGlobal("fetch") - .inspect((args, subject, agent) => this.inspectFetch(args, agent)); + // Whenever a request is made, we'll check the hostname whether it's a private IP + .inspect((args, subject, agent) => this.inspectFetch(args, agent)) + // We're not really modifying the arguments here, but we need to patch the global dispatcher + .modifyArguments((args, subject, agent) => { + if (!this.patchedGlobalDispatcher) { + this.patchGlobalDispatcher(agent); + this.patchedGlobalDispatcher = true; + } + + return args; + }); } } diff --git a/library/sinks/HTTPRequest.test.ts b/library/sinks/HTTPRequest.test.ts index ae822b6a2..879360842 100644 --- a/library/sinks/HTTPRequest.test.ts +++ b/library/sinks/HTTPRequest.test.ts @@ -1,25 +1,52 @@ +/* eslint-disable prefer-rest-params */ +import * as dns from "dns"; import * as t from "tap"; import { Agent } from "../agent/Agent"; import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Token } from "../agent/api/Token"; import { Context, runWithContext } from "../agent/Context"; import { LoggerNoop } from "../agent/logger/LoggerNoop"; +import { wrap } from "../helpers/wrap"; import { HTTPRequest } from "./HTTPRequest"; +const calls: Record = {}; +wrap(dns, "lookup", function lookup(original) { + return function lookup() { + const hostname = arguments[0]; + + if (!calls[hostname]) { + calls[hostname] = 0; + } + + calls[hostname]++; + + if (hostname === "thisdomainpointstointernalip.com") { + return original.apply(this, [ + "localhost", + ...Array.from(arguments).slice(1), + ]); + } + + original.apply(this, arguments); + }; +}); + const context: Context = { remoteAddress: "::1", method: "POST", url: "http://localhost:4000", query: {}, headers: {}, - body: {}, + body: { + image: "http://localhost:4000/api/internal", + }, cookies: {}, routeParams: {}, source: "express", route: "/posts/:id", }; -t.test("it works", async (t) => { +t.test("it works", (t) => { const agent = new Agent( true, new LoggerNoop(), @@ -32,6 +59,7 @@ t.test("it works", async (t) => { t.same(agent.getHostnames().asArray(), []); const http = require("http"); + const https = require("https"); runWithContext(context, () => { const google = http.request("http://aikido.dev"); @@ -43,8 +71,6 @@ t.test("it works", async (t) => { ]); agent.getHostnames().clear(); - const https = require("https"); - runWithContext(context, () => { const google = https.request("https://aikido.dev"); google.end(); @@ -65,6 +91,7 @@ t.test("it works", async (t) => { hostname: "aikido.dev", port: undefined, }); + t.same(withoutPort instanceof http.ClientRequest, true); withoutPort.end(); t.same(agent.getHostnames().asArray(), [ { hostname: "aikido.dev", port: 443 }, @@ -76,12 +103,14 @@ t.test("it works", async (t) => { port: undefined, }); httpWithoutPort.end(); + t.same(httpWithoutPort instanceof http.ClientRequest, true); t.same(agent.getHostnames().asArray(), [ { hostname: "aikido.dev", port: 80 }, ]); agent.getHostnames().clear(); const withPort = https.request({ hostname: "aikido.dev", port: 443 }); + t.same(withPort instanceof http.ClientRequest, true); withPort.end(); t.same(agent.getHostnames().asArray(), [ { hostname: "aikido.dev", port: 443 }, @@ -89,6 +118,7 @@ t.test("it works", async (t) => { agent.getHostnames().clear(); const withStringPort = https.request({ hostname: "aikido.dev", port: "443" }); + t.same(withStringPort instanceof http.ClientRequest, true); withStringPort.end(); t.same(agent.getHostnames().asArray(), [ { hostname: "aikido.dev", port: "443" }, @@ -99,4 +129,84 @@ t.test("it works", async (t) => { t.throws(() => https.request("invalid url")); t.same(agent.getHostnames().asArray(), []); agent.getHostnames().clear(); + + runWithContext( + { ...context, ...{ body: { image: "thisdomainpointstointernalip.com" } } }, + () => { + https + .request("https://thisdomainpointstointernalip.com") + .on("error", (error) => { + t.match( + error.message, + "Aikido firewall has blocked a server-side request forgery: https.request(...) originating from body.image" + ); + + // Ensure the lookup is only called once per hostname + // Otherwise, it could be vulnerable to TOCTOU + t.same(calls["thisdomainpointstointernalip.com"], 1); + }) + .on("finish", () => { + t.fail("should not finish"); + }) + .end(); + } + ); + + runWithContext(context, () => { + // With lookup function specified + const google = http.request("http://google.com", { lookup: dns.lookup }); + google.end(); + + // With options object + const google2 = http.request("http://google.com", {}); + google2.end(); + }); + + runWithContext(context, () => { + // Safe request + const google = https.request("https://google.com"); + google.end(); + + // With string URL + const error = t.throws(() => + https.request("https://localhost:4000/api/internal") + ); + if (error instanceof Error) { + t.same( + error.message, + "Aikido firewall has blocked a server-side request forgery: https.request(...) originating from body.image" + ); + } + + // With URL object + const error2 = t.throws(() => + https.request(new URL("https://localhost:4000/api/internal")) + ); + if (error2 instanceof Error) { + t.same( + error2.message, + "Aikido firewall has blocked a server-side request forgery: https.request(...) originating from body.image" + ); + } + + // With object like URL + const error3 = t.throws(() => + https.request({ + protocol: "https:", + hostname: "localhost", + port: 4000, + path: "/api/internal", + }) + ); + if (error3 instanceof Error) { + t.same( + error3.message, + "Aikido firewall has blocked a server-side request forgery: https.request(...) originating from body.image" + ); + } + }); + + setTimeout(() => { + t.end(); + }, 1000); }); diff --git a/library/sinks/HTTPRequest.ts b/library/sinks/HTTPRequest.ts index 02ca754fa..b251dcc31 100644 --- a/library/sinks/HTTPRequest.ts +++ b/library/sinks/HTTPRequest.ts @@ -1,17 +1,54 @@ +import { lookup } from "dns"; +import type { RequestOptions } from "http"; import { Agent } from "../agent/Agent"; +import { getContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; +import { InterceptorResult } from "../agent/hooks/MethodInterceptor"; import { Wrapper } from "../agent/Wrapper"; import { getPortFromURL } from "../helpers/getPortFromURL"; import { isPlainObject } from "../helpers/isPlainObject"; +import { checkContextForSSRF } from "../vulnerabilities/ssrf/checkContextForSSRF"; +import { inspectDNSLookupCalls } from "../vulnerabilities/ssrf/inspectDNSLookupCalls"; export class HTTPRequest implements Wrapper { - inspectHttpRequest(args: unknown[], agent: Agent, module: string) { + private inspectHostname( + agent: Agent, + hostname: string, + port: number | undefined, + module: string + ): InterceptorResult { + // Let the agent know that we are connecting to this hostname + // This is to build a list of all hostnames that the application is connecting to + agent.onConnectHostname(hostname, port); + const context = getContext(); + + if (!context) { + return undefined; + } + + return checkContextForSSRF({ + hostname: hostname, + operation: `${module}.request`, + context: context, + }); + } + + // eslint-disable-next-line max-lines-per-function + private inspectHttpRequest(args: unknown[], agent: Agent, module: string) { if (args.length > 0) { if (typeof args[0] === "string" && args[0].length > 0) { try { const url = new URL(args[0]); if (url.hostname.length > 0) { - agent.onConnectHostname(url.hostname, getPortFromURL(url)); + const attack = this.inspectHostname( + agent, + url.hostname, + getPortFromURL(url), + module + ); + if (attack) { + return attack; + } } } catch (e) { // Ignore @@ -19,7 +56,15 @@ export class HTTPRequest implements Wrapper { } if (args[0] instanceof URL && args[0].hostname.length > 0) { - agent.onConnectHostname(args[0].hostname, getPortFromURL(args[0])); + const attack = this.inspectHostname( + agent, + args[0].hostname, + getPortFromURL(args[0]), + module + ); + if (attack) { + return attack; + } } if ( @@ -37,24 +82,84 @@ export class HTTPRequest implements Wrapper { port = parseInt(args[0].port, 10); } - agent.onConnectHostname(args[0].hostname, port); + const attack = this.inspectHostname( + agent, + args[0].hostname, + port, + module + ); + if (attack) { + return attack; + } } } + + return undefined; + } + + private monitorDNSLookups( + args: unknown[], + agent: Agent, + module: string + ): unknown[] { + const context = getContext(); + + if (!context) { + return args; + } + + const optionObj = args.find((arg): arg is RequestOptions => + isPlainObject(arg) + ); + + if (!optionObj) { + return args.concat([ + { + lookup: inspectDNSLookupCalls( + lookup, + agent, + module, + `${module}.request` + ), + }, + ]); + } + + if (optionObj.lookup) { + optionObj.lookup = inspectDNSLookupCalls( + optionObj.lookup, + agent, + module, + `${module}.request` + ) as RequestOptions["lookup"]; + } else { + optionObj.lookup = inspectDNSLookupCalls( + lookup, + agent, + module, + `${module}.request` + ) as RequestOptions["lookup"]; + } + + return args; } wrap(hooks: Hooks) { - hooks - .addBuiltinModule("http") - .addSubject((exports) => exports) - .inspect("request", (args, subject, agent) => - this.inspectHttpRequest(args, agent, "http") - ); - - hooks - .addBuiltinModule("https") - .addSubject((exports) => exports) - .inspect("request", (args, subject, agent) => - this.inspectHttpRequest(args, agent, "https") - ); + const modules = ["http", "https"]; + + modules.forEach((module) => { + hooks + .addBuiltinModule(module) + .addSubject((exports) => exports) + // Whenever a request is made, we'll check the hostname whether it's a private IP + .inspect("request", (args, subject, agent) => + this.inspectHttpRequest(args, agent, module) + ) + // Whenever a request is made, we'll modify the options to pass a custom lookup function + // that will inspect resolved IP address (and thus preventing TOCTOU attacks) + .modifyArguments("request", (args, subject, agent) => + this.monitorDNSLookups(args, agent, module) + ); + }); } } diff --git a/library/sinks/Undici.test.ts b/library/sinks/Undici.test.ts index 39a90561c..d847ec2ac 100644 --- a/library/sinks/Undici.test.ts +++ b/library/sinks/Undici.test.ts @@ -1,11 +1,51 @@ +/* eslint-disable prefer-rest-params */ +import * as dns from "dns"; import * as t from "tap"; -import { fetch, request } from "undici"; import { Agent } from "../agent/Agent"; import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting"; import { Token } from "../agent/api/Token"; -import { LoggerNoop } from "../agent/logger/LoggerNoop"; +import { Context, runWithContext } from "../agent/Context"; +import { LoggerForTesting } from "../agent/logger/LoggerForTesting"; +import { wrap } from "../helpers/wrap"; import { Undici } from "./Undici"; +const calls: Record = {}; +wrap(dns, "lookup", function lookup(original) { + return function lookup() { + const hostname = arguments[0]; + + if (!calls[hostname]) { + calls[hostname] = 0; + } + + calls[hostname]++; + + if (hostname === "thisdomainpointstointernalip.com") { + return original.apply(this, [ + "localhost", + ...Array.from(arguments).slice(1), + ]); + } + + original.apply(this, arguments); + }; +}); + +const context: Context = { + remoteAddress: "::1", + method: "POST", + url: "http://localhost:4000", + query: {}, + headers: {}, + body: { + image: "http://localhost:4000/api/internal", + }, + cookies: {}, + routeParams: {}, + source: "express", + route: "/posts/:id", +}; + t.test( "it works", { @@ -14,9 +54,10 @@ t.test( : false, }, async () => { + const logger = new LoggerForTesting(); const agent = new Agent( true, - new LoggerNoop(), + logger, new ReportingAPIForTesting(), new Token("123"), undefined @@ -24,7 +65,12 @@ t.test( agent.start([new Undici()]); - const { request, fetch } = require("undici"); + const { + request, + fetch, + setGlobalDispatcher, + Agent: UndiciAgent, + } = require("undici"); await request("https://aikido.dev"); t.same(agent.getHostnames().asArray(), [ @@ -88,5 +134,82 @@ t.test( await t.rejects(() => request("invalid url")); await t.rejects(() => request({ hostname: "" })); + + await runWithContext(context, async () => { + await request("https://google.com"); + const error = await t.rejects(() => + request("http://localhost:4000/api/internal") + ); + if (error instanceof Error) { + t.same( + error.message, + "Aikido firewall has blocked a server-side request forgery: undici.request(...) originating from body.image" + ); + } + const error2 = await t.rejects(() => + request(new URL("http://localhost:4000/api/internal")) + ); + if (error2 instanceof Error) { + t.same( + error2.message, + "Aikido firewall has blocked a server-side request forgery: undici.request(...) originating from body.image" + ); + } + const error3 = await t.rejects(() => + request({ + protocol: "http:", + hostname: "localhost", + port: 4000, + path: "/api/internal", + }) + ); + if (error3 instanceof Error) { + t.same( + error3.message, + "Aikido firewall has blocked a server-side request forgery: undici.request(...) originating from body.image" + ); + } + }); + + await runWithContext( + { ...context, routeParams: { param: "http://0" } }, + async () => { + const error = await t.rejects(() => request("http://0")); + if (error instanceof Error) { + t.same( + error.message, + "Aikido firewall has blocked a server-side request forgery: undici.request(...) originating from routeParams.param" + ); + } + } + ); + + await runWithContext( + { + ...context, + body: { image: "http://thisdomainpointstointernalip.com" }, + }, + async () => { + const error = await t.rejects(() => + request("http://thisdomainpointstointernalip.com") + ); + if (error instanceof Error) { + t.same( + error.message, + "Aikido firewall has blocked a server-side request forgery: undici.[method](...) originating from body.image" + ); + } + + // Ensure the lookup is only called once per hostname + // Otherwise, it could be vulnerable to TOCTOU + t.same(calls["thisdomainpointstointernalip.com"], 1); + } + ); + + logger.clear(); + setGlobalDispatcher(new UndiciAgent({})); + t.same(logger.getMessages(), [ + "undici.setGlobalDispatcher was called, we can't provide protection!", + ]); } ); diff --git a/library/sinks/Undici.ts b/library/sinks/Undici.ts index cf46478ec..76c8f887f 100644 --- a/library/sinks/Undici.ts +++ b/library/sinks/Undici.ts @@ -1,25 +1,81 @@ +import { lookup } from "dns"; import { Agent } from "../agent/Agent"; +import { getContext } from "../agent/Context"; import { Hooks } from "../agent/hooks/Hooks"; +import { InterceptorResult } from "../agent/hooks/MethodInterceptor"; import { Wrapper } from "../agent/Wrapper"; import { getPortFromURL } from "../helpers/getPortFromURL"; import { isPlainObject } from "../helpers/isPlainObject"; +import { tryParseURL } from "../helpers/tryParseURL"; +import { checkContextForSSRF } from "../vulnerabilities/ssrf/checkContextForSSRF"; +import { inspectDNSLookupCalls } from "../vulnerabilities/ssrf/inspectDNSLookupCalls"; + +const methods = [ + "request", + "stream", + "pipeline", + "connect", + "fetch", + "upgrade", +]; export class Undici implements Wrapper { - inspect(args: unknown[], agent: Agent) { + private patchedGlobalDispatcher = false; + + private inspectHostname( + agent: Agent, + hostname: string, + port: number | undefined, + method: string + ): InterceptorResult { + // Let the agent know that we are connecting to this hostname + // This is to build a list of all hostnames that the application is connecting to + agent.onConnectHostname(hostname, port); + const context = getContext(); + + if (!context) { + return undefined; + } + + return checkContextForSSRF({ + hostname: hostname, + operation: `undici.${method}`, + context, + }); + } + + // eslint-disable-next-line max-lines-per-function + private inspect( + args: unknown[], + agent: Agent, + method: string + ): InterceptorResult { if (args.length > 0) { if (typeof args[0] === "string" && args[0].length > 0) { - try { - const url = new URL(args[0]); - if (url.hostname.length > 0) { - agent.onConnectHostname(url.hostname, getPortFromURL(url)); + const url = tryParseURL(args[0]); + if (url) { + const attack = this.inspectHostname( + agent, + url.hostname, + getPortFromURL(url), + method + ); + if (attack) { + return attack; } - } catch (e) { - // Ignore } } if (args[0] instanceof URL && args[0].hostname.length > 0) { - agent.onConnectHostname(args[0].hostname, getPortFromURL(args[0])); + const attack = this.inspectHostname( + agent, + args[0].hostname, + getPortFromURL(args[0]), + method + ); + if (attack) { + return attack; + } } if ( @@ -40,21 +96,69 @@ export class Undici implements Wrapper { port = parseInt(args[0].port, 10); } - agent.onConnectHostname(args[0].hostname, port); + const attack = this.inspectHostname( + agent, + args[0].hostname, + port, + method + ); + if (attack) { + return attack; + } } } + + return undefined; + } + + private patchGlobalDispatcher(agent: Agent) { + const undici = require("undici"); + + // We'll set a global dispatcher that will inspect the resolved IP address (and thus preventing TOCTOU attacks) + undici.setGlobalDispatcher( + new undici.Agent({ + connect: { + lookup: inspectDNSLookupCalls( + lookup, + agent, + "undici", + // We don't know the method here, so we just use "undici.[method]" + "undici.[method]" + ), + }, + }) + ); } wrap(hooks: Hooks) { - hooks + const undici = hooks .addPackage("undici") .withVersion("^4.0.0 || ^5.0.0 || ^6.0.0") - .addSubject((exports) => exports) - .inspect("request", (args, subject, agent) => this.inspect(args, agent)) - .inspect("stream", (args, subject, agent) => this.inspect(args, agent)) - .inspect("pipeline", (args, subject, agent) => this.inspect(args, agent)) - .inspect("connect", (args, subject, agent) => this.inspect(args, agent)) - .inspect("fetch", (args, subject, agent) => this.inspect(args, agent)) - .inspect("upgrade", (args, subject, agent) => this.inspect(args, agent)); + .addSubject((exports) => exports); + + undici.inspect("setGlobalDispatcher", (args, subject, agent) => { + if (this.patchedGlobalDispatcher) { + agent.log( + `undici.setGlobalDispatcher was called, we can't provide protection!` + ); + } + }); + + methods.forEach((method) => { + undici + // Whenever a request is made, we'll check the hostname whether it's a private IP + .inspect(method, (args, subject, agent) => + this.inspect(args, agent, method) + ) + // We're not really modifying the arguments here, but we need to patch the global dispatcher + .modifyArguments(method, (args, subject, agent) => { + if (!this.patchedGlobalDispatcher) { + this.patchGlobalDispatcher(agent); + this.patchedGlobalDispatcher = true; + } + + return args; + }); + }); } } diff --git a/library/vulnerabilities/ssrf/checkContextForSSRF.ts b/library/vulnerabilities/ssrf/checkContextForSSRF.ts new file mode 100644 index 000000000..895c1ed0f --- /dev/null +++ b/library/vulnerabilities/ssrf/checkContextForSSRF.ts @@ -0,0 +1,47 @@ +import { Context } from "../../agent/Context"; +import { InterceptorResult } from "../../agent/hooks/MethodInterceptor"; +import { Source } from "../../agent/Source"; +import { extractStringsFromUserInput } from "../../helpers/extractStringsFromUserInput"; +import { containsPrivateIPAddress } from "./containsPrivateIPAddress"; +import { findHostnameInUserInput } from "./findHostnameInUserInput"; + +/** + * This function goes over all the different input types in the context and checks + * if it possibly implies SSRF, if so the function returns an InterceptorResult + */ +export function checkContextForSSRF({ + hostname, + operation, + context, +}: { + hostname: string; + operation: string; + context: Context; +}): InterceptorResult { + for (const source of [ + "body", + "query", + "headers", + "cookies", + "routeParams", + "graphql", + "xml", + ] as Source[]) { + if (context[source]) { + const userInput = extractStringsFromUserInput(context[source]); + for (const [str, path] of userInput.entries()) { + const found = findHostnameInUserInput(str, hostname); + if (found && containsPrivateIPAddress(hostname)) { + return { + operation: operation, + kind: "ssrf", + source: source, + pathToPayload: path, + metadata: {}, + payload: str, + }; + } + } + } + } +} diff --git a/library/vulnerabilities/ssrf/containsPrivateIPAddress.test.ts b/library/vulnerabilities/ssrf/containsPrivateIPAddress.test.ts new file mode 100644 index 000000000..024c4f606 --- /dev/null +++ b/library/vulnerabilities/ssrf/containsPrivateIPAddress.test.ts @@ -0,0 +1,189 @@ +import * as t from "tap"; +import { containsPrivateIPAddress } from "./containsPrivateIPAddress"; + +const publicIPs = [ + "44.37.112.180", + "46.192.247.73", + "71.12.102.112", + "101.0.26.90", + "111.211.73.40", + "156.238.194.84", + "164.101.185.82", + "223.231.138.242", + "::1fff:0.0.0.0", + "::1fff:10.0.0.0", + "::1fff:0:0.0.0.0", + "::1fff:0:10.0.0.0", + "2001:2:ffff:ffff:ffff:ffff:ffff:ffff", + "64:ff9a::0.0.0.0", + "64:ff9a::255.255.255.255", + "99::", + "99::ffff:ffff:ffff:ffff", + "101::", + "101::ffff:ffff:ffff:ffff", + "2000::", + "2000::ffff:ffff:ffff:ffff:ffff:ffff", + "2001:10::", + "2001:1f:ffff:ffff:ffff:ffff:ffff:ffff", + "2001:db7::", + "2001:db7:ffff:ffff:ffff:ffff:ffff:ffff", + "2001:db9::", + "fb00::", + "fbff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", + "fec0::", +]; + +const privateIPs = [ + "0.0.0.0", + "0000.0000.0000.0000", + "0000.0000", + "0.0.0.1", + "0.0.0.7", + "0.0.0.255", + "0.0.255.255", + "0.1.255.255", + "0.15.255.255", + "0.63.255.255", + "0.255.255.254", + "0.255.255.255", + "10.0.0.0", + "10.0.0.1", + "10.0.0.01", + "10.0.0.001", + "10.255.255.254", + "10.255.255.255", + "100.64.0.0", + "100.64.0.1", + "100.127.255.254", + "100.127.255.255", + "127.0.0.0", + "127.0.0.1", + "127.0.0.01", + "127.1", + "127.0.1", + "127.000.000.1", + "127.255.255.254", + "127.255.255.255", + "169.254.0.0", + "169.254.0.1", + "169.254.255.254", + "169.254.255.255", + "172.16.0.0", + "172.16.0.1", + "172.16.0.001", + "172.31.255.254", + "172.31.255.255", + "192.0.0.0", + "192.0.0.1", + "192.0.0.6", + "192.0.0.7", + "192.0.0.8", + "192.0.0.9", + "192.0.0.10", + "192.0.0.11", + "192.0.0.170", + "192.0.0.171", + "192.0.0.254", + "192.0.0.255", + "192.0.2.0", + "192.0.2.1", + "192.0.2.254", + "192.0.2.255", + "192.31.196.0", + "192.31.196.1", + "192.31.196.254", + "192.31.196.255", + "192.52.193.0", + "192.52.193.1", + "192.52.193.254", + "192.52.193.255", + "192.88.99.0", + "192.88.99.1", + "192.88.99.254", + "192.88.99.255", + "192.168.0.0", + "192.168.0.1", + "192.168.255.254", + "192.168.255.255", + "192.175.48.0", + "192.175.48.1", + "192.175.48.254", + "192.175.48.255", + "198.18.0.0", + "198.18.0.1", + "198.19.255.254", + "198.19.255.255", + "198.51.100.0", + "198.51.100.1", + "198.51.100.254", + "198.51.100.255", + "203.0.113.0", + "203.0.113.1", + "203.0.113.254", + "203.0.113.255", + "240.0.0.0", + "240.0.0.1", + "224.0.0.0", + "224.0.0.1", + "255.0.0.0", + "255.192.0.0", + "255.240.0.0", + "255.254.0.0", + "255.255.0.0", + "255.255.255.0", + "255.255.255.248", + "255.255.255.254", + "255.255.255.255", + "0000:0000:0000:0000:0000:0000:0000:0000", + "::", + "::1", + "::ffff:0.0.0.0", + "::ffff:127.0.0.1", + "fe80::", + "fe80::1", + "fe80::abc:1", + "febf:ffff:ffff:ffff:ffff:ffff:ffff:ffff", + "fc00::", + "fc00::1", + "fc00::abc:1", + "fdff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", + "2130706433", + "0x7f000001", + + // AWS metadata + "fd00:ec2::254", + "169.254.169.254", +]; + +const invalidIPs = [ + "100::ffff::", + "::ffff:0.0.255.255.255", + "::ffff:0.255.255.255.255", +]; + +t.test("public IPs", async (t) => { + for (let ip of publicIPs) { + if (ip.includes(":")) { + ip = `[${ip}]`; // IPv6 are enclosed in brackets + } + t.same(containsPrivateIPAddress(ip), false, `Expected ${ip} to be public`); + } +}); + +t.test("private IPs", async (t) => { + for (let ip of privateIPs) { + if (ip.includes(":")) { + ip = `[${ip}]`; // IPv6 are enclosed in brackets + } + t.same(containsPrivateIPAddress(ip), true, `Expected ${ip} to be private`); + } +}); + +t.test("invalid IPs", async (t) => { + for (let ip of invalidIPs) { + if (ip.includes(":")) { + ip = `[${ip}]`; // IPv6 are enclosed in brackets + } + t.same(containsPrivateIPAddress(ip), false, `Expected ${ip} to be invalid`); + } +}); diff --git a/library/vulnerabilities/ssrf/containsPrivateIPAddress.ts b/library/vulnerabilities/ssrf/containsPrivateIPAddress.ts new file mode 100644 index 000000000..edcfca56a --- /dev/null +++ b/library/vulnerabilities/ssrf/containsPrivateIPAddress.ts @@ -0,0 +1,38 @@ +import { tryParseURL } from "../../helpers/tryParseURL"; +import { isPrivateIP } from "./isPrivateIP"; + +/** + * Check if the hostname contains a private IP address + * This function is used to detect obvious SSRF attacks (with a private IP address being used as the hostname) + * + * Examples + * http://192.168.0.1/some/path + * http://[::1]/some/path + * http://localhost/some/path + * + * This function gets to see "192.168.0.1", "[::1]", and "localhost" + * + * We won't flag this-domain-points-to-a-private-ip.com + * This will be handled by the inspectDNSLookupCalls function + */ +export function containsPrivateIPAddress(hostname: string): boolean { + if (hostname === "localhost") { + return true; + } + + const url = tryParseURL(`http://${hostname}`); + if (!url) { + return false; + } + + // IPv6 addresses are enclosed in square brackets + // e.g. http://[::1] + if (url.hostname.startsWith("[") && url.hostname.endsWith("]")) { + const ipv6 = url.hostname.substring(1, url.hostname.length - 1); + if (isPrivateIP(ipv6)) { + return true; + } + } + + return isPrivateIP(url.hostname); +} diff --git a/library/vulnerabilities/ssrf/findHostnameInUserInput.test.ts b/library/vulnerabilities/ssrf/findHostnameInUserInput.test.ts new file mode 100644 index 000000000..bd2e3e4db --- /dev/null +++ b/library/vulnerabilities/ssrf/findHostnameInUserInput.test.ts @@ -0,0 +1,89 @@ +import * as t from "tap"; +import { findHostnameInUserInput } from "./findHostnameInUserInput"; + +t.test("returns false if user input and hostname are empty", async (t) => { + t.same(findHostnameInUserInput("", ""), false); +}); + +t.test("returns false if user input is empty", async (t) => { + t.same(findHostnameInUserInput("", "example.com"), false); +}); + +t.test("returns false if hostname is empty", async (t) => { + t.same(findHostnameInUserInput("http://example.com", ""), false); +}); + +t.test("it parses hostname from user input", async (t) => { + t.same(findHostnameInUserInput("http://localhost", "localhost"), true); +}); + +t.test("it parses special IP", async (t) => { + t.same(findHostnameInUserInput("http://localhost", "localhost"), true); +}); + +t.test("it parses hostname from user input with path behind it", async (t) => { + t.same(findHostnameInUserInput("http://localhost/path", "localhost"), true); +}); + +t.test( + "it parses hostname from user input with misspelled protocol", + async (t) => { + t.same(findHostnameInUserInput("http:/localhost", "localhost"), true); + } +); + +t.test( + "it parses hostname from user input without protocol separator", + async (t) => { + t.same(findHostnameInUserInput("http:localhost", "localhost"), true); + } +); + +t.test( + "it parses hostname from user input with misspelled protocol and path behind it", + async (t) => { + t.same( + findHostnameInUserInput("http:/localhost/path/path", "localhost"), + true + ); + } +); + +t.test( + "it parses hostname from user input without protocol and path behind it", + async (t) => { + t.same(findHostnameInUserInput("localhost/path/path", "localhost"), true); + } +); + +t.test("it flags FTP as protocol", async (t) => { + t.same(findHostnameInUserInput("ftp://localhost", "localhost"), true); +}); + +t.test("it parses hostname from user input", async (t) => { + t.same(findHostnameInUserInput("localhost", "localhost"), true); +}); + +t.test("it ignores invalid URLs", async (t) => { + t.same(findHostnameInUserInput("http://", "localhost"), false); +}); + +t.test("user input is smaller than hostname", async (t) => { + t.same(findHostnameInUserInput("localhost", "localhost localhost"), false); +}); + +t.test("it find IP address inside URL", async () => { + t.same( + findHostnameInUserInput( + "http://169.254.169.254/latest/meta-data/", + "169.254.169.254" + ), + true + ); +}); + +t.test("it find IP address with strange notation inside URL", async () => { + t.same(findHostnameInUserInput("http://2130706433", "2130706433"), true); + t.same(findHostnameInUserInput("http://127.1", "127.1"), true); + t.same(findHostnameInUserInput("http://127.0.1", "127.0.1"), true); +}); diff --git a/library/vulnerabilities/ssrf/findHostnameInUserInput.ts b/library/vulnerabilities/ssrf/findHostnameInUserInput.ts new file mode 100644 index 000000000..3e27f0d71 --- /dev/null +++ b/library/vulnerabilities/ssrf/findHostnameInUserInput.ts @@ -0,0 +1,25 @@ +import { tryParseURL } from "../../helpers/tryParseURL"; + +export function findHostnameInUserInput( + userInput: string, + hostname: string +): boolean { + if (userInput.length <= 1) { + return false; + } + + const hostnameURL = tryParseURL(`http://${hostname}`); + if (!hostnameURL) { + return false; + } + + const variants = [userInput, `http://${userInput}`]; + for (const variant of variants) { + const userInputURL = tryParseURL(variant); + if (userInputURL && userInputURL.hostname === hostnameURL.hostname) { + return true; + } + } + + return false; +} diff --git a/library/vulnerabilities/ssrf/imds.test.ts b/library/vulnerabilities/ssrf/imds.test.ts new file mode 100644 index 000000000..36b6ccbe8 --- /dev/null +++ b/library/vulnerabilities/ssrf/imds.test.ts @@ -0,0 +1,24 @@ +import * as t from "tap"; +import { isIMDSIPAddress } from "./imds"; + +t.test("it returns true for IMDS IP addresses", async (t) => { + t.same( + isIMDSIPAddress( + new URL("http://169.254.169.254/latest/meta-data/").hostname + ), + true + ); + t.same( + isIMDSIPAddress( + new URL("http://[fd00:ec2::254]/latest/meta-data/").hostname + .replace("[", "") + .replace("]", "") + ), + true + ); +}); + +t.test("it returns false for non-IMDS IP addresses", async (t) => { + t.same(isIMDSIPAddress("1.2.3.4"), false); + t.same(isIMDSIPAddress("example.com"), false); +}); diff --git a/library/vulnerabilities/ssrf/imds.ts b/library/vulnerabilities/ssrf/imds.ts new file mode 100644 index 000000000..2ca2fc43c --- /dev/null +++ b/library/vulnerabilities/ssrf/imds.ts @@ -0,0 +1,22 @@ +import { BlockList } from "net"; + +const IMDSAddresses = new BlockList(); + +// This IP address is used by AWS EC2 instances to access the instance metadata service (IMDS) +// We should block any requests to these IP addresses +// This prevents STORED SSRF attacks that try to access the instance metadata service +IMDSAddresses.addAddress("169.254.169.254", "ipv4"); +IMDSAddresses.addAddress("fd00:ec2::254", "ipv6"); + +export function isIMDSIPAddress(ip: string): boolean { + return IMDSAddresses.check(ip) || IMDSAddresses.check(ip, "ipv6"); +} + +// Google cloud uses the same IP addresses for its metadata service +// However, you need to set specific headers to access it +// In order to not block legitimate requests, we should allow the IP addresses for Google Cloud +const trustedHosts = ["metadata.google.internal", "metadata.goog"]; + +export function isTrustedHostname(hostname: string): boolean { + return trustedHosts.includes(hostname); +} diff --git a/library/vulnerabilities/ssrf/inspectDNSLookupCalls.test.ts b/library/vulnerabilities/ssrf/inspectDNSLookupCalls.test.ts new file mode 100644 index 000000000..1e92c4af0 --- /dev/null +++ b/library/vulnerabilities/ssrf/inspectDNSLookupCalls.test.ts @@ -0,0 +1,509 @@ +import { LookupAddress, lookup } from "dns"; +import * as t from "tap"; +import { Agent } from "../../agent/Agent"; +import { ReportingAPIForTesting } from "../../agent/api/ReportingAPIForTesting"; +import { Token } from "../../agent/api/Token"; +import { Context, runWithContext } from "../../agent/Context"; +import { LoggerNoop } from "../../agent/logger/LoggerNoop"; +import { inspectDNSLookupCalls } from "./inspectDNSLookupCalls"; + +const context: Context = { + remoteAddress: "::1", + method: "POST", + url: "http://localhost:4000", + query: {}, + headers: {}, + body: { + image: "http://localhost", + }, + cookies: {}, + routeParams: {}, + source: "express", + route: "/posts/:id", +}; + +t.test("it resolves private IPv4 without context", (t) => { + const logger = new LoggerNoop(); + const api = new ReportingAPIForTesting(); + const token = new Token("123"); + const agent = new Agent(true, logger, api, token, undefined); + agent.start([]); + + const wrappedLookup = inspectDNSLookupCalls( + lookup, + agent, + "module", + "operation" + ); + + wrappedLookup("localhost", { family: 4 }, (err, address) => { + t.same(err, null); + t.same(address, "127.0.0.1"); + t.end(); + }); +}); + +t.test("it resolves private IPv6 without context", (t) => { + const logger = new LoggerNoop(); + const api = new ReportingAPIForTesting(); + const token = new Token("123"); + const agent = new Agent(true, logger, api, token, undefined); + agent.start([]); + + const wrappedLookup = inspectDNSLookupCalls( + lookup, + agent, + "module", + "operation" + ); + + wrappedLookup("localhost", (err, address) => { + t.same(err, null); + t.same(address, process.version.startsWith("v16") ? "127.0.0.1" : "::1"); + t.end(); + }); +}); + +t.test("it blocks lookup in blocking mode", (t) => { + const logger = new LoggerNoop(); + const api = new ReportingAPIForTesting(); + const token = new Token("123"); + const agent = new Agent(true, logger, api, token, undefined); + agent.start([]); + api.clear(); + + const wrappedLookup = inspectDNSLookupCalls( + lookup, + agent, + "module", + "operation" + ); + + runWithContext(context, () => { + wrappedLookup("localhost", (err, address) => { + t.same(err instanceof Error, true); + t.same( + err.message, + "Aikido firewall has blocked a server-side request forgery: operation(...) originating from body.image" + ); + t.same(address, undefined); + t.match(api.getEvents(), [ + { + type: "detected_attack", + attack: { + kind: "ssrf", + }, + }, + ]); + t.end(); + }); + }); +}); + +t.test("it allows resolved public IP", (t) => { + const logger = new LoggerNoop(); + const api = new ReportingAPIForTesting(); + const token = new Token("123"); + const agent = new Agent(true, logger, api, token, undefined); + agent.start([]); + api.clear(); + + const wrappedLookup = inspectDNSLookupCalls( + lookup, + agent, + "module", + "operation" + ); + + runWithContext( + { ...context, body: { image: "http://www.google.be" } }, + () => { + wrappedLookup("www.google.be", (err, address) => { + t.same(err, null); + t.ok(typeof address === "string"); + t.same(api.getEvents(), []); + t.end(); + }); + } + ); +}); + +t.test( + "it does not block resolved private IP if not found in user input", + (t) => { + const logger = new LoggerNoop(); + const api = new ReportingAPIForTesting(); + const token = new Token("123"); + const agent = new Agent(true, logger, api, token, undefined); + agent.start([]); + api.clear(); + + const wrappedLookup = inspectDNSLookupCalls( + lookup, + agent, + "module", + "operation" + ); + + runWithContext({ ...context, body: undefined }, () => { + wrappedLookup("localhost", (err, address) => { + t.same(err, null); + t.same( + address, + process.version.startsWith("v16") ? "127.0.0.1" : "::1" + ); + t.same(api.getEvents(), []); + t.end(); + }); + }); + } +); + +t.test( + "it does not block resolved private IP if endpoint protection is turned off", + async (t) => { + const logger = new LoggerNoop(); + const api = new ReportingAPIForTesting({ + success: true, + heartbeatIntervalInMS: 10 * 60 * 1000, + endpoints: [ + { + method: "POST", + route: "/posts/:id", + forceProtectionOff: true, + rateLimiting: { + enabled: false, + windowSizeInMS: 60 * 1000, + maxRequests: 100, + }, + }, + ], + blockedUserIds: [], + allowedIPAddresses: [], + configUpdatedAt: 0, + }); + const token = new Token("123"); + const agent = new Agent(true, logger, api, token, undefined); + agent.start([]); + + await new Promise((resolve) => setTimeout(resolve, 0)); + + api.clear(); + + const wrappedLookup = inspectDNSLookupCalls( + lookup, + agent, + "module", + "operation" + ); + + await new Promise((resolve) => { + runWithContext(context, () => { + wrappedLookup("localhost", (err, address) => { + t.same(err, null); + t.same( + address, + process.version.startsWith("v16") ? "127.0.0.1" : "::1" + ); + t.same(api.getEvents(), []); + resolve(); + }); + }); + }); + } +); + +t.test("it blocks lookup in blocking mode with all option", (t) => { + const logger = new LoggerNoop(); + const api = new ReportingAPIForTesting(); + const token = new Token("123"); + const agent = new Agent(true, logger, api, token, undefined); + agent.start([]); + + const wrappedLookup = inspectDNSLookupCalls( + lookup, + agent, + "module", + "operation" + ); + + runWithContext(context, () => { + wrappedLookup("localhost", { all: true }, (err, addresses) => { + t.same(err instanceof Error, true); + t.same( + err.message, + "Aikido firewall has blocked a server-side request forgery: operation(...) originating from body.image" + ); + t.same(addresses, undefined); + t.end(); + }); + }); +}); + +t.test("it does not block in dry mode", (t) => { + const logger = new LoggerNoop(); + const api = new ReportingAPIForTesting(); + const token = new Token("123"); + const agent = new Agent(false, logger, api, token, undefined); + agent.start([]); + api.clear(); + + const wrappedLookup = inspectDNSLookupCalls( + lookup, + agent, + "module", + "operation" + ); + + runWithContext(context, () => { + wrappedLookup("localhost", (err, address) => { + t.same(err, null); + t.same(address, process.version.startsWith("v16") ? "127.0.0.1" : "::1"); + t.match(api.getEvents(), [ + { + type: "detected_attack", + attack: { + kind: "ssrf", + }, + }, + ]); + t.end(); + }); + }); +}); + +t.test("it ignores invalid args", (t) => { + const logger = new LoggerNoop(); + const api = new ReportingAPIForTesting(); + const token = new Token("123"); + const agent = new Agent(true, logger, api, token, undefined); + agent.start([]); + + const wrappedLookup = inspectDNSLookupCalls( + lookup, + agent, + "module", + "operation" + ); + + const error = t.throws(() => wrappedLookup()); + if (error instanceof Error) { + // The "callback" argument must be of type function + t.match(error.message, /callback/i); + } + t.end(); +}); + +t.test("it ignores if lookup returns error", (t) => { + const logger = new LoggerNoop(); + const api = new ReportingAPIForTesting(); + const token = new Token("123"); + const agent = new Agent(true, logger, api, token, undefined); + agent.start([]); + + const wrappedLookup = inspectDNSLookupCalls( + (_, callback) => callback(new Error("lookup failed")), + agent, + "module", + "operation" + ); + + wrappedLookup("localhost", (err, address) => { + t.same(err instanceof Error, true); + t.same(err.message, "lookup failed"); + t.same(address, undefined); + t.end(); + }); +}); + +const imdsMockLookup = ( + hostname: string, + options: any, + callback: ( + err: any | null, + address: string | LookupAddress[], + family: number + ) => void +) => { + if ( + hostname === "imds.test.com" || + hostname === "metadata.google.internal" || + hostname === "metadata.goog" + ) { + return callback(null, "169.254.169.254", 4); + } + return lookup(hostname, options, callback); +}; + +t.test("Blocks IMDS SSRF with untrusted domain", async (t) => { + const logger = new LoggerNoop(); + const api = new ReportingAPIForTesting(); + const token = new Token("123"); + const agent = new Agent(true, logger, api, token, undefined); + agent.start([]); + + const wrappedLookup = inspectDNSLookupCalls( + imdsMockLookup, + agent, + "module", + "operation" + ); + + await Promise.all([ + new Promise((resolve) => { + wrappedLookup("imds.test.com", { family: 4 }, (err, addresses) => { + t.same(err instanceof Error, true); + t.same( + err.message, + "Aikido firewall has blocked a server-side request forgery: operation(...) originating from unknown source" + ); + t.same(addresses, undefined); + resolve(); + }); + }), + new Promise((resolve) => { + runWithContext(context, () => { + wrappedLookup("imds.test.com", { family: 4 }, (err, addresses) => { + t.same(err instanceof Error, true); + t.same( + err.message, + "Aikido firewall has blocked a server-side request forgery: operation(...) originating from unknown source" + ); + t.same(addresses, undefined); + resolve(); + }); + }); + }), + ]); +}); + +t.test( + "it ignores IMDS SSRF with untrusted domain when endpoint protection is force off", + async (t) => { + const logger = new LoggerNoop(); + const api = new ReportingAPIForTesting({ + success: true, + heartbeatIntervalInMS: 10 * 60 * 1000, + endpoints: [ + { + method: "POST", + route: "/posts/:id", + forceProtectionOff: true, + rateLimiting: { + enabled: false, + windowSizeInMS: 60 * 1000, + maxRequests: 100, + }, + }, + ], + blockedUserIds: [], + allowedIPAddresses: [], + configUpdatedAt: 0, + }); + const token = new Token("123"); + const agent = new Agent(true, logger, api, token, undefined); + agent.start([]); + + // Wait for the agent to start + await new Promise((resolve) => setTimeout(resolve, 0)); + + const wrappedLookup = inspectDNSLookupCalls( + imdsMockLookup, + agent, + "module", + "operation" + ); + + runWithContext(context, () => { + wrappedLookup("imds.test.com", { family: 4 }, (err, addresses) => { + t.same(err, null); + t.same(addresses, "169.254.169.254"); + t.end(); + }); + }); + } +); + +t.test("Does not block IMDS SSRF with Google metadata domain", async (t) => { + const logger = new LoggerNoop(); + const api = new ReportingAPIForTesting(); + const token = new Token("123"); + const agent = new Agent(true, logger, api, token, undefined); + agent.start([]); + + const wrappedLookup = inspectDNSLookupCalls( + imdsMockLookup, + agent, + "module", + "operation" + ); + + await Promise.all([ + new Promise((resolve) => { + wrappedLookup( + "metadata.google.internal", + { family: 4 }, + (err, addresses) => { + t.same(err, null); + t.same(addresses, "169.254.169.254"); + resolve(); + } + ); + }), + new Promise((resolve) => { + runWithContext(context, () => { + wrappedLookup( + "metadata.google.internal", + { family: 4 }, + (err, addresses) => { + t.same(err, null); + t.same(addresses, "169.254.169.254"); + resolve(); + } + ); + }); + }), + ]); +}); + +t.test("it ignores when the argument is an IP address", async (t) => { + const logger = new LoggerNoop(); + const api = new ReportingAPIForTesting(); + const token = new Token("123"); + const agent = new Agent(true, logger, api, token, undefined); + agent.start([]); + + const wrappedLookup = inspectDNSLookupCalls( + lookup, + agent, + "module", + "operation" + ); + + await Promise.all([ + new Promise((resolve) => { + runWithContext( + { ...context, routeParams: { id: "169.254.169.254" } }, + () => { + wrappedLookup("169.254.169.254", (err, address) => { + t.same(err, null); + t.same(address, "169.254.169.254"); + resolve(); + }); + } + ); + }), + new Promise((resolve) => { + runWithContext( + { ...context, routeParams: { id: "fd00:ec2::254" } }, + () => { + wrappedLookup("fd00:ec2::254", (err, address) => { + t.same(err, null); + t.same(address, "fd00:ec2::254"); + resolve(); + }); + } + ); + }), + ]); +}); diff --git a/library/vulnerabilities/ssrf/inspectDNSLookupCalls.ts b/library/vulnerabilities/ssrf/inspectDNSLookupCalls.ts new file mode 100644 index 000000000..e779f46af --- /dev/null +++ b/library/vulnerabilities/ssrf/inspectDNSLookupCalls.ts @@ -0,0 +1,218 @@ +import { isIP } from "net"; +import { LookupAddress } from "node:dns"; +import { Agent } from "../../agent/Agent"; +import { attackKindHumanName } from "../../agent/Attack"; +import { Context, getContext } from "../../agent/Context"; +import { Source } from "../../agent/Source"; +import { extractStringsFromUserInput } from "../../helpers/extractStringsFromUserInput"; +import { isPlainObject } from "../../helpers/isPlainObject"; +import { findHostnameInUserInput } from "./findHostnameInUserInput"; +import { isPrivateIP } from "./isPrivateIP"; +import { isIMDSIPAddress, isTrustedHostname } from "./imds"; + +export function inspectDNSLookupCalls( + lookup: Function, + agent: Agent, + module: string, + operation: string +): Function { + return function inspectDNSLookup(...args: unknown[]) { + const hostname = + args.length > 0 && typeof args[0] === "string" ? args[0] : undefined; + const callback = args.find((arg) => typeof arg === "function"); + + // If the hostname is an IP address, or if the callback is missing, we don't need to inspect the resolved IPs + if (!hostname || isIP(hostname) || !callback) { + return lookup(...args); + } + + const options = args.find((arg) => isPlainObject(arg)) as + | Record + | undefined; + + const argsToApply = options + ? [ + hostname, + options, + wrapDNSLookupCallback( + callback as Function, + hostname, + module, + agent, + operation + ), + ] + : [ + hostname, + wrapDNSLookupCallback( + callback as Function, + hostname, + module, + agent, + operation + ), + ]; + + lookup(...argsToApply); + }; +} + +// eslint-disable-next-line max-lines-per-function +function wrapDNSLookupCallback( + callback: Function, + hostname: string, + module: string, + agent: Agent, + operation: string +): Function { + // eslint-disable-next-line max-lines-per-function + return function wrappedDNSLookupCallback( + err: Error, + addresses: string | LookupAddress[], + family: number + ) { + if (err) { + return callback(err); + } + + const context = getContext(); + + if (context) { + const endpoint = agent.getConfig().getEndpoint(context); + + if (endpoint && endpoint.endpoint.forceProtectionOff) { + // User disabled protection for this endpoint, we don't need to inspect the resolved IPs + // Just call the original callback to allow the DNS lookup + return callback(err, addresses, family); + } + } + + const resolvedIPAddresses = getResolvedIPAddresses(addresses); + + if (resolvesToIMDSIP(resolvedIPAddresses, hostname)) { + // Block stored SSRF attack that target IMDS IP addresses + // An attacker could have stored a hostname in a database that points to an IMDS IP address + // We don't check if the user input contains the hostname because there's no context + if (agent.shouldBlock()) { + return callback( + new Error( + `Aikido firewall has blocked ${attackKindHumanName("ssrf")}: ${operation}(...) originating from unknown source` + ) + ); + } + } + + if (!context) { + // If there's no context, we can't check if the hostname is in the context + // Just call the original callback to allow the DNS lookup + return callback(err, addresses, family); + } + + const privateIP = resolvedIPAddresses.find(isPrivateIP); + + if (!privateIP) { + // If the hostname doesn't resolve to a private IP address, it's not an SSRF attack + // Just call the original callback to allow the DNS lookup + return callback(err, addresses, family); + } + + const found = findHostnameInContext(hostname, context); + + if (!found) { + // If we can't find the hostname in the context, it's not an SSRF attack + // Just call the original callback to allow the DNS lookup + return callback(err, addresses, family); + } + + agent.onDetectedAttack({ + module: module, + operation: operation, + kind: "ssrf", + source: found.source, + blocked: agent.shouldBlock(), + stack: new Error().stack!, + path: found.pathToPayload, + metadata: {}, + request: context, + payload: found.payload, + }); + + if (agent.shouldBlock()) { + return callback( + new Error( + `Aikido firewall has blocked ${attackKindHumanName("ssrf")}: ${operation}(...) originating from ${found.source}${found.pathToPayload}` + ) + ); + } + + // If the attack should not be blocked + // Just call the original callback to allow the DNS lookup + return callback(err, addresses, family); + }; +} + +type Location = { + source: Source; + pathToPayload: string; + payload: string; +}; + +function findHostnameInContext( + hostname: string, + context: Context +): Location | undefined { + for (const source of [ + "body", + "query", + "headers", + "cookies", + "routeParams", + "graphql", + "xml", + ] as Source[]) { + if (context[source]) { + const userInput = extractStringsFromUserInput(context[source]); + for (const [str, path] of userInput.entries()) { + const found = findHostnameInUserInput(str, hostname); + if (found) { + return { + source: source, + pathToPayload: path, + payload: str, + }; + } + } + } + } + + return undefined; +} + +function getResolvedIPAddresses(addresses: string | LookupAddress[]): string[] { + const resolvedIPAddresses: string[] = []; + for (const address of Array.isArray(addresses) ? addresses : [addresses]) { + if (typeof address === "string") { + resolvedIPAddresses.push(address); + continue; + } + + if (isPlainObject(address) && address.address) { + resolvedIPAddresses.push(address.address); + } + } + + return resolvedIPAddresses; +} + +function resolvesToIMDSIP( + resolvedIPAddresses: string[], + hostname: string +): boolean { + // Allow access to Google Cloud metadata service as you need to set specific headers to access it + // We don't want to block legitimate requests + if (isTrustedHostname(hostname)) { + return false; + } + + return resolvedIPAddresses.some((ip) => isIMDSIPAddress(ip)); +} diff --git a/library/vulnerabilities/ssrf/isPrivateIP.ts b/library/vulnerabilities/ssrf/isPrivateIP.ts new file mode 100644 index 000000000..9f2a98015 --- /dev/null +++ b/library/vulnerabilities/ssrf/isPrivateIP.ts @@ -0,0 +1,51 @@ +import { BlockList, isIPv4, isIPv6 } from "net"; + +// Taken from https://github.com/frenchbread/private-ip/blob/master/src/index.ts +const PRIVATE_IP_RANGES = [ + "0.0.0.0/8", + "10.0.0.0/8", + "100.64.0.0/10", + "127.0.0.0/8", + "169.254.0.0/16", + "172.16.0.0/12", + "192.0.0.0/24", + "192.0.2.0/24", + "192.31.196.0/24", + "192.52.193.0/24", + "192.88.99.0/24", + "192.168.0.0/16", + "192.175.48.0/24", + "198.18.0.0/15", + "198.51.100.0/24", + "203.0.113.0/24", + "240.0.0.0/4", + "224.0.0.0/4", + "255.255.255.255/32", +]; + +const PRIVATE_IPV6_RANGES = [ + "::/128", // Unspecified address + "::1/128", // Loopback address + "fc00::/7", // Unique local address (ULA) + "fe80::/10", // Link-local address (LLA) + "::ffff:127.0.0.1/128", // IPv4-mapped address +]; + +const privateIp = new BlockList(); + +PRIVATE_IP_RANGES.forEach((range) => { + const [ip, mask] = range.split("/"); + privateIp.addSubnet(ip, parseInt(mask, 10)); +}); + +PRIVATE_IPV6_RANGES.forEach((range) => { + const [ip, mask] = range.split("/"); + privateIp.addSubnet(ip, parseInt(mask, 10), "ipv6"); +}); + +export function isPrivateIP(ip: string): boolean { + return ( + (isIPv4(ip) && privateIp.check(ip)) || + (isIPv6(ip) && privateIp.check(ip, "ipv6")) + ); +} diff --git a/sample-apps/express-mongodb/README.md b/sample-apps/express-mongodb/README.md index 75a782a71..48caa6b4f 100644 --- a/sample-apps/express-mongodb/README.md +++ b/sample-apps/express-mongodb/README.md @@ -8,4 +8,5 @@ Try the following URLs: * http://localhost:4000/ add a few posts * http://localhost:4000/?search=title search for posts with title -* http://localhost:4000/?search[$ne]=null will abuse the vulnerability parameter to return all posts +* http://localhost:4000/?search[$ne]=null will abuse the vulnerable parameter to return all posts +* http://localhost:4000/images?url=http://localhost:80 will vulnerable parameter to fetch an image from a private server diff --git a/sample-apps/express-mongodb/app.js b/sample-apps/express-mongodb/app.js index 291098cc8..6e0fd8f7f 100644 --- a/sample-apps/express-mongodb/app.js +++ b/sample-apps/express-mongodb/app.js @@ -125,6 +125,27 @@ async function main(port) { }); }); + app.get( + "/images", + asyncHandler(async (req, res) => { + // This code is vulnerable to SSRF + const url = req.query.url; + + if (!url) { + return res.status(400).send("url parameter is required"); + } + + const response = await fetch(url, { + method: "GET", + }); + + const buffer = await (await response.blob()).arrayBuffer(); + + res.attachment("image.jpg"); + res.send(Buffer.from(buffer)); + }) + ); + return new Promise((resolve, reject) => { try { app.listen(port, () => {