diff --git a/cmd/kubehound/config.go b/cmd/kubehound/config.go index f64b95909..f1ebc1e6a 100644 --- a/cmd/kubehound/config.go +++ b/cmd/kubehound/config.go @@ -20,7 +20,7 @@ var ( Short: "Show the current configuration", Long: `[devOnly] Show the current configuration`, PreRunE: func(cobraCmd *cobra.Command, args []string) error { - return cmd.InitializeKubehoundConfig(cobraCmd.Context(), "", true, true) + return cmd.InitializeKubehoundConfig(cobraCmd.Context(), cfgFile, true, true) }, RunE: func(cobraCmd *cobra.Command, args []string) error { // Adding datadog setup diff --git a/cmd/kubehound/dumper.go b/cmd/kubehound/dumper.go index b4e4b5050..e0d7de076 100644 --- a/cmd/kubehound/dumper.go +++ b/cmd/kubehound/dumper.go @@ -36,7 +36,7 @@ var ( viper.BindPFlag(config.IngestorAPIEndpoint, cobraCmd.Flags().Lookup("khaas-server")) //nolint: errcheck viper.BindPFlag(config.IngestorAPIInsecure, cobraCmd.Flags().Lookup("insecure")) //nolint: errcheck - return cmd.InitializeKubehoundConfig(cobraCmd.Context(), "", true, true) + return cmd.InitializeKubehoundConfig(cobraCmd.Context(), cfgFile, true, true) }, RunE: func(cobraCmd *cobra.Command, args []string) error { // using compress feature @@ -62,7 +62,7 @@ var ( return fmt.Errorf("dump core: %w", err) } // Running the ingestion on KHaaS - if cobraCmd.Flags().Lookup("khaas-server").Value.String() != "" { + if khCfg.Ingestor.API.Endpoint != "" { return core.CoreClientGRPCIngest(cobraCmd.Context(), khCfg.Ingestor, khCfg.Dynamic.ClusterName, khCfg.Dynamic.RunID.String()) } @@ -77,7 +77,7 @@ var ( PreRunE: func(cobraCmd *cobra.Command, args []string) error { viper.Set(config.CollectorFileDirectory, args[0]) - return cmd.InitializeKubehoundConfig(cobraCmd.Context(), "", true, true) + return cmd.InitializeKubehoundConfig(cobraCmd.Context(), cfgFile, true, true) }, RunE: func(cobraCmd *cobra.Command, args []string) error { // Passing the Kubehound config from viper diff --git a/cmd/kubehound/ingest.go b/cmd/kubehound/ingest.go index fe2423d24..8f436086b 100644 --- a/cmd/kubehound/ingest.go +++ b/cmd/kubehound/ingest.go @@ -29,7 +29,7 @@ var ( PreRunE: func(cobraCmd *cobra.Command, args []string) error { cmd.BindFlagCluster(cobraCmd) - return cmd.InitializeKubehoundConfig(cobraCmd.Context(), "", true, true) + return cmd.InitializeKubehoundConfig(cobraCmd.Context(), cfgFile, true, true) }, RunE: func(cobraCmd *cobra.Command, args []string) error { // Passing the Kubehound config from viper @@ -56,7 +56,7 @@ var ( cobraCmd.MarkFlagRequired("cluster") //nolint: errcheck } - return cmd.InitializeKubehoundConfig(cobraCmd.Context(), "", false, true) + return cmd.InitializeKubehoundConfig(cobraCmd.Context(), cfgFile, false, true) }, RunE: func(cobraCmd *cobra.Command, args []string) error { // Passing the Kubehound config from viper diff --git a/cmd/kubehound/root.go b/cmd/kubehound/root.go index fedc496c1..fa82e1eaf 100644 --- a/cmd/kubehound/root.go +++ b/cmd/kubehound/root.go @@ -76,9 +76,9 @@ var ( ) func init() { - rootCmd.Flags().StringVarP(&cfgFile, "config", "c", cfgFile, "application config file") + rootCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", cfgFile, "application config file") - rootCmd.Flags().BoolVar(&skipBackend, "skip-backend", skipBackend, "skip the auto deployment of the backend stack (janusgraph, mongodb, and UI)") + rootCmd.PersistentFlags().BoolVar(&skipBackend, "skip-backend", skipBackend, "skip the auto deployment of the backend stack (janusgraph, mongodb, and UI)") cmd.InitRootCmd(rootCmd) } diff --git a/configs/etc/kubehound-reference.yaml b/configs/etc/kubehound-reference.yaml index f1328644f..cd0179e81 100644 --- a/configs/etc/kubehound-reference.yaml +++ b/configs/etc/kubehound-reference.yaml @@ -62,6 +62,9 @@ janusgraph: # Timeout on requests to the JanusGraph DB instance connection_timeout: 30s + # Number of worker threads for the JanusGraph writer pool + writer_worker_count: 10 + # # Datadog telemetry configuration # @@ -114,10 +117,10 @@ builder: # worker_pool_capacity: 100 # # Batch size for edge inserts - # batch_size: 500 + # batch_size: 250 # # Small batch size for edge inserts - # batch_size_small: 75 + # batch_size_small: 50 # # Cluster impact batch size for edge inserts # batch_size_cluster_impact: 1 diff --git a/configs/etc/kubehound.yaml b/configs/etc/kubehound.yaml index 7271f7813..0a8c6ee57 100644 --- a/configs/etc/kubehound.yaml +++ b/configs/etc/kubehound.yaml @@ -37,19 +37,22 @@ janusgraph: # Timeout on requests to the JanusGraph DB instance connection_timeout: 30s + # Number of worker threads for the JanusGraph writer pool + writer_worker_count: 10 + # Graph builder configuration builder: # Vertex builder configuration vertex: # Batch size for vertex inserts - batch_size: 500 + batch_size: 250 # Edge builder configuration edge: worker_pool_size: 2 # Batch size for edge inserts - batch_size: 500 + batch_size: 250 # Cluster impact batch size for edge inserts batch_size_cluster_impact: 10 diff --git a/deployments/kubehound/graph/dsl/kubehound/src/main/java/com/datadog/ase/kubehound/KubeHoundTraversalDsl.java b/deployments/kubehound/graph/dsl/kubehound/src/main/java/com/datadog/ase/kubehound/KubeHoundTraversalDsl.java index a759b4b96..05d998c93 100644 --- a/deployments/kubehound/graph/dsl/kubehound/src/main/java/com/datadog/ase/kubehound/KubeHoundTraversalDsl.java +++ b/deployments/kubehound/graph/dsl/kubehound/src/main/java/com/datadog/ase/kubehound/KubeHoundTraversalDsl.java @@ -34,16 +34,21 @@ import static org.apache.tinkerpop.gremlin.process.traversal.Scope.local; import static org.apache.tinkerpop.gremlin.structure.Column.values; - /** - * This KubeHound DSL is meant to be used with the Kubernetes attack graph created by the KubeHound application. + * This KubeHound DSL is meant to be used with the Kubernetes attack graph + * created by the KubeHound application. *

- * All DSLs should extend {@code GraphTraversal.Admin} and be suffixed with "TraversalDsl". Simply add DSL traversal - * methods to this interface. Use Gremlin's steps to build the underlying traversal in these methods to ensure - * compatibility with the rest of the TinkerPop stack and provider implementations. + * All DSLs should extend {@code GraphTraversal.Admin} and be suffixed with + * "TraversalDsl". Simply add DSL traversal + * methods to this interface. Use Gremlin's steps to build the underlying + * traversal in these methods to ensure + * compatibility with the rest of the TinkerPop stack and provider + * implementations. *

- * Arguments provided to the {@code GremlinDsl} annotation are all optional. In this case, a {@code traversalSource} is - * specified which points to a specific implementation to use. Had that argument not been specified then a default + * Arguments provided to the {@code GremlinDsl} annotation are all optional. In + * this case, a {@code traversalSource} is + * specified which points to a specific implementation to use. Had that argument + * not been specified then a default * {@code TraversalSource} would have been generated. */ @GremlinDsl(traversalSource = "com.datadog.ase.kubehound.KubeHoundTraversalSourceDsl") @@ -54,7 +59,8 @@ public interface KubeHoundTraversalDsl extends GraphTraversal.Admin public static final int PATH_HOPS_MIN_DEFAULT = 6; /** - * From a {@code Vertex} traverse immediate edges to display the next set of possible attacks and targets. + * From a {@code Vertex} traverse immediate edges to display the next set of + * possible attacks and targets. * */ public default GraphTraversal attacks() { @@ -62,78 +68,85 @@ public default GraphTraversal attacks() { } /** - * From a {@code Vertex} filter on whether incoming vertices are critical assets. + * From a {@code Vertex} filter on whether incoming vertices are critical + * assets. */ - @GremlinDsl.AnonymousMethod(returnTypeParameters = {"A", "A"}, methodTypeParameters = {"A"}) + @GremlinDsl.AnonymousMethod(returnTypeParameters = { "A", "A" }, methodTypeParameters = { "A" }) public default GraphTraversal critical() { return has("critical", true); } /** - * From a {@code Vertex} traverse edges until {@code maxHops} is exceeded or a critical asset is reached and return all paths. + * From a {@code Vertex} traverse edges until {@code maxHops} is exceeded or a + * critical asset is reached and return all paths. * * @param maxHops the maximum number of hops in an attack path */ public default GraphTraversal criticalPaths(int maxHops) { - if (maxHops < PATH_HOPS_MIN) throw new IllegalArgumentException(String.format("maxHops must be >= %d", PATH_HOPS_MIN)); - if (maxHops > PATH_HOPS_MAX) throw new IllegalArgumentException(String.format("maxHops must be <= %d", PATH_HOPS_MAX)); + if (maxHops < PATH_HOPS_MIN) + throw new IllegalArgumentException(String.format("maxHops must be >= %d", PATH_HOPS_MIN)); + if (maxHops > PATH_HOPS_MAX) + throw new IllegalArgumentException(String.format("maxHops must be <= %d", PATH_HOPS_MAX)); - return repeat(( - (KubeHoundTraversalDsl) __.outE()) + return repeat(((KubeHoundTraversalDsl) __.outE()) .inV() - .simplePath() - ).until( - __.has("critical", true) - .or() - .loops() - .is(maxHops) - ).has("critical", true) - .path(); + .simplePath()).until( + __.has("critical", true) + .or() + .loops() + .is(maxHops)) + .has("critical", true) + .path(); } /** - * From a {@code Vertex} traverse edges until a critical asset is reached and return all paths. + * From a {@code Vertex} traverse edges until a critical asset is reached and + * return all paths. */ public default GraphTraversal criticalPaths() { return criticalPaths(PATH_HOPS_DEFAULT); } /** - * From a {@code Vertex} traverse edges EXCLUDING labels provided in {@code exclusions} until {@code maxHops} is exceeded or - * a critical asset is reached and return all paths. + * From a {@code Vertex} traverse edges EXCLUDING labels provided in + * {@code exclusions} until {@code maxHops} is exceeded or + * a critical asset is reached and return all paths. * - * @param maxHops the maximum number of hops in an attack path + * @param maxHops the maximum number of hops in an attack path * @param exclusions edge labels to exclude from paths */ public default GraphTraversal criticalPathsFilter(int maxHops, String... exclusions) { - if (exclusions.length <= 0) throw new IllegalArgumentException("exclusions must be provided (otherwise use criticalPaths())"); - if (maxHops < PATH_HOPS_MIN) throw new IllegalArgumentException(String.format("maxHops must be >= %d", PATH_HOPS_MIN)); - if (maxHops > PATH_HOPS_MAX) throw new IllegalArgumentException(String.format("maxHops must be <= %d", PATH_HOPS_MAX)); - - return repeat(( - (KubeHoundTraversalDsl) __.outE()) - .hasLabel(P.not(P.within(exclusions))) + if (exclusions.length <= 0) + throw new IllegalArgumentException("exclusions must be provided (otherwise use criticalPaths())"); + if (maxHops < PATH_HOPS_MIN) + throw new IllegalArgumentException(String.format("maxHops must be >= %d", PATH_HOPS_MIN)); + if (maxHops > PATH_HOPS_MAX) + throw new IllegalArgumentException(String.format("maxHops must be <= %d", PATH_HOPS_MAX)); + + return repeat(((KubeHoundTraversalDsl) __.outE()) + .has("class", P.not(P.within(exclusions))) .inV() - .simplePath() - ).until( - __.has("critical", true) - .or() - .loops() - .is(maxHops) - ).has("critical", true) - .path(); + .simplePath()).until( + __.has("critical", true) + .or() + .loops() + .is(maxHops)) + .has("critical", true) + .path(); } /** - * From a {@code Vertex} filter on whether incoming vertices have at least one path to a critical asset. + * From a {@code Vertex} filter on whether incoming vertices have at least one + * path to a critical asset. */ - @GremlinDsl.AnonymousMethod(returnTypeParameters = {"A", "A"}, methodTypeParameters = {"A"}) + @GremlinDsl.AnonymousMethod(returnTypeParameters = { "A", "A" }, methodTypeParameters = { "A" }) public default GraphTraversal hasCriticalPath() { - return where(__.criticalPaths().limit(1)); + return where(__.criticalPaths().limit(1)); } /** - * From a {@code Vertex} returns the hop count of the shortest path to a critical asset. + * From a {@code Vertex} returns the hop count of the shortest path to a + * critical asset. * */ public default GraphTraversal minHopsToCritical() { @@ -141,61 +154,66 @@ public default GraphTraversal minHopsToCritical() } /** - * From a {@code Vertex} returns the hop count of the shortest path to a critical asset. - * + * From a {@code Vertex} returns the hop count of the shortest path to a + * critical asset. + * * @param maxHops the maximum number of hops in an attack path to consider * */ public default GraphTraversal minHopsToCritical(int maxHops) { - if (maxHops < PATH_HOPS_MIN) throw new IllegalArgumentException(String.format("maxHops must be >= %d", PATH_HOPS_MIN)); - if (maxHops > PATH_HOPS_MAX) throw new IllegalArgumentException(String.format("maxHops must be <= %d", PATH_HOPS_MAX)); - - return repeat(( - (KubeHoundTraversalDsl) __.out()) - .simplePath() - ).until( - __.has("critical", true) - .or() - .loops() - .is(maxHops) - ).has("critical", true) - .path() - .count(local) - .min(); + if (maxHops < PATH_HOPS_MIN) + throw new IllegalArgumentException(String.format("maxHops must be >= %d", PATH_HOPS_MIN)); + if (maxHops > PATH_HOPS_MAX) + throw new IllegalArgumentException(String.format("maxHops must be <= %d", PATH_HOPS_MAX)); + + return repeat(((KubeHoundTraversalDsl) __.out()) + .simplePath()).until( + __.has("critical", true) + .or() + .loops() + .is(maxHops)) + .has("critical", true) + .path() + .count(local) + .min(); } /** - * From a {@code Vertex} returns a group count (by label) of paths to a critical asset. + * From a {@code Vertex} returns a group count (by label) of paths to a critical + * asset. * */ public default GraphTraversal> criticalPathsFreq() { - return criticalPathsFreq(PATH_HOPS_DEFAULT); + return criticalPathsFreq(PATH_HOPS_DEFAULT); } /** - * From a {@code Vertex} returns a group count (by label) of paths to a critical asset. + * From a {@code Vertex} returns a group count (by label) of paths to a critical + * asset. * * @param maxHops the maximum number of hops in an attack path */ public default GraphTraversal> criticalPathsFreq(int maxHops) { - if (maxHops < PATH_HOPS_MIN) throw new IllegalArgumentException(String.format("maxHops must be >= %d", PATH_HOPS_MIN)); - if (maxHops > PATH_HOPS_MAX) throw new IllegalArgumentException(String.format("maxHops must be <= %d", PATH_HOPS_MAX)); + if (maxHops < PATH_HOPS_MIN) + throw new IllegalArgumentException(String.format("maxHops must be >= %d", PATH_HOPS_MIN)); + if (maxHops > PATH_HOPS_MAX) + throw new IllegalArgumentException(String.format("maxHops must be <= %d", PATH_HOPS_MAX)); return repeat( (KubeHoundTraversalDsl) __.outE() - .inV() - .simplePath() - ).emit() - .until( - __.has("critical", true) - .or() - .loops() - .is(maxHops) - ).has("critical", true) - .path() - .by(T.label) - .groupCount() - .order(local) - .by(__.select(values), Order.desc); + .inV() + .simplePath()) + .emit() + .until( + __.has("critical", true) + .or() + .loops() + .is(maxHops)) + .has("critical", true) + .path() + .by(T.label) + .groupCount() + .order(local) + .by(__.select(values), Order.desc); } } diff --git a/deployments/kubehound/graph/dsl/kubehound/src/main/java/com/datadog/ase/kubehound/KubeHoundTraversalSourceDsl.java b/deployments/kubehound/graph/dsl/kubehound/src/main/java/com/datadog/ase/kubehound/KubeHoundTraversalSourceDsl.java index 78b29facd..6352602e1 100644 --- a/deployments/kubehound/graph/dsl/kubehound/src/main/java/com/datadog/ase/kubehound/KubeHoundTraversalSourceDsl.java +++ b/deployments/kubehound/graph/dsl/kubehound/src/main/java/com/datadog/ase/kubehound/KubeHoundTraversalSourceDsl.java @@ -60,13 +60,14 @@ public GraphTraversal cluster(String... names) { if (names.length > 0) { traversal = traversal.has("cluster", P.within(names)); - } + } return traversal; } /** - * Starts a traversal that finds all vertices from the specified KubeHound run(s) + * Starts a traversal that finds all vertices from the specified KubeHound + * run(s) * * @param ids list of run ids to filter on */ @@ -75,13 +76,14 @@ public GraphTraversal run(String... ids) { if (ids.length > 0) { traversal = traversal.has("runID", P.within(ids)); - } + } return traversal; } /** - * Starts a traversal that finds all vertices with a "Container" label and optionally allows filtering of those + * Starts a traversal that finds all vertices with a "Container" label and + * optionally allows filtering of those * vertices on the "name" property. * * @param names list of container names to filter on @@ -89,16 +91,17 @@ public GraphTraversal run(String... ids) { public GraphTraversal containers(String... names) { GraphTraversal traversal = this.clone().V(); - traversal = traversal.hasLabel("Container"); + traversal = traversal.has("class", "Container"); if (names.length > 0) { traversal = traversal.has("name", P.within(names)); - } + } return traversal; } /** - * Starts a traversal that finds all vertices with a "Pod" label and optionally allows filtering of those + * Starts a traversal that finds all vertices with a "Pod" label and optionally + * allows filtering of those * vertices on the "name" property. * * @param names list of pod names to filter on @@ -106,16 +109,17 @@ public GraphTraversal containers(String... names) { public GraphTraversal pods(String... names) { GraphTraversal traversal = this.clone().V(); - traversal = traversal.hasLabel("Pod"); + traversal = traversal.has("class", "Pod"); if (names.length > 0) { traversal = traversal.has("name", P.within(names)); - } + } return traversal; } /** - * Starts a traversal that finds all vertices with a "Node" label and optionally allows filtering of those + * Starts a traversal that finds all vertices with a "Node" label and optionally + * allows filtering of those * vertices on the "name" property. * * @param names list of node names to filter on @@ -123,33 +127,35 @@ public GraphTraversal pods(String... names) { public GraphTraversal nodes(String... names) { GraphTraversal traversal = this.clone().V(); - traversal = traversal.hasLabel("Node"); + traversal = traversal.has("class", "Node"); if (names.length > 0) { traversal = traversal.has("name", P.within(names)); - } + } return traversal; } /** - * Starts a traversal that finds all container escape edges from a Container vertex to a Node vertex - * and optionally allows filtering of those vertices on the "nodeNames" property. + * Starts a traversal that finds all container escape edges from a Container + * vertex to a Node vertex + * and optionally allows filtering of those vertices on the "nodeNames" + * property. * * @param nodeNames list of node names to filter on - + * */ public GraphTraversal escapes(String... nodeNames) { GraphTraversal traversal = this.clone().V(); traversal = traversal - .hasLabel("Container") - .outE() - .inV() - .hasLabel("Node"); + .has("class", "Container") + .outE() + .inV() + .has("class", "Node"); if (nodeNames.length > 0) { traversal = traversal.has("name", P.within(nodeNames)); - } + } return traversal.path(); } @@ -159,183 +165,194 @@ public GraphTraversal escapes(String... nodeNames) { */ public GraphTraversal endpoints() { GraphTraversal traversal = this.clone().V(); - - traversal = traversal.hasLabel("Endpoint"); + + traversal = traversal.has("class", "Endpoint"); return traversal; } /** - * Starts a traversal that finds all vertices with a "Endpoint" label and optionally allows filtering of those + * Starts a traversal that finds all vertices with a "Endpoint" label and + * optionally allows filtering of those * vertices on the "exposure" property. * * @param exposure EndpointExposure enum value to filter on */ public GraphTraversal endpoints(EndpointExposure exposure) { if (exposure.ordinal() > EndpointExposure.Max.ordinal()) { - throw new IllegalArgumentException(String.format("invalid exposure value (must be <= %d)", EndpointExposure.Max.ordinal())); + throw new IllegalArgumentException( + String.format("invalid exposure value (must be <= %d)", EndpointExposure.Max.ordinal())); } if (exposure.ordinal() < EndpointExposure.None.ordinal()) { - throw new IllegalArgumentException(String.format("invalid exposure value (must be >= %d)", EndpointExposure.None.ordinal())); + throw new IllegalArgumentException( + String.format("invalid exposure value (must be >= %d)", EndpointExposure.None.ordinal())); } GraphTraversal traversal = this.clone().V(); - + traversal = traversal - .hasLabel("Endpoint") - .has("exposure", P.gte(exposure.ordinal())); + .has("class", "Endpoint") + .has("exposure", P.gte(exposure.ordinal())); return traversal; } /** - * Starts a traversal that finds all vertices with a "Endpoint" label exposed OUTSIDE the cluster as a service + * Starts a traversal that finds all vertices with a "Endpoint" label exposed + * OUTSIDE the cluster as a service * and optionally allows filtering of those vertices on the "portName" property. * * @param portNames list of port names to filter on */ public GraphTraversal services(String... portNames) { GraphTraversal traversal = this.clone().V(); - + traversal = traversal - .hasLabel("Endpoint") - .has("exposure", P.gte(EndpointExposure.External.ordinal())); + .has("class", "Endpoint") + .has("exposure", P.gte(EndpointExposure.External.ordinal())); if (portNames.length > 0) { traversal = traversal.has("portName", P.within(portNames)); - } + } return traversal; } /** - * Starts a traversal that finds all vertices with a "Volume" label and optionally allows filtering of those + * Starts a traversal that finds all vertices with a "Volume" label and + * optionally allows filtering of those * vertices on the "name" property. * * @param names list of volume names to filter on */ public GraphTraversal volumes(String... names) { GraphTraversal traversal = this.clone().V(); - - traversal = traversal.hasLabel("Volume"); + + traversal = traversal.has("class", "Volume"); if (names.length > 0) { traversal = traversal.has("name", P.within(names)); - } + } return traversal; } /** - * Starts a traversal that finds all vertices representing volume host mounts and optionally allows filtering of those + * Starts a traversal that finds all vertices representing volume host mounts + * and optionally allows filtering of those * vertices on the "sourcePath" property. * * @param sourcePaths list of host source paths to filter on */ public GraphTraversal hostMounts(String... sourcePaths) { GraphTraversal traversal = this.clone().V(); - + traversal = traversal - .hasLabel("Volume") - .has("type", "HostPath"); + .has("class", "Volume") + .has("type", "HostPath"); if (sourcePaths.length > 0) { traversal = traversal.has("sourcePath", P.within(sourcePaths)); - } + } return traversal; } /** - * Starts a traversal that finds all vertices with a "Identity" label and optionally allows filtering of those + * Starts a traversal that finds all vertices with a "Identity" label and + * optionally allows filtering of those * vertices on the "name" property. * * @param names list of identity names to filter on */ public GraphTraversal identities(String... names) { GraphTraversal traversal = this.clone().V(); - - traversal = traversal.hasLabel("Identity"); + + traversal = traversal.has("class", "Identity"); if (names.length > 0) { traversal = traversal.has("name", P.within(names)); - } + } return traversal; } /** - * Starts a traversal that finds all vertices representing service accounts and optionally allows filtering of those + * Starts a traversal that finds all vertices representing service accounts and + * optionally allows filtering of those * vertices on the "name" property. * * @param names list of service account names to filter on */ public GraphTraversal sas(String... names) { GraphTraversal traversal = this.clone().V(); - + traversal = traversal - .hasLabel("Identity") - .has("type", "ServiceAccount"); + .has("class", "Identity") + .has("type", "ServiceAccount"); if (names.length > 0) { traversal = traversal.has("name", P.within(names)); - } + } return traversal; } /** - * Starts a traversal that finds all vertices representing users and optionally allows filtering of those + * Starts a traversal that finds all vertices representing users and optionally + * allows filtering of those * vertices on the "name" property. * * @param names list of user names to filter on */ public GraphTraversal users(String... names) { GraphTraversal traversal = this.clone().V(); - + traversal = traversal - .hasLabel("Identity") - .has("type", "User"); + .has("class", "Identity") + .has("type", "User"); if (names.length > 0) { traversal = traversal.has("name", P.within(names)); - } + } return traversal; } /** - * Starts a traversal that finds all vertices representing groups and optionally allows filtering of those + * Starts a traversal that finds all vertices representing groups and optionally + * allows filtering of those * vertices on the "name" property. * * @param names list of groups names to filter on */ public GraphTraversal groups(String... names) { GraphTraversal traversal = this.clone().V(); - + traversal = traversal - .hasLabel("Identity") - .has("type", "Group"); + .has("class", "Identity") + .has("type", "Group"); if (names.length > 0) { traversal = traversal.has("name", P.within(names)); - } + } return traversal; } /** - * Starts a traversal that finds all vertices with a "PermissionSet" label and optionally allows filtering of those + * Starts a traversal that finds all vertices with a "PermissionSet" label and + * optionally allows filtering of those * vertices on the "role" property. * * @param roles list of underlying role names to filter on */ public GraphTraversal permissions(String... roles) { GraphTraversal traversal = this.clone().V(); - - traversal = traversal.hasLabel("PermissionSet"); + + traversal = traversal.has("class", "PermissionSet"); if (roles.length > 0) { traversal = traversal.has("role", P.within(roles)); - } + } return traversal; } diff --git a/docs/queries/dsl.md b/docs/queries/dsl.md index 3b54fea33..59729f30d 100644 --- a/docs/queries/dsl.md +++ b/docs/queries/dsl.md @@ -21,19 +21,19 @@ _DSL definition code available [here](https://github.com/DataDog/KubeHound/blob/ | Method | Gremlin equivalent | Example usage | | --------------------------- | ----------------------------------------------------- | --------------------------------------------------------------------------- | -| `.cluster([string...])` | `.hasLabel("Cluster")` | `kh.cluster("kind-kubehound.local")` | -| `.containers([string...])` | `.hasLabel("Container")` | `kh.cluster("kind-kubehound.local").containers("nginx")` | -| `.endpoints([int])` | `.hasLabel("Endpoint")` | `kh.cluster("kind-kubehound.local").endpoints(3)` | -| `.hostMounts([string...])` | `.hasLabel("Volume").has("type", "HostPath")` | `kh.cluster("kind-kubehound.local").hostMounts("/proc")` | -| `.nodes([string...])` | `.hasLabel("Node")` | `kh.cluster("kind-kubehound.local").nodes("control-plane")` | -| `.permissions([string...])` | `.hasLabel("PermissionSet")` | `kh.cluster("kind-kubehound.local").permissions("system::kube-controller")` | -| `.pods([string...])` | `.hasLabel("Pod")` | `kh.cluster("kind-kubehound.local").pods("app-pod")` | -| `.run([string...])` | `.has("runID", P.within(ids)` | `kh.run("01he5ebh73tah762qgdd5k4wqp")` | -| `.services([string...])` | `.hasLabel("Endpoint").has("exposure", EXTERNAL)` | `kh.cluster("kind-kubehound.local").services("app-front-proxy")` | -| `.sas([string...])` | `.hasLabel("Identity").has("type", "ServiceAccount")` | `kh.cluster("kind-kubehound.local").sas("postgres-admin")` | -| `.users([string...])` | `.hasLabel("Identity").has("type", "User")` | `kh.cluster("kind-kubehound.local").users("user@domain.tld")` | -| `.groups([string...])` | `.hasLabel("Identity").has("type", "Group")` | `kh.cluster("kind-kubehound.local").groups("engineering")` | -| `.volumes([string...])` | `.hasLabel("Volume")` | `kh.cluster("kind-kubehound.local").volumes("db-data")` | +| `.cluster([string...])` | `.has("class","Cluster")` | `kh.cluster("kind-kubehound.local")` | +| `.containers([string...])` | `.has("class","Container")` | `kh.cluster("kind-kubehound.local").containers("nginx")` | +| `.endpoints([int])` | `.has("class","Endpoint")` | `kh.cluster("kind-kubehound.local").endpoints(3)` | +| `.hostMounts([string...])` | `.has("class","Volume").has("type", "HostPath")` | `kh.cluster("kind-kubehound.local").hostMounts("/proc")` | +| `.nodes([string...])` | `.has("class","Node")` | `kh.cluster("kind-kubehound.local").nodes("control-plane")` | +| `.permissions([string...])` | `.has("class","PermissionSet")` | `kh.cluster("kind-kubehound.local").permissions("system::kube-controller")` | +| `.pods([string...])` | `.has("class","Pod")` | `kh.cluster("kind-kubehound.local").pods("app-pod")` | +| `.run([string...])` | `.has("runID", P.within(ids))` | `kh.run("01he5ebh73tah762qgdd5k4wqp")` | +| `.services([string...])` | `.has("class","Endpoint").has("exposure", "EXTERNAL")` | `kh.cluster("kind-kubehound.local").services("app-front-proxy")` | +| `.sas([string...])` | `.has("class","Identity").has("type", "ServiceAccount")` | `kh.cluster("kind-kubehound.local").sas("postgres-admin")` | +| `.users([string...])` | `.has("class","Identity").has("type", "User")` | `kh.cluster("kind-kubehound.local").users("user@domain.tld")` | +| `.groups([string...])` | `.has("class","Identity").has("type", "Group")` | `kh.cluster("kind-kubehound.local").groups("engineering")` | +| `.volumes([string...])` | `.has("class","Volume")` | `kh.cluster("kind-kubehound.local").volumes("db-data")` | ### Retrieving attack oriented data diff --git a/docs/queries/gremlin.md b/docs/queries/gremlin.md index fed428b7d..bd94b47c8 100644 --- a/docs/queries/gremlin.md +++ b/docs/queries/gremlin.md @@ -1,92 +1,92 @@ -# Queries +# Queries You can query KubeHound data stored in the JanusGraph database by using the [Gremlin query language](https://docs.janusgraph.org/getting-started/gremlin/). ## Basic queries -``` java title="Count the number of pods in the cluster" -g.V().hasLabel("Pod").count() +```java title="Count the number of pods in the cluster" +g.V().has("class","Pod").count() ``` -``` java title="View all possible container escapes in the cluster" -g.V().hasLabel("Container").outE().inV().hasLabel("Node").path() +```java title="View all possible container escapes in the cluster" +g.V().has("class","Container").outE().inV().has("class","Node").path() ``` -``` java title="List the names of all possible attacks in the cluster" +```java title="List the names of all possible attacks in the cluster" g.E().groupCount().by(label) ``` -``` java title="View all the mounted host path volumes in the cluster" -g.V().hasLabel("Volume").has("type", "HostPath").groupCount().by("sourcePath") +```java title="View all the mounted host path volumes in the cluster" +g.V().has("class","Volume").has("type", "HostPath").groupCount().by("sourcePath") ``` -``` java title="View host path mounts that can be exploited to escape to a node" -g.E().hasLabel("EXPLOIT_HOST_READ", "EXPLOIT_HOST_WRITE").outV().groupCount().by("sourcePath") +```java title="View host path mounts that can be exploited to escape to a node" +g.E().has("class","EXPLOIT_HOST_READ", "EXPLOIT_HOST_WRITE").outV().groupCount().by("sourcePath") ``` -``` java title="View all service endpoints by service name in the cluster" +```java title="View all service endpoints by service name in the cluster" // Leveraging the "EndpointExposureType" enum value to filter only on services // c.f. https://github.com/DataDog/KubeHound/blob/main/pkg/kubehound/models/shared/constants.go -g.V().hasLabel("Endpoint").has("exposure", 3).groupCount().by("serviceEndpoint") +g.V().has("class","Endpoint").has("exposure", 3).groupCount().by("serviceEndpoint") ``` ## Basic attack paths -``` java title="All paths between an endpoint and a node" -g.V().hasLabel("Endpoint").repeat(out().simplePath()).until(hasLabel("Node")).path() +```java title="All paths between an endpoint and a node" +g.V().has("class","Endpoint").repeat(out().simplePath()).until(has("class","Node")).path() ``` -``` java title="All paths (up to 5 hops) between a container and a node" -g.V().hasLabel("Container").repeat(out().simplePath()).until(hasLabel("Node").or().loops().is(5)).hasLabel("Node").path() +```java title="All paths (up to 5 hops) between a container and a node" +g.V().has("class","Container").repeat(out().simplePath()).until(has("class","Node").or().loops().is(5)).has("class","Node").path() ``` -``` java title="All attack paths (up to 6 hops) from any compomised identity (e.g. service account) to a critical asset" -g.V().hasLabel("Identity").repeat(out().simplePath()).until(has("critical", true).or().loops().is(6)).has("critical", true).path().limit(5) +```java title="All attack paths (up to 6 hops) from any compomised identity (e.g. service account) to a critical asset" +g.V().has("class","Identity").repeat(out().simplePath()).until(has("critical", true).or().loops().is(6)).has("critical", true).path().limit(5) ``` -## Attack paths from compromised assets +## Attack paths from compromised assets ### Containers -``` java title="Attack paths (up to 10 hops) from a known breached container to any critical asset" -g.V().hasLabel("Container").has("name", "nsenter-pod").repeat(out().simplePath()).until(has("critical", true).or().loops().is(10)).has("critical", true).path() +```java title="Attack paths (up to 10 hops) from a known breached container to any critical asset" +g.V().has("class","Container").has("name", "nsenter-pod").repeat(out().simplePath()).until(has("critical", true).or().loops().is(10)).has("critical", true).path() ``` -``` java title="Attack paths (up to 10 hops) from a known backdoored container image to any critical asset" -g.V().hasLabel("Container").has("image", TextP.containing("malicious-image")).repeat(out().simplePath()).until(has("critical", true).or().loops().is(10)).has("critical", true).path() +```java title="Attack paths (up to 10 hops) from a known backdoored container image to any critical asset" +g.V().has("class","Container").has("image", TextP.containing("malicious-image")).repeat(out().simplePath()).until(has("critical", true).or().loops().is(10)).has("critical", true).path() ``` ### Credentials -``` java title="Attack paths (up to 10 hops) from a known breached identity to a critical asset" -g.V().hasLabel("Identity").has("name", "compromised-sa").repeat(out().simplePath()).until(has("critical", true).or().loops().is(10)).has("critical", true).path() +```java title="Attack paths (up to 10 hops) from a known breached identity to a critical asset" +g.V().has("class","Identity").has("name", "compromised-sa").repeat(out().simplePath()).until(has("critical", true).or().loops().is(10)).has("critical", true).path() ``` ### Endpoints -``` java title="Attack paths (up to 6 hops) from any endpoint to a critical asset:" -g.V().hasLabel("Endpoint").repeat(out().simplePath()).until(has("critical", true).or().loops().is(6)).has("critical", true).path().limit(5) +```java title="Attack paths (up to 6 hops) from any endpoint to a critical asset:" +g.V().has("class","Endpoint").repeat(out().simplePath()).until(has("critical", true).or().loops().is(6)).has("critical", true).path().limit(5) ``` -``` java title="Attack paths (up to 10 hops) from a known risky endpoint (e.g JMX) to a critical asset" -g.V().hasLabel("Endpoint").has("portName", "jmx").repeat(out().simplePath()).until(has("critical", true).or().loops().is(6)).has("critical", true).path().limit(5) +```java title="Attack paths (up to 10 hops) from a known risky endpoint (e.g JMX) to a critical asset" +g.V().has("class","Endpoint").has("portName", "jmx").repeat(out().simplePath()).until(has("critical", true).or().loops().is(6)).has("critical", true).path().limit(5) ``` ## Risk assessment -``` java title="What is the shortest exploitable path between an exposed service and a critical asset?" -g.V().hasLabel("Endpoint").has("exposure", gte(3)).repeat(out().simplePath()).until(has("critical", true).or().loops().is(7)).has("critical", true).path().count(local).min() +```java title="What is the shortest exploitable path between an exposed service and a critical asset?" +g.V().has("class","Endpoint").has("exposure", gte(3)).repeat(out().simplePath()).until(has("critical", true).or().loops().is(7)).has("critical", true).path().count(local).min() ``` -``` java title="What percentage of external facing services have an exploitable path to a critical asset?" +```java title="What percentage of external facing services have an exploitable path to a critical asset?" // Leveraging the "EndpointExposureType" enum value to filter only on services // c.f. https://github.com/DataDog/KubeHound/blob/main/pkg/kubehound/models/shared/constants.go // Base case -g.V().hasLabel("Endpoint").has("exposure", gte(3)).count() +g.V().has("class","Endpoint").has("exposure", gte(3)).count() // Has a critical path -g.V().hasLabel("Endpoint").has("exposure", gte(3)).where(repeat(out().simplePath()).until(has("critical", true).or().loops().is(10)).has("critical", true).limit(1)).count() +g.V().has("class","Endpoint").has("exposure", gte(3)).where(repeat(out().simplePath()).until(has("critical", true).or().loops().is(10)).has("critical", true).limit(1)).count() ``` ## CVE impact assessment @@ -96,30 +96,30 @@ You can also use KubeHound to determine if workloads in your cluster may be vuln First, evaluate if a known vulnerable image is running in the cluster: ```java -g.V().hasLabel("Container").has("image", TextP.containing("elasticsearch")).groupCount().by("image") +g.V().has("class","Container").has("image", TextP.containing("elasticsearch")).groupCount().by("image") ``` Then, check any exposed services that could be affected and have a path to a critical asset. This helps prioritizing patching and remediation. ```java -g.V().hasLabel("Container").has("image", "dockerhub.com/elasticsearch:7.1.4").where(inE("ENDPOINT_EXPLOIT").outV().has("exposure", gte(3))).where(repeat(out().simplePath()).until(has("critical", true).or().loops().is(10)).has("critical", true).limit(1)) +g.V().has("class","Container").has("image", "dockerhub.com/elasticsearch:7.1.4").where(inE("ENDPOINT_EXPLOIT").outV().has("exposure", gte(3))).where(repeat(out().simplePath()).until(has("critical", true).or().loops().is(10)).has("critical", true).limit(1)) ``` ## Assessing the value of implementing new security controls To verify concrete impact, this can be achieved by comparing the difference in the key risk metrics above, before and after the control change. To simulate the impact of introducing a control (e.g to evaluate ROI), we can add conditions to our path queries. For example if we wanted to evaluate the impact of adding a gatekeeper rule that would deny the use of `hostPID` we can use the following: -``` java title="What percentage level of attack path reduction was achieved by the introduction of a control?" +```java title="What percentage level of attack path reduction was achieved by the introduction of a control?" // Calculate the base case -g.V().hasLabel("Endpoint").has("exposure", gte(3)).repeat(out().simplePath()).until(has("critical", true).or().loops().is(6)).has("critical", true).path().count() +g.V().has("class","Endpoint").has("exposure", gte(3)).repeat(out().simplePath()).until(has("critical", true).or().loops().is(6)).has("critical", true).path().count() // Calculate the impact of preventing CE_NSENTER attack -g.V().hasLabel("Endpoint").has("exposure", gte(3)).repeat(outE().not(hasLabel("CE_NSENTER")).inV().simplePath()).emit().until(has("critical", true).or().loops().is(6)).has("critical", true).path().count() +g.V().has("class","Endpoint").has("exposure", gte(3)).repeat(outE().not(has("class","CE_NSENTER")).inV().simplePath()).emit().until(has("critical", true).or().loops().is(6)).has("critical", true).path().count() ``` -``` java title="What type of control would cut off the largest number of attack paths to a specific asset in the cluster?" +```java title="What type of control would cut off the largest number of attack paths to a specific asset in the cluster?" // We count the number of instances of unique attack paths using -g.V().hasLabel("Container").repeat(outE().inV().simplePath()).emit() +g.V().has("class","Container").repeat(outE().inV().simplePath()).emit() .until(has("critical", true).or().loops().is(6)).has("critical", true) .path().by(label).groupCount().order(local).by(select(values), desc) @@ -136,15 +136,15 @@ g.V().hasLabel("Container").repeat(outE().inV().simplePath()).emit() ## Threat modelling -``` java title="All unique attack paths by labels to a specific asset (here, the cluster-admin role)" -g.V().hasLabel("Container", "Identity") +```java title="All unique attack paths by labels to a specific asset (here, the cluster-admin role)" +g.V().has("class","Container", "Identity") .repeat(out().simplePath()) .until(has("name", "cluster-admin").or().loops().is(5)) -.has("name", "cluster-admin").hasLabel("Role").path().as("p").by(label).dedup().select("p").path() +.has("name", "cluster-admin").has("class","Role").path().as("p").by(label).dedup().select("p").path() ``` -``` java title="All unique attack paths by labels to a any critical asset" -g.V().hasLabel("Container", "Identity") +```java title="All unique attack paths by labels to a any critical asset" +g.V().has("class","Container", "Identity") .repeat(out().simplePath()) .until(has("critical", true).or().loops().is(5)) .has("critical", true).path().as("p").by(label).dedup().select("p").path() @@ -160,10 +160,11 @@ To get started with Gremlin, have a look at the following tutorials: For large clusters it is recommended to add a `limit()` step to **all** queries where the graph output will be examined in the UI to prevent overloading it. An example looking for attack paths possible from a sample of 5 containers would look like: ```go -g.V().hasLabel("Container").limit(5).outE() +g.V().has("class","Container").limit(5).outE() ``` Additional tips: + - For queries to be displayed in the UI, try to limit the output to 1000 elements or less - Enable `large cluster optimizations` via configuration file if queries are returning too slowly -- Try to filter the initial element of queries by namespace/service/app to avoid generating too many results, for instance `g.V().hasLabel("Container").has("namespace", "your-namespace")` +- Try to filter the initial element of queries by namespace/service/app to avoid generating too many results, for instance `g.V().has("class","Container").has("namespace", "your-namespace")` diff --git a/pkg/config/builder.go b/pkg/config/builder.go index dc82d1946..a00f9b10b 100644 --- a/pkg/config/builder.go +++ b/pkg/config/builder.go @@ -3,11 +3,11 @@ package config const ( DefaultEdgeWorkerPoolSize = 5 DefaultEdgeWorkerPoolCapacity = 100 - DefaultEdgeBatchSize = 500 + DefaultEdgeBatchSize = 250 DefaultEdgeBatchSizeSmall = DefaultEdgeBatchSize / 5 DefaultEdgeBatchSizeClusterImpact = 10 - DefaultVertexBatchSize = 500 + DefaultVertexBatchSize = 250 DefaultVertexBatchSizeSmall = DefaultVertexBatchSize / 5 DefaultStopOnError = false diff --git a/pkg/config/config.go b/pkg/config/config.go index 8616a593f..d00018c11 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -80,7 +80,6 @@ func NewKubehoundConfig(ctx context.Context, configPath string, inLine bool) *Ku var cfg *KubehoundConfig switch { case len(configPath) != 0: - l.Info("Loading application configuration from file", log.String("path", configPath)) cfg = MustLoadConfig(ctx, configPath) case inLine: l.Info("Loading application from inline command") @@ -119,6 +118,9 @@ func SetDefaultValues(ctx context.Context, v *viper.Viper) { // Defaults values for JanusGraph v.SetDefault(JanusGraphUrl, DefaultJanusGraphUrl) v.SetDefault(JanusGrapTimeout, DefaultConnectionTimeout) + v.SetDefault(JanusGraphWriterTimeout, defaultJanusGraphWriterTimeout) + v.SetDefault(JanusGraphWriterMaxRetry, defaultJanusGraphWriterMaxRetry) + v.SetDefault(JanusGraphWriterWorkerCount, defaultJanusGraphWriterWorkerCount) // Profiler values v.SetDefault(TelemetryProfilerPeriod, DefaultProfilerPeriod) @@ -157,6 +159,9 @@ func SetEnvOverrides(ctx context.Context, c *viper.Viper) { res = multierror.Append(res, c.BindEnv(MongoUrl, "KH_MONGODB_URL")) res = multierror.Append(res, c.BindEnv(JanusGraphUrl, "KH_JANUSGRAPH_URL")) + res = multierror.Append(res, c.BindEnv(JanusGraphWriterMaxRetry, "KH_JANUSGRAPH_WRITER_MAX_RETRY")) + res = multierror.Append(res, c.BindEnv(JanusGraphWriterTimeout, "KH_JANUSGRAPH_WRITER_TIMEOUT")) + res = multierror.Append(res, c.BindEnv(JanusGraphWriterWorkerCount, "KH_JANUSGRAPH_WRITER_WORKER_COUNT")) res = multierror.Append(res, c.BindEnv(IngestorAPIEndpoint, "KH_INGESTOR_API_ENDPOINT")) res = multierror.Append(res, c.BindEnv(IngestorAPIInsecure, "KH_INGESTOR_API_INSECURE")) @@ -166,6 +171,11 @@ func SetEnvOverrides(ctx context.Context, c *viper.Viper) { res = multierror.Append(res, c.BindEnv(IngestorArchiveName, "KH_INGESTOR_ARCHIVE_NAME")) res = multierror.Append(res, c.BindEnv(IngestorBlobRegion, "KH_INGESTOR_REGION")) + res = multierror.Append(res, c.BindEnv("builder.vertex.batch_size", "KH_BUILDER_VERTEX_BATCH_SIZE")) + res = multierror.Append(res, c.BindEnv("builder.vertex.batch_size_small", "KH_BUILDER_VERTEX_BATCH_SIZE_SMALL")) + res = multierror.Append(res, c.BindEnv("builder.edge.batch_size", "KH_BUILDER_EDGE_BATCH_SIZE")) + res = multierror.Append(res, c.BindEnv("builder.edge.batch_size_small", "KH_BUILDER_EDGE_BATCH_SIZE_SMALL")) + res = multierror.Append(res, c.BindEnv(TelemetryStatsdUrl, "STATSD_URL")) res = multierror.Append(res, c.BindEnv(TelemetryTracerUrl, "TRACE_AGENT_URL")) @@ -196,10 +206,12 @@ func unmarshalConfig(v *viper.Viper) (*KubehoundConfig, error) { // NewConfig creates a new config instance from the provided file using viper. func NewConfig(ctx context.Context, v *viper.Viper, configPath string) (*KubehoundConfig, error) { + l := log.Logger(ctx) // Configure default values SetDefaultValues(ctx, v) // Loading inLine config path + l.Info("Loading application configuration from file", log.String("path", configPath)) v.SetConfigType(DefaultConfigType) v.SetConfigFile(configPath) diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 22d1b7cce..6b91dbd8f 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -52,6 +52,9 @@ func TestMustLoadConfig(t *testing.T) { JanusGraph: JanusGraphConfig{ URL: "ws://localhost:8182/gremlin", ConnectionTimeout: DefaultConnectionTimeout, + WriterTimeout: defaultJanusGraphWriterTimeout, + WriterMaxRetry: defaultJanusGraphWriterMaxRetry, + WriterWorkerCount: defaultJanusGraphWriterWorkerCount, }, Telemetry: TelemetryConfig{ Statsd: StatsdConfig{ @@ -64,21 +67,21 @@ func TestMustLoadConfig(t *testing.T) { }, Builder: BuilderConfig{ Vertex: VertexBuilderConfig{ - BatchSize: 500, - BatchSizeSmall: 100, + BatchSize: 250, + BatchSizeSmall: 50, }, Edge: EdgeBuilderConfig{ LargeClusterOptimizations: DefaultLargeClusterOptimizations, WorkerPoolSize: 5, WorkerPoolCapacity: 100, - BatchSize: 500, - BatchSizeSmall: 100, + BatchSize: 250, + BatchSizeSmall: 50, BatchSizeClusterImpact: 10, }, }, Ingestor: IngestorConfig{ API: IngestorAPIConfig{ - Endpoint: "127.0.0.1:9000", + Endpoint: "", Insecure: false, }, Blob: &BlobConfig{ @@ -126,6 +129,9 @@ func TestMustLoadConfig(t *testing.T) { JanusGraph: JanusGraphConfig{ URL: "ws://localhost:8182/gremlin", ConnectionTimeout: DefaultConnectionTimeout, + WriterTimeout: defaultJanusGraphWriterTimeout, + WriterMaxRetry: defaultJanusGraphWriterMaxRetry, + WriterWorkerCount: defaultJanusGraphWriterWorkerCount, }, Telemetry: TelemetryConfig{ Statsd: StatsdConfig{ @@ -139,7 +145,7 @@ func TestMustLoadConfig(t *testing.T) { Builder: BuilderConfig{ Vertex: VertexBuilderConfig{ BatchSize: 1000, - BatchSizeSmall: 100, + BatchSizeSmall: 50, }, Edge: EdgeBuilderConfig{ LargeClusterOptimizations: true, @@ -152,7 +158,7 @@ func TestMustLoadConfig(t *testing.T) { }, Ingestor: IngestorConfig{ API: IngestorAPIConfig{ - Endpoint: "127.0.0.1:9000", + Endpoint: "", Insecure: false, }, Blob: &BlobConfig{ diff --git a/pkg/config/ingestor.go b/pkg/config/ingestor.go index 324ea762f..5fdf73b1d 100644 --- a/pkg/config/ingestor.go +++ b/pkg/config/ingestor.go @@ -1,7 +1,7 @@ package config const ( - DefaultIngestorAPIEndpoint = "127.0.0.1:9000" + DefaultIngestorAPIEndpoint = "" DefaultIngestorAPIInsecure = false DefaultBucketName = "" // we want to let it empty because we can easily abort if it's not configured DefaultTempDir = "/tmp/kubehound" diff --git a/pkg/config/janusgraph.go b/pkg/config/janusgraph.go index 726c0dd43..1b3e294ad 100644 --- a/pkg/config/janusgraph.go +++ b/pkg/config/janusgraph.go @@ -7,12 +7,24 @@ import ( const ( DefaultJanusGraphUrl = "ws://localhost:8182/gremlin" - JanusGraphUrl = "janusgraph.url" - JanusGrapTimeout = "janusgraph.connection_timeout" + defaultJanusGraphWriterTimeout = 60 * time.Second + defaultJanusGraphWriterMaxRetry = 3 + defaultJanusGraphWriterWorkerCount = 10 + + JanusGraphUrl = "janusgraph.url" + JanusGrapTimeout = "janusgraph.connection_timeout" + JanusGraphWriterTimeout = "janusgraph.writer_timeout" + JanusGraphWriterMaxRetry = "janusgraph.writer_max_retry" + JanusGraphWriterWorkerCount = "janusgraph.writer_worker_count" ) // JanusGraphConfig configures JanusGraph specific parameters. type JanusGraphConfig struct { URL string `mapstructure:"url"` // JanusGraph specific configuration ConnectionTimeout time.Duration `mapstructure:"connection_timeout"` + + // JanusGraph vertex/edge writer configuration + WriterTimeout time.Duration `mapstructure:"writer_timeout"` + WriterMaxRetry int `mapstructure:"writer_max_retry"` + WriterWorkerCount int `mapstructure:"writer_worker_count"` } diff --git a/pkg/kubehound/graph/edge/escape_umh_core_pattern.go b/pkg/kubehound/graph/edge/escape_umh_core_pattern.go index 22340ce58..40ee89093 100644 --- a/pkg/kubehound/graph/edge/escape_umh_core_pattern.go +++ b/pkg/kubehound/graph/edge/escape_umh_core_pattern.go @@ -54,32 +54,32 @@ func (e *EscapeCorePattern) Stream(ctx context.Context, store storedb.Provider, }, { "$lookup": bson.M{ - "as": "procMountContainers", - "from": "volumes", - "let": bson.M{ - "rootContainerId": "$container_id", - }, + "as": "procMountContainers", + "from": "volumes", + "foreignField": "pod_id", + "localField": "pod_id", "pipeline": []bson.M{ { "$match": bson.M{ "$and": bson.A{ - bson.M{"$expr": bson.M{ - "$eq": bson.A{ - "$container_id", "$$rootContainerId", - }, + bson.M{"type": shared.VolumeTypeHost}, + bson.M{"source": bson.M{ + "$in": ProcMountList, }}, + bson.M{"runtime.runID": e.runtime.RunID.String()}, + bson.M{"runtime.cluster": e.runtime.ClusterName}, }, - "type": shared.VolumeTypeHost, - "source": bson.M{ - "$in": ProcMountList, - }, - "runtime.runID": e.runtime.RunID.String(), - "runtime.cluster": e.runtime.ClusterName, }, }, }, }, }, + { + "$unwind": bson.M{ + "path": "$procMountContainers", + "preserveNullAndEmptyArrays": false, + }, + }, { "$project": bson.M{ "_id": 1, diff --git a/pkg/kubehound/graph/edge/escape_var_log_symlink.go b/pkg/kubehound/graph/edge/escape_var_log_symlink.go index e8e993a90..da2744c1b 100644 --- a/pkg/kubehound/graph/edge/escape_var_log_symlink.go +++ b/pkg/kubehound/graph/edge/escape_var_log_symlink.go @@ -60,13 +60,13 @@ func (e *EscapeVarLogSymlink) Traversal() types.EdgeTraversal { return func(source *gremlin.GraphTraversalSource, inserts []any) *gremlin.GraphTraversal { g := source.GetGraphTraversal() // reduce the graph to only these permission sets - g.V(inserts...).HasLabel("PermissionSet"). + g.V(inserts...).Has("class", "PermissionSet"). // get identity vertices InE("PERMISSION_DISCOVER").OutV(). // get container vertices InE("IDENTITY_ASSUME").OutV(). // save container vertices as "c" so we can link to it to the node via CE_VAR_LOG_SYMLINK - HasLabel("Container").As("c"). + Has("class", "Container").As("c"). // Get all the volumes OutE("VOLUME_DISCOVER").InV(). Has("type", shared.VolumeTypeHost). @@ -74,7 +74,7 @@ func (e *EscapeVarLogSymlink) Traversal() types.EdgeTraversal { Has("sourcePath", P.Within("/", "/var", "/var/log")). // get the node related to that volume mount InE("VOLUME_ACCESS").OutV(). - HasLabel("Node").As("n"). + Has("class", "Node").As("n"). AddE("CE_VAR_LOG_SYMLINK").From("c").To("n"). Barrier().Limit(0) diff --git a/pkg/kubehound/graph/edge/pod_create.go b/pkg/kubehound/graph/edge/pod_create.go index e0450ef5d..fd31f3bf0 100644 --- a/pkg/kubehound/graph/edge/pod_create.go +++ b/pkg/kubehound/graph/edge/pod_create.go @@ -86,7 +86,8 @@ func (e *PodCreate) Traversal() types.EdgeTraversal { } else { // In smaller clusters we can still show the (large set of) attack paths generated by this attack g.V(). - HasLabel("Node"). + Has("runID", e.runtime.RunID.String()). + Has("cluster", e.runtime.ClusterName). Has("class", "Node"). As("n"). V(inserts...). diff --git a/pkg/kubehound/graph/edge/pod_exec.go b/pkg/kubehound/graph/edge/pod_exec.go index a9a3d6b91..e200b9227 100644 --- a/pkg/kubehound/graph/edge/pod_exec.go +++ b/pkg/kubehound/graph/edge/pod_exec.go @@ -86,7 +86,8 @@ func (e *PodExec) Traversal() types.EdgeTraversal { } else { // In smaller clusters we can still show the (large set of) attack paths generated by this attack g.V(). - HasLabel("Pod"). + Has("runID", e.runtime.RunID.String()). + Has("cluster", e.runtime.ClusterName). Has("class", "Pod"). As("p"). V(inserts...). diff --git a/pkg/kubehound/graph/edge/pod_patch.go b/pkg/kubehound/graph/edge/pod_patch.go index ff4627bf8..18910368a 100644 --- a/pkg/kubehound/graph/edge/pod_patch.go +++ b/pkg/kubehound/graph/edge/pod_patch.go @@ -86,7 +86,8 @@ func (e *PodPatch) Traversal() types.EdgeTraversal { } else { // In smaller clusters we can still show the (large set of) attack paths generated by this attack g.V(). - HasLabel("Pod"). + Has("runID", e.runtime.RunID.String()). + Has("cluster", e.runtime.ClusterName). Has("class", "Pod"). As("p"). V(inserts...). diff --git a/pkg/kubehound/graph/edge/role_bind_crb_cr_cr.go b/pkg/kubehound/graph/edge/role_bind_crb_cr_cr.go index ca84dfac0..c0c2c149c 100644 --- a/pkg/kubehound/graph/edge/role_bind_crb_cr_cr.go +++ b/pkg/kubehound/graph/edge/role_bind_crb_cr_cr.go @@ -7,7 +7,6 @@ import ( "github.com/DataDog/KubeHound/pkg/kubehound/graph/adapter" "github.com/DataDog/KubeHound/pkg/kubehound/graph/types" "github.com/DataDog/KubeHound/pkg/kubehound/models/converter" - "github.com/DataDog/KubeHound/pkg/kubehound/risk" "github.com/DataDog/KubeHound/pkg/kubehound/storage/cache" "github.com/DataDog/KubeHound/pkg/kubehound/storage/storedb" "github.com/DataDog/KubeHound/pkg/kubehound/store/collections" @@ -53,19 +52,16 @@ func (e *RoleBindCrbCrCr) Traversal() types.EdgeTraversal { return func(source *gremlin.GraphTraversalSource, inserts []any) *gremlin.GraphTraversal { g := source.GetGraphTraversal() - // Gathering all sensitives roles - sensitiveRoles := make([]string, 0, len(risk.CriticalRoleMap)) - for k := range risk.CriticalRoleMap { - sensitiveRoles = append(sensitiveRoles, k) - } - if e.cfg.LargeClusterOptimizations { // For larger clusters simply target specific roles to reduce number of attack paths g.V(). - HasLabel("PermissionSet"). + Has("runID", e.runtime.RunID.String()). + Has("cluster", e.runtime.ClusterName). + Has("class", "PermissionSet"). Has("isNamespaced", false). // Temporary measure, until we scan and flag for sensitive roles - Has("role", P.Within(sensitiveRoles)). + Has("critical", true). + // Has("role", P.Within(sensitiveRoles)). As("r"). V(inserts...). Has("critical", false). @@ -75,7 +71,9 @@ func (e *RoleBindCrbCrCr) Traversal() types.EdgeTraversal { } else { // In smaller clusters we can still show the (large set of) attack paths generated by this attack g.V(). - HasLabel("PermissionSet"). + Has("runID", e.runtime.RunID.String()). + Has("cluster", e.runtime.ClusterName). + Has("class", "PermissionSet"). Has("isNamespaced", false). As("i"). V(inserts...). diff --git a/pkg/kubehound/graph/edge/role_bind_crb_cr_r.go b/pkg/kubehound/graph/edge/role_bind_crb_cr_r.go index ecb09c7d2..3f2ea6702 100644 --- a/pkg/kubehound/graph/edge/role_bind_crb_cr_r.go +++ b/pkg/kubehound/graph/edge/role_bind_crb_cr_r.go @@ -7,7 +7,6 @@ import ( "github.com/DataDog/KubeHound/pkg/kubehound/graph/adapter" "github.com/DataDog/KubeHound/pkg/kubehound/graph/types" "github.com/DataDog/KubeHound/pkg/kubehound/models/converter" - "github.com/DataDog/KubeHound/pkg/kubehound/risk" "github.com/DataDog/KubeHound/pkg/kubehound/storage/cache" "github.com/DataDog/KubeHound/pkg/kubehound/storage/storedb" "github.com/DataDog/KubeHound/pkg/kubehound/store/collections" @@ -53,19 +52,16 @@ func (e *RoleBindCrbCrR) Traversal() types.EdgeTraversal { return func(source *gremlin.GraphTraversalSource, inserts []any) *gremlin.GraphTraversal { g := source.GetGraphTraversal() - // Gathering all sensitives roles - sensitiveRoles := make([]string, 0, len(risk.CriticalRoleMap)) - for k := range risk.CriticalRoleMap { - sensitiveRoles = append(sensitiveRoles, k) - } - if e.cfg.LargeClusterOptimizations { // For larger clusters simply target specific roles to reduce number of attack paths g.V(). - HasLabel("PermissionSet"). + Has("runID", e.runtime.RunID.String()). + Has("cluster", e.runtime.ClusterName). + Has("class", "PermissionSet"). Has("isNamespaced", true). // Temporary measure, until we scan and flag for sensitive roles - Has("role", P.Within(sensitiveRoles)). + Has("critical", true). + // Has("role", P.Within(sensitiveRoles)). As("r"). V(inserts...). Has("critical", false). @@ -75,7 +71,9 @@ func (e *RoleBindCrbCrR) Traversal() types.EdgeTraversal { } else { // In smaller clusters we can still show the (large set of) attack paths generated by this attack g.V(). - HasLabel("PermissionSet"). + Has("runID", e.runtime.RunID.String()). + Has("cluster", e.runtime.ClusterName). + Has("class", "PermissionSet"). Has("isNamespaced", true). As("i"). V(inserts...). diff --git a/pkg/kubehound/graph/edge/token_bruteforce.go b/pkg/kubehound/graph/edge/token_bruteforce.go index e40fc36f9..2f9b2a300 100644 --- a/pkg/kubehound/graph/edge/token_bruteforce.go +++ b/pkg/kubehound/graph/edge/token_bruteforce.go @@ -64,7 +64,9 @@ func (e *TokenBruteforce) Traversal() types.EdgeTraversal { if e.cfg.LargeClusterOptimizations { // For larger clusters simply target the system:masters group to reduce redundant attack paths g.V(). - HasLabel("Identity"). + Has("runID", e.runtime.RunID.String()). + Has("cluster", e.runtime.ClusterName). + Has("class", "Identity"). Has("name", "system:masters"). As("i"). V(inserts...). @@ -75,7 +77,8 @@ func (e *TokenBruteforce) Traversal() types.EdgeTraversal { } else { // In smaller clusters we can still show the (large set of) attack paths generated by this attack g.V(). - HasLabel("Identity"). + Has("runID", e.runtime.RunID.String()). + Has("cluster", e.runtime.ClusterName). Has("class", "Identity"). As("i"). V(inserts...). diff --git a/pkg/kubehound/graph/edge/token_list.go b/pkg/kubehound/graph/edge/token_list.go index beeea1999..e3b97b9fb 100644 --- a/pkg/kubehound/graph/edge/token_list.go +++ b/pkg/kubehound/graph/edge/token_list.go @@ -64,7 +64,9 @@ func (e *TokenList) Traversal() types.EdgeTraversal { if e.cfg.LargeClusterOptimizations { // For larger clusters simply target the system:masters group to reduce redundant attack paths g.V(). - HasLabel("Identity"). + Has("runID", e.runtime.RunID.String()). + Has("cluster", e.runtime.ClusterName). + Has("class", "Identity"). Has("name", "system:masters"). As("i"). V(inserts...). @@ -75,7 +77,8 @@ func (e *TokenList) Traversal() types.EdgeTraversal { } else { // In smaller clusters we can still show the (large set of) attack paths generated by this attack g.V(). - HasLabel("Identity"). + Has("runID", e.runtime.RunID.String()). + Has("cluster", e.runtime.ClusterName). Has("class", "Identity"). As("i"). V(inserts...). diff --git a/pkg/kubehound/storage/graphdb/errors.go b/pkg/kubehound/storage/graphdb/errors.go new file mode 100644 index 000000000..fe6772ee8 --- /dev/null +++ b/pkg/kubehound/storage/graphdb/errors.go @@ -0,0 +1,22 @@ +package graphdb + +import "fmt" + +// batchWriterError is an error type that wraps an error and indicates whether the +// error is retryable. +type batchWriterError struct { + err error + retryable bool +} + +func (e batchWriterError) Error() string { + if e.err == nil { + return fmt.Sprintf("batch writer error (retriable:%v)", e.retryable) + } + + return fmt.Sprintf("batch writer error (retriable:%v): %v", e.retryable, e.err.Error()) +} + +func (e batchWriterError) Unwrap() error { + return e.err +} diff --git a/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go b/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go index e8c2619ae..47fe37ee3 100644 --- a/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go +++ b/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go @@ -6,6 +6,7 @@ import ( "fmt" "sync" "sync/atomic" + "time" "github.com/DataDog/KubeHound/pkg/kubehound/graph/edge" "github.com/DataDog/KubeHound/pkg/kubehound/graph/types" @@ -29,21 +30,24 @@ type JanusGraphEdgeWriter struct { gremlin types.EdgeTraversal // Gremlin traversal generator function drc *gremlingo.DriverRemoteConnection // Gremlin driver remote connection traversalSource *gremlingo.GraphTraversalSource // Transacted graph traversal source - inserts []any // Object data to be inserted in the graph - mu sync.Mutex // Mutex protecting access to the inserts array - consumerChan chan []any // Channel consuming inserts for async writing writingInFlight *sync.WaitGroup // Wait group tracking current unfinished writes - batchSize int // Batchsize of graph DB inserts qcounter int32 // Track items queued wcounter int32 // Track items writtn tags []string // Telemetry tags + writerTimeout time.Duration // Timeout for the writer + maxRetry int // Maximum number of retries for failed writes + mb *microBatcher // Micro batcher to batch writes } // NewJanusGraphAsyncEdgeWriter creates a new bulk edge writer instance. func NewJanusGraphAsyncEdgeWriter(ctx context.Context, drc *gremlingo.DriverRemoteConnection, - e edge.Builder, opts ...WriterOption) (*JanusGraphEdgeWriter, error) { - - options := &writerOptions{} + e edge.Builder, opts ...WriterOption, +) (*JanusGraphEdgeWriter, error) { + options := &writerOptions{ + WriterTimeout: defaultWriterTimeout, + MaxRetry: defaultMaxRetry, + WriterWorkerCount: defaultWriterWorkerCount, + } for _, opt := range opts { opt(options) } @@ -53,54 +57,92 @@ func NewJanusGraphAsyncEdgeWriter(ctx context.Context, drc *gremlingo.DriverRemo builder: builder, gremlin: e.Traversal(), drc: drc, - inserts: make([]any, 0, e.BatchSize()), traversalSource: gremlingo.Traversal_().WithRemote(drc), - batchSize: e.BatchSize(), writingInFlight: &sync.WaitGroup{}, - consumerChan: make(chan []any, e.BatchSize()*channelSizeBatchFactor), tags: append(options.Tags, tag.Label(e.Label()), tag.Builder(builder)), + writerTimeout: options.WriterTimeout, + maxRetry: options.MaxRetry, } - jw.startBackgroundWriter(ctx) + // Create a new micro batcher to batch the inserts with split and retry logic. + jw.mb = newMicroBatcher(log.Trace(ctx), e.BatchSize(), options.WriterWorkerCount, func(ctx context.Context, a []any) error { + // Increment the writingInFlight wait group to track the number of writes in progress. + jw.writingInFlight.Add(1) + defer jw.writingInFlight.Done() + + // Try to write the batch to the graph DB. + if err := jw.batchWrite(ctx, a); err != nil { + var bwe *batchWriterError + if errors.As(err, &bwe) && bwe.retryable { + // If the write operation failed and is retryable, split the batch and retry. + return jw.splitAndRetry(ctx, 0, a) + } + + return err + } + + return nil + }) + jw.mb.Start(ctx) return &jw, nil } -// startBackgroundWriter starts a background go routine -func (jgv *JanusGraphEdgeWriter) startBackgroundWriter(ctx context.Context) { - go func() { - for { - select { - case data := <-jgv.consumerChan: - // closing the channel shoud stop the go routine - if data == nil { - return - } - - _ = statsd.Count(ctx, metric.BackgroundWriterCall, 1, jgv.tags, 1) - err := jgv.batchWrite(ctx, data) - if err != nil { - log.Trace(ctx).Errorf("write data in background batch writer: %v", err) - } - - _ = statsd.Decr(ctx, metric.QueueSize, jgv.tags, 1) - case <-ctx.Done(): - log.Trace(ctx).Info("Closed background janusgraph worker on context cancel") - - return +// retrySplitAndRequeue will split the batch into smaller chunks and requeue them for writing. +func (jgv *JanusGraphEdgeWriter) splitAndRetry(ctx context.Context, retryCount int, payload []any) error { + _ = statsd.Count(ctx, metric.RetryWriterCall, 1, jgv.tags, 1) + + // If we have reached the maximum number of retries, return an error. + if retryCount >= jgv.maxRetry { + return fmt.Errorf("max retry count reached: %d", retryCount) + } + + // Compute the new batch size. + newBatchSize := len(payload) / 2 + + log.Trace(ctx).Warnf("Retrying write operation with smaller edge batch (n:%d -> %d, r:%d)", len(payload), newBatchSize, retryCount) + + var leftErr, rightErr error + + // Split the batch into smaller chunks and retry them. + if len(payload[:newBatchSize]) > 0 { + if leftErr = jgv.batchWrite(ctx, payload[:newBatchSize]); leftErr == nil { + var bwe *batchWriterError + if errors.As(leftErr, &bwe) && bwe.retryable { + return jgv.splitAndRetry(ctx, retryCount+1, payload[:newBatchSize]) + } + } + } + + // Process the right side of the batch. + if len(payload[newBatchSize:]) > 0 { + if rightErr = jgv.batchWrite(ctx, payload[newBatchSize:]); rightErr != nil { + var bwe *batchWriterError + if errors.As(rightErr, &bwe) && bwe.retryable { + return jgv.splitAndRetry(ctx, retryCount+1, payload[newBatchSize:]) } } - }() + } + + // Return the first error encountered. + switch { + case leftErr != nil && rightErr != nil: + return fmt.Errorf("left: %w, right: %w", leftErr, rightErr) + case leftErr != nil: + return leftErr + case rightErr != nil: + return rightErr + } + + return nil } // batchWrite will write a batch of entries into the graph DB and block until the write completes. -// Callers are responsible for doing an Add(1) to the writingInFlight wait group to ensure proper synchronization. func (jgv *JanusGraphEdgeWriter) batchWrite(ctx context.Context, data []any) error { span, ctx := span.SpanRunFromContext(ctx, span.JanusGraphBatchWrite) span.SetTag(tag.LabelTag, jgv.builder) var err error defer func() { span.Finish(tracer.WithError(err)) }() - defer jgv.writingInFlight.Done() datalen := len(data) _ = statsd.Count(ctx, metric.EdgeWrite, int64(datalen), jgv.tags, 1) @@ -109,17 +151,28 @@ func (jgv *JanusGraphEdgeWriter) batchWrite(ctx context.Context, data []any) err op := jgv.gremlin(jgv.traversalSource, data) promise := op.Iterate() - err = <-promise - if err != nil { - return fmt.Errorf("%s edge insert: %w", jgv.builder, err) + + // Wait for the write operation to complete or timeout. + select { + case <-ctx.Done(): + // If the context is cancelled, return the error. + return ctx.Err() + case <-time.After(jgv.writerTimeout): + // If the write operation takes too long, return an error. + return &batchWriterError{ + err: errors.New("edge write operation timed out"), + retryable: true, + } + case err := <-promise: + if err != nil { + return fmt.Errorf("%s edge insert: %w", jgv.builder, err) + } } return nil } func (jgv *JanusGraphEdgeWriter) Close(ctx context.Context) error { - close(jgv.consumerChan) - return nil } @@ -131,29 +184,17 @@ func (jgv *JanusGraphEdgeWriter) Flush(ctx context.Context) error { var err error defer func() { span.Finish(tracer.WithError(err)) }() - jgv.mu.Lock() - defer jgv.mu.Unlock() - if jgv.traversalSource == nil { return errors.New("janusGraph traversalSource is not initialized") } - if len(jgv.inserts) != 0 { - _ = statsd.Incr(ctx, metric.FlushWriterCall, jgv.tags, 1) - - jgv.writingInFlight.Add(1) - err = jgv.batchWrite(ctx, jgv.inserts) - if err != nil { - log.Trace(ctx).Errorf("batch write %s: %+v", jgv.builder, err) - jgv.writingInFlight.Wait() - - return err - } - - log.Trace(ctx).Debugf("Done flushing %s writes. clearing the queue", jgv.builder) - jgv.inserts = nil + // Flush the micro batcher. + err = jgv.mb.Flush(ctx) + if err != nil { + return fmt.Errorf("micro batcher flush: %w", err) } + // Wait for all writes to complete. jgv.writingInFlight.Wait() log.Trace(ctx).Debugf("Edge writer %d %s queued", jgv.qcounter, jgv.builder) @@ -163,23 +204,7 @@ func (jgv *JanusGraphEdgeWriter) Flush(ctx context.Context) error { } func (jgv *JanusGraphEdgeWriter) Queue(ctx context.Context, v any) error { - jgv.mu.Lock() - defer jgv.mu.Unlock() - atomic.AddInt32(&jgv.qcounter, 1) - jgv.inserts = append(jgv.inserts, v) - if len(jgv.inserts) > jgv.batchSize { - copied := make([]any, len(jgv.inserts)) - copy(copied, jgv.inserts) - - jgv.writingInFlight.Add(1) - jgv.consumerChan <- copied - _ = statsd.Incr(ctx, metric.QueueSize, jgv.tags, 1) - - // cleanup the ops array after we have copied it to the channel - jgv.inserts = nil - } - - return nil + return jgv.mb.Enqueue(ctx, v) } diff --git a/pkg/kubehound/storage/graphdb/janusgraph_provider.go b/pkg/kubehound/storage/graphdb/janusgraph_provider.go index 8a4d04dd9..c2a2673c9 100644 --- a/pkg/kubehound/storage/graphdb/janusgraph_provider.go +++ b/pkg/kubehound/storage/graphdb/janusgraph_provider.go @@ -130,6 +130,9 @@ func (jgp *JanusGraphProvider) VertexWriter(ctx context.Context, v vertex.Builde c cache.CacheProvider, opts ...WriterOption) (AsyncVertexWriter, error) { opts = append(opts, WithTags(jgp.tags)) + opts = append(opts, WithWriterWorkerCount(jgp.cfg.JanusGraph.WriterWorkerCount)) + opts = append(opts, WithWriterTimeout(jgp.cfg.JanusGraph.WriterTimeout)) + opts = append(opts, WithWriterMaxRetry(jgp.cfg.JanusGraph.WriterMaxRetry)) return NewJanusGraphAsyncVertexWriter(ctx, jgp.drc, v, c, opts...) } @@ -137,6 +140,9 @@ func (jgp *JanusGraphProvider) VertexWriter(ctx context.Context, v vertex.Builde // EdgeWriter creates a new AsyncEdgeWriter instance to enable asynchronous bulk inserts of edges. func (jgp *JanusGraphProvider) EdgeWriter(ctx context.Context, e edge.Builder, opts ...WriterOption) (AsyncEdgeWriter, error) { opts = append(opts, WithTags(jgp.tags)) + opts = append(opts, WithWriterWorkerCount(jgp.cfg.JanusGraph.WriterWorkerCount)) + opts = append(opts, WithWriterTimeout(jgp.cfg.JanusGraph.WriterTimeout)) + opts = append(opts, WithWriterMaxRetry(jgp.cfg.JanusGraph.WriterMaxRetry)) return NewJanusGraphAsyncEdgeWriter(ctx, jgp.drc, e, opts...) } @@ -154,7 +160,7 @@ func (jgp *JanusGraphProvider) Clean(ctx context.Context, cluster string) error span, ctx := span.SpanRunFromContext(ctx, span.IngestorClean) defer func() { span.Finish(tracer.WithError(err)) }() l := log.Trace(ctx) - l.Infof("Cleaning cluster", log.FieldClusterKey, cluster) + l.Info("Cleaning cluster", log.String(log.FieldClusterKey, cluster)) g := gremlin.Traversal_().WithRemote(jgp.drc) tx := g.Tx() defer tx.Close() diff --git a/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go b/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go index 43396746c..5facc1b71 100644 --- a/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go +++ b/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go @@ -6,6 +6,7 @@ import ( "fmt" "sync" "sync/atomic" + "time" "github.com/DataDog/KubeHound/pkg/kubehound/graph/types" "github.com/DataDog/KubeHound/pkg/kubehound/graph/vertex" @@ -27,22 +28,25 @@ type JanusGraphVertexWriter struct { gremlin types.VertexTraversal // Gremlin traversal generator function drc *gremlin.DriverRemoteConnection // Gremlin driver remote connection traversalSource *gremlin.GraphTraversalSource // Transacted graph traversal source - inserts []any // Object data to be inserted in the graph - mu sync.Mutex // Mutex protecting access to the inserts array - consumerChan chan []any // Channel consuming inserts for async writing writingInFlight *sync.WaitGroup // Wait group tracking current unfinished writes - batchSize int // Batchsize of graph DB inserts qcounter int32 // Track items queued wcounter int32 // Track items writtn tags []string // Telemetry tags cache cache.AsyncWriter // Cache writer to cache store id -> vertex id mappings + writerTimeout time.Duration // Timeout for the writer + maxRetry int // Maximum number of retries for failed writes + mb *microBatcher // Micro batcher to batch writes } // NewJanusGraphAsyncVertexWriter creates a new bulk vertex writer instance. func NewJanusGraphAsyncVertexWriter(ctx context.Context, drc *gremlin.DriverRemoteConnection, - v vertex.Builder, c cache.CacheProvider, opts ...WriterOption) (*JanusGraphVertexWriter, error) { - - options := &writerOptions{} + v vertex.Builder, c cache.CacheProvider, opts ...WriterOption, +) (*JanusGraphVertexWriter, error) { + options := &writerOptions{ + WriterTimeout: defaultWriterTimeout, + MaxRetry: defaultMaxRetry, + WriterWorkerCount: defaultWriterWorkerCount, + } for _, opt := range opts { opt(options) } @@ -56,45 +60,36 @@ func NewJanusGraphAsyncVertexWriter(ctx context.Context, drc *gremlin.DriverRemo builder: v.Label(), gremlin: v.Traversal(), drc: drc, - inserts: make([]any, 0, v.BatchSize()), traversalSource: gremlin.Traversal_().WithRemote(drc), - batchSize: v.BatchSize(), writingInFlight: &sync.WaitGroup{}, - consumerChan: make(chan []any, v.BatchSize()*channelSizeBatchFactor), tags: append(options.Tags, tag.Label(v.Label()), tag.Builder(v.Label())), cache: cw, + writerTimeout: options.WriterTimeout, + maxRetry: options.MaxRetry, } - jw.startBackgroundWriter(ctx) - - return &jw, nil -} - -// startBackgroundWriter starts a background go routine -func (jgv *JanusGraphVertexWriter) startBackgroundWriter(ctx context.Context) { - go func() { - for { - select { - case data := <-jgv.consumerChan: - // closing the channel shoud stop the go routine - if data == nil { - return - } - - _ = statsd.Count(ctx, metric.BackgroundWriterCall, 1, jgv.tags, 1) - err := jgv.batchWrite(ctx, data) - if err != nil { - log.Trace(ctx).Errorf("Write data in background batch writer: %v", err) - } - - _ = statsd.Decr(ctx, metric.QueueSize, jgv.tags, 1) - case <-ctx.Done(): - log.Trace(ctx).Info("Closed background janusgraph worker on context cancel") - - return + // Create a new micro batcher to batch the inserts with split and retry logic. + jw.mb = newMicroBatcher(log.Trace(ctx), v.BatchSize(), options.WriterWorkerCount, func(ctx context.Context, a []any) error { + // Increment the writingInFlight wait group to track the number of writes in progress. + jw.writingInFlight.Add(1) + defer jw.writingInFlight.Done() + + // Try to write the batch to the graph DB. + if err := jw.batchWrite(ctx, a); err != nil { + var bwe *batchWriterError + if errors.As(err, &bwe) && bwe.retryable { + // If the write operation failed and is retryable, split the batch and retry. + return jw.splitAndRetry(ctx, 0, a) } + + return err } - }() + + return nil + }) + jw.mb.Start(ctx) + + return &jw, nil } func (jgv *JanusGraphVertexWriter) cacheIds(ctx context.Context, idMap []*gremlin.Result) error { @@ -121,41 +116,127 @@ func (jgv *JanusGraphVertexWriter) cacheIds(ctx context.Context, idMap []*gremli } // batchWrite will write a batch of entries into the graph DB and block until the write completes. -// Callers are responsible for doing an Add(1) to the writingInFlight wait group to ensure proper synchronization. func (jgv *JanusGraphVertexWriter) batchWrite(ctx context.Context, data []any) error { + _ = statsd.Count(ctx, metric.BackgroundWriterCall, 1, jgv.tags, 1) + span, ctx := span.SpanRunFromContext(ctx, span.JanusGraphBatchWrite) span.SetTag(tag.LabelTag, jgv.builder) var err error defer func() { span.Finish(tracer.WithError(err)) }() - defer jgv.writingInFlight.Done() datalen := len(data) _ = statsd.Count(ctx, metric.VertexWrite, int64(datalen), jgv.tags, 1) log.Trace(ctx).Debugf("Batch write JanusGraphVertexWriter with %d elements", datalen) atomic.AddInt32(&jgv.wcounter, int32(datalen)) //nolint:gosec // disable G115 - op := jgv.gremlin(jgv.traversalSource, data) - raw, err := op.Project("id", "storeID"). - By(gremlin.T.Id). - By("storeID"). - ToList() - if err != nil { - return fmt.Errorf("%s vertex insert: %w", jgv.builder, err) + // Create a channel to signal the completion of the write operation. + errChan := make(chan error, 1) + + // We need to ensure that the write operation is completed within a certain + // time frame to avoid blocking the writer indefinitely if the backend + // is unresponsive. + go func() { + // Create a new gremlin operation to insert the data into the graph. + op := jgv.gremlin(jgv.traversalSource, data) + raw, err := op.Project("id", "storeID"). + By(gremlin.T.Id). + By("storeID"). + ToList() + if err != nil { + errChan <- fmt.Errorf("%s vertex insert: %w", jgv.builder, err) + + return + } + + // Gremlin will return a list of maps containing and vertex id and store + // id values for each vertex inserted. + // We need to parse each map entry and add to our cache. + if err = jgv.cacheIds(ctx, raw); err != nil { + errChan <- fmt.Errorf("cache ids: %w", err) + + return + } + + errChan <- nil + }() + + // Wait for the write operation to complete or timeout. + select { + case <-ctx.Done(): + // If the context is cancelled, return the error. + return ctx.Err() + case <-time.After(jgv.writerTimeout): + // If the write operation takes too long, return an error. + return &batchWriterError{ + err: errors.New("vertex write operation timed out"), + retryable: true, + } + case err = <-errChan: + if err != nil { + return fmt.Errorf("janusgraph batch write: %w", err) + } } - // Gremlin will return a list of maps containing and vertex id and store id values for each vertex inserted. - // We need to parse each map entry and add to our cache. - if err = jgv.cacheIds(ctx, raw); err != nil { - return err + return nil +} + +// retrySplitAndRequeue will split the batch into smaller chunks and requeue them for writing. +func (jgv *JanusGraphVertexWriter) splitAndRetry(ctx context.Context, retryCount int, payload []any) error { + _ = statsd.Count(ctx, metric.RetryWriterCall, 1, jgv.tags, 1) + + // If we have reached the maximum number of retries, return an error. + if retryCount >= jgv.maxRetry { + return fmt.Errorf("max retry count reached: %d", retryCount) + } + + // Compute the new batch size. + newBatchSize := len(payload) / 2 + + log.Trace(ctx).Warnf("Retrying write operation with smaller vertex batch (n:%d -> %d, r:%d)", len(payload), newBatchSize, retryCount) + + var leftErr, rightErr error + + // Split the batch into smaller chunks and retry them. + if len(payload[:newBatchSize]) > 0 { + if leftErr = jgv.batchWrite(ctx, payload[:newBatchSize]); leftErr == nil { + var bwe *batchWriterError + if errors.As(leftErr, &bwe) && bwe.retryable { + return jgv.splitAndRetry(ctx, retryCount+1, payload[:newBatchSize]) + } + } + } + + // Process the right side of the batch. + if len(payload[newBatchSize:]) > 0 { + if rightErr = jgv.batchWrite(ctx, payload[newBatchSize:]); rightErr != nil { + var bwe *batchWriterError + if errors.As(rightErr, &bwe) && bwe.retryable { + return jgv.splitAndRetry(ctx, retryCount+1, payload[newBatchSize:]) + } + } + } + + // Return the first error encountered. + switch { + case leftErr != nil && rightErr != nil: + return fmt.Errorf("left: %w, right: %w", leftErr, rightErr) + case leftErr != nil: + return leftErr + case rightErr != nil: + return rightErr } return nil } func (jgv *JanusGraphVertexWriter) Close(ctx context.Context) error { - close(jgv.consumerChan) + if jgv.cache != nil { + if err := jgv.cache.Close(ctx); err != nil { + return fmt.Errorf("closing cache: %w", err) + } + } - return jgv.cache.Close(ctx) + return nil } // Flush triggers writes of any remaining items in the queue. @@ -166,29 +247,17 @@ func (jgv *JanusGraphVertexWriter) Flush(ctx context.Context) error { var err error defer func() { span.Finish(tracer.WithError(err)) }() - jgv.mu.Lock() - defer jgv.mu.Unlock() - if jgv.traversalSource == nil { return errors.New("janusGraph traversalSource is not initialized") } - if len(jgv.inserts) != 0 { - _ = statsd.Incr(ctx, metric.FlushWriterCall, jgv.tags, 1) - - jgv.writingInFlight.Add(1) - err = jgv.batchWrite(ctx, jgv.inserts) - if err != nil { - log.Trace(ctx).Errorf("batch write %s: %+v", jgv.builder, err) - jgv.writingInFlight.Wait() - - return err - } - - log.Trace(ctx).Debugf("Done flushing %s writes. clearing the queue", jgv.builder) - jgv.inserts = nil + // Flush the micro batcher. + err = jgv.mb.Flush(ctx) + if err != nil { + return fmt.Errorf("micro batcher flush: %w", err) } + // Wait for all writes to complete. jgv.writingInFlight.Wait() err = jgv.cache.Flush(ctx) @@ -203,23 +272,7 @@ func (jgv *JanusGraphVertexWriter) Flush(ctx context.Context) error { } func (jgv *JanusGraphVertexWriter) Queue(ctx context.Context, v any) error { - jgv.mu.Lock() - defer jgv.mu.Unlock() - atomic.AddInt32(&jgv.qcounter, 1) - jgv.inserts = append(jgv.inserts, v) - - if len(jgv.inserts) > jgv.batchSize { - copied := make([]any, len(jgv.inserts)) - copy(copied, jgv.inserts) - - jgv.writingInFlight.Add(1) - jgv.consumerChan <- copied - _ = statsd.Incr(ctx, metric.QueueSize, jgv.tags, 1) - // cleanup the ops array after we have copied it to the channel - jgv.inserts = nil - } - - return nil + return jgv.mb.Enqueue(ctx, v) } diff --git a/pkg/kubehound/storage/graphdb/microbatcher.go b/pkg/kubehound/storage/graphdb/microbatcher.go new file mode 100644 index 000000000..889622ac2 --- /dev/null +++ b/pkg/kubehound/storage/graphdb/microbatcher.go @@ -0,0 +1,183 @@ +package graphdb + +import ( + "context" + "errors" + "sync" + "sync/atomic" + + "github.com/DataDog/KubeHound/pkg/telemetry/log" +) + +// batchItem is a single item in the batch writer queue that contains the data +// to be written and the number of retries. +type batchItem struct { + data []any + retryCount int +} + +// microBatcher is a utility to batch items and flush them when the batch is full. +type microBatcher struct { + // batchSize is the maximum number of items to batch. + batchSize int + // items is the current item accumulator for the batch. This is reset after + // the batch is flushed. + items []any + // flush is the function to call to flush the batch. + flushFunc func(context.Context, []any) error + // itemChan is the channel to receive items to batch. + itemChan chan any + // batchChan is the channel to send batches to. + batchChan chan batchItem + // workerCount is the number of workers to process the batch. + workerCount int + // workerGroup is the worker group to wait for the workers to finish. + workerGroup *sync.WaitGroup + // shuttingDown is a flag to indicate if the batcher is shutting down. + shuttingDown atomic.Bool + // logger is the logger to use for logging. + logger log.LoggerI +} + +// NewMicroBatcher creates a new micro batcher. +func newMicroBatcher(logger log.LoggerI, batchSize int, workerCount int, flushFunc func(context.Context, []any) error) *microBatcher { + return µBatcher{ + logger: logger, + batchSize: batchSize, + items: make([]any, 0, batchSize), + flushFunc: flushFunc, + itemChan: make(chan any, batchSize), + batchChan: make(chan batchItem, batchSize), + workerCount: workerCount, + workerGroup: nil, // Set in Start. + } +} + +// Flush flushes the current batch and waits for the batch writer to finish. +func (mb *microBatcher) Flush(_ context.Context) error { + // Set the shutting down flag to true. + if !mb.shuttingDown.CompareAndSwap(false, true) { + return errors.New("batcher is already shutting down") + } + + // Closing the item channel to signal the accumulator to stop and flush the batch. + close(mb.itemChan) + + // Wait for the workers to finish. + if mb.workerGroup != nil { + mb.workerGroup.Wait() + } + + return nil +} + +// Enqueue adds an item to the batch processor. +func (mb *microBatcher) Enqueue(ctx context.Context, item any) error { + // If the batcher is shutting down, return an error immediately. + if mb.shuttingDown.Load() { + return errors.New("batcher is shutting down") + } + + select { + case <-ctx.Done(): + // If the context is cancelled, return. + return ctx.Err() + case mb.itemChan <- item: + } + + return nil +} + +// Start starts the batch processor. +func (mb *microBatcher) Start(ctx context.Context) { + if mb.workerGroup != nil { + // If the worker group is already set, return. + return + } + + var wg sync.WaitGroup + + // Start the workers. + for i := 0; i < mb.workerCount; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := mb.worker(ctx, mb.batchChan); err != nil { + mb.logger.Errorf("worker: %v", err) + } + }() + } + + // Start the item accumulator. + wg.Add(1) + go func() { + defer wg.Done() + if err := mb.runItemBatcher(ctx); err != nil { + mb.logger.Errorf("run item batcher: %v", err) + } + + // Close the batch channel to signal the workers to stop. + close(mb.batchChan) + }() + + // Set the worker group to wait for the workers to finish. + mb.workerGroup = &wg +} + +// startItemBatcher starts the item accumulator to batch items. +func (mb *microBatcher) runItemBatcher(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + + case item, ok := <-mb.itemChan: + if !ok { + // If the item channel is closed, send the current batch and return. + mb.batchChan <- batchItem{ + data: mb.items, + retryCount: 0, + } + + // End the accumulator. + return nil + } + + // Add the item to the batch. + mb.items = append(mb.items, item) + + // If the batch is full, send it. + if len(mb.items) == mb.batchSize { + // Send the batch to the processor. + mb.batchChan <- batchItem{ + data: mb.items, + retryCount: 0, + } + + // Reset the batch. + mb.items = mb.items[len(mb.items):] + } + } + } +} + +// startWorkers starts the workers to process the batches. +func (mb *microBatcher) worker(ctx context.Context, batchQueue <-chan batchItem) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case batch, ok := <-batchQueue: + if !ok { + return nil + } + + // Send the batch to the processor. + if len(batch.data) > 0 && mb.flushFunc != nil { + if err := mb.flushFunc(ctx, batch.data); err != nil { + mb.logger.Errorf("flush data in background batch writer: %v", err) + } + } + } + } +} diff --git a/pkg/kubehound/storage/graphdb/microbatcher_test.go b/pkg/kubehound/storage/graphdb/microbatcher_test.go new file mode 100644 index 000000000..b455c1b18 --- /dev/null +++ b/pkg/kubehound/storage/graphdb/microbatcher_test.go @@ -0,0 +1,63 @@ +package graphdb + +import ( + "context" + "sync/atomic" + "testing" + + "github.com/DataDog/KubeHound/pkg/telemetry/log" + "github.com/stretchr/testify/assert" +) + +func microBatcherTestInstance(t *testing.T) (*microBatcher, *atomic.Int32) { + t.Helper() + + var ( + writerFuncCalledCount atomic.Int32 + ) + + underTest := newMicroBatcher(log.DefaultLogger(), 5, 1, + func(_ context.Context, _ []any) error { + writerFuncCalledCount.Add(1) + + return nil + }) + + return underTest, &writerFuncCalledCount +} + +func TestMicroBatcher_AfterBatchSize(t *testing.T) { + t.Parallel() + + underTest, writerFuncCalledCount := microBatcherTestInstance(t) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + underTest.Start(ctx) + + for i := 0; i < 10; i++ { + assert.NoError(t, underTest.Enqueue(ctx, i)) + } + + assert.NoError(t, underTest.Flush(ctx)) + + assert.Equal(t, int32(2), writerFuncCalledCount.Load()) +} + +func TestMicroBatcher_AfterFlush(t *testing.T) { + t.Parallel() + + underTest, writerFuncCalledCount := microBatcherTestInstance(t) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + underTest.Start(ctx) + + for i := 0; i < 11; i++ { + assert.NoError(t, underTest.Enqueue(ctx, i)) + } + + assert.NoError(t, underTest.Flush(ctx)) + + assert.Equal(t, int32(3), writerFuncCalledCount.Load()) +} diff --git a/pkg/kubehound/storage/graphdb/provider.go b/pkg/kubehound/storage/graphdb/provider.go index bd82a7e11..5b7feb553 100644 --- a/pkg/kubehound/storage/graphdb/provider.go +++ b/pkg/kubehound/storage/graphdb/provider.go @@ -2,6 +2,7 @@ package graphdb import ( "context" + "time" "github.com/DataDog/KubeHound/pkg/config" "github.com/DataDog/KubeHound/pkg/kubehound/graph/edge" @@ -11,8 +12,17 @@ import ( "github.com/DataDog/KubeHound/pkg/kubehound/storage/cache" ) +const ( + defaultWriterTimeout = 60 * time.Second + defaultMaxRetry = 3 + defaultWriterWorkerCount = 10 +) + type writerOptions struct { - Tags []string + Tags []string + WriterWorkerCount int + WriterTimeout time.Duration + MaxRetry int } type WriterOption func(*writerOptions) @@ -23,6 +33,27 @@ func WithTags(tags []string) WriterOption { } } +// WithWriterTimeout sets the timeout for the writer to complete the write operation. +func WithWriterTimeout(timeout time.Duration) WriterOption { + return func(wo *writerOptions) { + wo.WriterTimeout = timeout + } +} + +// WithWriterMaxRetry sets the maximum number of retries for failed writes. +func WithWriterMaxRetry(maxRetry int) WriterOption { + return func(wo *writerOptions) { + wo.MaxRetry = maxRetry + } +} + +// WithWriterWorkerCount sets the number of workers to process the batch. +func WithWriterWorkerCount(workerCount int) WriterOption { + return func(wo *writerOptions) { + wo.WriterWorkerCount = workerCount + } +} + // Provider defines the interface for implementations of the graphdb provider for storage of the calculated K8s attack graph. // //go:generate mockery --name Provider --output mocks --case underscore --filename graph_provider.go --with-expecter diff --git a/pkg/kubehound/storage/storedb/index_builder.go b/pkg/kubehound/storage/storedb/index_builder.go index d06e40eef..f7839f5df 100644 --- a/pkg/kubehound/storage/storedb/index_builder.go +++ b/pkg/kubehound/storage/storedb/index_builder.go @@ -123,6 +123,14 @@ func (ib *IndexBuilder) containers(ctx context.Context) error { }, Options: options.Index().SetName("byRun"), }, + { + Keys: bson.D{ + {Key: "k8.securitycontext.runasuser", Value: 1}, + {Key: "runtime.runID", Value: 1}, + {Key: "runtime.cluster", Value: 1}, + }, + Options: options.Index().SetName("byRunAsUser"), + }, } _, err := containers.Indexes().CreateMany(ctx, indices) diff --git a/pkg/telemetry/metric/metrics.go b/pkg/telemetry/metric/metrics.go index afb927704..1a7e97b5f 100644 --- a/pkg/telemetry/metric/metrics.go +++ b/pkg/telemetry/metric/metrics.go @@ -28,6 +28,7 @@ var ( QueueSize = "kubehound.storage.queue.size" BackgroundWriterCall = "kubehound.storage.writer.background" FlushWriterCall = "kubehound.storage.writer.flush" + RetryWriterCall = "kubehound.storage.writer.retry" ) // Cache metrics diff --git a/scripts/dashboard-demo/main.py b/scripts/dashboard-demo/main.py index aeb9535c1..6195c18ee 100644 --- a/scripts/dashboard-demo/main.py +++ b/scripts/dashboard-demo/main.py @@ -72,11 +72,11 @@ class EndpointKPI(KPI): KH_QUERY_EXTERNAL_COUNTS = "kh.endpoints().count()" KH_QUERY_DETAILS= 'kh.endpoints().criticalPaths().limit(local,1).dedup().valueMap("serviceEndpoint","port", "namespace")' KH_QUERY_EXTERNAL_CRITICAL_PATH = '''kh.V(). - hasLabel("Endpoint"). + has("class","Endpoint"). count(). aggregate("t"). V(). - hasLabel("Endpoint"). + has("class","Endpoint"). hasCriticalPath(). count(). as("e"). @@ -95,12 +95,12 @@ class IdentitiesKPI(KPI): KH_QUERY_EXTERNAL_COUNTS = "kh.identities().count()" KH_QUERY_DETAILS= 'kh.identities().criticalPaths().limit(local,1).dedup().valueMap("name","type","namespace")' KH_QUERY_EXTERNAL_CRITICAL_PATH = '''kh.V(). - hasLabel("Identity"). + has("class","Identity"). has("critical", false). count(). aggregate("t"). V(). - hasLabel("Identity"). + has("class","Identity"). has("critical", false). hasCriticalPath(). count(). @@ -121,11 +121,11 @@ class ContainersKPI(KPI): KH_QUERY_EXTERNAL_COUNTS = "kh.containers().count()" KH_QUERY_DETAILS= 'kh.containers().criticalPaths().limit(local,1).dedup().valueMap("name","image","app","namespace")' KH_QUERY_EXTERNAL_CRITICAL_PATH = '''kh.V(). - hasLabel("Container"). + has("class","Container"). count(). aggregate("t"). V(). - hasLabel("Container"). + has("class","Container"). hasCriticalPath(). count(). as("e"). @@ -146,11 +146,11 @@ class VolumesKPI(KPI): KH_QUERY_DETAILS= 'kh.volumes().criticalPaths().limit(local,1).dedup().valueMap("name","sourcePath", "namespace")' KH_QUERY_DETAILS_KEYS = ["name", "sourcePath"] KH_QUERY_EXTERNAL_CRITICAL_PATH = '''kh.V(). - hasLabel("Volume"). + has("class","Volume"). count(). aggregate("t"). V(). - hasLabel("Volume"). + has("class","Volume"). hasCriticalPath(). count(). as("e"). diff --git a/test/system/graph_edge_test.go b/test/system/graph_edge_test.go index 03286cd1b..57df54e21 100644 --- a/test/system/graph_edge_test.go +++ b/test/system/graph_edge_test.go @@ -142,7 +142,7 @@ func (suite *EdgeTestSuite) TestEdge_CE_UMH_CORE_PATTERN() { func (suite *EdgeTestSuite) TestEdge_CONTAINER_ATTACH() { // Every container should have a CONTAINER_ATTACH incoming from a pod rawCount, err := suite.g.V(). - HasLabel("Container"). + Has("class", "Container"). Count().Next() suite.NoError(err) @@ -151,9 +151,9 @@ func (suite *EdgeTestSuite) TestEdge_CONTAINER_ATTACH() { suite.NotEqual(containerCount, 0) rawCount, err = suite.g.V(). - HasLabel("Pod"). + Has("class", "Pod"). OutE().HasLabel("CONTAINER_ATTACH"). - InV().HasLabel("Container"). + InV().Has("class", "Container"). Dedup(). Path(). Count().Next() @@ -177,9 +177,9 @@ func (suite *EdgeTestSuite) TestEdge_IDENTITY_ASSUME_Container() { // tokenlist-sa 0 7h39m results, err := suite.g.V(). - HasLabel("Container"). + Has("class", "Container"). OutE().HasLabel("IDENTITY_ASSUME"). - InV().HasLabel("Identity"). + InV().Has("class", "Identity"). Path(). By(__.ValueMap("name")). ToList() @@ -212,9 +212,9 @@ func (suite *EdgeTestSuite) TestEdge_IDENTITY_ASSUME_Container() { func (suite *EdgeTestSuite) TestEdge_IDENTITY_ASSUME_Node() { results, err := suite.g.V(). - HasLabel("Node"). + Has("class", "Node"). OutE().HasLabel("IDENTITY_ASSUME"). - InV().HasLabel("Identity"). + InV().Has("class", "Identity"). Path(). By(__.ValueMap("name")). ToList() @@ -243,9 +243,9 @@ func (suite *EdgeTestSuite) TestEdge_POD_ATTACH() { suite.NotEqual(podCount, 0) rawCount, err = suite.g.V(). - HasLabel("Node"). + Has("class", "Node"). OutE().HasLabel("POD_ATTACH"). - InV().HasLabel("Pod"). + InV().Has("class", "Pod"). Dedup(). Path(). Count().Next() @@ -260,10 +260,10 @@ func (suite *EdgeTestSuite) TestEdge_POD_PATCH() { // We have one bespoke container running with pod/patch permissions which should reach all nodes // since they are not namespaced results, err := suite.g.V(). - HasLabel("PermissionSet"). + Has("class", "PermissionSet"). Has("namespace", "default"). OutE().HasLabel("POD_PATCH"). - InV().HasLabel("Pod"). + InV().Has("class", "Pod"). Path(). By(__.ValueMap("name")). ToList() @@ -309,10 +309,10 @@ func (suite *EdgeTestSuite) TestEdge_POD_CREATE() { // We have one bespoke container running with pod/create permissions which should reach all nodes // since they are not namespaced results, err := suite.g.V(). - HasLabel("PermissionSet"). + Has("class", "PermissionSet"). Has("namespace", "default"). OutE().HasLabel("POD_CREATE"). - InV().HasLabel("Node"). + InV().Has("class", "Node"). Path(). By(__.ValueMap("name")). ToList() @@ -332,10 +332,10 @@ func (suite *EdgeTestSuite) TestEdge_POD_CREATE() { func (suite *EdgeTestSuite) TestEdge_POD_EXEC() { // We have one bespoke container running with pod/exec permissions which should reach all pods in the namespace results, err := suite.g.V(). - HasLabel("PermissionSet"). + Has("class", "PermissionSet"). Has("namespace", "default"). OutE().HasLabel("POD_EXEC"). - InV().HasLabel("Pod"). + InV().Has("class", "Pod"). Path(). By(__.ValueMap("name")). ToList() @@ -391,10 +391,10 @@ func (suite *EdgeTestSuite) TestEdge_PERMISSION_DISCOVER() { // tokenget-sa 0 7h39m // tokenlist-sa 0 7h39m results, err := suite.g.V(). - HasLabel("Identity"). + Has("class", "Identity"). Has("namespace", "default"). OutE().HasLabel("PERMISSION_DISCOVER"). - InV().HasLabel("PermissionSet"). + InV().Has("class", "PermissionSet"). Path(). By(__.ValueMap("name")). ToList() @@ -429,7 +429,7 @@ func (suite *EdgeTestSuite) TestEdge_PERMISSION_DISCOVER() { func (suite *EdgeTestSuite) TestEdge_VOLUME_ACCESS() { // Every volume should have a VOLUME_ACCESS incoming from a node rawCount, err := suite.g.V(). - HasLabel("Volume"). + Has("class", "Volume"). Count().Next() suite.NoError(err) @@ -438,9 +438,9 @@ func (suite *EdgeTestSuite) TestEdge_VOLUME_ACCESS() { suite.NotEqual(volumeCount, 0) rawCount, err = suite.g.V(). - HasLabel("Node"). + Has("class", "Node"). OutE().HasLabel("VOLUME_ACCESS"). - InV().HasLabel("Volume"). + InV().Has("class", "Volume"). Dedup(). Path(). Count().Next() @@ -454,7 +454,7 @@ func (suite *EdgeTestSuite) TestEdge_VOLUME_ACCESS() { func (suite *EdgeTestSuite) TestEdge_VOLUME_DISCOVER() { // Every volume should have a VOLUME_DISCOVER incoming from a container rawCount, err := suite.g.V(). - HasLabel("Volume"). + Has("class", "Volume"). Count().Next() suite.NoError(err) @@ -463,9 +463,9 @@ func (suite *EdgeTestSuite) TestEdge_VOLUME_DISCOVER() { suite.NotEqual(volumeCount, 0) rawCount, err = suite.g.V(). - HasLabel("Container"). + Has("class", "Container"). OutE().HasLabel("VOLUME_DISCOVER"). - InV().HasLabel("Volume"). + InV().Has("class", "Volume"). Dedup(). Path(). Count().Next() @@ -478,10 +478,10 @@ func (suite *EdgeTestSuite) TestEdge_VOLUME_DISCOVER() { func (suite *EdgeTestSuite) TestEdge_TOKEN_BRUTEFORCE() { results, err := suite.g.V(). - HasLabel("PermissionSet"). + Has("class", "PermissionSet"). Has("namespace", "default"). OutE().HasLabel("TOKEN_BRUTEFORCE"). - InV().HasLabel("Identity"). + InV().Has("class", "Identity"). Path(). By(__.ValueMap("name")). ToList() @@ -514,10 +514,10 @@ func (suite *EdgeTestSuite) TestEdge_TOKEN_BRUTEFORCE() { func (suite *EdgeTestSuite) TestEdge_TOKEN_LIST() { results, err := suite.g.V(). - HasLabel("PermissionSet"). + Has("class", "PermissionSet"). Has("namespace", "default"). OutE().HasLabel("TOKEN_LIST"). - InV().HasLabel("Identity"). + InV().Has("class", "Identity"). Path(). By(__.ValueMap("name")). ToList() @@ -552,11 +552,11 @@ func (suite *EdgeTestSuite) TestEdge_TOKEN_STEAL() { // Every pod in our test cluster should have projected volume holding a token. BUT we only // save those with a non-default service account token as shown below. results, err := suite.g.V(). - HasLabel("Volume"). + Has("class", "Volume"). OutE(). HasLabel("TOKEN_STEAL"). InV(). - HasLabel("Identity"). + Has("class", "Identity"). Has("namespace", "default"). Values("name"). ToList() @@ -589,11 +589,11 @@ func (suite *EdgeTestSuite) TestEdge_TOKEN_STEAL() { func (suite *EdgeTestSuite) TestEdge_EXPLOIT_HOST_READ() { results, err := suite.g.V(). - HasLabel("Container"). + Has("class", "Container"). OutE().HasLabel("VOLUME_DISCOVER"). - InV().HasLabel("Volume"). + InV().Has("class", "Volume"). Where(__.OutE().HasLabel("EXPLOIT_HOST_READ"). - InV().HasLabel("Node")). + InV().Has("class", "Node")). Path(). By(__.ValueMap("name")). ToList() @@ -610,11 +610,11 @@ func (suite *EdgeTestSuite) TestEdge_EXPLOIT_HOST_READ() { func (suite *EdgeTestSuite) TestEdge_EXPLOIT_HOST_WRITE() { results, err := suite.g.V(). - HasLabel("Container"). + Has("class", "Container"). OutE().HasLabel("VOLUME_DISCOVER"). - InV().HasLabel("Volume"). + InV().Has("class", "Volume"). Where(__.OutE().HasLabel("EXPLOIT_HOST_WRITE"). - InV().HasLabel("Node")). + InV().Has("class", "Node")). Path(). By(__.ValueMap("name")). ToList() @@ -633,10 +633,10 @@ func (suite *EdgeTestSuite) TestEdge_EXPLOIT_HOST_TRAVERSE() { for _, c := range []string{"host-read-exploit-pod", "host-write-exploit-pod"} { // Find the containers on the same node as our vulnerable pod and map to their service accounts results, err := suite.g.V(). - HasLabel("Container"). + Has("class", "Container"). Has("name", c). Values("node").As("n"). - V().HasLabel("Container"). + V().Has("class", "Container"). Has("node", __.Where(P.Eq("n"))). OutE("IDENTITY_ASSUME"). InV(). @@ -649,14 +649,14 @@ func (suite *EdgeTestSuite) TestEdge_EXPLOIT_HOST_TRAVERSE() { // Now find the identities our vulnerable pod can reach via doing a traverse to the projected token volume results, err = suite.g.V(). - HasLabel("Container"). + Has("class", "Container"). Has("name", c). OutE().HasLabel("VOLUME_DISCOVER"). - InV().HasLabel("Volume"). + InV().Has("class", "Volume"). OutE().HasLabel("EXPLOIT_HOST_TRAVERSE"). - InV().HasLabel("Volume"). + InV().Has("class", "Volume"). OutE().HasLabel("TOKEN_STEAL"). - InV().HasLabel("Identity"). + InV().Has("class", "Identity"). Values("name"). ToList() @@ -671,12 +671,12 @@ func (suite *EdgeTestSuite) TestEdge_EXPLOIT_HOST_TRAVERSE() { func (suite *EdgeTestSuite) TestEdge_ENDPOINT_EXPLOIT_ContainerPort() { results, err := suite.g.V(). - HasLabel("Endpoint"). + Has("class", "Endpoint"). Where( __.Has("exposure", P.Eq(int(shared.EndpointExposureClusterIP))). OutE("ENDPOINT_EXPLOIT"). InV(). - HasLabel("Container")). + Has("class", "Container")). Values("serviceEndpoint"). ToList() @@ -693,12 +693,12 @@ func (suite *EdgeTestSuite) TestEdge_ENDPOINT_EXPLOIT_ContainerPort() { func (suite *EdgeTestSuite) TestEdge_ENDPOINT_EXPLOIT_NodePort() { results, err := suite.g.V(). - HasLabel("Endpoint"). + Has("class", "Endpoint"). Where( __.Has("exposure", P.Eq(int(shared.EndpointExposureNodeIP))). OutE("ENDPOINT_EXPLOIT"). InV(). - HasLabel("Container")). + Has("class", "Container")). Values("serviceEndpoint"). ToList() @@ -715,12 +715,12 @@ func (suite *EdgeTestSuite) TestEdge_ENDPOINT_EXPLOIT_NodePort() { func (suite *EdgeTestSuite) TestEdge_ENDPOINT_EXPLOIT_External() { results, err := suite.g.V(). - HasLabel("Endpoint"). + Has("class", "Endpoint"). Where( __.Has("exposure", P.Eq(int(shared.EndpointExposureExternal))). OutE("ENDPOINT_EXPLOIT"). InV(). - HasLabel("Container")). + Has("class", "Container")). Values("serviceEndpoint"). ToList() @@ -737,9 +737,9 @@ func (suite *EdgeTestSuite) TestEdge_ENDPOINT_EXPLOIT_External() { func (suite *EdgeTestSuite) TestEdge_SHARE_PS_NAMESPACE() { results, err := suite.g.V(). - HasLabel("Container"). + Has("class", "Container"). OutE().HasLabel("SHARE_PS_NAMESPACE"). - InV().HasLabel("Container"). + InV().Has("class", "Container"). Path(). By(__.ValueMap("name")). ToList() @@ -767,10 +767,10 @@ func (suite *EdgeTestSuite) TestEdge_SHARE_PS_NAMESPACE() { // Case 1 (cf docs) func (suite *EdgeTestSuite) TestEdge_ROLE_BIND_CASE_1() { results, err := suite.g.V(). - HasLabel("PermissionSet"). + Has("class", "PermissionSet"). Has("isNamespaced", false). OutE().HasLabel("ROLE_BIND"). - InV().HasLabel("PermissionSet"). + InV().Has("class", "PermissionSet"). Has("isNamespaced", false). // Scoping only to the roles related to the attacks to avoid dependency on the Kind Cluster default roles Has("name", gremlingo.TextP.StartingWith("rolebind")). @@ -798,10 +798,10 @@ func (suite *EdgeTestSuite) TestEdge_ROLE_BIND_CASE_1() { // Case 2 (cf docs) func (suite *EdgeTestSuite) TestEdge_ROLE_BIND_CASE_2() { results, err := suite.g.V(). - HasLabel("PermissionSet"). + Has("class", "PermissionSet"). Has("isNamespaced", false). OutE().HasLabel("ROLE_BIND"). - InV().HasLabel("PermissionSet"). + InV().Has("class", "PermissionSet"). Has("isNamespaced", false). // Scoping only to the roles related to the attacks to avoid dependency on the Kind Cluster default roles Has("name", gremlingo.TextP.StartingWith("rolebind")). @@ -829,10 +829,10 @@ func (suite *EdgeTestSuite) TestEdge_ROLE_BIND_CASE_2() { // Case 3 (cf docs) func (suite *EdgeTestSuite) TestEdge_ROLE_BIND_CASE_3() { results, err := suite.g.V(). - HasLabel("PermissionSet"). + Has("class", "PermissionSet"). Has("isNamespaced", true). OutE().HasLabel("ROLE_BIND"). - InV().HasLabel("PermissionSet"). + InV().Has("class", "PermissionSet"). Has("isNamespaced", true). // Scoping only to the roles related to the attacks to avoid dependency on the Kind Cluster default roles Has("name", gremlingo.TextP.StartingWith("rolebind-")). @@ -896,10 +896,10 @@ func (suite *EdgeTestSuite) TestEdge_ROLE_BIND_CASE_3() { // Case 4 (cf docs) func (suite *EdgeTestSuite) TestEdge_ROLE_BIND_CASE_4() { results, err := suite.g.V(). - HasLabel("PermissionSet"). + Has("class", "PermissionSet"). Has("isNamespaced", true). OutE().HasLabel("ROLE_BIND"). - InV().HasLabel("PermissionSet"). + InV().Has("class", "PermissionSet"). Has("isNamespaced", false). // Scoping only to the roles related to the attacks to avoid dependency on the Kind Cluster default roles Has("name", gremlingo.TextP.StartingWith("rolebind")). @@ -919,7 +919,7 @@ func (suite *EdgeTestSuite) TestEdge_ROLE_BIND_CASE_4() { func (suite *EdgeTestSuite) Test_NoEdgeCase() { // The control pod has no interesting properties and therefore should have NO outgoing edges results, err := suite.g.V(). - HasLabel("Container"). + Has("class", "Container"). Has("name", "control-pod"). Out(). ToList() diff --git a/test/system/graph_vertex_test.go b/test/system/graph_vertex_test.go index 1cc139565..20eaf2606 100644 --- a/test/system/graph_vertex_test.go +++ b/test/system/graph_vertex_test.go @@ -106,7 +106,7 @@ func (suite *VertexTestSuite) resultsToStringArray(results []*gremlingo.Result) } func (suite *VertexTestSuite) TestVertexContainer() { - results, err := suite.g.V().HasLabel(vertex.ContainerLabel).ElementMap().ToList() + results, err := suite.g.V().Has("class", vertex.ContainerLabel).ElementMap().ToList() suite.NoError(err) suite.Equal(len(expectedContainers), len(results)-numberOfKindDefaultContainer) @@ -183,7 +183,7 @@ func (suite *VertexTestSuite) TestVertexContainer() { } func (suite *VertexTestSuite) TestVertexNode() { - results, err := suite.g.V().HasLabel(vertex.NodeLabel).ElementMap().ToList() + results, err := suite.g.V().Has("class", vertex.NodeLabel).ElementMap().ToList() suite.NoError(err) suite.Equal(len(expectedNodes), len(results)) @@ -219,7 +219,7 @@ func (suite *VertexTestSuite) TestVertexNode() { } func (suite *VertexTestSuite) TestVertexPod() { - results, err := suite.g.V().HasLabel(vertex.PodLabel).ElementMap().ToList() + results, err := suite.g.V().Has("class", vertex.PodLabel).ElementMap().ToList() suite.NoError(err) suite.Equal(len(expectedPods), len(results)-numberOfKindDefaultPod) @@ -269,7 +269,7 @@ func (suite *VertexTestSuite) TestVertexPod() { func (suite *VertexTestSuite) TestVertexPermissionSet() { results, err := suite.g.V(). - HasLabel(vertex.PermissionSetLabel). + Has("class", vertex.PermissionSetLabel). Has("namespace", "default"). Values("name"). ToList() @@ -292,7 +292,7 @@ func (suite *VertexTestSuite) TestVertexPermissionSet() { func (suite *VertexTestSuite) TestVertexCritical() { results, err := suite.g.V(). - HasLabel(vertex.PermissionSetLabel). + Has("class", vertex.PermissionSetLabel). Has("critical", true). Values("role"). ToList() @@ -311,45 +311,45 @@ func (suite *VertexTestSuite) TestVertexCritical() { } func (suite *VertexTestSuite) TestVertexVolume() { - results, err := suite.g.V().HasLabel(vertex.VolumeLabel).ElementMap().ToList() + results, err := suite.g.V().Has("class", vertex.VolumeLabel).ElementMap().ToList() suite.NoError(err) suite.Equal(61, len(results)) - results, err = suite.g.V().HasLabel(vertex.VolumeLabel).Has("sourcePath", "/proc/sys/kernel").Has("name", "nodeproc").ElementMap().ToList() + results, err = suite.g.V().Has("class", vertex.VolumeLabel).Has("sourcePath", "/proc/sys/kernel").Has("name", "nodeproc").ElementMap().ToList() suite.NoError(err) suite.Equal(1, len(results)) - results, err = suite.g.V().HasLabel(vertex.VolumeLabel).Has("sourcePath", "/lib/modules").Has("name", "lib-modules").ElementMap().ToList() + results, err = suite.g.V().Has("class", vertex.VolumeLabel).Has("sourcePath", "/lib/modules").Has("name", "lib-modules").ElementMap().ToList() suite.NoError(err) suite.Greater(len(results), 1) // Not sure why it has "6" - results, err = suite.g.V().HasLabel(vertex.VolumeLabel).Has("sourcePath", "/var/log").Has("name", "nodelog").ElementMap().ToList() + results, err = suite.g.V().Has("class", vertex.VolumeLabel).Has("sourcePath", "/var/log").Has("name", "nodelog").ElementMap().ToList() suite.NoError(err) suite.Equal(len(results), 1) } func (suite *VertexTestSuite) TestVertexIdentity() { - results, err := suite.g.V().HasLabel(vertex.IdentityLabel).ElementMap().ToList() + results, err := suite.g.V().Has("class", vertex.IdentityLabel).ElementMap().ToList() suite.NoError(err) suite.Greater(len(results), 50) - results, err = suite.g.V().HasLabel(vertex.IdentityLabel).Has("name", "tokenget-sa").ElementMap().ToList() + results, err = suite.g.V().Has("class", vertex.IdentityLabel).Has("name", "tokenget-sa").ElementMap().ToList() suite.NoError(err) suite.Equal(len(results), 1) - results, err = suite.g.V().HasLabel(vertex.IdentityLabel).Has("name", "impersonate-sa").ElementMap().ToList() + results, err = suite.g.V().Has("class", vertex.IdentityLabel).Has("name", "impersonate-sa").ElementMap().ToList() suite.NoError(err) suite.Equal(len(results), 1) - results, err = suite.g.V().HasLabel(vertex.IdentityLabel).Has("name", "tokenlist-sa").ElementMap().ToList() + results, err = suite.g.V().Has("class", vertex.IdentityLabel).Has("name", "tokenlist-sa").ElementMap().ToList() suite.NoError(err) suite.Equal(len(results), 1) - results, err = suite.g.V().HasLabel(vertex.IdentityLabel).Has("name", "pod-patch-sa").ElementMap().ToList() + results, err = suite.g.V().Has("class", vertex.IdentityLabel).Has("name", "pod-patch-sa").ElementMap().ToList() suite.NoError(err) suite.Equal(len(results), 1) - results, err = suite.g.V().HasLabel(vertex.IdentityLabel).Has("name", "pod-create-sa").ElementMap().ToList() + results, err = suite.g.V().Has("class", vertex.IdentityLabel).Has("name", "pod-create-sa").ElementMap().ToList() suite.NoError(err) suite.Equal(len(results), 1) }