From 95753d3c2f21fa0a32aa63d1ab68666eba5fa5e0 Mon Sep 17 00:00:00 2001 From: Thibault Normand Date: Thu, 28 Nov 2024 14:55:28 +0100 Subject: [PATCH 01/18] feat(graphdb): add deadline/retry/split behavioural patterns to batch writer. --- .../graphdb/janusgraph_vertex_writer.go | 150 +++++++++++++++--- pkg/kubehound/storage/graphdb/provider.go | 19 ++- pkg/telemetry/metric/metrics.go | 1 + 3 files changed, 146 insertions(+), 24 deletions(-) diff --git a/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go b/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go index 43396746..85c2e72e 100644 --- a/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go +++ b/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go @@ -6,6 +6,7 @@ import ( "fmt" "sync" "sync/atomic" + "time" "github.com/DataDog/KubeHound/pkg/kubehound/graph/types" "github.com/DataDog/KubeHound/pkg/kubehound/graph/vertex" @@ -29,20 +30,51 @@ type JanusGraphVertexWriter struct { traversalSource *gremlin.GraphTraversalSource // Transacted graph traversal source inserts []any // Object data to be inserted in the graph mu sync.Mutex // Mutex protecting access to the inserts array - consumerChan chan []any // Channel consuming inserts for async writing + consumerChan chan batchItem // Channel consuming inserts for async writing writingInFlight *sync.WaitGroup // Wait group tracking current unfinished writes batchSize int // Batchsize of graph DB inserts qcounter int32 // Track items queued wcounter int32 // Track items writtn tags []string // Telemetry tags cache cache.AsyncWriter // Cache writer to cache store id -> vertex id mappings + writerTimeout time.Duration // Timeout for the writer + maxRetry int // Maximum number of retries for failed writes +} + +// batchItem is a single item in the batch writer queue that contains the data +// to be written and the number of retries. +type batchItem struct { + data []any + retryCount int +} + +// errBatchWriter is an error type that wraps an error and indicates whether the +// error is retryable. +type errBatchWriter struct { + err error + retryable bool +} + +func (e errBatchWriter) Error() string { + if e.err == nil { + return fmt.Sprintf("batch writer error (retriable:%v)", e.retryable) + } + + return fmt.Sprintf("batch writer error (retriable:%v): %v", e.retryable, e.err.Error()) +} + +func (e errBatchWriter) Unwrap() error { + return e.err } // NewJanusGraphAsyncVertexWriter creates a new bulk vertex writer instance. func NewJanusGraphAsyncVertexWriter(ctx context.Context, drc *gremlin.DriverRemoteConnection, - v vertex.Builder, c cache.CacheProvider, opts ...WriterOption) (*JanusGraphVertexWriter, error) { - - options := &writerOptions{} + v vertex.Builder, c cache.CacheProvider, opts ...WriterOption, +) (*JanusGraphVertexWriter, error) { + options := &writerOptions{ + WriterTimeout: 60 * time.Second, + MaxRetry: 3, + } for _, opt := range opts { opt(options) } @@ -60,9 +92,11 @@ func NewJanusGraphAsyncVertexWriter(ctx context.Context, drc *gremlin.DriverRemo traversalSource: gremlin.Traversal_().WithRemote(drc), batchSize: v.BatchSize(), writingInFlight: &sync.WaitGroup{}, - consumerChan: make(chan []any, v.BatchSize()*channelSizeBatchFactor), + consumerChan: make(chan batchItem, v.BatchSize()*channelSizeBatchFactor), tags: append(options.Tags, tag.Label(v.Label()), tag.Builder(v.Label())), cache: cw, + writerTimeout: options.WriterTimeout, + maxRetry: options.MaxRetry, } jw.startBackgroundWriter(ctx) @@ -75,16 +109,52 @@ func (jgv *JanusGraphVertexWriter) startBackgroundWriter(ctx context.Context) { go func() { for { select { - case data := <-jgv.consumerChan: - // closing the channel shoud stop the go routine - if data == nil { + case batch, ok := <-jgv.consumerChan: + // If the channel is closed, return. + if !ok { + log.Trace(ctx).Info("Closed background janusgraph worker on channel close") + return + } + + // If the batch is empty, return. + if len(batch.data) == 0 { + log.Trace(ctx).Warn("Empty batch received in background janusgraph worker, skipping") return } _ = statsd.Count(ctx, metric.BackgroundWriterCall, 1, jgv.tags, 1) - err := jgv.batchWrite(ctx, data) + err := jgv.batchWrite(ctx, batch.data) if err != nil { - log.Trace(ctx).Errorf("Write data in background batch writer: %v", err) + var e *errBatchWriter + if errors.As(err, &e) && e.retryable { + // If the context deadline is exceeded, retry the write operation with a smaller batch. + if batch.retryCount < jgv.maxRetry { + // Compute the new batch size. + newBatchSize := len(batch.data) / 2 + batch.retryCount++ + + log.Trace(ctx).Warnf("Retrying write operation with smaller batch (n:%d -> %d, r:%d): %v", len(batch.data), newBatchSize, batch.retryCount, e.Unwrap()) + + // Split the batch into smaller chunks and requeue them. + if len(batch.data[:newBatchSize]) > 0 { + jgv.consumerChan <- batchItem{ + data: batch.data[:newBatchSize], + retryCount: batch.retryCount, + } + } + if len(batch.data[newBatchSize:]) > 0 { + jgv.consumerChan <- batchItem{ + data: batch.data[newBatchSize:], + retryCount: batch.retryCount, + } + } + continue + } + + log.Trace(ctx).Errorf("Retry limit exceeded for write operation: %v", err) + } + + log.Trace(ctx).Errorf("Write data in background batch writer, data will be lost: %v", err) } _ = statsd.Decr(ctx, metric.QueueSize, jgv.tags, 1) @@ -134,19 +204,50 @@ func (jgv *JanusGraphVertexWriter) batchWrite(ctx context.Context, data []any) e log.Trace(ctx).Debugf("Batch write JanusGraphVertexWriter with %d elements", datalen) atomic.AddInt32(&jgv.wcounter, int32(datalen)) //nolint:gosec // disable G115 - op := jgv.gremlin(jgv.traversalSource, data) - raw, err := op.Project("id", "storeID"). - By(gremlin.T.Id). - By("storeID"). - ToList() - if err != nil { - return fmt.Errorf("%s vertex insert: %w", jgv.builder, err) - } + // Create a channel to signal the completion of the write operation. + errChan := make(chan error, 1) + + // We need to ensure that the write operation is completed within a certain + // time frame to avoid blocking the writer indefinitely if the backend + // is unresponsive. + go func() { + // Create a new gremlin operation to insert the data into the graph. + op := jgv.gremlin(jgv.traversalSource, data) + raw, err := op.Project("id", "storeID"). + By(gremlin.T.Id). + By("storeID"). + ToList() + if err != nil { + errChan <- fmt.Errorf("%s vertex insert: %w", jgv.builder, err) + return + } + + // Gremlin will return a list of maps containing and vertex id and store + // id values for each vertex inserted. + // We need to parse each map entry and add to our cache. + if err = jgv.cacheIds(ctx, raw); err != nil { + errChan <- fmt.Errorf("cache ids: %w", err) + return + } + + errChan <- nil + }() - // Gremlin will return a list of maps containing and vertex id and store id values for each vertex inserted. - // We need to parse each map entry and add to our cache. - if err = jgv.cacheIds(ctx, raw); err != nil { - return err + // Wait for the write operation to complete or timeout. + select { + case <-ctx.Done(): + // If the context is cancelled, return the error. + return ctx.Err() + case <-time.After(jgv.writerTimeout): + // If the write operation takes too long, return an error. + return &errBatchWriter{ + err: errors.New("write operation timed out"), + retryable: true, + } + case err = <-errChan: + if err != nil { + return fmt.Errorf("janusgraph batch write: %w", err) + } } return nil @@ -214,7 +315,10 @@ func (jgv *JanusGraphVertexWriter) Queue(ctx context.Context, v any) error { copy(copied, jgv.inserts) jgv.writingInFlight.Add(1) - jgv.consumerChan <- copied + jgv.consumerChan <- batchItem{ + data: copied, + retryCount: 0, + } _ = statsd.Incr(ctx, metric.QueueSize, jgv.tags, 1) // cleanup the ops array after we have copied it to the channel diff --git a/pkg/kubehound/storage/graphdb/provider.go b/pkg/kubehound/storage/graphdb/provider.go index bd82a7e1..03c80536 100644 --- a/pkg/kubehound/storage/graphdb/provider.go +++ b/pkg/kubehound/storage/graphdb/provider.go @@ -2,6 +2,7 @@ package graphdb import ( "context" + "time" "github.com/DataDog/KubeHound/pkg/config" "github.com/DataDog/KubeHound/pkg/kubehound/graph/edge" @@ -12,7 +13,9 @@ import ( ) type writerOptions struct { - Tags []string + Tags []string + WriterTimeout time.Duration + MaxRetry int } type WriterOption func(*writerOptions) @@ -23,6 +26,20 @@ func WithTags(tags []string) WriterOption { } } +// WithWriterTimeout sets the timeout for the writer to complete the write operation. +func WithWriterTimeout(timeout time.Duration) WriterOption { + return func(wo *writerOptions) { + wo.WriterTimeout = timeout + } +} + +// WithWriterMaxRetry sets the maximum number of retries for failed writes. +func WithWriterMaxRetry(maxRetry int) WriterOption { + return func(wo *writerOptions) { + wo.MaxRetry = maxRetry + } +} + // Provider defines the interface for implementations of the graphdb provider for storage of the calculated K8s attack graph. // //go:generate mockery --name Provider --output mocks --case underscore --filename graph_provider.go --with-expecter diff --git a/pkg/telemetry/metric/metrics.go b/pkg/telemetry/metric/metrics.go index afb92770..1a7e97b5 100644 --- a/pkg/telemetry/metric/metrics.go +++ b/pkg/telemetry/metric/metrics.go @@ -28,6 +28,7 @@ var ( QueueSize = "kubehound.storage.queue.size" BackgroundWriterCall = "kubehound.storage.writer.background" FlushWriterCall = "kubehound.storage.writer.flush" + RetryWriterCall = "kubehound.storage.writer.retry" ) // Cache metrics From 3cfe99b4e0ed80b8bcc32967cf0dd68b1e818503 Mon Sep 17 00:00:00 2001 From: Thibault Normand Date: Thu, 28 Nov 2024 15:09:52 +0100 Subject: [PATCH 02/18] feat(graphdb): replicate resiliency patterns on edge writer. --- pkg/kubehound/storage/graphdb/errors.go | 22 +++++ .../storage/graphdb/janusgraph_edge_writer.go | 86 ++++++++++++++++--- .../graphdb/janusgraph_vertex_writer.go | 31 ++----- pkg/kubehound/storage/graphdb/provider.go | 5 ++ 4 files changed, 106 insertions(+), 38 deletions(-) create mode 100644 pkg/kubehound/storage/graphdb/errors.go diff --git a/pkg/kubehound/storage/graphdb/errors.go b/pkg/kubehound/storage/graphdb/errors.go new file mode 100644 index 00000000..6ccb2186 --- /dev/null +++ b/pkg/kubehound/storage/graphdb/errors.go @@ -0,0 +1,22 @@ +package graphdb + +import "fmt" + +// errBatchWriter is an error type that wraps an error and indicates whether the +// error is retryable. +type errBatchWriter struct { + err error + retryable bool +} + +func (e errBatchWriter) Error() string { + if e.err == nil { + return fmt.Sprintf("batch writer error (retriable:%v)", e.retryable) + } + + return fmt.Sprintf("batch writer error (retriable:%v): %v", e.retryable, e.err.Error()) +} + +func (e errBatchWriter) Unwrap() error { + return e.err +} diff --git a/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go b/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go index e8c2619a..9f66e753 100644 --- a/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go +++ b/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go @@ -6,6 +6,7 @@ import ( "fmt" "sync" "sync/atomic" + "time" "github.com/DataDog/KubeHound/pkg/kubehound/graph/edge" "github.com/DataDog/KubeHound/pkg/kubehound/graph/types" @@ -31,19 +32,24 @@ type JanusGraphEdgeWriter struct { traversalSource *gremlingo.GraphTraversalSource // Transacted graph traversal source inserts []any // Object data to be inserted in the graph mu sync.Mutex // Mutex protecting access to the inserts array - consumerChan chan []any // Channel consuming inserts for async writing + consumerChan chan batchItem // Channel consuming inserts for async writing writingInFlight *sync.WaitGroup // Wait group tracking current unfinished writes batchSize int // Batchsize of graph DB inserts qcounter int32 // Track items queued wcounter int32 // Track items writtn tags []string // Telemetry tags + writerTimeout time.Duration // Timeout for the writer + maxRetry int // Maximum number of retries for failed writes } // NewJanusGraphAsyncEdgeWriter creates a new bulk edge writer instance. func NewJanusGraphAsyncEdgeWriter(ctx context.Context, drc *gremlingo.DriverRemoteConnection, - e edge.Builder, opts ...WriterOption) (*JanusGraphEdgeWriter, error) { - - options := &writerOptions{} + e edge.Builder, opts ...WriterOption, +) (*JanusGraphEdgeWriter, error) { + options := &writerOptions{ + WriterTimeout: defaultWriterTimeout, + MaxRetry: defaultMaxRetry, + } for _, opt := range opts { opt(options) } @@ -57,8 +63,10 @@ func NewJanusGraphAsyncEdgeWriter(ctx context.Context, drc *gremlingo.DriverRemo traversalSource: gremlingo.Traversal_().WithRemote(drc), batchSize: e.BatchSize(), writingInFlight: &sync.WaitGroup{}, - consumerChan: make(chan []any, e.BatchSize()*channelSizeBatchFactor), + consumerChan: make(chan batchItem, e.BatchSize()*channelSizeBatchFactor), tags: append(options.Tags, tag.Label(e.Label()), tag.Builder(builder)), + writerTimeout: options.WriterTimeout, + maxRetry: options.MaxRetry, } jw.startBackgroundWriter(ctx) @@ -71,15 +79,51 @@ func (jgv *JanusGraphEdgeWriter) startBackgroundWriter(ctx context.Context) { go func() { for { select { - case data := <-jgv.consumerChan: - // closing the channel shoud stop the go routine - if data == nil { + case batch, ok := <-jgv.consumerChan: + // If the channel is closed, return. + if !ok { + log.Trace(ctx).Info("Closed background janusgraph worker on channel close") + return + } + + // If the batch is empty, return. + if len(batch.data) == 0 { + log.Trace(ctx).Warn("Empty edge batch received in background janusgraph worker, skipping") return } _ = statsd.Count(ctx, metric.BackgroundWriterCall, 1, jgv.tags, 1) - err := jgv.batchWrite(ctx, data) + err := jgv.batchWrite(ctx, batch.data) if err != nil { + var e *errBatchWriter + if errors.As(err, &e) && e.retryable { + // If the error is retryable, retry the write operation with a smaller batch. + if batch.retryCount < jgv.maxRetry { + // Compute the new batch size. + newBatchSize := len(batch.data) / 2 + batch.retryCount++ + + log.Trace(ctx).Warnf("Retrying write operation with smaller edge batch (n:%d -> %d, r:%d): %v", len(batch.data), newBatchSize, batch.retryCount, e.Unwrap()) + + // Split the batch into smaller chunks and requeue them. + if len(batch.data[:newBatchSize]) > 0 { + jgv.consumerChan <- batchItem{ + data: batch.data[:newBatchSize], + retryCount: batch.retryCount, + } + } + if len(batch.data[newBatchSize:]) > 0 { + jgv.consumerChan <- batchItem{ + data: batch.data[newBatchSize:], + retryCount: batch.retryCount, + } + } + continue + } + + log.Trace(ctx).Errorf("Retry limit exceeded for write operation: %v", err) + } + log.Trace(ctx).Errorf("write data in background batch writer: %v", err) } @@ -109,9 +153,22 @@ func (jgv *JanusGraphEdgeWriter) batchWrite(ctx context.Context, data []any) err op := jgv.gremlin(jgv.traversalSource, data) promise := op.Iterate() - err = <-promise - if err != nil { - return fmt.Errorf("%s edge insert: %w", jgv.builder, err) + + // Wait for the write operation to complete or timeout. + select { + case <-ctx.Done(): + // If the context is cancelled, return the error. + return ctx.Err() + case <-time.After(jgv.writerTimeout): + // If the write operation takes too long, return an error. + return &errBatchWriter{ + err: errors.New("edge write operation timed out"), + retryable: true, + } + case err := <-promise: + if err != nil { + return fmt.Errorf("%s edge insert: %w", jgv.builder, err) + } } return nil @@ -174,7 +231,10 @@ func (jgv *JanusGraphEdgeWriter) Queue(ctx context.Context, v any) error { copy(copied, jgv.inserts) jgv.writingInFlight.Add(1) - jgv.consumerChan <- copied + jgv.consumerChan <- batchItem{ + data: copied, + retryCount: 0, + } _ = statsd.Incr(ctx, metric.QueueSize, jgv.tags, 1) // cleanup the ops array after we have copied it to the channel diff --git a/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go b/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go index 85c2e72e..9c901cda 100644 --- a/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go +++ b/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go @@ -48,32 +48,13 @@ type batchItem struct { retryCount int } -// errBatchWriter is an error type that wraps an error and indicates whether the -// error is retryable. -type errBatchWriter struct { - err error - retryable bool -} - -func (e errBatchWriter) Error() string { - if e.err == nil { - return fmt.Sprintf("batch writer error (retriable:%v)", e.retryable) - } - - return fmt.Sprintf("batch writer error (retriable:%v): %v", e.retryable, e.err.Error()) -} - -func (e errBatchWriter) Unwrap() error { - return e.err -} - // NewJanusGraphAsyncVertexWriter creates a new bulk vertex writer instance. func NewJanusGraphAsyncVertexWriter(ctx context.Context, drc *gremlin.DriverRemoteConnection, v vertex.Builder, c cache.CacheProvider, opts ...WriterOption, ) (*JanusGraphVertexWriter, error) { options := &writerOptions{ - WriterTimeout: 60 * time.Second, - MaxRetry: 3, + WriterTimeout: defaultWriterTimeout, + MaxRetry: defaultMaxRetry, } for _, opt := range opts { opt(options) @@ -118,7 +99,7 @@ func (jgv *JanusGraphVertexWriter) startBackgroundWriter(ctx context.Context) { // If the batch is empty, return. if len(batch.data) == 0 { - log.Trace(ctx).Warn("Empty batch received in background janusgraph worker, skipping") + log.Trace(ctx).Warn("Empty vertex batch received in background janusgraph worker, skipping") return } @@ -127,13 +108,13 @@ func (jgv *JanusGraphVertexWriter) startBackgroundWriter(ctx context.Context) { if err != nil { var e *errBatchWriter if errors.As(err, &e) && e.retryable { - // If the context deadline is exceeded, retry the write operation with a smaller batch. + // If the error is retryable, retry the write operation with a smaller batch. if batch.retryCount < jgv.maxRetry { // Compute the new batch size. newBatchSize := len(batch.data) / 2 batch.retryCount++ - log.Trace(ctx).Warnf("Retrying write operation with smaller batch (n:%d -> %d, r:%d): %v", len(batch.data), newBatchSize, batch.retryCount, e.Unwrap()) + log.Trace(ctx).Warnf("Retrying write operation with vertex smaller batch (n:%d -> %d, r:%d): %v", len(batch.data), newBatchSize, batch.retryCount, e.Unwrap()) // Split the batch into smaller chunks and requeue them. if len(batch.data[:newBatchSize]) > 0 { @@ -241,7 +222,7 @@ func (jgv *JanusGraphVertexWriter) batchWrite(ctx context.Context, data []any) e case <-time.After(jgv.writerTimeout): // If the write operation takes too long, return an error. return &errBatchWriter{ - err: errors.New("write operation timed out"), + err: errors.New("vertex write operation timed out"), retryable: true, } case err = <-errChan: diff --git a/pkg/kubehound/storage/graphdb/provider.go b/pkg/kubehound/storage/graphdb/provider.go index 03c80536..120800bf 100644 --- a/pkg/kubehound/storage/graphdb/provider.go +++ b/pkg/kubehound/storage/graphdb/provider.go @@ -12,6 +12,11 @@ import ( "github.com/DataDog/KubeHound/pkg/kubehound/storage/cache" ) +const ( + defaultWriterTimeout = 60 * time.Second + defaultMaxRetry = 3 +) + type writerOptions struct { Tags []string WriterTimeout time.Duration From 1240ff080c39f2b102271de8ab37864eb09f6888 Mon Sep 17 00:00:00 2001 From: Thibault Normand Date: Thu, 28 Nov 2024 15:18:15 +0100 Subject: [PATCH 03/18] feat(graphdb): register retry metric counter. --- pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go | 2 ++ pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go | 2 ++ 2 files changed, 4 insertions(+) diff --git a/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go b/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go index 9f66e753..cf0defdb 100644 --- a/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go +++ b/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go @@ -99,6 +99,8 @@ func (jgv *JanusGraphEdgeWriter) startBackgroundWriter(ctx context.Context) { if errors.As(err, &e) && e.retryable { // If the error is retryable, retry the write operation with a smaller batch. if batch.retryCount < jgv.maxRetry { + _ = statsd.Count(ctx, metric.RetryWriterCall, 1, jgv.tags, 1) + // Compute the new batch size. newBatchSize := len(batch.data) / 2 batch.retryCount++ diff --git a/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go b/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go index 9c901cda..9c938476 100644 --- a/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go +++ b/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go @@ -110,6 +110,8 @@ func (jgv *JanusGraphVertexWriter) startBackgroundWriter(ctx context.Context) { if errors.As(err, &e) && e.retryable { // If the error is retryable, retry the write operation with a smaller batch. if batch.retryCount < jgv.maxRetry { + _ = statsd.Count(ctx, metric.RetryWriterCall, 1, jgv.tags, 1) + // Compute the new batch size. newBatchSize := len(batch.data) / 2 batch.retryCount++ From 3b859a41477212fa3977afe66d3c6b515154baf2 Mon Sep 17 00:00:00 2001 From: Thibault Normand Date: Thu, 28 Nov 2024 15:32:44 +0100 Subject: [PATCH 04/18] chore(lint): reduce complexity. --- .../storage/graphdb/janusgraph_edge_writer.go | 51 ++++++++++--------- .../graphdb/janusgraph_vertex_writer.go | 51 ++++++++++--------- 2 files changed, 56 insertions(+), 46 deletions(-) diff --git a/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go b/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go index cf0defdb..307ca5d9 100644 --- a/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go +++ b/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go @@ -96,30 +96,10 @@ func (jgv *JanusGraphEdgeWriter) startBackgroundWriter(ctx context.Context) { err := jgv.batchWrite(ctx, batch.data) if err != nil { var e *errBatchWriter - if errors.As(err, &e) && e.retryable { + if errors.As(err, &e) { // If the error is retryable, retry the write operation with a smaller batch. - if batch.retryCount < jgv.maxRetry { - _ = statsd.Count(ctx, metric.RetryWriterCall, 1, jgv.tags, 1) - - // Compute the new batch size. - newBatchSize := len(batch.data) / 2 - batch.retryCount++ - - log.Trace(ctx).Warnf("Retrying write operation with smaller edge batch (n:%d -> %d, r:%d): %v", len(batch.data), newBatchSize, batch.retryCount, e.Unwrap()) - - // Split the batch into smaller chunks and requeue them. - if len(batch.data[:newBatchSize]) > 0 { - jgv.consumerChan <- batchItem{ - data: batch.data[:newBatchSize], - retryCount: batch.retryCount, - } - } - if len(batch.data[newBatchSize:]) > 0 { - jgv.consumerChan <- batchItem{ - data: batch.data[newBatchSize:], - retryCount: batch.retryCount, - } - } + if e.retryable && batch.retryCount < jgv.maxRetry { + jgv.retrySplitAndRequeue(ctx, &batch, e) continue } @@ -139,6 +119,31 @@ func (jgv *JanusGraphEdgeWriter) startBackgroundWriter(ctx context.Context) { }() } +// retrySplitAndRequeue will split the batch into smaller chunks and requeue them for writing. +func (jgv *JanusGraphEdgeWriter) retrySplitAndRequeue(ctx context.Context, batch *batchItem, e *errBatchWriter) { + _ = statsd.Count(ctx, metric.RetryWriterCall, 1, jgv.tags, 1) + + // Compute the new batch size. + newBatchSize := len(batch.data) / 2 + batch.retryCount++ + + log.Trace(ctx).Warnf("Retrying write operation with smaller edge batch (n:%d -> %d, r:%d): %v", len(batch.data), newBatchSize, batch.retryCount, e.Unwrap()) + + // Split the batch into smaller chunks and requeue them. + if len(batch.data[:newBatchSize]) > 0 { + jgv.consumerChan <- batchItem{ + data: batch.data[:newBatchSize], + retryCount: batch.retryCount, + } + } + if len(batch.data[newBatchSize:]) > 0 { + jgv.consumerChan <- batchItem{ + data: batch.data[newBatchSize:], + retryCount: batch.retryCount, + } + } +} + // batchWrite will write a batch of entries into the graph DB and block until the write completes. // Callers are responsible for doing an Add(1) to the writingInFlight wait group to ensure proper synchronization. func (jgv *JanusGraphEdgeWriter) batchWrite(ctx context.Context, data []any) error { diff --git a/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go b/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go index 9c938476..a55efb81 100644 --- a/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go +++ b/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go @@ -107,30 +107,10 @@ func (jgv *JanusGraphVertexWriter) startBackgroundWriter(ctx context.Context) { err := jgv.batchWrite(ctx, batch.data) if err != nil { var e *errBatchWriter - if errors.As(err, &e) && e.retryable { + if errors.As(err, &e) { // If the error is retryable, retry the write operation with a smaller batch. - if batch.retryCount < jgv.maxRetry { - _ = statsd.Count(ctx, metric.RetryWriterCall, 1, jgv.tags, 1) - - // Compute the new batch size. - newBatchSize := len(batch.data) / 2 - batch.retryCount++ - - log.Trace(ctx).Warnf("Retrying write operation with vertex smaller batch (n:%d -> %d, r:%d): %v", len(batch.data), newBatchSize, batch.retryCount, e.Unwrap()) - - // Split the batch into smaller chunks and requeue them. - if len(batch.data[:newBatchSize]) > 0 { - jgv.consumerChan <- batchItem{ - data: batch.data[:newBatchSize], - retryCount: batch.retryCount, - } - } - if len(batch.data[newBatchSize:]) > 0 { - jgv.consumerChan <- batchItem{ - data: batch.data[newBatchSize:], - retryCount: batch.retryCount, - } - } + if e.retryable && batch.retryCount < jgv.maxRetry { + jgv.retrySplitAndRequeue(ctx, &batch, e) continue } @@ -150,6 +130,31 @@ func (jgv *JanusGraphVertexWriter) startBackgroundWriter(ctx context.Context) { }() } +// retrySplitAndRequeue will split the batch into smaller chunks and requeue them for writing. +func (jgv *JanusGraphVertexWriter) retrySplitAndRequeue(ctx context.Context, batch *batchItem, e *errBatchWriter) { + _ = statsd.Count(ctx, metric.RetryWriterCall, 1, jgv.tags, 1) + + // Compute the new batch size. + newBatchSize := len(batch.data) / 2 + batch.retryCount++ + + log.Trace(ctx).Warnf("Retrying write operation with smaller vertex batch (n:%d -> %d, r:%d): %v", len(batch.data), newBatchSize, batch.retryCount, e.Unwrap()) + + // Split the batch into smaller chunks and requeue them. + if len(batch.data[:newBatchSize]) > 0 { + jgv.consumerChan <- batchItem{ + data: batch.data[:newBatchSize], + retryCount: batch.retryCount, + } + } + if len(batch.data[newBatchSize:]) > 0 { + jgv.consumerChan <- batchItem{ + data: batch.data[newBatchSize:], + retryCount: batch.retryCount, + } + } +} + func (jgv *JanusGraphVertexWriter) cacheIds(ctx context.Context, idMap []*gremlin.Result) error { for _, r := range idMap { idMap, ok := r.GetInterface().(map[interface{}]interface{}) From 99e875f406e96362ffa15e721f4c4b84c9f19407 Mon Sep 17 00:00:00 2001 From: Thibault Normand Date: Thu, 28 Nov 2024 15:42:28 +0100 Subject: [PATCH 05/18] chore(ci): fix linter findings. --- pkg/kubehound/storage/graphdb/errors.go | 8 ++++---- .../storage/graphdb/janusgraph_edge_writer.go | 9 ++++++--- .../storage/graphdb/janusgraph_vertex_writer.go | 11 ++++++++--- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/pkg/kubehound/storage/graphdb/errors.go b/pkg/kubehound/storage/graphdb/errors.go index 6ccb2186..fe6772ee 100644 --- a/pkg/kubehound/storage/graphdb/errors.go +++ b/pkg/kubehound/storage/graphdb/errors.go @@ -2,14 +2,14 @@ package graphdb import "fmt" -// errBatchWriter is an error type that wraps an error and indicates whether the +// batchWriterError is an error type that wraps an error and indicates whether the // error is retryable. -type errBatchWriter struct { +type batchWriterError struct { err error retryable bool } -func (e errBatchWriter) Error() string { +func (e batchWriterError) Error() string { if e.err == nil { return fmt.Sprintf("batch writer error (retriable:%v)", e.retryable) } @@ -17,6 +17,6 @@ func (e errBatchWriter) Error() string { return fmt.Sprintf("batch writer error (retriable:%v): %v", e.retryable, e.err.Error()) } -func (e errBatchWriter) Unwrap() error { +func (e batchWriterError) Unwrap() error { return e.err } diff --git a/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go b/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go index 307ca5d9..994d09b1 100644 --- a/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go +++ b/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go @@ -83,23 +83,26 @@ func (jgv *JanusGraphEdgeWriter) startBackgroundWriter(ctx context.Context) { // If the channel is closed, return. if !ok { log.Trace(ctx).Info("Closed background janusgraph worker on channel close") + return } // If the batch is empty, return. if len(batch.data) == 0 { log.Trace(ctx).Warn("Empty edge batch received in background janusgraph worker, skipping") + return } _ = statsd.Count(ctx, metric.BackgroundWriterCall, 1, jgv.tags, 1) err := jgv.batchWrite(ctx, batch.data) if err != nil { - var e *errBatchWriter + var e *batchWriterError if errors.As(err, &e) { // If the error is retryable, retry the write operation with a smaller batch. if e.retryable && batch.retryCount < jgv.maxRetry { jgv.retrySplitAndRequeue(ctx, &batch, e) + continue } @@ -120,7 +123,7 @@ func (jgv *JanusGraphEdgeWriter) startBackgroundWriter(ctx context.Context) { } // retrySplitAndRequeue will split the batch into smaller chunks and requeue them for writing. -func (jgv *JanusGraphEdgeWriter) retrySplitAndRequeue(ctx context.Context, batch *batchItem, e *errBatchWriter) { +func (jgv *JanusGraphEdgeWriter) retrySplitAndRequeue(ctx context.Context, batch *batchItem, e *batchWriterError) { _ = statsd.Count(ctx, metric.RetryWriterCall, 1, jgv.tags, 1) // Compute the new batch size. @@ -168,7 +171,7 @@ func (jgv *JanusGraphEdgeWriter) batchWrite(ctx context.Context, data []any) err return ctx.Err() case <-time.After(jgv.writerTimeout): // If the write operation takes too long, return an error. - return &errBatchWriter{ + return &batchWriterError{ err: errors.New("edge write operation timed out"), retryable: true, } diff --git a/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go b/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go index a55efb81..c176c0bc 100644 --- a/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go +++ b/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go @@ -94,23 +94,26 @@ func (jgv *JanusGraphVertexWriter) startBackgroundWriter(ctx context.Context) { // If the channel is closed, return. if !ok { log.Trace(ctx).Info("Closed background janusgraph worker on channel close") + return } // If the batch is empty, return. if len(batch.data) == 0 { log.Trace(ctx).Warn("Empty vertex batch received in background janusgraph worker, skipping") + return } _ = statsd.Count(ctx, metric.BackgroundWriterCall, 1, jgv.tags, 1) err := jgv.batchWrite(ctx, batch.data) if err != nil { - var e *errBatchWriter + var e *batchWriterError if errors.As(err, &e) { // If the error is retryable, retry the write operation with a smaller batch. if e.retryable && batch.retryCount < jgv.maxRetry { jgv.retrySplitAndRequeue(ctx, &batch, e) + continue } @@ -131,7 +134,7 @@ func (jgv *JanusGraphVertexWriter) startBackgroundWriter(ctx context.Context) { } // retrySplitAndRequeue will split the batch into smaller chunks and requeue them for writing. -func (jgv *JanusGraphVertexWriter) retrySplitAndRequeue(ctx context.Context, batch *batchItem, e *errBatchWriter) { +func (jgv *JanusGraphVertexWriter) retrySplitAndRequeue(ctx context.Context, batch *batchItem, e *batchWriterError) { _ = statsd.Count(ctx, metric.RetryWriterCall, 1, jgv.tags, 1) // Compute the new batch size. @@ -207,6 +210,7 @@ func (jgv *JanusGraphVertexWriter) batchWrite(ctx context.Context, data []any) e ToList() if err != nil { errChan <- fmt.Errorf("%s vertex insert: %w", jgv.builder, err) + return } @@ -215,6 +219,7 @@ func (jgv *JanusGraphVertexWriter) batchWrite(ctx context.Context, data []any) e // We need to parse each map entry and add to our cache. if err = jgv.cacheIds(ctx, raw); err != nil { errChan <- fmt.Errorf("cache ids: %w", err) + return } @@ -228,7 +233,7 @@ func (jgv *JanusGraphVertexWriter) batchWrite(ctx context.Context, data []any) e return ctx.Err() case <-time.After(jgv.writerTimeout): // If the write operation takes too long, return an error. - return &errBatchWriter{ + return &batchWriterError{ err: errors.New("vertex write operation timed out"), retryable: true, } From ee79469dd5a375bd1809322b77dfe63139b7d060 Mon Sep 17 00:00:00 2001 From: Thibault Normand Date: Fri, 29 Nov 2024 10:21:40 +0100 Subject: [PATCH 06/18] refactor(graphdb): split microbatcher and retrier concerns. --- .../storage/graphdb/janusgraph_edge_writer.go | 179 ++++++--------- .../graphdb/janusgraph_vertex_writer.go | 214 +++++++----------- pkg/kubehound/storage/graphdb/microbatcher.go | 183 +++++++++++++++ .../storage/graphdb/microbatcher_test.go | 63 ++++++ pkg/kubehound/storage/graphdb/provider.go | 19 +- 5 files changed, 410 insertions(+), 248 deletions(-) create mode 100644 pkg/kubehound/storage/graphdb/microbatcher.go create mode 100644 pkg/kubehound/storage/graphdb/microbatcher_test.go diff --git a/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go b/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go index 994d09b1..9d5eddcd 100644 --- a/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go +++ b/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go @@ -30,16 +30,13 @@ type JanusGraphEdgeWriter struct { gremlin types.EdgeTraversal // Gremlin traversal generator function drc *gremlingo.DriverRemoteConnection // Gremlin driver remote connection traversalSource *gremlingo.GraphTraversalSource // Transacted graph traversal source - inserts []any // Object data to be inserted in the graph - mu sync.Mutex // Mutex protecting access to the inserts array - consumerChan chan batchItem // Channel consuming inserts for async writing writingInFlight *sync.WaitGroup // Wait group tracking current unfinished writes - batchSize int // Batchsize of graph DB inserts qcounter int32 // Track items queued wcounter int32 // Track items writtn tags []string // Telemetry tags writerTimeout time.Duration // Timeout for the writer maxRetry int // Maximum number of retries for failed writes + mb *microBatcher // Micro batcher to batch writes } // NewJanusGraphAsyncEdgeWriter creates a new bulk edge writer instance. @@ -47,8 +44,9 @@ func NewJanusGraphAsyncEdgeWriter(ctx context.Context, drc *gremlingo.DriverRemo e edge.Builder, opts ...WriterOption, ) (*JanusGraphEdgeWriter, error) { options := &writerOptions{ - WriterTimeout: defaultWriterTimeout, - MaxRetry: defaultMaxRetry, + WriterTimeout: defaultWriterTimeout, + MaxRetry: defaultMaxRetry, + WriterWorkerCount: defaultWriterWorkerCount, } for _, opt := range opts { opt(options) @@ -59,101 +57,91 @@ func NewJanusGraphAsyncEdgeWriter(ctx context.Context, drc *gremlingo.DriverRemo builder: builder, gremlin: e.Traversal(), drc: drc, - inserts: make([]any, 0, e.BatchSize()), traversalSource: gremlingo.Traversal_().WithRemote(drc), - batchSize: e.BatchSize(), writingInFlight: &sync.WaitGroup{}, - consumerChan: make(chan batchItem, e.BatchSize()*channelSizeBatchFactor), tags: append(options.Tags, tag.Label(e.Label()), tag.Builder(builder)), writerTimeout: options.WriterTimeout, maxRetry: options.MaxRetry, } - jw.startBackgroundWriter(ctx) - - return &jw, nil -} - -// startBackgroundWriter starts a background go routine -func (jgv *JanusGraphEdgeWriter) startBackgroundWriter(ctx context.Context) { - go func() { - for { - select { - case batch, ok := <-jgv.consumerChan: - // If the channel is closed, return. - if !ok { - log.Trace(ctx).Info("Closed background janusgraph worker on channel close") - - return - } - - // If the batch is empty, return. - if len(batch.data) == 0 { - log.Trace(ctx).Warn("Empty edge batch received in background janusgraph worker, skipping") - - return - } - - _ = statsd.Count(ctx, metric.BackgroundWriterCall, 1, jgv.tags, 1) - err := jgv.batchWrite(ctx, batch.data) - if err != nil { - var e *batchWriterError - if errors.As(err, &e) { - // If the error is retryable, retry the write operation with a smaller batch. - if e.retryable && batch.retryCount < jgv.maxRetry { - jgv.retrySplitAndRequeue(ctx, &batch, e) - - continue - } - - log.Trace(ctx).Errorf("Retry limit exceeded for write operation: %v", err) - } - - log.Trace(ctx).Errorf("write data in background batch writer: %v", err) - } - - _ = statsd.Decr(ctx, metric.QueueSize, jgv.tags, 1) - case <-ctx.Done(): - log.Trace(ctx).Info("Closed background janusgraph worker on context cancel") - - return + // Create a new micro batcher to batch the inserts with split and retry logic. + jw.mb = newMicroBatcher(log.Trace(ctx), e.BatchSize(), options.WriterWorkerCount, func(ctx context.Context, a []any) error { + // Try to write the batch to the graph DB. + if err := jw.batchWrite(ctx, a); err != nil { + var bwe *batchWriterError + if errors.As(err, &bwe) && bwe.retryable { + // If the write operation failed and is retryable, split the batch and retry. + return jw.splitAndRetry(ctx, 0, a) } + + return err } - }() + + return nil + }) + jw.mb.Start(ctx) + + return &jw, nil } // retrySplitAndRequeue will split the batch into smaller chunks and requeue them for writing. -func (jgv *JanusGraphEdgeWriter) retrySplitAndRequeue(ctx context.Context, batch *batchItem, e *batchWriterError) { +func (jgv *JanusGraphEdgeWriter) splitAndRetry(ctx context.Context, retryCount int, payload []any) error { _ = statsd.Count(ctx, metric.RetryWriterCall, 1, jgv.tags, 1) + // If we have reached the maximum number of retries, return an error. + if retryCount >= jgv.maxRetry { + return fmt.Errorf("max retry count reached: %d", retryCount) + } + // Compute the new batch size. - newBatchSize := len(batch.data) / 2 - batch.retryCount++ + newBatchSize := len(payload) / 2 + + log.Trace(ctx).Warnf("Retrying write operation with smaller edge batch (n:%d -> %d, r:%d)", len(payload), newBatchSize, retryCount) - log.Trace(ctx).Warnf("Retrying write operation with smaller edge batch (n:%d -> %d, r:%d): %v", len(batch.data), newBatchSize, batch.retryCount, e.Unwrap()) + var leftErr, rightErr error - // Split the batch into smaller chunks and requeue them. - if len(batch.data[:newBatchSize]) > 0 { - jgv.consumerChan <- batchItem{ - data: batch.data[:newBatchSize], - retryCount: batch.retryCount, + // Split the batch into smaller chunks and retry them. + if len(payload[:newBatchSize]) > 0 { + if leftErr = jgv.batchWrite(ctx, payload[:newBatchSize]); leftErr == nil { + var bwe *batchWriterError + if errors.As(leftErr, &bwe) && bwe.retryable { + return jgv.splitAndRetry(ctx, retryCount+1, payload[:newBatchSize]) + } } } - if len(batch.data[newBatchSize:]) > 0 { - jgv.consumerChan <- batchItem{ - data: batch.data[newBatchSize:], - retryCount: batch.retryCount, + + // Process the right side of the batch. + if len(payload[newBatchSize:]) > 0 { + if rightErr = jgv.batchWrite(ctx, payload[newBatchSize:]); rightErr != nil { + var bwe *batchWriterError + if errors.As(rightErr, &bwe) && bwe.retryable { + return jgv.splitAndRetry(ctx, retryCount+1, payload[newBatchSize:]) + } } } + + // Return the first error encountered. + switch { + case leftErr != nil && rightErr != nil: + return fmt.Errorf("left: %w, right: %w", leftErr, rightErr) + case leftErr != nil: + return leftErr + case rightErr != nil: + return rightErr + } + + return nil } // batchWrite will write a batch of entries into the graph DB and block until the write completes. -// Callers are responsible for doing an Add(1) to the writingInFlight wait group to ensure proper synchronization. func (jgv *JanusGraphEdgeWriter) batchWrite(ctx context.Context, data []any) error { span, ctx := span.SpanRunFromContext(ctx, span.JanusGraphBatchWrite) span.SetTag(tag.LabelTag, jgv.builder) var err error defer func() { span.Finish(tracer.WithError(err)) }() + + // Increment the writingInFlight wait group to track the number of writes in progress. + jgv.writingInFlight.Add(1) defer jgv.writingInFlight.Done() datalen := len(data) @@ -185,8 +173,6 @@ func (jgv *JanusGraphEdgeWriter) batchWrite(ctx context.Context, data []any) err } func (jgv *JanusGraphEdgeWriter) Close(ctx context.Context) error { - close(jgv.consumerChan) - return nil } @@ -198,29 +184,17 @@ func (jgv *JanusGraphEdgeWriter) Flush(ctx context.Context) error { var err error defer func() { span.Finish(tracer.WithError(err)) }() - jgv.mu.Lock() - defer jgv.mu.Unlock() - if jgv.traversalSource == nil { return errors.New("janusGraph traversalSource is not initialized") } - if len(jgv.inserts) != 0 { - _ = statsd.Incr(ctx, metric.FlushWriterCall, jgv.tags, 1) - - jgv.writingInFlight.Add(1) - err = jgv.batchWrite(ctx, jgv.inserts) - if err != nil { - log.Trace(ctx).Errorf("batch write %s: %+v", jgv.builder, err) - jgv.writingInFlight.Wait() - - return err - } - - log.Trace(ctx).Debugf("Done flushing %s writes. clearing the queue", jgv.builder) - jgv.inserts = nil + // Flush the micro batcher. + err = jgv.mb.Flush(ctx) + if err != nil { + return fmt.Errorf("micro batcher flush: %w", err) } + // Wait for all writes to complete. jgv.writingInFlight.Wait() log.Trace(ctx).Debugf("Edge writer %d %s queued", jgv.qcounter, jgv.builder) @@ -230,26 +204,5 @@ func (jgv *JanusGraphEdgeWriter) Flush(ctx context.Context) error { } func (jgv *JanusGraphEdgeWriter) Queue(ctx context.Context, v any) error { - jgv.mu.Lock() - defer jgv.mu.Unlock() - - atomic.AddInt32(&jgv.qcounter, 1) - jgv.inserts = append(jgv.inserts, v) - - if len(jgv.inserts) > jgv.batchSize { - copied := make([]any, len(jgv.inserts)) - copy(copied, jgv.inserts) - - jgv.writingInFlight.Add(1) - jgv.consumerChan <- batchItem{ - data: copied, - retryCount: 0, - } - _ = statsd.Incr(ctx, metric.QueueSize, jgv.tags, 1) - - // cleanup the ops array after we have copied it to the channel - jgv.inserts = nil - } - - return nil + return jgv.mb.Enqueue(ctx, v) } diff --git a/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go b/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go index c176c0bc..ce44eeb2 100644 --- a/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go +++ b/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go @@ -28,24 +28,14 @@ type JanusGraphVertexWriter struct { gremlin types.VertexTraversal // Gremlin traversal generator function drc *gremlin.DriverRemoteConnection // Gremlin driver remote connection traversalSource *gremlin.GraphTraversalSource // Transacted graph traversal source - inserts []any // Object data to be inserted in the graph - mu sync.Mutex // Mutex protecting access to the inserts array - consumerChan chan batchItem // Channel consuming inserts for async writing writingInFlight *sync.WaitGroup // Wait group tracking current unfinished writes - batchSize int // Batchsize of graph DB inserts qcounter int32 // Track items queued wcounter int32 // Track items writtn tags []string // Telemetry tags cache cache.AsyncWriter // Cache writer to cache store id -> vertex id mappings writerTimeout time.Duration // Timeout for the writer maxRetry int // Maximum number of retries for failed writes -} - -// batchItem is a single item in the batch writer queue that contains the data -// to be written and the number of retries. -type batchItem struct { - data []any - retryCount int + mb *microBatcher // Micro batcher to batch writes } // NewJanusGraphAsyncVertexWriter creates a new bulk vertex writer instance. @@ -53,8 +43,9 @@ func NewJanusGraphAsyncVertexWriter(ctx context.Context, drc *gremlin.DriverRemo v vertex.Builder, c cache.CacheProvider, opts ...WriterOption, ) (*JanusGraphVertexWriter, error) { options := &writerOptions{ - WriterTimeout: defaultWriterTimeout, - MaxRetry: defaultMaxRetry, + WriterTimeout: defaultWriterTimeout, + MaxRetry: defaultMaxRetry, + WriterWorkerCount: defaultWriterWorkerCount, } for _, opt := range opts { opt(options) @@ -69,93 +60,32 @@ func NewJanusGraphAsyncVertexWriter(ctx context.Context, drc *gremlin.DriverRemo builder: v.Label(), gremlin: v.Traversal(), drc: drc, - inserts: make([]any, 0, v.BatchSize()), traversalSource: gremlin.Traversal_().WithRemote(drc), - batchSize: v.BatchSize(), writingInFlight: &sync.WaitGroup{}, - consumerChan: make(chan batchItem, v.BatchSize()*channelSizeBatchFactor), tags: append(options.Tags, tag.Label(v.Label()), tag.Builder(v.Label())), cache: cw, writerTimeout: options.WriterTimeout, maxRetry: options.MaxRetry, } - jw.startBackgroundWriter(ctx) - - return &jw, nil -} - -// startBackgroundWriter starts a background go routine -func (jgv *JanusGraphVertexWriter) startBackgroundWriter(ctx context.Context) { - go func() { - for { - select { - case batch, ok := <-jgv.consumerChan: - // If the channel is closed, return. - if !ok { - log.Trace(ctx).Info("Closed background janusgraph worker on channel close") - - return - } - - // If the batch is empty, return. - if len(batch.data) == 0 { - log.Trace(ctx).Warn("Empty vertex batch received in background janusgraph worker, skipping") - - return - } - - _ = statsd.Count(ctx, metric.BackgroundWriterCall, 1, jgv.tags, 1) - err := jgv.batchWrite(ctx, batch.data) - if err != nil { - var e *batchWriterError - if errors.As(err, &e) { - // If the error is retryable, retry the write operation with a smaller batch. - if e.retryable && batch.retryCount < jgv.maxRetry { - jgv.retrySplitAndRequeue(ctx, &batch, e) - - continue - } - - log.Trace(ctx).Errorf("Retry limit exceeded for write operation: %v", err) - } - - log.Trace(ctx).Errorf("Write data in background batch writer, data will be lost: %v", err) - } - - _ = statsd.Decr(ctx, metric.QueueSize, jgv.tags, 1) - case <-ctx.Done(): - log.Trace(ctx).Info("Closed background janusgraph worker on context cancel") - - return + // Create a new micro batcher to batch the inserts with split and retry logic. + jw.mb = newMicroBatcher(log.Trace(ctx), v.BatchSize(), options.WriterWorkerCount, func(ctx context.Context, a []any) error { + // Try to write the batch to the graph DB. + if err := jw.batchWrite(ctx, a); err != nil { + var bwe *batchWriterError + if errors.As(err, &bwe) && bwe.retryable { + // If the write operation failed and is retryable, split the batch and retry. + return jw.splitAndRetry(ctx, 0, a) } - } - }() -} - -// retrySplitAndRequeue will split the batch into smaller chunks and requeue them for writing. -func (jgv *JanusGraphVertexWriter) retrySplitAndRequeue(ctx context.Context, batch *batchItem, e *batchWriterError) { - _ = statsd.Count(ctx, metric.RetryWriterCall, 1, jgv.tags, 1) - // Compute the new batch size. - newBatchSize := len(batch.data) / 2 - batch.retryCount++ + return err + } - log.Trace(ctx).Warnf("Retrying write operation with smaller vertex batch (n:%d -> %d, r:%d): %v", len(batch.data), newBatchSize, batch.retryCount, e.Unwrap()) + return nil + }) + jw.mb.Start(ctx) - // Split the batch into smaller chunks and requeue them. - if len(batch.data[:newBatchSize]) > 0 { - jgv.consumerChan <- batchItem{ - data: batch.data[:newBatchSize], - retryCount: batch.retryCount, - } - } - if len(batch.data[newBatchSize:]) > 0 { - jgv.consumerChan <- batchItem{ - data: batch.data[newBatchSize:], - retryCount: batch.retryCount, - } - } + return &jw, nil } func (jgv *JanusGraphVertexWriter) cacheIds(ctx context.Context, idMap []*gremlin.Result) error { @@ -182,12 +112,16 @@ func (jgv *JanusGraphVertexWriter) cacheIds(ctx context.Context, idMap []*gremli } // batchWrite will write a batch of entries into the graph DB and block until the write completes. -// Callers are responsible for doing an Add(1) to the writingInFlight wait group to ensure proper synchronization. func (jgv *JanusGraphVertexWriter) batchWrite(ctx context.Context, data []any) error { + _ = statsd.Count(ctx, metric.BackgroundWriterCall, 1, jgv.tags, 1) + span, ctx := span.SpanRunFromContext(ctx, span.JanusGraphBatchWrite) span.SetTag(tag.LabelTag, jgv.builder) var err error defer func() { span.Finish(tracer.WithError(err)) }() + + // Increment the writingInFlight wait group to track the number of writes in progress. + jgv.writingInFlight.Add(1) defer jgv.writingInFlight.Done() datalen := len(data) @@ -246,10 +180,63 @@ func (jgv *JanusGraphVertexWriter) batchWrite(ctx context.Context, data []any) e return nil } +// retrySplitAndRequeue will split the batch into smaller chunks and requeue them for writing. +func (jgv *JanusGraphVertexWriter) splitAndRetry(ctx context.Context, retryCount int, payload []any) error { + _ = statsd.Count(ctx, metric.RetryWriterCall, 1, jgv.tags, 1) + + // If we have reached the maximum number of retries, return an error. + if retryCount >= jgv.maxRetry { + return fmt.Errorf("max retry count reached: %d", retryCount) + } + + // Compute the new batch size. + newBatchSize := len(payload) / 2 + + log.Trace(ctx).Warnf("Retrying write operation with smaller vertex batch (n:%d -> %d, r:%d)", len(payload), newBatchSize, retryCount) + + var leftErr, rightErr error + + // Split the batch into smaller chunks and retry them. + if len(payload[:newBatchSize]) > 0 { + if leftErr = jgv.batchWrite(ctx, payload[:newBatchSize]); leftErr == nil { + var bwe *batchWriterError + if errors.As(leftErr, &bwe) && bwe.retryable { + return jgv.splitAndRetry(ctx, retryCount+1, payload[:newBatchSize]) + } + } + } + + // Process the right side of the batch. + if len(payload[newBatchSize:]) > 0 { + if rightErr = jgv.batchWrite(ctx, payload[newBatchSize:]); rightErr != nil { + var bwe *batchWriterError + if errors.As(rightErr, &bwe) && bwe.retryable { + return jgv.splitAndRetry(ctx, retryCount+1, payload[newBatchSize:]) + } + } + } + + // Return the first error encountered. + switch { + case leftErr != nil && rightErr != nil: + return fmt.Errorf("left: %w, right: %w", leftErr, rightErr) + case leftErr != nil: + return leftErr + case rightErr != nil: + return rightErr + } + + return nil +} + func (jgv *JanusGraphVertexWriter) Close(ctx context.Context) error { - close(jgv.consumerChan) + if jgv.cache != nil { + if err := jgv.cache.Close(ctx); err != nil { + return fmt.Errorf("closing cache: %w", err) + } + } - return jgv.cache.Close(ctx) + return nil } // Flush triggers writes of any remaining items in the queue. @@ -260,29 +247,17 @@ func (jgv *JanusGraphVertexWriter) Flush(ctx context.Context) error { var err error defer func() { span.Finish(tracer.WithError(err)) }() - jgv.mu.Lock() - defer jgv.mu.Unlock() - if jgv.traversalSource == nil { return errors.New("janusGraph traversalSource is not initialized") } - if len(jgv.inserts) != 0 { - _ = statsd.Incr(ctx, metric.FlushWriterCall, jgv.tags, 1) - - jgv.writingInFlight.Add(1) - err = jgv.batchWrite(ctx, jgv.inserts) - if err != nil { - log.Trace(ctx).Errorf("batch write %s: %+v", jgv.builder, err) - jgv.writingInFlight.Wait() - - return err - } - - log.Trace(ctx).Debugf("Done flushing %s writes. clearing the queue", jgv.builder) - jgv.inserts = nil + // Flush the micro batcher. + err = jgv.mb.Flush(ctx) + if err != nil { + return fmt.Errorf("micro batcher flush: %w", err) } + // Wait for all writes to complete. jgv.writingInFlight.Wait() err = jgv.cache.Flush(ctx) @@ -297,26 +272,5 @@ func (jgv *JanusGraphVertexWriter) Flush(ctx context.Context) error { } func (jgv *JanusGraphVertexWriter) Queue(ctx context.Context, v any) error { - jgv.mu.Lock() - defer jgv.mu.Unlock() - - atomic.AddInt32(&jgv.qcounter, 1) - jgv.inserts = append(jgv.inserts, v) - - if len(jgv.inserts) > jgv.batchSize { - copied := make([]any, len(jgv.inserts)) - copy(copied, jgv.inserts) - - jgv.writingInFlight.Add(1) - jgv.consumerChan <- batchItem{ - data: copied, - retryCount: 0, - } - _ = statsd.Incr(ctx, metric.QueueSize, jgv.tags, 1) - - // cleanup the ops array after we have copied it to the channel - jgv.inserts = nil - } - - return nil + return jgv.mb.Enqueue(ctx, v) } diff --git a/pkg/kubehound/storage/graphdb/microbatcher.go b/pkg/kubehound/storage/graphdb/microbatcher.go new file mode 100644 index 00000000..889622ac --- /dev/null +++ b/pkg/kubehound/storage/graphdb/microbatcher.go @@ -0,0 +1,183 @@ +package graphdb + +import ( + "context" + "errors" + "sync" + "sync/atomic" + + "github.com/DataDog/KubeHound/pkg/telemetry/log" +) + +// batchItem is a single item in the batch writer queue that contains the data +// to be written and the number of retries. +type batchItem struct { + data []any + retryCount int +} + +// microBatcher is a utility to batch items and flush them when the batch is full. +type microBatcher struct { + // batchSize is the maximum number of items to batch. + batchSize int + // items is the current item accumulator for the batch. This is reset after + // the batch is flushed. + items []any + // flush is the function to call to flush the batch. + flushFunc func(context.Context, []any) error + // itemChan is the channel to receive items to batch. + itemChan chan any + // batchChan is the channel to send batches to. + batchChan chan batchItem + // workerCount is the number of workers to process the batch. + workerCount int + // workerGroup is the worker group to wait for the workers to finish. + workerGroup *sync.WaitGroup + // shuttingDown is a flag to indicate if the batcher is shutting down. + shuttingDown atomic.Bool + // logger is the logger to use for logging. + logger log.LoggerI +} + +// NewMicroBatcher creates a new micro batcher. +func newMicroBatcher(logger log.LoggerI, batchSize int, workerCount int, flushFunc func(context.Context, []any) error) *microBatcher { + return µBatcher{ + logger: logger, + batchSize: batchSize, + items: make([]any, 0, batchSize), + flushFunc: flushFunc, + itemChan: make(chan any, batchSize), + batchChan: make(chan batchItem, batchSize), + workerCount: workerCount, + workerGroup: nil, // Set in Start. + } +} + +// Flush flushes the current batch and waits for the batch writer to finish. +func (mb *microBatcher) Flush(_ context.Context) error { + // Set the shutting down flag to true. + if !mb.shuttingDown.CompareAndSwap(false, true) { + return errors.New("batcher is already shutting down") + } + + // Closing the item channel to signal the accumulator to stop and flush the batch. + close(mb.itemChan) + + // Wait for the workers to finish. + if mb.workerGroup != nil { + mb.workerGroup.Wait() + } + + return nil +} + +// Enqueue adds an item to the batch processor. +func (mb *microBatcher) Enqueue(ctx context.Context, item any) error { + // If the batcher is shutting down, return an error immediately. + if mb.shuttingDown.Load() { + return errors.New("batcher is shutting down") + } + + select { + case <-ctx.Done(): + // If the context is cancelled, return. + return ctx.Err() + case mb.itemChan <- item: + } + + return nil +} + +// Start starts the batch processor. +func (mb *microBatcher) Start(ctx context.Context) { + if mb.workerGroup != nil { + // If the worker group is already set, return. + return + } + + var wg sync.WaitGroup + + // Start the workers. + for i := 0; i < mb.workerCount; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if err := mb.worker(ctx, mb.batchChan); err != nil { + mb.logger.Errorf("worker: %v", err) + } + }() + } + + // Start the item accumulator. + wg.Add(1) + go func() { + defer wg.Done() + if err := mb.runItemBatcher(ctx); err != nil { + mb.logger.Errorf("run item batcher: %v", err) + } + + // Close the batch channel to signal the workers to stop. + close(mb.batchChan) + }() + + // Set the worker group to wait for the workers to finish. + mb.workerGroup = &wg +} + +// startItemBatcher starts the item accumulator to batch items. +func (mb *microBatcher) runItemBatcher(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + + case item, ok := <-mb.itemChan: + if !ok { + // If the item channel is closed, send the current batch and return. + mb.batchChan <- batchItem{ + data: mb.items, + retryCount: 0, + } + + // End the accumulator. + return nil + } + + // Add the item to the batch. + mb.items = append(mb.items, item) + + // If the batch is full, send it. + if len(mb.items) == mb.batchSize { + // Send the batch to the processor. + mb.batchChan <- batchItem{ + data: mb.items, + retryCount: 0, + } + + // Reset the batch. + mb.items = mb.items[len(mb.items):] + } + } + } +} + +// startWorkers starts the workers to process the batches. +func (mb *microBatcher) worker(ctx context.Context, batchQueue <-chan batchItem) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case batch, ok := <-batchQueue: + if !ok { + return nil + } + + // Send the batch to the processor. + if len(batch.data) > 0 && mb.flushFunc != nil { + if err := mb.flushFunc(ctx, batch.data); err != nil { + mb.logger.Errorf("flush data in background batch writer: %v", err) + } + } + } + } +} diff --git a/pkg/kubehound/storage/graphdb/microbatcher_test.go b/pkg/kubehound/storage/graphdb/microbatcher_test.go new file mode 100644 index 00000000..b455c1b1 --- /dev/null +++ b/pkg/kubehound/storage/graphdb/microbatcher_test.go @@ -0,0 +1,63 @@ +package graphdb + +import ( + "context" + "sync/atomic" + "testing" + + "github.com/DataDog/KubeHound/pkg/telemetry/log" + "github.com/stretchr/testify/assert" +) + +func microBatcherTestInstance(t *testing.T) (*microBatcher, *atomic.Int32) { + t.Helper() + + var ( + writerFuncCalledCount atomic.Int32 + ) + + underTest := newMicroBatcher(log.DefaultLogger(), 5, 1, + func(_ context.Context, _ []any) error { + writerFuncCalledCount.Add(1) + + return nil + }) + + return underTest, &writerFuncCalledCount +} + +func TestMicroBatcher_AfterBatchSize(t *testing.T) { + t.Parallel() + + underTest, writerFuncCalledCount := microBatcherTestInstance(t) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + underTest.Start(ctx) + + for i := 0; i < 10; i++ { + assert.NoError(t, underTest.Enqueue(ctx, i)) + } + + assert.NoError(t, underTest.Flush(ctx)) + + assert.Equal(t, int32(2), writerFuncCalledCount.Load()) +} + +func TestMicroBatcher_AfterFlush(t *testing.T) { + t.Parallel() + + underTest, writerFuncCalledCount := microBatcherTestInstance(t) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + underTest.Start(ctx) + + for i := 0; i < 11; i++ { + assert.NoError(t, underTest.Enqueue(ctx, i)) + } + + assert.NoError(t, underTest.Flush(ctx)) + + assert.Equal(t, int32(3), writerFuncCalledCount.Load()) +} diff --git a/pkg/kubehound/storage/graphdb/provider.go b/pkg/kubehound/storage/graphdb/provider.go index 120800bf..fbbe9b98 100644 --- a/pkg/kubehound/storage/graphdb/provider.go +++ b/pkg/kubehound/storage/graphdb/provider.go @@ -13,14 +13,16 @@ import ( ) const ( - defaultWriterTimeout = 60 * time.Second - defaultMaxRetry = 3 + defaultWriterTimeout = 60 * time.Second + defaultMaxRetry = 3 + defaultWriterWorkerCount = 1 ) type writerOptions struct { - Tags []string - WriterTimeout time.Duration - MaxRetry int + Tags []string + WriterWorkerCount int + WriterTimeout time.Duration + MaxRetry int } type WriterOption func(*writerOptions) @@ -45,6 +47,13 @@ func WithWriterMaxRetry(maxRetry int) WriterOption { } } +// WithWriterWorkerCount sets the number of workers to process the batch. +func WithWriterWorkerCount(workerCount int) WriterOption { + return func(wo *writerOptions) { + wo.WriterWorkerCount = workerCount + } +} + // Provider defines the interface for implementations of the graphdb provider for storage of the calculated K8s attack graph. // //go:generate mockery --name Provider --output mocks --case underscore --filename graph_provider.go --with-expecter From 8d0aedeccac217ac98dc59f154fde838e8d554f7 Mon Sep 17 00:00:00 2001 From: Thibault Normand Date: Fri, 29 Nov 2024 11:10:24 +0100 Subject: [PATCH 07/18] feat(app): wire configuration state to builders. --- pkg/config/config.go | 6 ++++++ pkg/config/janusgraph.go | 16 ++++++++++++++-- .../storage/graphdb/janusgraph_provider.go | 6 ++++++ 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/pkg/config/config.go b/pkg/config/config.go index 8616a593..4b816428 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -119,6 +119,9 @@ func SetDefaultValues(ctx context.Context, v *viper.Viper) { // Defaults values for JanusGraph v.SetDefault(JanusGraphUrl, DefaultJanusGraphUrl) v.SetDefault(JanusGrapTimeout, DefaultConnectionTimeout) + v.SetDefault(JanusGraphWriterTimeout, defaultJanusGraphWriterTimeout) + v.SetDefault(JanusGraphWriterMaxRetry, defaultJanusGraphWriterMaxRetry) + v.SetDefault(JanusGraphWriterWorkerCount, defaultJanusGraphWriterWorkerCount) // Profiler values v.SetDefault(TelemetryProfilerPeriod, DefaultProfilerPeriod) @@ -157,6 +160,9 @@ func SetEnvOverrides(ctx context.Context, c *viper.Viper) { res = multierror.Append(res, c.BindEnv(MongoUrl, "KH_MONGODB_URL")) res = multierror.Append(res, c.BindEnv(JanusGraphUrl, "KH_JANUSGRAPH_URL")) + res = multierror.Append(res, c.BindEnv(JanusGraphWriterMaxRetry, "KH_JANUSGRAPH_WRITER_MAX_RETRY")) + res = multierror.Append(res, c.BindEnv(JanusGraphWriterTimeout, "KH_JANUSGRAPH_WRITER_TIMEOUT")) + res = multierror.Append(res, c.BindEnv(JanusGraphWriterWorkerCount, "KH_JANUSGRAPH_WRITER_WORKER_COUNT")) res = multierror.Append(res, c.BindEnv(IngestorAPIEndpoint, "KH_INGESTOR_API_ENDPOINT")) res = multierror.Append(res, c.BindEnv(IngestorAPIInsecure, "KH_INGESTOR_API_INSECURE")) diff --git a/pkg/config/janusgraph.go b/pkg/config/janusgraph.go index 726c0dd4..ae678bd0 100644 --- a/pkg/config/janusgraph.go +++ b/pkg/config/janusgraph.go @@ -7,12 +7,24 @@ import ( const ( DefaultJanusGraphUrl = "ws://localhost:8182/gremlin" - JanusGraphUrl = "janusgraph.url" - JanusGrapTimeout = "janusgraph.connection_timeout" + defaultJanusGraphWriterTimeout = 60 * time.Second + defaultJanusGraphWriterMaxRetry = 3 + defaultJanusGraphWriterWorkerCount = 1 + + JanusGraphUrl = "janusgraph.url" + JanusGrapTimeout = "janusgraph.connection_timeout" + JanusGraphWriterTimeout = "janusgraph.writer_timeout" + JanusGraphWriterMaxRetry = "janusgraph.writer_max_retry" + JanusGraphWriterWorkerCount = "janusgraph.writer_worker_count" ) // JanusGraphConfig configures JanusGraph specific parameters. type JanusGraphConfig struct { URL string `mapstructure:"url"` // JanusGraph specific configuration ConnectionTimeout time.Duration `mapstructure:"connection_timeout"` + + // JanusGraph vertex/edge writer configuration + WriterTimeout time.Duration `mapstructure:"writer_timeout"` + WriterMaxRetry int `mapstructure:"writer_max_retry"` + WriterWorkerCount int `mapstructure:"writer_worker_count"` } diff --git a/pkg/kubehound/storage/graphdb/janusgraph_provider.go b/pkg/kubehound/storage/graphdb/janusgraph_provider.go index 8a4d04dd..2c55e577 100644 --- a/pkg/kubehound/storage/graphdb/janusgraph_provider.go +++ b/pkg/kubehound/storage/graphdb/janusgraph_provider.go @@ -130,6 +130,9 @@ func (jgp *JanusGraphProvider) VertexWriter(ctx context.Context, v vertex.Builde c cache.CacheProvider, opts ...WriterOption) (AsyncVertexWriter, error) { opts = append(opts, WithTags(jgp.tags)) + opts = append(opts, WithWriterWorkerCount(jgp.cfg.JanusGraph.WriterWorkerCount)) + opts = append(opts, WithWriterTimeout(jgp.cfg.JanusGraph.WriterTimeout)) + opts = append(opts, WithWriterMaxRetry(jgp.cfg.JanusGraph.WriterMaxRetry)) return NewJanusGraphAsyncVertexWriter(ctx, jgp.drc, v, c, opts...) } @@ -137,6 +140,9 @@ func (jgp *JanusGraphProvider) VertexWriter(ctx context.Context, v vertex.Builde // EdgeWriter creates a new AsyncEdgeWriter instance to enable asynchronous bulk inserts of edges. func (jgp *JanusGraphProvider) EdgeWriter(ctx context.Context, e edge.Builder, opts ...WriterOption) (AsyncEdgeWriter, error) { opts = append(opts, WithTags(jgp.tags)) + opts = append(opts, WithWriterWorkerCount(jgp.cfg.JanusGraph.WriterWorkerCount)) + opts = append(opts, WithWriterTimeout(jgp.cfg.JanusGraph.WriterTimeout)) + opts = append(opts, WithWriterMaxRetry(jgp.cfg.JanusGraph.WriterMaxRetry)) return NewJanusGraphAsyncEdgeWriter(ctx, jgp.drc, e, opts...) } From 93690d213aee69bb76ff9ec36ce7fda756e591ee Mon Sep 17 00:00:00 2001 From: Thibault Normand Date: Fri, 29 Nov 2024 11:19:46 +0100 Subject: [PATCH 08/18] test(config): add missing values. --- pkg/config/config_test.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 22d1b7cc..ee6021ab 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -52,6 +52,9 @@ func TestMustLoadConfig(t *testing.T) { JanusGraph: JanusGraphConfig{ URL: "ws://localhost:8182/gremlin", ConnectionTimeout: DefaultConnectionTimeout, + WriterTimeout: defaultJanusGraphWriterTimeout, + WriterMaxRetry: defaultJanusGraphWriterMaxRetry, + WriterWorkerCount: defaultJanusGraphWriterWorkerCount, }, Telemetry: TelemetryConfig{ Statsd: StatsdConfig{ @@ -126,6 +129,9 @@ func TestMustLoadConfig(t *testing.T) { JanusGraph: JanusGraphConfig{ URL: "ws://localhost:8182/gremlin", ConnectionTimeout: DefaultConnectionTimeout, + WriterTimeout: defaultJanusGraphWriterTimeout, + WriterMaxRetry: defaultJanusGraphWriterMaxRetry, + WriterWorkerCount: defaultJanusGraphWriterWorkerCount, }, Telemetry: TelemetryConfig{ Statsd: StatsdConfig{ From aee82538ec50849fc4aed77db955a5a276e23a6b Mon Sep 17 00:00:00 2001 From: Thibault Normand Date: Fri, 29 Nov 2024 11:32:49 +0100 Subject: [PATCH 09/18] feat(graphdb): restore queue metrics. --- pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go | 1 + pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go | 1 + 2 files changed, 2 insertions(+) diff --git a/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go b/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go index 9d5eddcd..613f5064 100644 --- a/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go +++ b/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go @@ -204,5 +204,6 @@ func (jgv *JanusGraphEdgeWriter) Flush(ctx context.Context) error { } func (jgv *JanusGraphEdgeWriter) Queue(ctx context.Context, v any) error { + atomic.AddInt32(&jgv.qcounter, 1) return jgv.mb.Enqueue(ctx, v) } diff --git a/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go b/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go index ce44eeb2..7d9915aa 100644 --- a/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go +++ b/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go @@ -272,5 +272,6 @@ func (jgv *JanusGraphVertexWriter) Flush(ctx context.Context) error { } func (jgv *JanusGraphVertexWriter) Queue(ctx context.Context, v any) error { + atomic.AddInt32(&jgv.qcounter, 1) return jgv.mb.Enqueue(ctx, v) } From a19ce985fdf468d2b75478dc8001cea802b9b014 Mon Sep 17 00:00:00 2001 From: Thibault Normand Date: Fri, 29 Nov 2024 11:41:22 +0100 Subject: [PATCH 10/18] chore(ci): fix nlreturn issues. --- pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go | 1 + pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go | 1 + 2 files changed, 2 insertions(+) diff --git a/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go b/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go index 613f5064..ee2905a4 100644 --- a/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go +++ b/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go @@ -205,5 +205,6 @@ func (jgv *JanusGraphEdgeWriter) Flush(ctx context.Context) error { func (jgv *JanusGraphEdgeWriter) Queue(ctx context.Context, v any) error { atomic.AddInt32(&jgv.qcounter, 1) + return jgv.mb.Enqueue(ctx, v) } diff --git a/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go b/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go index 7d9915aa..0fda6940 100644 --- a/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go +++ b/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go @@ -273,5 +273,6 @@ func (jgv *JanusGraphVertexWriter) Flush(ctx context.Context) error { func (jgv *JanusGraphVertexWriter) Queue(ctx context.Context, v any) error { atomic.AddInt32(&jgv.qcounter, 1) + return jgv.mb.Enqueue(ctx, v) } From 338526ed765f79b539ac23d2de9f19f37575485b Mon Sep 17 00:00:00 2001 From: Thibault Normand Date: Fri, 29 Nov 2024 13:07:27 +0100 Subject: [PATCH 11/18] feat(graphdb): change default settings. --- pkg/config/builder.go | 4 ++-- pkg/config/config.go | 5 +++++ pkg/config/janusgraph.go | 2 +- pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go | 8 ++++---- pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go | 8 ++++---- pkg/kubehound/storage/graphdb/provider.go | 2 +- 6 files changed, 17 insertions(+), 12 deletions(-) diff --git a/pkg/config/builder.go b/pkg/config/builder.go index dc82d194..a00f9b10 100644 --- a/pkg/config/builder.go +++ b/pkg/config/builder.go @@ -3,11 +3,11 @@ package config const ( DefaultEdgeWorkerPoolSize = 5 DefaultEdgeWorkerPoolCapacity = 100 - DefaultEdgeBatchSize = 500 + DefaultEdgeBatchSize = 250 DefaultEdgeBatchSizeSmall = DefaultEdgeBatchSize / 5 DefaultEdgeBatchSizeClusterImpact = 10 - DefaultVertexBatchSize = 500 + DefaultVertexBatchSize = 250 DefaultVertexBatchSizeSmall = DefaultVertexBatchSize / 5 DefaultStopOnError = false diff --git a/pkg/config/config.go b/pkg/config/config.go index 4b816428..a3569f6d 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -172,6 +172,11 @@ func SetEnvOverrides(ctx context.Context, c *viper.Viper) { res = multierror.Append(res, c.BindEnv(IngestorArchiveName, "KH_INGESTOR_ARCHIVE_NAME")) res = multierror.Append(res, c.BindEnv(IngestorBlobRegion, "KH_INGESTOR_REGION")) + res = multierror.Append(res, c.BindEnv("builder.vertex.batch_size", "KH_BUILDER_VERTEX_BATCH_SIZE")) + res = multierror.Append(res, c.BindEnv("builder.vertex.batch_size_small", "KH_BUILDER_VERTEX_BATCH_SIZE_SMALL")) + res = multierror.Append(res, c.BindEnv("builder.edge.batch_size", "KH_BUILDER_EDGE_BATCH_SIZE")) + res = multierror.Append(res, c.BindEnv("builder.edge.batch_size_small", "KH_BUILDER_EDGE_BATCH_SIZE_SMALL")) + res = multierror.Append(res, c.BindEnv(TelemetryStatsdUrl, "STATSD_URL")) res = multierror.Append(res, c.BindEnv(TelemetryTracerUrl, "TRACE_AGENT_URL")) diff --git a/pkg/config/janusgraph.go b/pkg/config/janusgraph.go index ae678bd0..1b3e294a 100644 --- a/pkg/config/janusgraph.go +++ b/pkg/config/janusgraph.go @@ -9,7 +9,7 @@ const ( defaultJanusGraphWriterTimeout = 60 * time.Second defaultJanusGraphWriterMaxRetry = 3 - defaultJanusGraphWriterWorkerCount = 1 + defaultJanusGraphWriterWorkerCount = 10 JanusGraphUrl = "janusgraph.url" JanusGrapTimeout = "janusgraph.connection_timeout" diff --git a/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go b/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go index ee2905a4..47fe37ee 100644 --- a/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go +++ b/pkg/kubehound/storage/graphdb/janusgraph_edge_writer.go @@ -66,6 +66,10 @@ func NewJanusGraphAsyncEdgeWriter(ctx context.Context, drc *gremlingo.DriverRemo // Create a new micro batcher to batch the inserts with split and retry logic. jw.mb = newMicroBatcher(log.Trace(ctx), e.BatchSize(), options.WriterWorkerCount, func(ctx context.Context, a []any) error { + // Increment the writingInFlight wait group to track the number of writes in progress. + jw.writingInFlight.Add(1) + defer jw.writingInFlight.Done() + // Try to write the batch to the graph DB. if err := jw.batchWrite(ctx, a); err != nil { var bwe *batchWriterError @@ -140,10 +144,6 @@ func (jgv *JanusGraphEdgeWriter) batchWrite(ctx context.Context, data []any) err var err error defer func() { span.Finish(tracer.WithError(err)) }() - // Increment the writingInFlight wait group to track the number of writes in progress. - jgv.writingInFlight.Add(1) - defer jgv.writingInFlight.Done() - datalen := len(data) _ = statsd.Count(ctx, metric.EdgeWrite, int64(datalen), jgv.tags, 1) log.Trace(ctx).Debugf("Batch write JanusGraphEdgeWriter with %d elements", datalen) diff --git a/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go b/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go index 0fda6940..5facc1b7 100644 --- a/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go +++ b/pkg/kubehound/storage/graphdb/janusgraph_vertex_writer.go @@ -70,6 +70,10 @@ func NewJanusGraphAsyncVertexWriter(ctx context.Context, drc *gremlin.DriverRemo // Create a new micro batcher to batch the inserts with split and retry logic. jw.mb = newMicroBatcher(log.Trace(ctx), v.BatchSize(), options.WriterWorkerCount, func(ctx context.Context, a []any) error { + // Increment the writingInFlight wait group to track the number of writes in progress. + jw.writingInFlight.Add(1) + defer jw.writingInFlight.Done() + // Try to write the batch to the graph DB. if err := jw.batchWrite(ctx, a); err != nil { var bwe *batchWriterError @@ -120,10 +124,6 @@ func (jgv *JanusGraphVertexWriter) batchWrite(ctx context.Context, data []any) e var err error defer func() { span.Finish(tracer.WithError(err)) }() - // Increment the writingInFlight wait group to track the number of writes in progress. - jgv.writingInFlight.Add(1) - defer jgv.writingInFlight.Done() - datalen := len(data) _ = statsd.Count(ctx, metric.VertexWrite, int64(datalen), jgv.tags, 1) log.Trace(ctx).Debugf("Batch write JanusGraphVertexWriter with %d elements", datalen) diff --git a/pkg/kubehound/storage/graphdb/provider.go b/pkg/kubehound/storage/graphdb/provider.go index fbbe9b98..5b7feb55 100644 --- a/pkg/kubehound/storage/graphdb/provider.go +++ b/pkg/kubehound/storage/graphdb/provider.go @@ -15,7 +15,7 @@ import ( const ( defaultWriterTimeout = 60 * time.Second defaultMaxRetry = 3 - defaultWriterWorkerCount = 1 + defaultWriterWorkerCount = 10 ) type writerOptions struct { From 9af486e77e91f51ac97527029c4c5e3d2464aca4 Mon Sep 17 00:00:00 2001 From: Thibault Normand Date: Fri, 29 Nov 2024 13:49:26 +0100 Subject: [PATCH 12/18] test(config): fix tests. --- pkg/config/config_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index ee6021ab..6c4aee19 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -67,15 +67,15 @@ func TestMustLoadConfig(t *testing.T) { }, Builder: BuilderConfig{ Vertex: VertexBuilderConfig{ - BatchSize: 500, - BatchSizeSmall: 100, + BatchSize: 250, + BatchSizeSmall: 50, }, Edge: EdgeBuilderConfig{ LargeClusterOptimizations: DefaultLargeClusterOptimizations, WorkerPoolSize: 5, WorkerPoolCapacity: 100, - BatchSize: 500, - BatchSizeSmall: 100, + BatchSize: 250, + BatchSizeSmall: 50, BatchSizeClusterImpact: 10, }, }, @@ -145,7 +145,7 @@ func TestMustLoadConfig(t *testing.T) { Builder: BuilderConfig{ Vertex: VertexBuilderConfig{ BatchSize: 1000, - BatchSizeSmall: 100, + BatchSizeSmall: 50, }, Edge: EdgeBuilderConfig{ LargeClusterOptimizations: true, From 5cabd42174cab72bac9757cf35248bd5795f7a2d Mon Sep 17 00:00:00 2001 From: Edouard Schweisguth Date: Fri, 29 Nov 2024 18:00:52 +0100 Subject: [PATCH 13/18] Fix log (#296) --- pkg/kubehound/storage/graphdb/janusgraph_provider.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/kubehound/storage/graphdb/janusgraph_provider.go b/pkg/kubehound/storage/graphdb/janusgraph_provider.go index 8a4d04dd..109bca10 100644 --- a/pkg/kubehound/storage/graphdb/janusgraph_provider.go +++ b/pkg/kubehound/storage/graphdb/janusgraph_provider.go @@ -154,7 +154,7 @@ func (jgp *JanusGraphProvider) Clean(ctx context.Context, cluster string) error span, ctx := span.SpanRunFromContext(ctx, span.IngestorClean) defer func() { span.Finish(tracer.WithError(err)) }() l := log.Trace(ctx) - l.Infof("Cleaning cluster", log.FieldClusterKey, cluster) + l.Info("Cleaning cluster", log.String(log.FieldClusterKey, cluster)) g := gremlin.Traversal_().WithRemote(jgp.drc) tx := g.Tx() defer tx.Close() From c60598a5ff2e7fe3460b9d4c52cd256c52c1265d Mon Sep 17 00:00:00 2001 From: Thibault Normand Date: Mon, 2 Dec 2024 20:01:10 +0100 Subject: [PATCH 14/18] chore(doc): update reference configuration. --- configs/etc/kubehound-reference.yaml | 7 +++++-- configs/etc/kubehound.yaml | 7 +++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/configs/etc/kubehound-reference.yaml b/configs/etc/kubehound-reference.yaml index f1328644..cd0179e8 100644 --- a/configs/etc/kubehound-reference.yaml +++ b/configs/etc/kubehound-reference.yaml @@ -62,6 +62,9 @@ janusgraph: # Timeout on requests to the JanusGraph DB instance connection_timeout: 30s + # Number of worker threads for the JanusGraph writer pool + writer_worker_count: 10 + # # Datadog telemetry configuration # @@ -114,10 +117,10 @@ builder: # worker_pool_capacity: 100 # # Batch size for edge inserts - # batch_size: 500 + # batch_size: 250 # # Small batch size for edge inserts - # batch_size_small: 75 + # batch_size_small: 50 # # Cluster impact batch size for edge inserts # batch_size_cluster_impact: 1 diff --git a/configs/etc/kubehound.yaml b/configs/etc/kubehound.yaml index 7271f781..0a8c6ee5 100644 --- a/configs/etc/kubehound.yaml +++ b/configs/etc/kubehound.yaml @@ -37,19 +37,22 @@ janusgraph: # Timeout on requests to the JanusGraph DB instance connection_timeout: 30s + # Number of worker threads for the JanusGraph writer pool + writer_worker_count: 10 + # Graph builder configuration builder: # Vertex builder configuration vertex: # Batch size for vertex inserts - batch_size: 500 + batch_size: 250 # Edge builder configuration edge: worker_pool_size: 2 # Batch size for edge inserts - batch_size: 500 + batch_size: 250 # Cluster impact batch size for edge inserts batch_size_cluster_impact: 10 From 0d76841423c84af2635b4a0878f4603fd3ba116f Mon Sep 17 00:00:00 2001 From: jt-dd <112463504+jt-dd@users.noreply.github.com> Date: Tue, 3 Dec 2024 10:02:27 +0100 Subject: [PATCH 15/18] Fix umh core pattern attacks (#298) * fix umh core pattern * adding index for mongodb query --- .../graph/edge/escape_umh_core_pattern.go | 30 +++++++++---------- .../storage/storedb/index_builder.go | 8 +++++ 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/pkg/kubehound/graph/edge/escape_umh_core_pattern.go b/pkg/kubehound/graph/edge/escape_umh_core_pattern.go index 22340ce5..40ee8909 100644 --- a/pkg/kubehound/graph/edge/escape_umh_core_pattern.go +++ b/pkg/kubehound/graph/edge/escape_umh_core_pattern.go @@ -54,32 +54,32 @@ func (e *EscapeCorePattern) Stream(ctx context.Context, store storedb.Provider, }, { "$lookup": bson.M{ - "as": "procMountContainers", - "from": "volumes", - "let": bson.M{ - "rootContainerId": "$container_id", - }, + "as": "procMountContainers", + "from": "volumes", + "foreignField": "pod_id", + "localField": "pod_id", "pipeline": []bson.M{ { "$match": bson.M{ "$and": bson.A{ - bson.M{"$expr": bson.M{ - "$eq": bson.A{ - "$container_id", "$$rootContainerId", - }, + bson.M{"type": shared.VolumeTypeHost}, + bson.M{"source": bson.M{ + "$in": ProcMountList, }}, + bson.M{"runtime.runID": e.runtime.RunID.String()}, + bson.M{"runtime.cluster": e.runtime.ClusterName}, }, - "type": shared.VolumeTypeHost, - "source": bson.M{ - "$in": ProcMountList, - }, - "runtime.runID": e.runtime.RunID.String(), - "runtime.cluster": e.runtime.ClusterName, }, }, }, }, }, + { + "$unwind": bson.M{ + "path": "$procMountContainers", + "preserveNullAndEmptyArrays": false, + }, + }, { "$project": bson.M{ "_id": 1, diff --git a/pkg/kubehound/storage/storedb/index_builder.go b/pkg/kubehound/storage/storedb/index_builder.go index d06e40ee..f7839f5d 100644 --- a/pkg/kubehound/storage/storedb/index_builder.go +++ b/pkg/kubehound/storage/storedb/index_builder.go @@ -123,6 +123,14 @@ func (ib *IndexBuilder) containers(ctx context.Context) error { }, Options: options.Index().SetName("byRun"), }, + { + Keys: bson.D{ + {Key: "k8.securitycontext.runasuser", Value: 1}, + {Key: "runtime.runID", Value: 1}, + {Key: "runtime.cluster", Value: 1}, + }, + Options: options.Index().SetName("byRunAsUser"), + }, } _, err := containers.Indexes().CreateMany(ctx, indices) From 7dd7f32a1ce7f62fafcf842243b9e4adda5095e3 Mon Sep 17 00:00:00 2001 From: jt-dd <112463504+jt-dd@users.noreply.github.com> Date: Tue, 3 Dec 2024 10:02:40 +0100 Subject: [PATCH 16/18] Fix role-bind attack (#299) * fix role-bind attack * Fixing non large cluster optimization (limiting to runid) --- pkg/kubehound/graph/edge/pod_create.go | 3 ++- pkg/kubehound/graph/edge/pod_exec.go | 3 ++- pkg/kubehound/graph/edge/pod_patch.go | 3 ++- .../graph/edge/role_bind_crb_cr_cr.go | 18 ++++++++---------- pkg/kubehound/graph/edge/role_bind_crb_cr_r.go | 18 ++++++++---------- pkg/kubehound/graph/edge/token_bruteforce.go | 7 +++++-- pkg/kubehound/graph/edge/token_list.go | 7 +++++-- 7 files changed, 32 insertions(+), 27 deletions(-) diff --git a/pkg/kubehound/graph/edge/pod_create.go b/pkg/kubehound/graph/edge/pod_create.go index e0450ef5..fd31f3bf 100644 --- a/pkg/kubehound/graph/edge/pod_create.go +++ b/pkg/kubehound/graph/edge/pod_create.go @@ -86,7 +86,8 @@ func (e *PodCreate) Traversal() types.EdgeTraversal { } else { // In smaller clusters we can still show the (large set of) attack paths generated by this attack g.V(). - HasLabel("Node"). + Has("runID", e.runtime.RunID.String()). + Has("cluster", e.runtime.ClusterName). Has("class", "Node"). As("n"). V(inserts...). diff --git a/pkg/kubehound/graph/edge/pod_exec.go b/pkg/kubehound/graph/edge/pod_exec.go index a9a3d6b9..e200b922 100644 --- a/pkg/kubehound/graph/edge/pod_exec.go +++ b/pkg/kubehound/graph/edge/pod_exec.go @@ -86,7 +86,8 @@ func (e *PodExec) Traversal() types.EdgeTraversal { } else { // In smaller clusters we can still show the (large set of) attack paths generated by this attack g.V(). - HasLabel("Pod"). + Has("runID", e.runtime.RunID.String()). + Has("cluster", e.runtime.ClusterName). Has("class", "Pod"). As("p"). V(inserts...). diff --git a/pkg/kubehound/graph/edge/pod_patch.go b/pkg/kubehound/graph/edge/pod_patch.go index ff4627bf..18910368 100644 --- a/pkg/kubehound/graph/edge/pod_patch.go +++ b/pkg/kubehound/graph/edge/pod_patch.go @@ -86,7 +86,8 @@ func (e *PodPatch) Traversal() types.EdgeTraversal { } else { // In smaller clusters we can still show the (large set of) attack paths generated by this attack g.V(). - HasLabel("Pod"). + Has("runID", e.runtime.RunID.String()). + Has("cluster", e.runtime.ClusterName). Has("class", "Pod"). As("p"). V(inserts...). diff --git a/pkg/kubehound/graph/edge/role_bind_crb_cr_cr.go b/pkg/kubehound/graph/edge/role_bind_crb_cr_cr.go index ca84dfac..c0c2c149 100644 --- a/pkg/kubehound/graph/edge/role_bind_crb_cr_cr.go +++ b/pkg/kubehound/graph/edge/role_bind_crb_cr_cr.go @@ -7,7 +7,6 @@ import ( "github.com/DataDog/KubeHound/pkg/kubehound/graph/adapter" "github.com/DataDog/KubeHound/pkg/kubehound/graph/types" "github.com/DataDog/KubeHound/pkg/kubehound/models/converter" - "github.com/DataDog/KubeHound/pkg/kubehound/risk" "github.com/DataDog/KubeHound/pkg/kubehound/storage/cache" "github.com/DataDog/KubeHound/pkg/kubehound/storage/storedb" "github.com/DataDog/KubeHound/pkg/kubehound/store/collections" @@ -53,19 +52,16 @@ func (e *RoleBindCrbCrCr) Traversal() types.EdgeTraversal { return func(source *gremlin.GraphTraversalSource, inserts []any) *gremlin.GraphTraversal { g := source.GetGraphTraversal() - // Gathering all sensitives roles - sensitiveRoles := make([]string, 0, len(risk.CriticalRoleMap)) - for k := range risk.CriticalRoleMap { - sensitiveRoles = append(sensitiveRoles, k) - } - if e.cfg.LargeClusterOptimizations { // For larger clusters simply target specific roles to reduce number of attack paths g.V(). - HasLabel("PermissionSet"). + Has("runID", e.runtime.RunID.String()). + Has("cluster", e.runtime.ClusterName). + Has("class", "PermissionSet"). Has("isNamespaced", false). // Temporary measure, until we scan and flag for sensitive roles - Has("role", P.Within(sensitiveRoles)). + Has("critical", true). + // Has("role", P.Within(sensitiveRoles)). As("r"). V(inserts...). Has("critical", false). @@ -75,7 +71,9 @@ func (e *RoleBindCrbCrCr) Traversal() types.EdgeTraversal { } else { // In smaller clusters we can still show the (large set of) attack paths generated by this attack g.V(). - HasLabel("PermissionSet"). + Has("runID", e.runtime.RunID.String()). + Has("cluster", e.runtime.ClusterName). + Has("class", "PermissionSet"). Has("isNamespaced", false). As("i"). V(inserts...). diff --git a/pkg/kubehound/graph/edge/role_bind_crb_cr_r.go b/pkg/kubehound/graph/edge/role_bind_crb_cr_r.go index ecb09c7d..3f2ea670 100644 --- a/pkg/kubehound/graph/edge/role_bind_crb_cr_r.go +++ b/pkg/kubehound/graph/edge/role_bind_crb_cr_r.go @@ -7,7 +7,6 @@ import ( "github.com/DataDog/KubeHound/pkg/kubehound/graph/adapter" "github.com/DataDog/KubeHound/pkg/kubehound/graph/types" "github.com/DataDog/KubeHound/pkg/kubehound/models/converter" - "github.com/DataDog/KubeHound/pkg/kubehound/risk" "github.com/DataDog/KubeHound/pkg/kubehound/storage/cache" "github.com/DataDog/KubeHound/pkg/kubehound/storage/storedb" "github.com/DataDog/KubeHound/pkg/kubehound/store/collections" @@ -53,19 +52,16 @@ func (e *RoleBindCrbCrR) Traversal() types.EdgeTraversal { return func(source *gremlin.GraphTraversalSource, inserts []any) *gremlin.GraphTraversal { g := source.GetGraphTraversal() - // Gathering all sensitives roles - sensitiveRoles := make([]string, 0, len(risk.CriticalRoleMap)) - for k := range risk.CriticalRoleMap { - sensitiveRoles = append(sensitiveRoles, k) - } - if e.cfg.LargeClusterOptimizations { // For larger clusters simply target specific roles to reduce number of attack paths g.V(). - HasLabel("PermissionSet"). + Has("runID", e.runtime.RunID.String()). + Has("cluster", e.runtime.ClusterName). + Has("class", "PermissionSet"). Has("isNamespaced", true). // Temporary measure, until we scan and flag for sensitive roles - Has("role", P.Within(sensitiveRoles)). + Has("critical", true). + // Has("role", P.Within(sensitiveRoles)). As("r"). V(inserts...). Has("critical", false). @@ -75,7 +71,9 @@ func (e *RoleBindCrbCrR) Traversal() types.EdgeTraversal { } else { // In smaller clusters we can still show the (large set of) attack paths generated by this attack g.V(). - HasLabel("PermissionSet"). + Has("runID", e.runtime.RunID.String()). + Has("cluster", e.runtime.ClusterName). + Has("class", "PermissionSet"). Has("isNamespaced", true). As("i"). V(inserts...). diff --git a/pkg/kubehound/graph/edge/token_bruteforce.go b/pkg/kubehound/graph/edge/token_bruteforce.go index e40fc36f..2f9b2a30 100644 --- a/pkg/kubehound/graph/edge/token_bruteforce.go +++ b/pkg/kubehound/graph/edge/token_bruteforce.go @@ -64,7 +64,9 @@ func (e *TokenBruteforce) Traversal() types.EdgeTraversal { if e.cfg.LargeClusterOptimizations { // For larger clusters simply target the system:masters group to reduce redundant attack paths g.V(). - HasLabel("Identity"). + Has("runID", e.runtime.RunID.String()). + Has("cluster", e.runtime.ClusterName). + Has("class", "Identity"). Has("name", "system:masters"). As("i"). V(inserts...). @@ -75,7 +77,8 @@ func (e *TokenBruteforce) Traversal() types.EdgeTraversal { } else { // In smaller clusters we can still show the (large set of) attack paths generated by this attack g.V(). - HasLabel("Identity"). + Has("runID", e.runtime.RunID.String()). + Has("cluster", e.runtime.ClusterName). Has("class", "Identity"). As("i"). V(inserts...). diff --git a/pkg/kubehound/graph/edge/token_list.go b/pkg/kubehound/graph/edge/token_list.go index beeea199..e3b97b9f 100644 --- a/pkg/kubehound/graph/edge/token_list.go +++ b/pkg/kubehound/graph/edge/token_list.go @@ -64,7 +64,9 @@ func (e *TokenList) Traversal() types.EdgeTraversal { if e.cfg.LargeClusterOptimizations { // For larger clusters simply target the system:masters group to reduce redundant attack paths g.V(). - HasLabel("Identity"). + Has("runID", e.runtime.RunID.String()). + Has("cluster", e.runtime.ClusterName). + Has("class", "Identity"). Has("name", "system:masters"). As("i"). V(inserts...). @@ -75,7 +77,8 @@ func (e *TokenList) Traversal() types.EdgeTraversal { } else { // In smaller clusters we can still show the (large set of) attack paths generated by this attack g.V(). - HasLabel("Identity"). + Has("runID", e.runtime.RunID.String()). + Has("cluster", e.runtime.ClusterName). Has("class", "Identity"). As("i"). V(inserts...). From 5965f9bc81bd3f5d0da2db88ffb1245d82969ecd Mon Sep 17 00:00:00 2001 From: jt-dd <112463504+jt-dd@users.noreply.github.com> Date: Tue, 3 Dec 2024 10:02:53 +0100 Subject: [PATCH 17/18] removed the use of labels (#300) --- .../ase/kubehound/KubeHoundTraversalDsl.java | 182 ++++++++++-------- .../KubeHoundTraversalSourceDsl.java | 147 +++++++------- docs/queries/dsl.md | 30 +-- docs/queries/gremlin.md | 97 +++++----- .../graph/edge/escape_var_log_symlink.go | 6 +- scripts/dashboard-demo/main.py | 16 +- test/system/graph_edge_test.go | 116 +++++------ test/system/graph_vertex_test.go | 30 +-- 8 files changed, 330 insertions(+), 294 deletions(-) diff --git a/deployments/kubehound/graph/dsl/kubehound/src/main/java/com/datadog/ase/kubehound/KubeHoundTraversalDsl.java b/deployments/kubehound/graph/dsl/kubehound/src/main/java/com/datadog/ase/kubehound/KubeHoundTraversalDsl.java index a759b4b9..05d998c9 100644 --- a/deployments/kubehound/graph/dsl/kubehound/src/main/java/com/datadog/ase/kubehound/KubeHoundTraversalDsl.java +++ b/deployments/kubehound/graph/dsl/kubehound/src/main/java/com/datadog/ase/kubehound/KubeHoundTraversalDsl.java @@ -34,16 +34,21 @@ import static org.apache.tinkerpop.gremlin.process.traversal.Scope.local; import static org.apache.tinkerpop.gremlin.structure.Column.values; - /** - * This KubeHound DSL is meant to be used with the Kubernetes attack graph created by the KubeHound application. + * This KubeHound DSL is meant to be used with the Kubernetes attack graph + * created by the KubeHound application. *

- * All DSLs should extend {@code GraphTraversal.Admin} and be suffixed with "TraversalDsl". Simply add DSL traversal - * methods to this interface. Use Gremlin's steps to build the underlying traversal in these methods to ensure - * compatibility with the rest of the TinkerPop stack and provider implementations. + * All DSLs should extend {@code GraphTraversal.Admin} and be suffixed with + * "TraversalDsl". Simply add DSL traversal + * methods to this interface. Use Gremlin's steps to build the underlying + * traversal in these methods to ensure + * compatibility with the rest of the TinkerPop stack and provider + * implementations. *

- * Arguments provided to the {@code GremlinDsl} annotation are all optional. In this case, a {@code traversalSource} is - * specified which points to a specific implementation to use. Had that argument not been specified then a default + * Arguments provided to the {@code GremlinDsl} annotation are all optional. In + * this case, a {@code traversalSource} is + * specified which points to a specific implementation to use. Had that argument + * not been specified then a default * {@code TraversalSource} would have been generated. */ @GremlinDsl(traversalSource = "com.datadog.ase.kubehound.KubeHoundTraversalSourceDsl") @@ -54,7 +59,8 @@ public interface KubeHoundTraversalDsl extends GraphTraversal.Admin public static final int PATH_HOPS_MIN_DEFAULT = 6; /** - * From a {@code Vertex} traverse immediate edges to display the next set of possible attacks and targets. + * From a {@code Vertex} traverse immediate edges to display the next set of + * possible attacks and targets. * */ public default GraphTraversal attacks() { @@ -62,78 +68,85 @@ public default GraphTraversal attacks() { } /** - * From a {@code Vertex} filter on whether incoming vertices are critical assets. + * From a {@code Vertex} filter on whether incoming vertices are critical + * assets. */ - @GremlinDsl.AnonymousMethod(returnTypeParameters = {"A", "A"}, methodTypeParameters = {"A"}) + @GremlinDsl.AnonymousMethod(returnTypeParameters = { "A", "A" }, methodTypeParameters = { "A" }) public default GraphTraversal critical() { return has("critical", true); } /** - * From a {@code Vertex} traverse edges until {@code maxHops} is exceeded or a critical asset is reached and return all paths. + * From a {@code Vertex} traverse edges until {@code maxHops} is exceeded or a + * critical asset is reached and return all paths. * * @param maxHops the maximum number of hops in an attack path */ public default GraphTraversal criticalPaths(int maxHops) { - if (maxHops < PATH_HOPS_MIN) throw new IllegalArgumentException(String.format("maxHops must be >= %d", PATH_HOPS_MIN)); - if (maxHops > PATH_HOPS_MAX) throw new IllegalArgumentException(String.format("maxHops must be <= %d", PATH_HOPS_MAX)); + if (maxHops < PATH_HOPS_MIN) + throw new IllegalArgumentException(String.format("maxHops must be >= %d", PATH_HOPS_MIN)); + if (maxHops > PATH_HOPS_MAX) + throw new IllegalArgumentException(String.format("maxHops must be <= %d", PATH_HOPS_MAX)); - return repeat(( - (KubeHoundTraversalDsl) __.outE()) + return repeat(((KubeHoundTraversalDsl) __.outE()) .inV() - .simplePath() - ).until( - __.has("critical", true) - .or() - .loops() - .is(maxHops) - ).has("critical", true) - .path(); + .simplePath()).until( + __.has("critical", true) + .or() + .loops() + .is(maxHops)) + .has("critical", true) + .path(); } /** - * From a {@code Vertex} traverse edges until a critical asset is reached and return all paths. + * From a {@code Vertex} traverse edges until a critical asset is reached and + * return all paths. */ public default GraphTraversal criticalPaths() { return criticalPaths(PATH_HOPS_DEFAULT); } /** - * From a {@code Vertex} traverse edges EXCLUDING labels provided in {@code exclusions} until {@code maxHops} is exceeded or - * a critical asset is reached and return all paths. + * From a {@code Vertex} traverse edges EXCLUDING labels provided in + * {@code exclusions} until {@code maxHops} is exceeded or + * a critical asset is reached and return all paths. * - * @param maxHops the maximum number of hops in an attack path + * @param maxHops the maximum number of hops in an attack path * @param exclusions edge labels to exclude from paths */ public default GraphTraversal criticalPathsFilter(int maxHops, String... exclusions) { - if (exclusions.length <= 0) throw new IllegalArgumentException("exclusions must be provided (otherwise use criticalPaths())"); - if (maxHops < PATH_HOPS_MIN) throw new IllegalArgumentException(String.format("maxHops must be >= %d", PATH_HOPS_MIN)); - if (maxHops > PATH_HOPS_MAX) throw new IllegalArgumentException(String.format("maxHops must be <= %d", PATH_HOPS_MAX)); - - return repeat(( - (KubeHoundTraversalDsl) __.outE()) - .hasLabel(P.not(P.within(exclusions))) + if (exclusions.length <= 0) + throw new IllegalArgumentException("exclusions must be provided (otherwise use criticalPaths())"); + if (maxHops < PATH_HOPS_MIN) + throw new IllegalArgumentException(String.format("maxHops must be >= %d", PATH_HOPS_MIN)); + if (maxHops > PATH_HOPS_MAX) + throw new IllegalArgumentException(String.format("maxHops must be <= %d", PATH_HOPS_MAX)); + + return repeat(((KubeHoundTraversalDsl) __.outE()) + .has("class", P.not(P.within(exclusions))) .inV() - .simplePath() - ).until( - __.has("critical", true) - .or() - .loops() - .is(maxHops) - ).has("critical", true) - .path(); + .simplePath()).until( + __.has("critical", true) + .or() + .loops() + .is(maxHops)) + .has("critical", true) + .path(); } /** - * From a {@code Vertex} filter on whether incoming vertices have at least one path to a critical asset. + * From a {@code Vertex} filter on whether incoming vertices have at least one + * path to a critical asset. */ - @GremlinDsl.AnonymousMethod(returnTypeParameters = {"A", "A"}, methodTypeParameters = {"A"}) + @GremlinDsl.AnonymousMethod(returnTypeParameters = { "A", "A" }, methodTypeParameters = { "A" }) public default GraphTraversal hasCriticalPath() { - return where(__.criticalPaths().limit(1)); + return where(__.criticalPaths().limit(1)); } /** - * From a {@code Vertex} returns the hop count of the shortest path to a critical asset. + * From a {@code Vertex} returns the hop count of the shortest path to a + * critical asset. * */ public default GraphTraversal minHopsToCritical() { @@ -141,61 +154,66 @@ public default GraphTraversal minHopsToCritical() } /** - * From a {@code Vertex} returns the hop count of the shortest path to a critical asset. - * + * From a {@code Vertex} returns the hop count of the shortest path to a + * critical asset. + * * @param maxHops the maximum number of hops in an attack path to consider * */ public default GraphTraversal minHopsToCritical(int maxHops) { - if (maxHops < PATH_HOPS_MIN) throw new IllegalArgumentException(String.format("maxHops must be >= %d", PATH_HOPS_MIN)); - if (maxHops > PATH_HOPS_MAX) throw new IllegalArgumentException(String.format("maxHops must be <= %d", PATH_HOPS_MAX)); - - return repeat(( - (KubeHoundTraversalDsl) __.out()) - .simplePath() - ).until( - __.has("critical", true) - .or() - .loops() - .is(maxHops) - ).has("critical", true) - .path() - .count(local) - .min(); + if (maxHops < PATH_HOPS_MIN) + throw new IllegalArgumentException(String.format("maxHops must be >= %d", PATH_HOPS_MIN)); + if (maxHops > PATH_HOPS_MAX) + throw new IllegalArgumentException(String.format("maxHops must be <= %d", PATH_HOPS_MAX)); + + return repeat(((KubeHoundTraversalDsl) __.out()) + .simplePath()).until( + __.has("critical", true) + .or() + .loops() + .is(maxHops)) + .has("critical", true) + .path() + .count(local) + .min(); } /** - * From a {@code Vertex} returns a group count (by label) of paths to a critical asset. + * From a {@code Vertex} returns a group count (by label) of paths to a critical + * asset. * */ public default GraphTraversal> criticalPathsFreq() { - return criticalPathsFreq(PATH_HOPS_DEFAULT); + return criticalPathsFreq(PATH_HOPS_DEFAULT); } /** - * From a {@code Vertex} returns a group count (by label) of paths to a critical asset. + * From a {@code Vertex} returns a group count (by label) of paths to a critical + * asset. * * @param maxHops the maximum number of hops in an attack path */ public default GraphTraversal> criticalPathsFreq(int maxHops) { - if (maxHops < PATH_HOPS_MIN) throw new IllegalArgumentException(String.format("maxHops must be >= %d", PATH_HOPS_MIN)); - if (maxHops > PATH_HOPS_MAX) throw new IllegalArgumentException(String.format("maxHops must be <= %d", PATH_HOPS_MAX)); + if (maxHops < PATH_HOPS_MIN) + throw new IllegalArgumentException(String.format("maxHops must be >= %d", PATH_HOPS_MIN)); + if (maxHops > PATH_HOPS_MAX) + throw new IllegalArgumentException(String.format("maxHops must be <= %d", PATH_HOPS_MAX)); return repeat( (KubeHoundTraversalDsl) __.outE() - .inV() - .simplePath() - ).emit() - .until( - __.has("critical", true) - .or() - .loops() - .is(maxHops) - ).has("critical", true) - .path() - .by(T.label) - .groupCount() - .order(local) - .by(__.select(values), Order.desc); + .inV() + .simplePath()) + .emit() + .until( + __.has("critical", true) + .or() + .loops() + .is(maxHops)) + .has("critical", true) + .path() + .by(T.label) + .groupCount() + .order(local) + .by(__.select(values), Order.desc); } } diff --git a/deployments/kubehound/graph/dsl/kubehound/src/main/java/com/datadog/ase/kubehound/KubeHoundTraversalSourceDsl.java b/deployments/kubehound/graph/dsl/kubehound/src/main/java/com/datadog/ase/kubehound/KubeHoundTraversalSourceDsl.java index 78b29fac..6352602e 100644 --- a/deployments/kubehound/graph/dsl/kubehound/src/main/java/com/datadog/ase/kubehound/KubeHoundTraversalSourceDsl.java +++ b/deployments/kubehound/graph/dsl/kubehound/src/main/java/com/datadog/ase/kubehound/KubeHoundTraversalSourceDsl.java @@ -60,13 +60,14 @@ public GraphTraversal cluster(String... names) { if (names.length > 0) { traversal = traversal.has("cluster", P.within(names)); - } + } return traversal; } /** - * Starts a traversal that finds all vertices from the specified KubeHound run(s) + * Starts a traversal that finds all vertices from the specified KubeHound + * run(s) * * @param ids list of run ids to filter on */ @@ -75,13 +76,14 @@ public GraphTraversal run(String... ids) { if (ids.length > 0) { traversal = traversal.has("runID", P.within(ids)); - } + } return traversal; } /** - * Starts a traversal that finds all vertices with a "Container" label and optionally allows filtering of those + * Starts a traversal that finds all vertices with a "Container" label and + * optionally allows filtering of those * vertices on the "name" property. * * @param names list of container names to filter on @@ -89,16 +91,17 @@ public GraphTraversal run(String... ids) { public GraphTraversal containers(String... names) { GraphTraversal traversal = this.clone().V(); - traversal = traversal.hasLabel("Container"); + traversal = traversal.has("class", "Container"); if (names.length > 0) { traversal = traversal.has("name", P.within(names)); - } + } return traversal; } /** - * Starts a traversal that finds all vertices with a "Pod" label and optionally allows filtering of those + * Starts a traversal that finds all vertices with a "Pod" label and optionally + * allows filtering of those * vertices on the "name" property. * * @param names list of pod names to filter on @@ -106,16 +109,17 @@ public GraphTraversal containers(String... names) { public GraphTraversal pods(String... names) { GraphTraversal traversal = this.clone().V(); - traversal = traversal.hasLabel("Pod"); + traversal = traversal.has("class", "Pod"); if (names.length > 0) { traversal = traversal.has("name", P.within(names)); - } + } return traversal; } /** - * Starts a traversal that finds all vertices with a "Node" label and optionally allows filtering of those + * Starts a traversal that finds all vertices with a "Node" label and optionally + * allows filtering of those * vertices on the "name" property. * * @param names list of node names to filter on @@ -123,33 +127,35 @@ public GraphTraversal pods(String... names) { public GraphTraversal nodes(String... names) { GraphTraversal traversal = this.clone().V(); - traversal = traversal.hasLabel("Node"); + traversal = traversal.has("class", "Node"); if (names.length > 0) { traversal = traversal.has("name", P.within(names)); - } + } return traversal; } /** - * Starts a traversal that finds all container escape edges from a Container vertex to a Node vertex - * and optionally allows filtering of those vertices on the "nodeNames" property. + * Starts a traversal that finds all container escape edges from a Container + * vertex to a Node vertex + * and optionally allows filtering of those vertices on the "nodeNames" + * property. * * @param nodeNames list of node names to filter on - + * */ public GraphTraversal escapes(String... nodeNames) { GraphTraversal traversal = this.clone().V(); traversal = traversal - .hasLabel("Container") - .outE() - .inV() - .hasLabel("Node"); + .has("class", "Container") + .outE() + .inV() + .has("class", "Node"); if (nodeNames.length > 0) { traversal = traversal.has("name", P.within(nodeNames)); - } + } return traversal.path(); } @@ -159,183 +165,194 @@ public GraphTraversal escapes(String... nodeNames) { */ public GraphTraversal endpoints() { GraphTraversal traversal = this.clone().V(); - - traversal = traversal.hasLabel("Endpoint"); + + traversal = traversal.has("class", "Endpoint"); return traversal; } /** - * Starts a traversal that finds all vertices with a "Endpoint" label and optionally allows filtering of those + * Starts a traversal that finds all vertices with a "Endpoint" label and + * optionally allows filtering of those * vertices on the "exposure" property. * * @param exposure EndpointExposure enum value to filter on */ public GraphTraversal endpoints(EndpointExposure exposure) { if (exposure.ordinal() > EndpointExposure.Max.ordinal()) { - throw new IllegalArgumentException(String.format("invalid exposure value (must be <= %d)", EndpointExposure.Max.ordinal())); + throw new IllegalArgumentException( + String.format("invalid exposure value (must be <= %d)", EndpointExposure.Max.ordinal())); } if (exposure.ordinal() < EndpointExposure.None.ordinal()) { - throw new IllegalArgumentException(String.format("invalid exposure value (must be >= %d)", EndpointExposure.None.ordinal())); + throw new IllegalArgumentException( + String.format("invalid exposure value (must be >= %d)", EndpointExposure.None.ordinal())); } GraphTraversal traversal = this.clone().V(); - + traversal = traversal - .hasLabel("Endpoint") - .has("exposure", P.gte(exposure.ordinal())); + .has("class", "Endpoint") + .has("exposure", P.gte(exposure.ordinal())); return traversal; } /** - * Starts a traversal that finds all vertices with a "Endpoint" label exposed OUTSIDE the cluster as a service + * Starts a traversal that finds all vertices with a "Endpoint" label exposed + * OUTSIDE the cluster as a service * and optionally allows filtering of those vertices on the "portName" property. * * @param portNames list of port names to filter on */ public GraphTraversal services(String... portNames) { GraphTraversal traversal = this.clone().V(); - + traversal = traversal - .hasLabel("Endpoint") - .has("exposure", P.gte(EndpointExposure.External.ordinal())); + .has("class", "Endpoint") + .has("exposure", P.gte(EndpointExposure.External.ordinal())); if (portNames.length > 0) { traversal = traversal.has("portName", P.within(portNames)); - } + } return traversal; } /** - * Starts a traversal that finds all vertices with a "Volume" label and optionally allows filtering of those + * Starts a traversal that finds all vertices with a "Volume" label and + * optionally allows filtering of those * vertices on the "name" property. * * @param names list of volume names to filter on */ public GraphTraversal volumes(String... names) { GraphTraversal traversal = this.clone().V(); - - traversal = traversal.hasLabel("Volume"); + + traversal = traversal.has("class", "Volume"); if (names.length > 0) { traversal = traversal.has("name", P.within(names)); - } + } return traversal; } /** - * Starts a traversal that finds all vertices representing volume host mounts and optionally allows filtering of those + * Starts a traversal that finds all vertices representing volume host mounts + * and optionally allows filtering of those * vertices on the "sourcePath" property. * * @param sourcePaths list of host source paths to filter on */ public GraphTraversal hostMounts(String... sourcePaths) { GraphTraversal traversal = this.clone().V(); - + traversal = traversal - .hasLabel("Volume") - .has("type", "HostPath"); + .has("class", "Volume") + .has("type", "HostPath"); if (sourcePaths.length > 0) { traversal = traversal.has("sourcePath", P.within(sourcePaths)); - } + } return traversal; } /** - * Starts a traversal that finds all vertices with a "Identity" label and optionally allows filtering of those + * Starts a traversal that finds all vertices with a "Identity" label and + * optionally allows filtering of those * vertices on the "name" property. * * @param names list of identity names to filter on */ public GraphTraversal identities(String... names) { GraphTraversal traversal = this.clone().V(); - - traversal = traversal.hasLabel("Identity"); + + traversal = traversal.has("class", "Identity"); if (names.length > 0) { traversal = traversal.has("name", P.within(names)); - } + } return traversal; } /** - * Starts a traversal that finds all vertices representing service accounts and optionally allows filtering of those + * Starts a traversal that finds all vertices representing service accounts and + * optionally allows filtering of those * vertices on the "name" property. * * @param names list of service account names to filter on */ public GraphTraversal sas(String... names) { GraphTraversal traversal = this.clone().V(); - + traversal = traversal - .hasLabel("Identity") - .has("type", "ServiceAccount"); + .has("class", "Identity") + .has("type", "ServiceAccount"); if (names.length > 0) { traversal = traversal.has("name", P.within(names)); - } + } return traversal; } /** - * Starts a traversal that finds all vertices representing users and optionally allows filtering of those + * Starts a traversal that finds all vertices representing users and optionally + * allows filtering of those * vertices on the "name" property. * * @param names list of user names to filter on */ public GraphTraversal users(String... names) { GraphTraversal traversal = this.clone().V(); - + traversal = traversal - .hasLabel("Identity") - .has("type", "User"); + .has("class", "Identity") + .has("type", "User"); if (names.length > 0) { traversal = traversal.has("name", P.within(names)); - } + } return traversal; } /** - * Starts a traversal that finds all vertices representing groups and optionally allows filtering of those + * Starts a traversal that finds all vertices representing groups and optionally + * allows filtering of those * vertices on the "name" property. * * @param names list of groups names to filter on */ public GraphTraversal groups(String... names) { GraphTraversal traversal = this.clone().V(); - + traversal = traversal - .hasLabel("Identity") - .has("type", "Group"); + .has("class", "Identity") + .has("type", "Group"); if (names.length > 0) { traversal = traversal.has("name", P.within(names)); - } + } return traversal; } /** - * Starts a traversal that finds all vertices with a "PermissionSet" label and optionally allows filtering of those + * Starts a traversal that finds all vertices with a "PermissionSet" label and + * optionally allows filtering of those * vertices on the "role" property. * * @param roles list of underlying role names to filter on */ public GraphTraversal permissions(String... roles) { GraphTraversal traversal = this.clone().V(); - - traversal = traversal.hasLabel("PermissionSet"); + + traversal = traversal.has("class", "PermissionSet"); if (roles.length > 0) { traversal = traversal.has("role", P.within(roles)); - } + } return traversal; } diff --git a/docs/queries/dsl.md b/docs/queries/dsl.md index 7a627bde..3092f7df 100644 --- a/docs/queries/dsl.md +++ b/docs/queries/dsl.md @@ -17,21 +17,21 @@ _DSL definition code available [here](https://github.com/DataDog/KubeHound/blob/ ### Retrieve cluster data -| Method | Gremlin equivalent | -| --------------------------- | ----------------------------------------------------- | -| `.cluster([string...])` | `.hasLabel("Cluster")` | -| `.containers([string...])` | `.hasLabel("Container")` | -| `.endpoints([int])` | `.hasLabel("Endpoint")` | -| `.groups([string...])` | `.hasLabel("Group")` | -| `.hostMounts([string...])` | `.hasLabel("Volume").has("type", "HostPath")` | -| `.nodes([string...])` | `.hasLabel("Node")` | -| `.permissions([string...])` | `.hasLabel("PermissionSet")` | -| `.pods([string...])` | `.hasLabel("Pod")` | -| `.run([string...])` | `.has("runID", P.within(ids)` | -| `.sas([string...])` | `.hasLabel("Identity").has("type", "ServiceAccount")` | -| `.services([string...])` | `.hasLabel("Endpoint").has("exposure", EXTERNAL)` | -| `.users([string...])` | `.hasLabel("Identity").has("type", "User")` | -| `.volumes([string...])` | `.hasLabel("Volume")` | +| Method | Gremlin equivalent | +| --------------------------- | -------------------------------------------------------- | +| `.cluster([string...])` | `.has("class","Cluster")` | +| `.containers([string...])` | `.has("class","Container")` | +| `.endpoints([int])` | `.has("class","Endpoint")` | +| `.groups([string...])` | `.has("class","Group")` | +| `.hostMounts([string...])` | `.has("class","Volume").has("type", "HostPath")` | +| `.nodes([string...])` | `.has("class","Node")` | +| `.permissions([string...])` | `.has("class","PermissionSet")` | +| `.pods([string...])` | `.has("class","Pod")` | +| `.run([string...])` | `.has("runID", P.within(ids)` | +| `.sas([string...])` | `.has("class","Identity").has("type", "ServiceAccount")` | +| `.services([string...])` | `.has("class","Endpoint").has("exposure", EXTERNAL)` | +| `.users([string...])` | `.has("class","Identity").has("type", "User")` | +| `.volumes([string...])` | `.has("class","Volume")` | ### Retrieving attack oriented data diff --git a/docs/queries/gremlin.md b/docs/queries/gremlin.md index fed428b7..bd94b47c 100644 --- a/docs/queries/gremlin.md +++ b/docs/queries/gremlin.md @@ -1,92 +1,92 @@ -# Queries +# Queries You can query KubeHound data stored in the JanusGraph database by using the [Gremlin query language](https://docs.janusgraph.org/getting-started/gremlin/). ## Basic queries -``` java title="Count the number of pods in the cluster" -g.V().hasLabel("Pod").count() +```java title="Count the number of pods in the cluster" +g.V().has("class","Pod").count() ``` -``` java title="View all possible container escapes in the cluster" -g.V().hasLabel("Container").outE().inV().hasLabel("Node").path() +```java title="View all possible container escapes in the cluster" +g.V().has("class","Container").outE().inV().has("class","Node").path() ``` -``` java title="List the names of all possible attacks in the cluster" +```java title="List the names of all possible attacks in the cluster" g.E().groupCount().by(label) ``` -``` java title="View all the mounted host path volumes in the cluster" -g.V().hasLabel("Volume").has("type", "HostPath").groupCount().by("sourcePath") +```java title="View all the mounted host path volumes in the cluster" +g.V().has("class","Volume").has("type", "HostPath").groupCount().by("sourcePath") ``` -``` java title="View host path mounts that can be exploited to escape to a node" -g.E().hasLabel("EXPLOIT_HOST_READ", "EXPLOIT_HOST_WRITE").outV().groupCount().by("sourcePath") +```java title="View host path mounts that can be exploited to escape to a node" +g.E().has("class","EXPLOIT_HOST_READ", "EXPLOIT_HOST_WRITE").outV().groupCount().by("sourcePath") ``` -``` java title="View all service endpoints by service name in the cluster" +```java title="View all service endpoints by service name in the cluster" // Leveraging the "EndpointExposureType" enum value to filter only on services // c.f. https://github.com/DataDog/KubeHound/blob/main/pkg/kubehound/models/shared/constants.go -g.V().hasLabel("Endpoint").has("exposure", 3).groupCount().by("serviceEndpoint") +g.V().has("class","Endpoint").has("exposure", 3).groupCount().by("serviceEndpoint") ``` ## Basic attack paths -``` java title="All paths between an endpoint and a node" -g.V().hasLabel("Endpoint").repeat(out().simplePath()).until(hasLabel("Node")).path() +```java title="All paths between an endpoint and a node" +g.V().has("class","Endpoint").repeat(out().simplePath()).until(has("class","Node")).path() ``` -``` java title="All paths (up to 5 hops) between a container and a node" -g.V().hasLabel("Container").repeat(out().simplePath()).until(hasLabel("Node").or().loops().is(5)).hasLabel("Node").path() +```java title="All paths (up to 5 hops) between a container and a node" +g.V().has("class","Container").repeat(out().simplePath()).until(has("class","Node").or().loops().is(5)).has("class","Node").path() ``` -``` java title="All attack paths (up to 6 hops) from any compomised identity (e.g. service account) to a critical asset" -g.V().hasLabel("Identity").repeat(out().simplePath()).until(has("critical", true).or().loops().is(6)).has("critical", true).path().limit(5) +```java title="All attack paths (up to 6 hops) from any compomised identity (e.g. service account) to a critical asset" +g.V().has("class","Identity").repeat(out().simplePath()).until(has("critical", true).or().loops().is(6)).has("critical", true).path().limit(5) ``` -## Attack paths from compromised assets +## Attack paths from compromised assets ### Containers -``` java title="Attack paths (up to 10 hops) from a known breached container to any critical asset" -g.V().hasLabel("Container").has("name", "nsenter-pod").repeat(out().simplePath()).until(has("critical", true).or().loops().is(10)).has("critical", true).path() +```java title="Attack paths (up to 10 hops) from a known breached container to any critical asset" +g.V().has("class","Container").has("name", "nsenter-pod").repeat(out().simplePath()).until(has("critical", true).or().loops().is(10)).has("critical", true).path() ``` -``` java title="Attack paths (up to 10 hops) from a known backdoored container image to any critical asset" -g.V().hasLabel("Container").has("image", TextP.containing("malicious-image")).repeat(out().simplePath()).until(has("critical", true).or().loops().is(10)).has("critical", true).path() +```java title="Attack paths (up to 10 hops) from a known backdoored container image to any critical asset" +g.V().has("class","Container").has("image", TextP.containing("malicious-image")).repeat(out().simplePath()).until(has("critical", true).or().loops().is(10)).has("critical", true).path() ``` ### Credentials -``` java title="Attack paths (up to 10 hops) from a known breached identity to a critical asset" -g.V().hasLabel("Identity").has("name", "compromised-sa").repeat(out().simplePath()).until(has("critical", true).or().loops().is(10)).has("critical", true).path() +```java title="Attack paths (up to 10 hops) from a known breached identity to a critical asset" +g.V().has("class","Identity").has("name", "compromised-sa").repeat(out().simplePath()).until(has("critical", true).or().loops().is(10)).has("critical", true).path() ``` ### Endpoints -``` java title="Attack paths (up to 6 hops) from any endpoint to a critical asset:" -g.V().hasLabel("Endpoint").repeat(out().simplePath()).until(has("critical", true).or().loops().is(6)).has("critical", true).path().limit(5) +```java title="Attack paths (up to 6 hops) from any endpoint to a critical asset:" +g.V().has("class","Endpoint").repeat(out().simplePath()).until(has("critical", true).or().loops().is(6)).has("critical", true).path().limit(5) ``` -``` java title="Attack paths (up to 10 hops) from a known risky endpoint (e.g JMX) to a critical asset" -g.V().hasLabel("Endpoint").has("portName", "jmx").repeat(out().simplePath()).until(has("critical", true).or().loops().is(6)).has("critical", true).path().limit(5) +```java title="Attack paths (up to 10 hops) from a known risky endpoint (e.g JMX) to a critical asset" +g.V().has("class","Endpoint").has("portName", "jmx").repeat(out().simplePath()).until(has("critical", true).or().loops().is(6)).has("critical", true).path().limit(5) ``` ## Risk assessment -``` java title="What is the shortest exploitable path between an exposed service and a critical asset?" -g.V().hasLabel("Endpoint").has("exposure", gte(3)).repeat(out().simplePath()).until(has("critical", true).or().loops().is(7)).has("critical", true).path().count(local).min() +```java title="What is the shortest exploitable path between an exposed service and a critical asset?" +g.V().has("class","Endpoint").has("exposure", gte(3)).repeat(out().simplePath()).until(has("critical", true).or().loops().is(7)).has("critical", true).path().count(local).min() ``` -``` java title="What percentage of external facing services have an exploitable path to a critical asset?" +```java title="What percentage of external facing services have an exploitable path to a critical asset?" // Leveraging the "EndpointExposureType" enum value to filter only on services // c.f. https://github.com/DataDog/KubeHound/blob/main/pkg/kubehound/models/shared/constants.go // Base case -g.V().hasLabel("Endpoint").has("exposure", gte(3)).count() +g.V().has("class","Endpoint").has("exposure", gte(3)).count() // Has a critical path -g.V().hasLabel("Endpoint").has("exposure", gte(3)).where(repeat(out().simplePath()).until(has("critical", true).or().loops().is(10)).has("critical", true).limit(1)).count() +g.V().has("class","Endpoint").has("exposure", gte(3)).where(repeat(out().simplePath()).until(has("critical", true).or().loops().is(10)).has("critical", true).limit(1)).count() ``` ## CVE impact assessment @@ -96,30 +96,30 @@ You can also use KubeHound to determine if workloads in your cluster may be vuln First, evaluate if a known vulnerable image is running in the cluster: ```java -g.V().hasLabel("Container").has("image", TextP.containing("elasticsearch")).groupCount().by("image") +g.V().has("class","Container").has("image", TextP.containing("elasticsearch")).groupCount().by("image") ``` Then, check any exposed services that could be affected and have a path to a critical asset. This helps prioritizing patching and remediation. ```java -g.V().hasLabel("Container").has("image", "dockerhub.com/elasticsearch:7.1.4").where(inE("ENDPOINT_EXPLOIT").outV().has("exposure", gte(3))).where(repeat(out().simplePath()).until(has("critical", true).or().loops().is(10)).has("critical", true).limit(1)) +g.V().has("class","Container").has("image", "dockerhub.com/elasticsearch:7.1.4").where(inE("ENDPOINT_EXPLOIT").outV().has("exposure", gte(3))).where(repeat(out().simplePath()).until(has("critical", true).or().loops().is(10)).has("critical", true).limit(1)) ``` ## Assessing the value of implementing new security controls To verify concrete impact, this can be achieved by comparing the difference in the key risk metrics above, before and after the control change. To simulate the impact of introducing a control (e.g to evaluate ROI), we can add conditions to our path queries. For example if we wanted to evaluate the impact of adding a gatekeeper rule that would deny the use of `hostPID` we can use the following: -``` java title="What percentage level of attack path reduction was achieved by the introduction of a control?" +```java title="What percentage level of attack path reduction was achieved by the introduction of a control?" // Calculate the base case -g.V().hasLabel("Endpoint").has("exposure", gte(3)).repeat(out().simplePath()).until(has("critical", true).or().loops().is(6)).has("critical", true).path().count() +g.V().has("class","Endpoint").has("exposure", gte(3)).repeat(out().simplePath()).until(has("critical", true).or().loops().is(6)).has("critical", true).path().count() // Calculate the impact of preventing CE_NSENTER attack -g.V().hasLabel("Endpoint").has("exposure", gte(3)).repeat(outE().not(hasLabel("CE_NSENTER")).inV().simplePath()).emit().until(has("critical", true).or().loops().is(6)).has("critical", true).path().count() +g.V().has("class","Endpoint").has("exposure", gte(3)).repeat(outE().not(has("class","CE_NSENTER")).inV().simplePath()).emit().until(has("critical", true).or().loops().is(6)).has("critical", true).path().count() ``` -``` java title="What type of control would cut off the largest number of attack paths to a specific asset in the cluster?" +```java title="What type of control would cut off the largest number of attack paths to a specific asset in the cluster?" // We count the number of instances of unique attack paths using -g.V().hasLabel("Container").repeat(outE().inV().simplePath()).emit() +g.V().has("class","Container").repeat(outE().inV().simplePath()).emit() .until(has("critical", true).or().loops().is(6)).has("critical", true) .path().by(label).groupCount().order(local).by(select(values), desc) @@ -136,15 +136,15 @@ g.V().hasLabel("Container").repeat(outE().inV().simplePath()).emit() ## Threat modelling -``` java title="All unique attack paths by labels to a specific asset (here, the cluster-admin role)" -g.V().hasLabel("Container", "Identity") +```java title="All unique attack paths by labels to a specific asset (here, the cluster-admin role)" +g.V().has("class","Container", "Identity") .repeat(out().simplePath()) .until(has("name", "cluster-admin").or().loops().is(5)) -.has("name", "cluster-admin").hasLabel("Role").path().as("p").by(label).dedup().select("p").path() +.has("name", "cluster-admin").has("class","Role").path().as("p").by(label).dedup().select("p").path() ``` -``` java title="All unique attack paths by labels to a any critical asset" -g.V().hasLabel("Container", "Identity") +```java title="All unique attack paths by labels to a any critical asset" +g.V().has("class","Container", "Identity") .repeat(out().simplePath()) .until(has("critical", true).or().loops().is(5)) .has("critical", true).path().as("p").by(label).dedup().select("p").path() @@ -160,10 +160,11 @@ To get started with Gremlin, have a look at the following tutorials: For large clusters it is recommended to add a `limit()` step to **all** queries where the graph output will be examined in the UI to prevent overloading it. An example looking for attack paths possible from a sample of 5 containers would look like: ```go -g.V().hasLabel("Container").limit(5).outE() +g.V().has("class","Container").limit(5).outE() ``` Additional tips: + - For queries to be displayed in the UI, try to limit the output to 1000 elements or less - Enable `large cluster optimizations` via configuration file if queries are returning too slowly -- Try to filter the initial element of queries by namespace/service/app to avoid generating too many results, for instance `g.V().hasLabel("Container").has("namespace", "your-namespace")` +- Try to filter the initial element of queries by namespace/service/app to avoid generating too many results, for instance `g.V().has("class","Container").has("namespace", "your-namespace")` diff --git a/pkg/kubehound/graph/edge/escape_var_log_symlink.go b/pkg/kubehound/graph/edge/escape_var_log_symlink.go index e8e993a9..da2744c1 100644 --- a/pkg/kubehound/graph/edge/escape_var_log_symlink.go +++ b/pkg/kubehound/graph/edge/escape_var_log_symlink.go @@ -60,13 +60,13 @@ func (e *EscapeVarLogSymlink) Traversal() types.EdgeTraversal { return func(source *gremlin.GraphTraversalSource, inserts []any) *gremlin.GraphTraversal { g := source.GetGraphTraversal() // reduce the graph to only these permission sets - g.V(inserts...).HasLabel("PermissionSet"). + g.V(inserts...).Has("class", "PermissionSet"). // get identity vertices InE("PERMISSION_DISCOVER").OutV(). // get container vertices InE("IDENTITY_ASSUME").OutV(). // save container vertices as "c" so we can link to it to the node via CE_VAR_LOG_SYMLINK - HasLabel("Container").As("c"). + Has("class", "Container").As("c"). // Get all the volumes OutE("VOLUME_DISCOVER").InV(). Has("type", shared.VolumeTypeHost). @@ -74,7 +74,7 @@ func (e *EscapeVarLogSymlink) Traversal() types.EdgeTraversal { Has("sourcePath", P.Within("/", "/var", "/var/log")). // get the node related to that volume mount InE("VOLUME_ACCESS").OutV(). - HasLabel("Node").As("n"). + Has("class", "Node").As("n"). AddE("CE_VAR_LOG_SYMLINK").From("c").To("n"). Barrier().Limit(0) diff --git a/scripts/dashboard-demo/main.py b/scripts/dashboard-demo/main.py index aeb9535c..6195c18e 100644 --- a/scripts/dashboard-demo/main.py +++ b/scripts/dashboard-demo/main.py @@ -72,11 +72,11 @@ class EndpointKPI(KPI): KH_QUERY_EXTERNAL_COUNTS = "kh.endpoints().count()" KH_QUERY_DETAILS= 'kh.endpoints().criticalPaths().limit(local,1).dedup().valueMap("serviceEndpoint","port", "namespace")' KH_QUERY_EXTERNAL_CRITICAL_PATH = '''kh.V(). - hasLabel("Endpoint"). + has("class","Endpoint"). count(). aggregate("t"). V(). - hasLabel("Endpoint"). + has("class","Endpoint"). hasCriticalPath(). count(). as("e"). @@ -95,12 +95,12 @@ class IdentitiesKPI(KPI): KH_QUERY_EXTERNAL_COUNTS = "kh.identities().count()" KH_QUERY_DETAILS= 'kh.identities().criticalPaths().limit(local,1).dedup().valueMap("name","type","namespace")' KH_QUERY_EXTERNAL_CRITICAL_PATH = '''kh.V(). - hasLabel("Identity"). + has("class","Identity"). has("critical", false). count(). aggregate("t"). V(). - hasLabel("Identity"). + has("class","Identity"). has("critical", false). hasCriticalPath(). count(). @@ -121,11 +121,11 @@ class ContainersKPI(KPI): KH_QUERY_EXTERNAL_COUNTS = "kh.containers().count()" KH_QUERY_DETAILS= 'kh.containers().criticalPaths().limit(local,1).dedup().valueMap("name","image","app","namespace")' KH_QUERY_EXTERNAL_CRITICAL_PATH = '''kh.V(). - hasLabel("Container"). + has("class","Container"). count(). aggregate("t"). V(). - hasLabel("Container"). + has("class","Container"). hasCriticalPath(). count(). as("e"). @@ -146,11 +146,11 @@ class VolumesKPI(KPI): KH_QUERY_DETAILS= 'kh.volumes().criticalPaths().limit(local,1).dedup().valueMap("name","sourcePath", "namespace")' KH_QUERY_DETAILS_KEYS = ["name", "sourcePath"] KH_QUERY_EXTERNAL_CRITICAL_PATH = '''kh.V(). - hasLabel("Volume"). + has("class","Volume"). count(). aggregate("t"). V(). - hasLabel("Volume"). + has("class","Volume"). hasCriticalPath(). count(). as("e"). diff --git a/test/system/graph_edge_test.go b/test/system/graph_edge_test.go index 03286cd1..57df54e2 100644 --- a/test/system/graph_edge_test.go +++ b/test/system/graph_edge_test.go @@ -142,7 +142,7 @@ func (suite *EdgeTestSuite) TestEdge_CE_UMH_CORE_PATTERN() { func (suite *EdgeTestSuite) TestEdge_CONTAINER_ATTACH() { // Every container should have a CONTAINER_ATTACH incoming from a pod rawCount, err := suite.g.V(). - HasLabel("Container"). + Has("class", "Container"). Count().Next() suite.NoError(err) @@ -151,9 +151,9 @@ func (suite *EdgeTestSuite) TestEdge_CONTAINER_ATTACH() { suite.NotEqual(containerCount, 0) rawCount, err = suite.g.V(). - HasLabel("Pod"). + Has("class", "Pod"). OutE().HasLabel("CONTAINER_ATTACH"). - InV().HasLabel("Container"). + InV().Has("class", "Container"). Dedup(). Path(). Count().Next() @@ -177,9 +177,9 @@ func (suite *EdgeTestSuite) TestEdge_IDENTITY_ASSUME_Container() { // tokenlist-sa 0 7h39m results, err := suite.g.V(). - HasLabel("Container"). + Has("class", "Container"). OutE().HasLabel("IDENTITY_ASSUME"). - InV().HasLabel("Identity"). + InV().Has("class", "Identity"). Path(). By(__.ValueMap("name")). ToList() @@ -212,9 +212,9 @@ func (suite *EdgeTestSuite) TestEdge_IDENTITY_ASSUME_Container() { func (suite *EdgeTestSuite) TestEdge_IDENTITY_ASSUME_Node() { results, err := suite.g.V(). - HasLabel("Node"). + Has("class", "Node"). OutE().HasLabel("IDENTITY_ASSUME"). - InV().HasLabel("Identity"). + InV().Has("class", "Identity"). Path(). By(__.ValueMap("name")). ToList() @@ -243,9 +243,9 @@ func (suite *EdgeTestSuite) TestEdge_POD_ATTACH() { suite.NotEqual(podCount, 0) rawCount, err = suite.g.V(). - HasLabel("Node"). + Has("class", "Node"). OutE().HasLabel("POD_ATTACH"). - InV().HasLabel("Pod"). + InV().Has("class", "Pod"). Dedup(). Path(). Count().Next() @@ -260,10 +260,10 @@ func (suite *EdgeTestSuite) TestEdge_POD_PATCH() { // We have one bespoke container running with pod/patch permissions which should reach all nodes // since they are not namespaced results, err := suite.g.V(). - HasLabel("PermissionSet"). + Has("class", "PermissionSet"). Has("namespace", "default"). OutE().HasLabel("POD_PATCH"). - InV().HasLabel("Pod"). + InV().Has("class", "Pod"). Path(). By(__.ValueMap("name")). ToList() @@ -309,10 +309,10 @@ func (suite *EdgeTestSuite) TestEdge_POD_CREATE() { // We have one bespoke container running with pod/create permissions which should reach all nodes // since they are not namespaced results, err := suite.g.V(). - HasLabel("PermissionSet"). + Has("class", "PermissionSet"). Has("namespace", "default"). OutE().HasLabel("POD_CREATE"). - InV().HasLabel("Node"). + InV().Has("class", "Node"). Path(). By(__.ValueMap("name")). ToList() @@ -332,10 +332,10 @@ func (suite *EdgeTestSuite) TestEdge_POD_CREATE() { func (suite *EdgeTestSuite) TestEdge_POD_EXEC() { // We have one bespoke container running with pod/exec permissions which should reach all pods in the namespace results, err := suite.g.V(). - HasLabel("PermissionSet"). + Has("class", "PermissionSet"). Has("namespace", "default"). OutE().HasLabel("POD_EXEC"). - InV().HasLabel("Pod"). + InV().Has("class", "Pod"). Path(). By(__.ValueMap("name")). ToList() @@ -391,10 +391,10 @@ func (suite *EdgeTestSuite) TestEdge_PERMISSION_DISCOVER() { // tokenget-sa 0 7h39m // tokenlist-sa 0 7h39m results, err := suite.g.V(). - HasLabel("Identity"). + Has("class", "Identity"). Has("namespace", "default"). OutE().HasLabel("PERMISSION_DISCOVER"). - InV().HasLabel("PermissionSet"). + InV().Has("class", "PermissionSet"). Path(). By(__.ValueMap("name")). ToList() @@ -429,7 +429,7 @@ func (suite *EdgeTestSuite) TestEdge_PERMISSION_DISCOVER() { func (suite *EdgeTestSuite) TestEdge_VOLUME_ACCESS() { // Every volume should have a VOLUME_ACCESS incoming from a node rawCount, err := suite.g.V(). - HasLabel("Volume"). + Has("class", "Volume"). Count().Next() suite.NoError(err) @@ -438,9 +438,9 @@ func (suite *EdgeTestSuite) TestEdge_VOLUME_ACCESS() { suite.NotEqual(volumeCount, 0) rawCount, err = suite.g.V(). - HasLabel("Node"). + Has("class", "Node"). OutE().HasLabel("VOLUME_ACCESS"). - InV().HasLabel("Volume"). + InV().Has("class", "Volume"). Dedup(). Path(). Count().Next() @@ -454,7 +454,7 @@ func (suite *EdgeTestSuite) TestEdge_VOLUME_ACCESS() { func (suite *EdgeTestSuite) TestEdge_VOLUME_DISCOVER() { // Every volume should have a VOLUME_DISCOVER incoming from a container rawCount, err := suite.g.V(). - HasLabel("Volume"). + Has("class", "Volume"). Count().Next() suite.NoError(err) @@ -463,9 +463,9 @@ func (suite *EdgeTestSuite) TestEdge_VOLUME_DISCOVER() { suite.NotEqual(volumeCount, 0) rawCount, err = suite.g.V(). - HasLabel("Container"). + Has("class", "Container"). OutE().HasLabel("VOLUME_DISCOVER"). - InV().HasLabel("Volume"). + InV().Has("class", "Volume"). Dedup(). Path(). Count().Next() @@ -478,10 +478,10 @@ func (suite *EdgeTestSuite) TestEdge_VOLUME_DISCOVER() { func (suite *EdgeTestSuite) TestEdge_TOKEN_BRUTEFORCE() { results, err := suite.g.V(). - HasLabel("PermissionSet"). + Has("class", "PermissionSet"). Has("namespace", "default"). OutE().HasLabel("TOKEN_BRUTEFORCE"). - InV().HasLabel("Identity"). + InV().Has("class", "Identity"). Path(). By(__.ValueMap("name")). ToList() @@ -514,10 +514,10 @@ func (suite *EdgeTestSuite) TestEdge_TOKEN_BRUTEFORCE() { func (suite *EdgeTestSuite) TestEdge_TOKEN_LIST() { results, err := suite.g.V(). - HasLabel("PermissionSet"). + Has("class", "PermissionSet"). Has("namespace", "default"). OutE().HasLabel("TOKEN_LIST"). - InV().HasLabel("Identity"). + InV().Has("class", "Identity"). Path(). By(__.ValueMap("name")). ToList() @@ -552,11 +552,11 @@ func (suite *EdgeTestSuite) TestEdge_TOKEN_STEAL() { // Every pod in our test cluster should have projected volume holding a token. BUT we only // save those with a non-default service account token as shown below. results, err := suite.g.V(). - HasLabel("Volume"). + Has("class", "Volume"). OutE(). HasLabel("TOKEN_STEAL"). InV(). - HasLabel("Identity"). + Has("class", "Identity"). Has("namespace", "default"). Values("name"). ToList() @@ -589,11 +589,11 @@ func (suite *EdgeTestSuite) TestEdge_TOKEN_STEAL() { func (suite *EdgeTestSuite) TestEdge_EXPLOIT_HOST_READ() { results, err := suite.g.V(). - HasLabel("Container"). + Has("class", "Container"). OutE().HasLabel("VOLUME_DISCOVER"). - InV().HasLabel("Volume"). + InV().Has("class", "Volume"). Where(__.OutE().HasLabel("EXPLOIT_HOST_READ"). - InV().HasLabel("Node")). + InV().Has("class", "Node")). Path(). By(__.ValueMap("name")). ToList() @@ -610,11 +610,11 @@ func (suite *EdgeTestSuite) TestEdge_EXPLOIT_HOST_READ() { func (suite *EdgeTestSuite) TestEdge_EXPLOIT_HOST_WRITE() { results, err := suite.g.V(). - HasLabel("Container"). + Has("class", "Container"). OutE().HasLabel("VOLUME_DISCOVER"). - InV().HasLabel("Volume"). + InV().Has("class", "Volume"). Where(__.OutE().HasLabel("EXPLOIT_HOST_WRITE"). - InV().HasLabel("Node")). + InV().Has("class", "Node")). Path(). By(__.ValueMap("name")). ToList() @@ -633,10 +633,10 @@ func (suite *EdgeTestSuite) TestEdge_EXPLOIT_HOST_TRAVERSE() { for _, c := range []string{"host-read-exploit-pod", "host-write-exploit-pod"} { // Find the containers on the same node as our vulnerable pod and map to their service accounts results, err := suite.g.V(). - HasLabel("Container"). + Has("class", "Container"). Has("name", c). Values("node").As("n"). - V().HasLabel("Container"). + V().Has("class", "Container"). Has("node", __.Where(P.Eq("n"))). OutE("IDENTITY_ASSUME"). InV(). @@ -649,14 +649,14 @@ func (suite *EdgeTestSuite) TestEdge_EXPLOIT_HOST_TRAVERSE() { // Now find the identities our vulnerable pod can reach via doing a traverse to the projected token volume results, err = suite.g.V(). - HasLabel("Container"). + Has("class", "Container"). Has("name", c). OutE().HasLabel("VOLUME_DISCOVER"). - InV().HasLabel("Volume"). + InV().Has("class", "Volume"). OutE().HasLabel("EXPLOIT_HOST_TRAVERSE"). - InV().HasLabel("Volume"). + InV().Has("class", "Volume"). OutE().HasLabel("TOKEN_STEAL"). - InV().HasLabel("Identity"). + InV().Has("class", "Identity"). Values("name"). ToList() @@ -671,12 +671,12 @@ func (suite *EdgeTestSuite) TestEdge_EXPLOIT_HOST_TRAVERSE() { func (suite *EdgeTestSuite) TestEdge_ENDPOINT_EXPLOIT_ContainerPort() { results, err := suite.g.V(). - HasLabel("Endpoint"). + Has("class", "Endpoint"). Where( __.Has("exposure", P.Eq(int(shared.EndpointExposureClusterIP))). OutE("ENDPOINT_EXPLOIT"). InV(). - HasLabel("Container")). + Has("class", "Container")). Values("serviceEndpoint"). ToList() @@ -693,12 +693,12 @@ func (suite *EdgeTestSuite) TestEdge_ENDPOINT_EXPLOIT_ContainerPort() { func (suite *EdgeTestSuite) TestEdge_ENDPOINT_EXPLOIT_NodePort() { results, err := suite.g.V(). - HasLabel("Endpoint"). + Has("class", "Endpoint"). Where( __.Has("exposure", P.Eq(int(shared.EndpointExposureNodeIP))). OutE("ENDPOINT_EXPLOIT"). InV(). - HasLabel("Container")). + Has("class", "Container")). Values("serviceEndpoint"). ToList() @@ -715,12 +715,12 @@ func (suite *EdgeTestSuite) TestEdge_ENDPOINT_EXPLOIT_NodePort() { func (suite *EdgeTestSuite) TestEdge_ENDPOINT_EXPLOIT_External() { results, err := suite.g.V(). - HasLabel("Endpoint"). + Has("class", "Endpoint"). Where( __.Has("exposure", P.Eq(int(shared.EndpointExposureExternal))). OutE("ENDPOINT_EXPLOIT"). InV(). - HasLabel("Container")). + Has("class", "Container")). Values("serviceEndpoint"). ToList() @@ -737,9 +737,9 @@ func (suite *EdgeTestSuite) TestEdge_ENDPOINT_EXPLOIT_External() { func (suite *EdgeTestSuite) TestEdge_SHARE_PS_NAMESPACE() { results, err := suite.g.V(). - HasLabel("Container"). + Has("class", "Container"). OutE().HasLabel("SHARE_PS_NAMESPACE"). - InV().HasLabel("Container"). + InV().Has("class", "Container"). Path(). By(__.ValueMap("name")). ToList() @@ -767,10 +767,10 @@ func (suite *EdgeTestSuite) TestEdge_SHARE_PS_NAMESPACE() { // Case 1 (cf docs) func (suite *EdgeTestSuite) TestEdge_ROLE_BIND_CASE_1() { results, err := suite.g.V(). - HasLabel("PermissionSet"). + Has("class", "PermissionSet"). Has("isNamespaced", false). OutE().HasLabel("ROLE_BIND"). - InV().HasLabel("PermissionSet"). + InV().Has("class", "PermissionSet"). Has("isNamespaced", false). // Scoping only to the roles related to the attacks to avoid dependency on the Kind Cluster default roles Has("name", gremlingo.TextP.StartingWith("rolebind")). @@ -798,10 +798,10 @@ func (suite *EdgeTestSuite) TestEdge_ROLE_BIND_CASE_1() { // Case 2 (cf docs) func (suite *EdgeTestSuite) TestEdge_ROLE_BIND_CASE_2() { results, err := suite.g.V(). - HasLabel("PermissionSet"). + Has("class", "PermissionSet"). Has("isNamespaced", false). OutE().HasLabel("ROLE_BIND"). - InV().HasLabel("PermissionSet"). + InV().Has("class", "PermissionSet"). Has("isNamespaced", false). // Scoping only to the roles related to the attacks to avoid dependency on the Kind Cluster default roles Has("name", gremlingo.TextP.StartingWith("rolebind")). @@ -829,10 +829,10 @@ func (suite *EdgeTestSuite) TestEdge_ROLE_BIND_CASE_2() { // Case 3 (cf docs) func (suite *EdgeTestSuite) TestEdge_ROLE_BIND_CASE_3() { results, err := suite.g.V(). - HasLabel("PermissionSet"). + Has("class", "PermissionSet"). Has("isNamespaced", true). OutE().HasLabel("ROLE_BIND"). - InV().HasLabel("PermissionSet"). + InV().Has("class", "PermissionSet"). Has("isNamespaced", true). // Scoping only to the roles related to the attacks to avoid dependency on the Kind Cluster default roles Has("name", gremlingo.TextP.StartingWith("rolebind-")). @@ -896,10 +896,10 @@ func (suite *EdgeTestSuite) TestEdge_ROLE_BIND_CASE_3() { // Case 4 (cf docs) func (suite *EdgeTestSuite) TestEdge_ROLE_BIND_CASE_4() { results, err := suite.g.V(). - HasLabel("PermissionSet"). + Has("class", "PermissionSet"). Has("isNamespaced", true). OutE().HasLabel("ROLE_BIND"). - InV().HasLabel("PermissionSet"). + InV().Has("class", "PermissionSet"). Has("isNamespaced", false). // Scoping only to the roles related to the attacks to avoid dependency on the Kind Cluster default roles Has("name", gremlingo.TextP.StartingWith("rolebind")). @@ -919,7 +919,7 @@ func (suite *EdgeTestSuite) TestEdge_ROLE_BIND_CASE_4() { func (suite *EdgeTestSuite) Test_NoEdgeCase() { // The control pod has no interesting properties and therefore should have NO outgoing edges results, err := suite.g.V(). - HasLabel("Container"). + Has("class", "Container"). Has("name", "control-pod"). Out(). ToList() diff --git a/test/system/graph_vertex_test.go b/test/system/graph_vertex_test.go index 1cc13956..20eaf260 100644 --- a/test/system/graph_vertex_test.go +++ b/test/system/graph_vertex_test.go @@ -106,7 +106,7 @@ func (suite *VertexTestSuite) resultsToStringArray(results []*gremlingo.Result) } func (suite *VertexTestSuite) TestVertexContainer() { - results, err := suite.g.V().HasLabel(vertex.ContainerLabel).ElementMap().ToList() + results, err := suite.g.V().Has("class", vertex.ContainerLabel).ElementMap().ToList() suite.NoError(err) suite.Equal(len(expectedContainers), len(results)-numberOfKindDefaultContainer) @@ -183,7 +183,7 @@ func (suite *VertexTestSuite) TestVertexContainer() { } func (suite *VertexTestSuite) TestVertexNode() { - results, err := suite.g.V().HasLabel(vertex.NodeLabel).ElementMap().ToList() + results, err := suite.g.V().Has("class", vertex.NodeLabel).ElementMap().ToList() suite.NoError(err) suite.Equal(len(expectedNodes), len(results)) @@ -219,7 +219,7 @@ func (suite *VertexTestSuite) TestVertexNode() { } func (suite *VertexTestSuite) TestVertexPod() { - results, err := suite.g.V().HasLabel(vertex.PodLabel).ElementMap().ToList() + results, err := suite.g.V().Has("class", vertex.PodLabel).ElementMap().ToList() suite.NoError(err) suite.Equal(len(expectedPods), len(results)-numberOfKindDefaultPod) @@ -269,7 +269,7 @@ func (suite *VertexTestSuite) TestVertexPod() { func (suite *VertexTestSuite) TestVertexPermissionSet() { results, err := suite.g.V(). - HasLabel(vertex.PermissionSetLabel). + Has("class", vertex.PermissionSetLabel). Has("namespace", "default"). Values("name"). ToList() @@ -292,7 +292,7 @@ func (suite *VertexTestSuite) TestVertexPermissionSet() { func (suite *VertexTestSuite) TestVertexCritical() { results, err := suite.g.V(). - HasLabel(vertex.PermissionSetLabel). + Has("class", vertex.PermissionSetLabel). Has("critical", true). Values("role"). ToList() @@ -311,45 +311,45 @@ func (suite *VertexTestSuite) TestVertexCritical() { } func (suite *VertexTestSuite) TestVertexVolume() { - results, err := suite.g.V().HasLabel(vertex.VolumeLabel).ElementMap().ToList() + results, err := suite.g.V().Has("class", vertex.VolumeLabel).ElementMap().ToList() suite.NoError(err) suite.Equal(61, len(results)) - results, err = suite.g.V().HasLabel(vertex.VolumeLabel).Has("sourcePath", "/proc/sys/kernel").Has("name", "nodeproc").ElementMap().ToList() + results, err = suite.g.V().Has("class", vertex.VolumeLabel).Has("sourcePath", "/proc/sys/kernel").Has("name", "nodeproc").ElementMap().ToList() suite.NoError(err) suite.Equal(1, len(results)) - results, err = suite.g.V().HasLabel(vertex.VolumeLabel).Has("sourcePath", "/lib/modules").Has("name", "lib-modules").ElementMap().ToList() + results, err = suite.g.V().Has("class", vertex.VolumeLabel).Has("sourcePath", "/lib/modules").Has("name", "lib-modules").ElementMap().ToList() suite.NoError(err) suite.Greater(len(results), 1) // Not sure why it has "6" - results, err = suite.g.V().HasLabel(vertex.VolumeLabel).Has("sourcePath", "/var/log").Has("name", "nodelog").ElementMap().ToList() + results, err = suite.g.V().Has("class", vertex.VolumeLabel).Has("sourcePath", "/var/log").Has("name", "nodelog").ElementMap().ToList() suite.NoError(err) suite.Equal(len(results), 1) } func (suite *VertexTestSuite) TestVertexIdentity() { - results, err := suite.g.V().HasLabel(vertex.IdentityLabel).ElementMap().ToList() + results, err := suite.g.V().Has("class", vertex.IdentityLabel).ElementMap().ToList() suite.NoError(err) suite.Greater(len(results), 50) - results, err = suite.g.V().HasLabel(vertex.IdentityLabel).Has("name", "tokenget-sa").ElementMap().ToList() + results, err = suite.g.V().Has("class", vertex.IdentityLabel).Has("name", "tokenget-sa").ElementMap().ToList() suite.NoError(err) suite.Equal(len(results), 1) - results, err = suite.g.V().HasLabel(vertex.IdentityLabel).Has("name", "impersonate-sa").ElementMap().ToList() + results, err = suite.g.V().Has("class", vertex.IdentityLabel).Has("name", "impersonate-sa").ElementMap().ToList() suite.NoError(err) suite.Equal(len(results), 1) - results, err = suite.g.V().HasLabel(vertex.IdentityLabel).Has("name", "tokenlist-sa").ElementMap().ToList() + results, err = suite.g.V().Has("class", vertex.IdentityLabel).Has("name", "tokenlist-sa").ElementMap().ToList() suite.NoError(err) suite.Equal(len(results), 1) - results, err = suite.g.V().HasLabel(vertex.IdentityLabel).Has("name", "pod-patch-sa").ElementMap().ToList() + results, err = suite.g.V().Has("class", vertex.IdentityLabel).Has("name", "pod-patch-sa").ElementMap().ToList() suite.NoError(err) suite.Equal(len(results), 1) - results, err = suite.g.V().HasLabel(vertex.IdentityLabel).Has("name", "pod-create-sa").ElementMap().ToList() + results, err = suite.g.V().Has("class", vertex.IdentityLabel).Has("name", "pod-create-sa").ElementMap().ToList() suite.NoError(err) suite.Equal(len(results), 1) } From 30084a5dd9a3ce09af07cc288ef66e56c2083cc9 Mon Sep 17 00:00:00 2001 From: jt-dd <112463504+jt-dd@users.noreply.github.com> Date: Mon, 9 Dec 2024 14:19:10 +0100 Subject: [PATCH 18/18] Fix flags for rootcmd (#301) * Fix flags for rootcmd * Pushing new config path * Fix * add cfgFile to config command * fix logs * fixing remote chain in ingest command * fix unit tests --- cmd/kubehound/config.go | 2 +- cmd/kubehound/dumper.go | 6 +++--- cmd/kubehound/ingest.go | 4 ++-- cmd/kubehound/root.go | 4 ++-- pkg/config/config.go | 3 ++- pkg/config/config_test.go | 4 ++-- pkg/config/ingestor.go | 2 +- 7 files changed, 13 insertions(+), 12 deletions(-) diff --git a/cmd/kubehound/config.go b/cmd/kubehound/config.go index f64b9590..f1ebc1e6 100644 --- a/cmd/kubehound/config.go +++ b/cmd/kubehound/config.go @@ -20,7 +20,7 @@ var ( Short: "Show the current configuration", Long: `[devOnly] Show the current configuration`, PreRunE: func(cobraCmd *cobra.Command, args []string) error { - return cmd.InitializeKubehoundConfig(cobraCmd.Context(), "", true, true) + return cmd.InitializeKubehoundConfig(cobraCmd.Context(), cfgFile, true, true) }, RunE: func(cobraCmd *cobra.Command, args []string) error { // Adding datadog setup diff --git a/cmd/kubehound/dumper.go b/cmd/kubehound/dumper.go index b4e4b505..e0d7de07 100644 --- a/cmd/kubehound/dumper.go +++ b/cmd/kubehound/dumper.go @@ -36,7 +36,7 @@ var ( viper.BindPFlag(config.IngestorAPIEndpoint, cobraCmd.Flags().Lookup("khaas-server")) //nolint: errcheck viper.BindPFlag(config.IngestorAPIInsecure, cobraCmd.Flags().Lookup("insecure")) //nolint: errcheck - return cmd.InitializeKubehoundConfig(cobraCmd.Context(), "", true, true) + return cmd.InitializeKubehoundConfig(cobraCmd.Context(), cfgFile, true, true) }, RunE: func(cobraCmd *cobra.Command, args []string) error { // using compress feature @@ -62,7 +62,7 @@ var ( return fmt.Errorf("dump core: %w", err) } // Running the ingestion on KHaaS - if cobraCmd.Flags().Lookup("khaas-server").Value.String() != "" { + if khCfg.Ingestor.API.Endpoint != "" { return core.CoreClientGRPCIngest(cobraCmd.Context(), khCfg.Ingestor, khCfg.Dynamic.ClusterName, khCfg.Dynamic.RunID.String()) } @@ -77,7 +77,7 @@ var ( PreRunE: func(cobraCmd *cobra.Command, args []string) error { viper.Set(config.CollectorFileDirectory, args[0]) - return cmd.InitializeKubehoundConfig(cobraCmd.Context(), "", true, true) + return cmd.InitializeKubehoundConfig(cobraCmd.Context(), cfgFile, true, true) }, RunE: func(cobraCmd *cobra.Command, args []string) error { // Passing the Kubehound config from viper diff --git a/cmd/kubehound/ingest.go b/cmd/kubehound/ingest.go index fe2423d2..8f436086 100644 --- a/cmd/kubehound/ingest.go +++ b/cmd/kubehound/ingest.go @@ -29,7 +29,7 @@ var ( PreRunE: func(cobraCmd *cobra.Command, args []string) error { cmd.BindFlagCluster(cobraCmd) - return cmd.InitializeKubehoundConfig(cobraCmd.Context(), "", true, true) + return cmd.InitializeKubehoundConfig(cobraCmd.Context(), cfgFile, true, true) }, RunE: func(cobraCmd *cobra.Command, args []string) error { // Passing the Kubehound config from viper @@ -56,7 +56,7 @@ var ( cobraCmd.MarkFlagRequired("cluster") //nolint: errcheck } - return cmd.InitializeKubehoundConfig(cobraCmd.Context(), "", false, true) + return cmd.InitializeKubehoundConfig(cobraCmd.Context(), cfgFile, false, true) }, RunE: func(cobraCmd *cobra.Command, args []string) error { // Passing the Kubehound config from viper diff --git a/cmd/kubehound/root.go b/cmd/kubehound/root.go index fedc496c..fa82e1ea 100644 --- a/cmd/kubehound/root.go +++ b/cmd/kubehound/root.go @@ -76,9 +76,9 @@ var ( ) func init() { - rootCmd.Flags().StringVarP(&cfgFile, "config", "c", cfgFile, "application config file") + rootCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", cfgFile, "application config file") - rootCmd.Flags().BoolVar(&skipBackend, "skip-backend", skipBackend, "skip the auto deployment of the backend stack (janusgraph, mongodb, and UI)") + rootCmd.PersistentFlags().BoolVar(&skipBackend, "skip-backend", skipBackend, "skip the auto deployment of the backend stack (janusgraph, mongodb, and UI)") cmd.InitRootCmd(rootCmd) } diff --git a/pkg/config/config.go b/pkg/config/config.go index a3569f6d..d00018c1 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -80,7 +80,6 @@ func NewKubehoundConfig(ctx context.Context, configPath string, inLine bool) *Ku var cfg *KubehoundConfig switch { case len(configPath) != 0: - l.Info("Loading application configuration from file", log.String("path", configPath)) cfg = MustLoadConfig(ctx, configPath) case inLine: l.Info("Loading application from inline command") @@ -207,10 +206,12 @@ func unmarshalConfig(v *viper.Viper) (*KubehoundConfig, error) { // NewConfig creates a new config instance from the provided file using viper. func NewConfig(ctx context.Context, v *viper.Viper, configPath string) (*KubehoundConfig, error) { + l := log.Logger(ctx) // Configure default values SetDefaultValues(ctx, v) // Loading inLine config path + l.Info("Loading application configuration from file", log.String("path", configPath)) v.SetConfigType(DefaultConfigType) v.SetConfigFile(configPath) diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 6c4aee19..6b91dbd8 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -81,7 +81,7 @@ func TestMustLoadConfig(t *testing.T) { }, Ingestor: IngestorConfig{ API: IngestorAPIConfig{ - Endpoint: "127.0.0.1:9000", + Endpoint: "", Insecure: false, }, Blob: &BlobConfig{ @@ -158,7 +158,7 @@ func TestMustLoadConfig(t *testing.T) { }, Ingestor: IngestorConfig{ API: IngestorAPIConfig{ - Endpoint: "127.0.0.1:9000", + Endpoint: "", Insecure: false, }, Blob: &BlobConfig{ diff --git a/pkg/config/ingestor.go b/pkg/config/ingestor.go index 324ea762..5fdf73b1 100644 --- a/pkg/config/ingestor.go +++ b/pkg/config/ingestor.go @@ -1,7 +1,7 @@ package config const ( - DefaultIngestorAPIEndpoint = "127.0.0.1:9000" + DefaultIngestorAPIEndpoint = "" DefaultIngestorAPIInsecure = false DefaultBucketName = "" // we want to let it empty because we can easily abort if it's not configured DefaultTempDir = "/tmp/kubehound"