From 4220d59eb335a046095d6a30848296ea9ffc9d5f Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Wed, 15 Jul 2020 10:26:13 +0200 Subject: [PATCH 01/11] protection/http: finer-grained instrumentation Mainly for debugging purposes, split the existing file in two so that we can have one that is not instrumented thanks to the `sqreen:ignore` directive. --- internal/protection/http/http.go | 383 ++++------------------------ internal/protection/http/request.go | 291 +++++++++++++++++++++ 2 files changed, 339 insertions(+), 335 deletions(-) create mode 100644 internal/protection/http/request.go diff --git a/internal/protection/http/http.go b/internal/protection/http/http.go index e12fc59f..6cb6ab68 100644 --- a/internal/protection/http/http.go +++ b/internal/protection/http/http.go @@ -5,25 +5,65 @@ package http import ( - "bytes" "context" - "fmt" "io" - "net" "net/http" - "net/textproto" - "net/url" - "strings" "sync" - "github.com/pkg/errors" - "github.com/sqreen/go-agent/internal/config" "github.com/sqreen/go-agent/internal/event" protectioncontext "github.com/sqreen/go-agent/internal/protection/context" "github.com/sqreen/go-agent/internal/protection/http/types" "github.com/sqreen/go-agent/internal/sqlib/sqgls" ) +func FromContext(ctx context.Context) *RequestContext { + c, _ := protectioncontext.FromContext(ctx).(*RequestContext) + return c +} + +func FromGLS() *RequestContext { + ctx := sqgls.Get() + if ctx == nil { + return nil + } + return ctx.(*RequestContext) +} + +func NewRequestContext(ctx context.Context, agent protectioncontext.AgentFace, w types.ResponseWriter, r types.RequestReader) (*RequestContext, context.Context, context.CancelFunc) { + if agent.IsPathAllowed(r.URL().Path) { + return nil, nil, nil + } + + clientIP := r.ClientIP() + if clientIP == nil { + cfg := agent.Config() + clientIP = ClientIP(r.RemoteAddr(), r.Headers(), cfg.PrioritizedIPHeader(), cfg.PrioritizedIPHeaderFormat()) + } + + if agent.IsIPAllowed(clientIP) { + return nil, nil, nil + } + + reqCtx, cancelHandlerContextFunc := context.WithCancel(ctx) + + rr := &requestReader{ + clientIP: clientIP, + RequestReader: r, + requestParams: make(types.RequestParamMap), + } + + protCtx := &RequestContext{ + RequestContext: protectioncontext.NewRequestContext(agent), + ResponseWriter: w, + RequestReader: rr, + cancelHandlerContextFunc: cancelHandlerContextFunc, + // Keep a reference to the request param map to be able to add more params + // to it. + requestReader: rr, + } + return protCtx, reqCtx, cancelHandlerContextFunc +} + type RequestContext struct { *protectioncontext.RequestContext RequestReader types.RequestReader @@ -69,46 +109,6 @@ var _ protectioncontext.EventRecorderGetter = (*RequestContext)(nil) func (c *RequestContext) EventRecorder() protectioncontext.EventRecorder { return c } -type requestReader struct { - types.RequestReader - - // clientIP is the actual IP address of the client performing the request. - clientIP net.IP - - // requestParams is the set of HTTP request parameters taken from the HTTP - // request. The map key is the source (eg. json, query, multipart-form, etc.) - // so that we can report it and make it clearer to understand where the value - // comes from. - requestParams types.RequestParamMap - - // bodyReadBuffer is the buffers body reads - bodyReadBuffer bytes.Buffer -} - -func (r *requestReader) Body() []byte { return r.bodyReadBuffer.Bytes() } - -func (r *requestReader) ClientIP() net.IP { return r.clientIP } - -func (r *requestReader) Params() types.RequestParamMap { - params := r.RequestReader.Params() - if len(params) == 0 { - return r.requestParams - } - - if len(r.requestParams) == 0 { - return params - } - - res := make(types.RequestParamMap, len(params)+len(r.requestParams)) - for n, v := range params { - res[n] = v - } - for n, v := range r.requestParams { - res[n] = v - } - return res -} - func (c *RequestContext) AddAttackEvent(attack *event.AttackEvent) { c.events.AddAttackEvent(attack) } @@ -133,54 +133,6 @@ func (c *RequestContext) IdentifyUser(id map[string]string) error { // Static assert that the SDK interface is implemented. var _ protectioncontext.EventRecorder = &RequestContext{} -func FromContext(ctx context.Context) *RequestContext { - c, _ := protectioncontext.FromContext(ctx).(*RequestContext) - return c -} - -func FromGLS() *RequestContext { - ctx := sqgls.Get() - if ctx == nil { - return nil - } - return ctx.(*RequestContext) -} - -func NewRequestContext(ctx context.Context, agent protectioncontext.AgentFace, w types.ResponseWriter, r types.RequestReader) (*RequestContext, context.Context, context.CancelFunc) { - if agent.IsPathAllowed(r.URL().Path) { - return nil, nil, nil - } - - clientIP := r.ClientIP() - if clientIP == nil { - cfg := agent.Config() - clientIP = ClientIP(r.RemoteAddr(), r.Headers(), cfg.PrioritizedIPHeader(), cfg.PrioritizedIPHeaderFormat()) - } - - if agent.IsIPAllowed(clientIP) { - return nil, nil, nil - } - - reqCtx, cancelHandlerContextFunc := context.WithCancel(ctx) - - rr := &requestReader{ - clientIP: clientIP, - RequestReader: r, - requestParams: make(types.RequestParamMap), - } - - protCtx := &RequestContext{ - RequestContext: protectioncontext.NewRequestContext(agent), - ResponseWriter: w, - RequestReader: rr, - cancelHandlerContextFunc: cancelHandlerContextFunc, - // Keep a reference to the request param map to be able to add more params - // to it. - requestReader: rr, - } - return protCtx, reqCtx, cancelHandlerContextFunc -} - // When a non-nil error is returned, the request handler shouldn't be called // and the request should be stopped immediately by closing the RequestContext // and returning. @@ -253,26 +205,6 @@ func (c *RequestContext) isContextHandlerCanceled() bool { } -type closedRequestContext struct { - response types.ResponseFace - request types.RequestReader - events event.Recorded -} - -var _ types.ClosedRequestContextFace = (*closedRequestContext)(nil) - -func (c *closedRequestContext) Events() event.Recorded { - return c.events -} - -func (c *closedRequestContext) Request() types.RequestReader { - return c.request -} - -func (c *closedRequestContext) Response() types.ResponseFace { - return c.response -} - func (c *RequestContext) Close(response types.ResponseFace) error { // Make sure to clear the goroutine local storage to avoid keeping it if some // memory pools are used under the hood. @@ -289,68 +221,6 @@ func (c *RequestContext) Close(response types.ResponseFace) error { }) } -func copyRequest(reader types.RequestReader) types.RequestReader { - return &handledRequest{ - headers: reader.Headers(), - method: reader.Method(), - url: reader.URL(), - requestURI: reader.RequestURI(), - host: reader.Host(), - remoteAddr: reader.RemoteAddr(), - isTLS: reader.IsTLS(), - userAgent: reader.UserAgent(), - referer: reader.Referer(), - form: reader.Form(), - postForm: reader.PostForm(), - clientIP: reader.ClientIP(), - params: reader.Params(), - body: reader.Body(), - } -} - -type handledRequest struct { - headers http.Header - method string - url *url.URL - requestURI string - host string - remoteAddr string - isTLS bool - userAgent string - referer string - form url.Values - postForm url.Values - clientIP net.IP - params types.RequestParamMap - body []byte -} - -func (h *handledRequest) Headers() http.Header { return h.headers } -func (h *handledRequest) Method() string { return h.method } -func (h *handledRequest) URL() *url.URL { return h.url } -func (h *handledRequest) RequestURI() string { return h.requestURI } -func (h *handledRequest) Host() string { return h.host } -func (h *handledRequest) RemoteAddr() string { return h.remoteAddr } -func (h *handledRequest) IsTLS() bool { return h.isTLS } -func (h *handledRequest) UserAgent() string { return h.userAgent } -func (h *handledRequest) Referer() string { return h.referer } -func (h *handledRequest) Form() url.Values { return h.form } -func (h *handledRequest) PostForm() url.Values { return h.postForm } -func (h *handledRequest) ClientIP() net.IP { return h.clientIP } -func (h *handledRequest) Params() types.RequestParamMap { return h.params } -func (h *handledRequest) Body() []byte { return h.body } -func (h *handledRequest) Header(header string) (value *string) { - headers := h.headers - if headers == nil { - return nil - } - v := headers[textproto.CanonicalMIMEHeaderKey(header)] - if len(v) == 0 { - return nil - } - return &v[0] -} - // Write the default blocking response. This method only write the response, it // doesn't block nor cancel the handler context. Users of this method must // handle their @@ -385,160 +255,3 @@ func (c *RequestContext) AddRequestParam(name string, param interface{}) { params := c.requestReader.requestParams[name] c.requestReader.requestParams[name] = append(params, param) } - -type rawBodyWAF struct { - io.ReadCloser - c *RequestContext -} - -// Read buffers what has been read and ultimately calls the WAF on EOF. -func (t rawBodyWAF) Read(p []byte) (n int, err error) { - n, err = t.ReadCloser.Read(p) - if n > 0 { - t.c.requestReader.bodyReadBuffer.Write(p[:n]) - } - fmt.Println(err) - if err == io.EOF { - if wafErr := t.c.bodyWAF(); wafErr != nil { - err = wafErr - } - } - return -} - -//go:noinline -func (c *RequestContext) onEOF() error { return nil /* dynamically instrumented */ } - -func ClientIP(remoteAddr string, headers http.Header, prioritizedIPHeader string, prioritizedIPHeaderFormat string) net.IP { - var privateIP net.IP - check := func(value string) net.IP { - for _, ip := range strings.Split(value, ",") { - ipStr := strings.Trim(ip, " ") - ipStr, _ = splitHostPort(ipStr) - ip := net.ParseIP(ipStr) - if ip == nil { - return nil - } - - if isGlobal(ip) { - return ip - } - - if privateIP == nil && !ip.IsLoopback() && isPrivate(ip) { - privateIP = ip - } - } - return nil - } - - if prioritizedIPHeader != "" { - if value := headers.Get(prioritizedIPHeader); value != "" { - if prioritizedIPHeaderFormat != "" { - parsed, err := parseClientIPHeaderHeaderValue(prioritizedIPHeaderFormat, value) - if err == nil { - // Parsing ok, keep its returned value. - value = parsed - } else { - // An error occurred while parsing the header value, so ignore it. - value = "" - } - } - - if value != "" { - if ip := check(value); ip != nil { - return ip - } - } - } - } - - for _, key := range config.IPRelatedHTTPHeaders { - value := headers.Get(key) - if ip := check(value); ip != nil { - return ip - } - } - - remoteIPStr, _ := splitHostPort(remoteAddr) - if remoteIPStr == "" { - if privateIP != nil { - return privateIP - } - return nil - } - - if remoteIP := net.ParseIP(remoteIPStr); remoteIP != nil && (privateIP == nil || isGlobal(remoteIP)) { - return remoteIP - } - return privateIP -} - -func isGlobal(ip net.IP) bool { - if ipv4 := ip.To4(); ipv4 != nil && config.IPv4PublicNetwork.Contains(ipv4) { - return false - } - return !isPrivate(ip) -} - -func isPrivate(ip net.IP) bool { - var privateNetworks []*net.IPNet - // We cannot rely on `len(IP)` to know what type of IP address this is. - // `net.ParseIP()` or `net.IPv4()` can return internal 16-byte representations - // of an IP address even if it is an IPv4. So the trick is to use `IP.To4()` - // which returns nil if the address in not an IPv4 address. - if ipv4 := ip.To4(); ipv4 != nil { - privateNetworks = config.IPv4PrivateNetworks - } else { - privateNetworks = config.IPv6PrivateNetworks - } - - for _, network := range privateNetworks { - if network.Contains(ip) { - return true - } - } - return false -} - -// SplitHostPort splits a network address of the form `host:port` or -// `[host]:port` into `host` and `port`. -func splitHostPort(addr string) (host string, port string) { - i := strings.LastIndex(addr, "]:") - if i != -1 { - // ipv6 - return strings.Trim(addr[:i+1], "[]"), addr[i+2:] - } - - i = strings.LastIndex(addr, ":") - if i == -1 { - // not an address with a port number - return addr, "" - } - return addr[:i], addr[i+1:] -} - -func parseClientIPHeaderHeaderValue(format, value string) (string, error) { - // Hard-coded HA Proxy format for now: `%ci:%cp...` so we expect the value to - // start with the client IP in hexadecimal format (eg. 7F000001) separated by - // the client port number with a semicolon `:`. - sep := strings.IndexRune(value, ':') - if sep == -1 { - return "", errors.Errorf("unexpected IP address value `%s`", value) - } - - clientIPHexStr := value[:sep] - // Optimize for the best case: there will be an IP address, so allocate size - // for at least an IPv4 address. - clientIPBuf := make([]byte, 0, net.IPv4len) - _, err := fmt.Sscanf(clientIPHexStr, "%x", &clientIPBuf) - if err != nil { - return "", errors.Wrap(err, "could not parse the IP address value") - } - - switch len(clientIPBuf) { - case net.IPv4len, net.IPv6len: - return net.IP(clientIPBuf).String(), nil - default: - return "", errors.Errorf("unexpected IP address value `%s`", clientIPBuf) - } -} diff --git a/internal/protection/http/request.go b/internal/protection/http/request.go new file mode 100644 index 00000000..89820a13 --- /dev/null +++ b/internal/protection/http/request.go @@ -0,0 +1,291 @@ +// Copyright (c) 2016 - 2020 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +//sqreen:ignore + +package http + +import ( + "bytes" + "fmt" + "io" + "net" + "net/http" + "net/textproto" + "net/url" + "strings" + + "github.com/pkg/errors" + "github.com/sqreen/go-agent/internal/config" + "github.com/sqreen/go-agent/internal/event" + "github.com/sqreen/go-agent/internal/protection/http/types" +) + +type requestReader struct { + types.RequestReader + + // clientIP is the actual IP address of the client performing the request. + clientIP net.IP + + // requestParams is the set of HTTP request parameters taken from the HTTP + // request. The map key is the source (eg. json, query, multipart-form, etc.) + // so that we can report it and make it clearer to understand where the value + // comes from. + requestParams types.RequestParamMap + + // bodyReadBuffer is the buffers body reads + bodyReadBuffer bytes.Buffer +} + +func (r *requestReader) Body() []byte { return r.bodyReadBuffer.Bytes() } + +func (r *requestReader) ClientIP() net.IP { return r.clientIP } + +func (r *requestReader) Params() types.RequestParamMap { + params := r.RequestReader.Params() + if len(params) == 0 { + return r.requestParams + } + + if len(r.requestParams) == 0 { + return params + } + + res := make(types.RequestParamMap, len(params)+len(r.requestParams)) + for n, v := range params { + res[n] = v + } + for n, v := range r.requestParams { + res[n] = v + } + return res +} + +type rawBodyWAF struct { + io.ReadCloser + c *RequestContext +} + +// Read buffers what has been read and ultimately calls the WAF on EOF. +func (t rawBodyWAF) Read(p []byte) (n int, err error) { + n, err = t.ReadCloser.Read(p) + if n > 0 { + t.c.requestReader.bodyReadBuffer.Write(p[:n]) + } + + if err == io.EOF { + if wafErr := t.c.bodyWAF(); wafErr != nil { + err = wafErr + } + } + return +} + +func ClientIP(remoteAddr string, headers http.Header, prioritizedIPHeader string, prioritizedIPHeaderFormat string) net.IP { + var privateIP net.IP + check := func(value string) net.IP { + for _, ip := range strings.Split(value, ",") { + ipStr := strings.Trim(ip, " ") + ipStr, _ = splitHostPort(ipStr) + ip := net.ParseIP(ipStr) + if ip == nil { + return nil + } + + if isGlobal(ip) { + return ip + } + + if privateIP == nil && !ip.IsLoopback() && isPrivate(ip) { + privateIP = ip + } + } + return nil + } + + if prioritizedIPHeader != "" { + if value := headers.Get(prioritizedIPHeader); value != "" { + if prioritizedIPHeaderFormat != "" { + parsed, err := parseClientIPHeaderHeaderValue(prioritizedIPHeaderFormat, value) + if err == nil { + // Parsing ok, keep its returned value. + value = parsed + } else { + // An error occurred while parsing the header value, so ignore it. + value = "" + } + } + + if value != "" { + if ip := check(value); ip != nil { + return ip + } + } + } + } + + for _, key := range config.IPRelatedHTTPHeaders { + value := headers.Get(key) + if ip := check(value); ip != nil { + return ip + } + } + + remoteIPStr, _ := splitHostPort(remoteAddr) + if remoteIPStr == "" { + if privateIP != nil { + return privateIP + } + return nil + } + + if remoteIP := net.ParseIP(remoteIPStr); remoteIP != nil && (privateIP == nil || isGlobal(remoteIP)) { + return remoteIP + } + return privateIP +} + +func isGlobal(ip net.IP) bool { + if ipv4 := ip.To4(); ipv4 != nil && config.IPv4PublicNetwork.Contains(ipv4) { + return false + } + return !isPrivate(ip) +} + +func isPrivate(ip net.IP) bool { + var privateNetworks []*net.IPNet + // We cannot rely on `len(IP)` to know what type of IP address this is. + // `net.ParseIP()` or `net.IPv4()` can return internal 16-byte representations + // of an IP address even if it is an IPv4. So the trick is to use `IP.To4()` + // which returns nil if the address in not an IPv4 address. + if ipv4 := ip.To4(); ipv4 != nil { + privateNetworks = config.IPv4PrivateNetworks + } else { + privateNetworks = config.IPv6PrivateNetworks + } + + for _, network := range privateNetworks { + if network.Contains(ip) { + return true + } + } + return false +} + +// SplitHostPort splits a network address of the form `host:port` or +// `[host]:port` into `host` and `port`. +func splitHostPort(addr string) (host string, port string) { + i := strings.LastIndex(addr, "]:") + if i != -1 { + // ipv6 + return strings.Trim(addr[:i+1], "[]"), addr[i+2:] + } + + i = strings.LastIndex(addr, ":") + if i == -1 { + // not an address with a port number + return addr, "" + } + return addr[:i], addr[i+1:] +} + +func parseClientIPHeaderHeaderValue(format, value string) (string, error) { + // Hard-coded HA Proxy format for now: `%ci:%cp...` so we expect the value to + // start with the client IP in hexadecimal format (eg. 7F000001) separated by + // the client port number with a semicolon `:`. + sep := strings.IndexRune(value, ':') + if sep == -1 { + return "", errors.Errorf("unexpected IP address value `%s`", value) + } + + clientIPHexStr := value[:sep] + // Optimize for the best case: there will be an IP address, so allocate size + // for at least an IPv4 address. + clientIPBuf := make([]byte, 0, net.IPv4len) + _, err := fmt.Sscanf(clientIPHexStr, "%x", &clientIPBuf) + if err != nil { + return "", errors.Wrap(err, "could not parse the IP address value") + } + + switch len(clientIPBuf) { + case net.IPv4len, net.IPv6len: + return net.IP(clientIPBuf).String(), nil + default: + return "", errors.Errorf("unexpected IP address value `%s`", clientIPBuf) + } +} + +type handledRequest struct { + headers http.Header + method string + url *url.URL + requestURI string + host string + remoteAddr string + isTLS bool + userAgent string + referer string + form url.Values + postForm url.Values + clientIP net.IP + params types.RequestParamMap + body []byte +} + +func (h *handledRequest) Headers() http.Header { return h.headers } +func (h *handledRequest) Method() string { return h.method } +func (h *handledRequest) URL() *url.URL { return h.url } +func (h *handledRequest) RequestURI() string { return h.requestURI } +func (h *handledRequest) Host() string { return h.host } +func (h *handledRequest) RemoteAddr() string { return h.remoteAddr } +func (h *handledRequest) IsTLS() bool { return h.isTLS } +func (h *handledRequest) UserAgent() string { return h.userAgent } +func (h *handledRequest) Referer() string { return h.referer } +func (h *handledRequest) Form() url.Values { return h.form } +func (h *handledRequest) PostForm() url.Values { return h.postForm } +func (h *handledRequest) ClientIP() net.IP { return h.clientIP } +func (h *handledRequest) Params() types.RequestParamMap { return h.params } +func (h *handledRequest) Body() []byte { return h.body } +func (h *handledRequest) Header(header string) (value *string) { + headers := h.headers + if headers == nil { + return nil + } + v := headers[textproto.CanonicalMIMEHeaderKey(header)] + if len(v) == 0 { + return nil + } + return &v[0] +} + +func copyRequest(reader types.RequestReader) types.RequestReader { + return &handledRequest{ + headers: reader.Headers(), + method: reader.Method(), + url: reader.URL(), + requestURI: reader.RequestURI(), + host: reader.Host(), + remoteAddr: reader.RemoteAddr(), + isTLS: reader.IsTLS(), + userAgent: reader.UserAgent(), + referer: reader.Referer(), + form: reader.Form(), + postForm: reader.PostForm(), + clientIP: reader.ClientIP(), + params: reader.Params(), + body: reader.Body(), + } +} + +type closedRequestContext struct { + response types.ResponseFace + request types.RequestReader + events event.Recorded +} + +var _ types.ClosedRequestContextFace = (*closedRequestContext)(nil) + +func (c *closedRequestContext) Events() event.Recorded { return c.events } +func (c *closedRequestContext) Request() types.RequestReader { return c.request } +func (c *closedRequestContext) Response() types.ResponseFace { return c.response } From 95a1d9e8ca7677e90f3b8c5342e8d30d4571b087 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Thu, 16 Jul 2020 16:36:23 +0200 Subject: [PATCH 02/11] sqlib/sqhook: allow attaching multiple callbacks per hook When multiple callbacks are attached to a hook, a reflected callback is created to call each one of them. The prolog calls are stopped as soon as an error is returned, and the epilogs returned so far are called in the same order of their prolog calls. The implementation is the most straightforward and reuses already existing features (ie. return a reflected prolog callback calling every prolog callbacks) - the downside is that it is based on `reflect.MakeFunc` which is a dynamic function stub according the function type. We couldn't find anyting simple only using type assertions. For the record, we used to have an explicit callback chaining (a callback had its next callback pointer) so that it could make benefit from the type information at the callback level (a callback knows its function type) in order to avoid using `reflect` thanks to straightforward type assertions. But the downside was a very complex implementation of callbacks which had to deal with the management of the next callback - readibility was terrible. This implementation hides the chaining details in a wrapper callback. --- .../rule/callback/add-security-headers.go | 3 +- internal/rule/instrumentation.go | 2 +- internal/rule/rule_test.go | 20 +- internal/sqlib/sqassert/assert.go | 8 +- internal/sqlib/sqassert/assert_disabled.go | 6 +- internal/sqlib/sqhook/hook.go | 96 +++- internal/sqlib/sqhook/hook_test.go | 113 ++++- internal/sqlib/sqhook/hook_unit_test.go | 409 ++++++++++++++++++ internal/sqlib/sqhook/validation_test.go | 14 +- 9 files changed, 621 insertions(+), 50 deletions(-) create mode 100644 internal/sqlib/sqhook/hook_unit_test.go diff --git a/internal/rule/callback/add-security-headers.go b/internal/rule/callback/add-security-headers.go index a9c4068e..b343eb2e 100644 --- a/internal/rule/callback/add-security-headers.go +++ b/internal/rule/callback/add-security-headers.go @@ -19,8 +19,7 @@ import ( // to be attached to compatible HTTP protection middlewares such as // `protection/http`. It adds HTTP headers provided by the rule's configuration. func NewAddSecurityHeadersCallback(rule RuleFace, cfg NativeCallbackConfig) (sqhook.PrologCallback, error) { - sqassert.NotNil(rule) - sqassert.NotNil(cfg) + sqassert.NotNil(rule, cfg) var headers http.Header data, ok := cfg.Data().([]interface{}) if !ok { diff --git a/internal/rule/instrumentation.go b/internal/rule/instrumentation.go index 8afe2367..82499d2f 100644 --- a/internal/rule/instrumentation.go +++ b/internal/rule/instrumentation.go @@ -16,7 +16,7 @@ type InstrumentationFace interface { } type HookFace interface { - Attach(prolog sqhook.PrologCallback) error + Attach(prologs ...sqhook.PrologCallback) error } type defaultInstrumentationImpl struct{} diff --git a/internal/rule/rule_test.go b/internal/rule/rule_test.go index bc86122a..5d589164 100644 --- a/internal/rule/rule_test.go +++ b/internal/rule/rule_test.go @@ -34,6 +34,8 @@ func (i *instrumentationMockup) Health(expectedVersion string) error { type hookMockup struct{ mock.Mock } +var _ rule.HookFace = &hookMockup{} + func (i *instrumentationMockup) Find(symbol string) (rule.HookFace, error) { res := i.Called(symbol) err := res.Error(1) @@ -47,12 +49,22 @@ func (i *instrumentationMockup) ExpectFind(symbol string) *mock.Call { return i.On("Find", symbol) } -func (h *hookMockup) Attach(prolog sqhook.PrologCallback) error { - return h.Called(prolog).Error(0) +func (h *hookMockup) Attach(prologs ...sqhook.PrologCallback) error { + return h.Called(prologs).Error(0) } -func (h *hookMockup) ExpectAttach(prolog interface{}) *mock.Call { - return h.On("Attach", prolog) +func (h *hookMockup) ExpectAttach(prologs ...interface{}) *mock.Call { + var args interface{} + if l := len(prologs); l == 1 && prologs[0] == mock.Anything { + args = prologs[0] + } else { + prologArgs := make([]sqhook.PrologCallback, l) + for i, p := range prologs { + prologArgs[i] = p + } + args = prologArgs + } + return h.On("Attach", args) } func (h *hookMockup) PrologFuncType() reflect.Type { diff --git a/internal/sqlib/sqassert/assert.go b/internal/sqlib/sqassert/assert.go index 62abc355..305d9d09 100644 --- a/internal/sqlib/sqassert/assert.go +++ b/internal/sqlib/sqassert/assert.go @@ -20,9 +20,11 @@ func NoError(err error) { } } -func NotNil(v interface{}) { - if v == nil { - doPanic(sqerrors.New("sqassert: unexpected nil value")) +func NotNil(v ...interface{}) { + for _, v := range v { + if v == nil { + doPanic(sqerrors.New("sqassert: unexpected nil value")) + } } } diff --git a/internal/sqlib/sqassert/assert_disabled.go b/internal/sqlib/sqassert/assert_disabled.go index f0ca0426..96fa42f3 100644 --- a/internal/sqlib/sqassert/assert_disabled.go +++ b/internal/sqlib/sqassert/assert_disabled.go @@ -6,6 +6,6 @@ package sqassert -func True(bool) {} -func NoError(error) {} -func NotNil(interface{}) {} +func True(bool) {} +func NoError(error) {} +func NotNil(...interface{}) {} diff --git a/internal/sqlib/sqhook/hook.go b/internal/sqlib/sqhook/hook.go index eb876648..ea8d7792 100644 --- a/internal/sqlib/sqhook/hook.go +++ b/internal/sqlib/sqhook/hook.go @@ -186,7 +186,7 @@ func normalizedHookID(symbol string) string { // add creates the hook object for function `fn`, adds it to the find map and // returns it. It returns an error if it is not possible. -func (t symbolIndexType) add(fn, prologVar interface{}) (*Hook, error) { +func (t symbolIndexType) add(fn, prologVar interface{}) (h *Hook, err error) { // Check fn is a non-nil function value if fn == nil { return nil, sqerrors.New("unexpected function argument value `nil`") @@ -194,18 +194,25 @@ func (t symbolIndexType) add(fn, prologVar interface{}) (*Hook, error) { fnValue := reflect.ValueOf(fn) fnType := fnValue.Type() if fnType.Kind() != reflect.Func { - return nil, sqerrors.Errorf("unexpected function argument type: expecting a function value but got `%v`", fn) + return nil, sqerrors.Errorf("unexpected function argument type: expecting a function value but got `%T`", fn) } // Get the symbol name symbol := runtime.FuncForPC(fnValue.Pointer()).Name() if symbol == "" { - return nil, sqerrors.Errorf("could not read the symbol name of function `%#v`", fn) + return nil, sqerrors.Errorf("could not read the symbol name of function `%T`", fn) } // Unvendor it so that it is not prefixed by `/vendor/` symbol = sqgo.Unvendor(symbol) + // Use the symbol name for better error messages + defer func() { + if err != nil { + err = sqerrors.Wrapf(err, "symbol `%s`", symbol) + } + }() + // The hook may have been already added by a previous lookup if hook, exists := t[symbol]; exists { return hook, nil @@ -241,29 +248,44 @@ func (h *Hook) String() string { // Attach atomically attaches a prolog function to the hook. The hook can be // disabled with a `nil` prolog value. -func (h *Hook) Attach(prolog PrologCallback) error { +func (h *Hook) Attach(prologs ...PrologCallback) error { addr := h.prologVarAddr - if prolog == nil { + if l := len(prologs); l == 0 || (l == 1 && prologs[0] == nil) { // Disable atomic.StorePointer(addr, nil) - // TODO: should we check if the attach cb has a Close() method? return nil } -loop: - for { - switch actual := prolog.(type) { - case ReflectedPrologCallback: - prolog = makePrologCallback(h, actual) - case PrologCallbackGetter: - prolog = actual.PrologCallback() - default: - break loop + prologCallbacks := make([]PrologCallback, len(prologs)) + for i, prolog := range prologs { + // Loop until the prolog type is not one of the above + loop: + for { + switch actual := prolog.(type) { + case ReflectedPrologCallback: + prolog = makePrologCallback(h, actual) + case PrologCallbackGetter: + prolog = actual.PrologCallback() + default: + // Final type + break loop + } + } + + if h.prologFuncType != reflect.TypeOf(prolog) { + return sqerrors.Errorf("unexpected prolog type for hook `%s`: got `%T`, wanted `%s`", h, prolog, h.prologFuncType) } + + prologCallbacks[i] = prolog } - if h.prologFuncType != reflect.TypeOf(prolog) { - return sqerrors.Errorf("unexpected prolog type for hook `%s`: got `%T`, wanted `%s`", h, prolog, h.prologFuncType) + // Create the prolog out of the prologCallbacks + var prolog PrologCallback + if l := len(prologCallbacks); l == 1 { + prolog = prologCallbacks[0] + } else { + // Create a dynamic function calling the prolog + prolog = makeMultiPrologCallback(h, prologCallbacks) } // Create a value having type "pointer to the prolog function" @@ -275,6 +297,46 @@ loop: return nil } +func makeMultiPrologCallback(h *Hook, prologs []PrologCallback) PrologCallback { + return makePrologCallback(h, func(params []reflect.Value) (epilog ReflectedEpilogCallback, err error) { + safeCallErr := sqsafe.Call(func() error { + epilogs := make([]reflect.Value, 0, len(prologs)) + for _, prolog := range prologs { + prologValue := reflect.ValueOf(prolog) + results := prologValue.Call(params) + if r0 := results[0]; !r0.IsNil() { + epilogs = append(epilogs, r0) + } + if r1 := results[1]; !r1.IsNil() { + if len(epilogs) > 0 { + epilog = func(results []reflect.Value) { + for _, epilog := range epilogs { + epilog.Call(results) + } + } + } + err = r1.Interface().(error) + return nil + } + } + + if len(epilogs) > 0 { + epilog = func(results []reflect.Value) { + for _, epilog := range epilogs { + epilog.Call(results) + } + } + } + + return nil + }) + if safeCallErr != nil { + // TODO: log this error once + } + return epilog, err + }) +} + func makePrologCallback(h *Hook, prolog ReflectedPrologCallback) PrologCallback { prologFuncType := h.prologFuncType epilogFuncType := h.prologFuncType.Out(0) diff --git a/internal/sqlib/sqhook/hook_test.go b/internal/sqlib/sqhook/hook_test.go index cf7f21fc..18f1ec9b 100644 --- a/internal/sqlib/sqhook/hook_test.go +++ b/internal/sqlib/sqhook/hook_test.go @@ -36,14 +36,16 @@ func (example) myMethod() {} func (example) MyExportedMethod() {} func (*example) myMethodWithPointerReceiver() {} -func myFunction(_ int, _ string, _ bool) (float32, error) { return 0, nil } -func MyExportedFunction(_ int, _ string, _ bool) error { return nil } +func myFunction(_ int, _ string, _ bool) (float32, error) { return 0, nil } +func myFunction2(_ int, _ string, _ bool) (float32, error) { return 0, nil } +func MyExportedFunction(_ int, _ string, _ bool) error { return nil } var ( MyMethodSymbol = runtime.FuncForPC(reflect.ValueOf(example.myMethod).Pointer()).Name() MyMethodWithPointerRecvSymbol = runtime.FuncForPC(reflect.ValueOf((*example).myMethodWithPointerReceiver).Pointer()).Name() MyExportedMethodSymbol = runtime.FuncForPC(reflect.ValueOf(example.MyExportedMethod).Pointer()).Name() MyFunctionSymbol = runtime.FuncForPC(reflect.ValueOf(myFunction).Pointer()).Name() + MyFunction2Symbol = runtime.FuncForPC(reflect.ValueOf(myFunction2).Pointer()).Name() MyExportedFunctionSymbol = runtime.FuncForPC(reflect.ValueOf(MyExportedFunction).Pointer()).Name() ) @@ -53,6 +55,7 @@ var sortedSymbols = []string{ // Sorted by normalized name MyMethodSymbol, MyMethodWithPointerRecvSymbol, MyFunctionSymbol, + MyFunction2Symbol, } var expectedSymbols = map[string]internal.HookDescriptorFuncType{ @@ -84,6 +87,13 @@ var expectedSymbols = map[string]internal.HookDescriptorFuncType{ } }, + MyFunction2Symbol: func(d *internal.HookDescriptorType) { + *d = internal.HookDescriptorType{ + Func: myFunction2, + PrologVar: &MyFunctionProlog, + } + }, + MyExportedFunctionSymbol: func(d *internal.HookDescriptorType) { *d = internal.HookDescriptorType{ Func: MyExportedFunction, @@ -135,7 +145,7 @@ func TestGoAssumptions(t *testing.T) { require.Equal(t, (*sqhook.PrologCallback)(nil), cb) }) - t.Run("the first argument of a myMethod is the method receiver", func(t *testing.T) { + t.Run("the first argument of a method is the method receiver", func(t *testing.T) { require.Equal(t, reflect.TypeOf(example{}).Name(), reflect.TypeOf(example.myMethod).In(0).Name()) }) @@ -168,6 +178,14 @@ func TestFind(t *testing.T) { } } +type prologCallbackGetter struct { + prolog sqhook.PrologCallback +} + +func (p prologCallbackGetter) PrologCallback() sqhook.PrologCallback { + return p.prolog +} + func TestAttach(t *testing.T) { for _, tc := range []struct { Symbol string @@ -255,10 +273,16 @@ func TestAttach(t *testing.T) { return []reflect.Value{{}, {}} // not used by the test }) - checkProlog := func(t *testing.T) { + checkPrologAddr := func(t *testing.T, expected uintptr) { + // Read barrier using the prolog var + _ = atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(reflect.ValueOf(descr.PrologVar).Pointer()))) + require.Equal(t, expected, reflect.ValueOf(descr.PrologVar).Elem().Elem().Pointer()) + } + + checkPrologAddrNotNil := func(t *testing.T) { // Read barrier using the prolog var _ = atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(reflect.ValueOf(descr.PrologVar).Pointer()))) - require.Equal(t, expectedProlog.Pointer(), reflect.ValueOf(descr.PrologVar).Elem().Elem().Pointer()) + require.NotZero(t, reflect.ValueOf(descr.PrologVar).Elem().Elem().Pointer()) } t.Run(tc.Symbol, func(t *testing.T) { @@ -269,21 +293,84 @@ func TestAttach(t *testing.T) { hook, err := sqhook.Find(tc.Symbol) require.NoError(t, err) require.NotNil(t, hook) - // Attach the expected prolog function - err = hook.Attach(expectedProlog.Interface()) - require.NoError(t, err) - // Read back the prolog variable - checkProlog(t) + + t.Run("native prolog callback", func(t *testing.T) { + // Attach the expected prolog function + err = hook.Attach(expectedProlog.Interface()) + require.NoError(t, err) + // Read back the prolog variable + checkPrologAddr(t, expectedProlog.Pointer()) + }) + + t.Run("reflected prolog callback", func(t *testing.T) { + var reflected sqhook.ReflectedPrologCallback = func(params []reflect.Value) (epilog sqhook.ReflectedEpilogCallback, err error) { + return nil, nil + } + // Attach the expected prolog function + err = hook.Attach(reflected) + require.NoError(t, err) + // Read back the prolog variable + checkPrologAddrNotNil(t) + }) + + t.Run("prolog callback getter", func(t *testing.T) { + t.Run("returning a native prolog", func(t *testing.T) { + // Attach the expected prolog function + err = hook.Attach(prologCallbackGetter{prolog: expectedProlog.Interface()}) + require.NoError(t, err) + // Read back the prolog variable + checkPrologAddrNotNil(t) + }) + + t.Run("returning a reflected prolog", func(t *testing.T) { + var reflected sqhook.ReflectedPrologCallback = func(params []reflect.Value) (epilog sqhook.ReflectedEpilogCallback, err error) { + return nil, nil + } + // Attach the expected prolog function + err = hook.Attach(prologCallbackGetter{ + prolog: reflected, + }) + require.NoError(t, err) + // Read back the prolog variable + checkPrologAddrNotNil(t) + }) + }) + + t.Run("multiple prolog callbacks", func(t *testing.T) { + var reflected sqhook.ReflectedPrologCallback = func(params []reflect.Value) (epilog sqhook.ReflectedEpilogCallback, err error) { + return nil, nil + } + // Attach the expected prolog function + native := expectedProlog.Interface() + err = hook.Attach(reflected, native, reflected, native, reflected, native, prologCallbackGetter{prolog: reflected}, prologCallbackGetter{prolog: expectedProlog.Interface()}) + require.NoError(t, err) + // Read back the prolog variable + checkPrologAddrNotNil(t) + }) }) + t.Run("not expected prolog types", func(t *testing.T) { + hook, err := sqhook.Find(tc.Symbol) + require.NoError(t, err) + require.NotNil(t, hook) + require.NoError(t, hook.Attach(nil)) + for _, invalidProlog := range tc.InvalidPrologs { invalidProlog := invalidProlog t.Run(fmt.Sprintf("%T", invalidProlog), func(t *testing.T) { - hook, err := sqhook.Find(tc.Symbol) - require.NoError(t, err) - require.NotNil(t, hook) err = hook.Attach(invalidProlog) require.Error(t, err) + //checkPrologAddr(t, 0) + }) + + t.Run(fmt.Sprintf("%T along with the expected prolog callback", invalidProlog), func(t *testing.T) { + err = hook.Attach(expectedProlog.Interface(), invalidProlog) + require.Error(t, err) + //checkPrologAddr(t, 0) + + err = hook.Attach(invalidProlog, expectedProlog.Interface()) + require.Error(t, err) + //checkPrologAddr(t, 0) }) } }) diff --git a/internal/sqlib/sqhook/hook_unit_test.go b/internal/sqlib/sqhook/hook_unit_test.go new file mode 100644 index 00000000..b46eae1a --- /dev/null +++ b/internal/sqlib/sqhook/hook_unit_test.go @@ -0,0 +1,409 @@ +// Copyright (c) 2016 - 2020 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +package sqhook + +import ( + "errors" + "reflect" + "testing" + + fuzz "github.com/google/gofuzz" + "github.com/stretchr/testify/require" +) + +func myFunction(int, string, bool) (float32, error) { return 0, nil } + +func TestInstrumentationError(t *testing.T) { + // Test that we would catch instrumentation mistakes - which should never + // happen + var myFunctionPrologVar *func(*int, *string, *bool) (func(*float32, *error), error) + + for _, tc := range []struct { + Name string + Fn interface{} + PrologVar interface{} + }{ + { + Name: "nil function", + Fn: nil, + PrologVar: &myFunctionPrologVar, + }, + + { + Name: "not a function", + Fn: 33, + PrologVar: &myFunctionPrologVar, + }, + + { + Name: "nil prolog var", + Fn: myFunction, + PrologVar: nil, + }, + + { + Name: "prolog var is not a pointer", + Fn: myFunction, + PrologVar: myFunctionPrologVar, + }, + } { + t.Run(tc.Name, func(t *testing.T) { + h, err := symbolIndexType{}.add(tc.Fn, tc.PrologVar) + require.Error(t, err) + require.Nil(t, h) + }) + } +} + +func TestReflectedCallback(t *testing.T) { + t.Run("", func(t *testing.T) { + type epilogType = func() + type prologType = func() (epilogType, error) + + prolog, ok := makePrologCallback(&Hook{ + prologFuncType: reflect.TypeOf(prologType(nil)), + }, func(params []reflect.Value) (epilog ReflectedEpilogCallback, err error) { + require.Len(t, params, 0) + return nil, nil + }).(prologType) + + require.True(t, ok) + require.NotNil(t, prolog) + + epilog, err := prolog() + require.NoError(t, err) + require.Nil(t, epilog) + }) + + t.Run("", func(t *testing.T) { + type epilogType = func() + type prologType = func(int, bool, string, float64, map[string]bool) (epilogType, error) + + var prologArgs struct { + A int + B bool + C string + D float64 + E map[string]bool + } + + prolog, ok := makePrologCallback(&Hook{ + prologFuncType: reflect.TypeOf(prologType(nil)), + }, func(params []reflect.Value) (epilog ReflectedEpilogCallback, err error) { + argValues := reflect.ValueOf(prologArgs) + argTypes := argValues.Type() + require.Len(t, params, argTypes.NumField()) + for i := range params { + require.Equal(t, argTypes.Field(i).Type, params[i].Type()) + require.Equal(t, argValues.Field(i).Interface(), params[i].Interface()) + } + return nil, nil + }).(prologType) + + require.True(t, ok) + require.NotNil(t, prolog) + + fuzz.New().Fuzz(&prologArgs) + epilog, err := prolog(prologArgs.A, prologArgs.B, prologArgs.C, prologArgs.D, prologArgs.E) + require.NoError(t, err) + require.Nil(t, epilog) + }) + + t.Run("", func(t *testing.T) { + type epilogType = func() + type prologType = func(int, bool, string, float64, map[string]bool) (epilogType, error) + + var prologArgs struct { + A int + B bool + C string + D float64 + E map[string]bool + } + + prologErr := errors.New("my error") + + prolog, ok := makePrologCallback(&Hook{ + prologFuncType: reflect.TypeOf(prologType(nil)), + }, func(params []reflect.Value) (epilog ReflectedEpilogCallback, err error) { + argValues := reflect.ValueOf(prologArgs) + argTypes := argValues.Type() + require.Len(t, params, argTypes.NumField()) + for i := range params { + require.Equal(t, argTypes.Field(i).Type, params[i].Type()) + require.Equal(t, argValues.Field(i).Interface(), params[i].Interface()) + } + return nil, prologErr + }).(prologType) + + require.True(t, ok) + require.NotNil(t, prolog) + + fuzz.New().Fuzz(&prologArgs) + epilog, err := prolog(prologArgs.A, prologArgs.B, prologArgs.C, prologArgs.D, prologArgs.E) + require.Error(t, err) + require.Equal(t, prologErr, err) + require.Nil(t, epilog) + }) + + t.Run("", func(t *testing.T) { + type epilogType = func(int, bool, string, float64, map[string]bool) + type prologType = func(int, bool, string, float64, map[string]bool) (epilogType, error) + + var prologArgs, epilogArgs struct { + A int + B bool + C string + D float64 + E map[string]bool + } + + prologErr := errors.New("my error") + + prolog, ok := makePrologCallback(&Hook{ + prologFuncType: reflect.TypeOf(prologType(nil)), + }, func(params []reflect.Value) (epilog ReflectedEpilogCallback, err error) { + argValues := reflect.ValueOf(prologArgs) + argTypes := argValues.Type() + require.Len(t, params, argTypes.NumField()) + for i := range params { + require.Equal(t, argTypes.Field(i).Type, params[i].Type()) + require.Equal(t, argValues.Field(i).Interface(), params[i].Interface()) + } + return func(results []reflect.Value) { + argValues := reflect.ValueOf(epilogArgs) + argTypes := argValues.Type() + require.Len(t, results, argTypes.NumField()) + for i := range results { + require.Equal(t, argTypes.Field(i).Type, results[i].Type()) + require.Equal(t, argValues.Field(i).Interface(), results[i].Interface()) + } + }, prologErr + }).(prologType) + + require.True(t, ok) + require.NotNil(t, prolog) + + f := fuzz.New() + + f.Fuzz(&prologArgs) + epilog, err := prolog(prologArgs.A, prologArgs.B, prologArgs.C, prologArgs.D, prologArgs.E) + require.Error(t, err) + require.Equal(t, prologErr, err) + + require.NotNil(t, epilog) + f.Fuzz(&epilogArgs) + epilog(epilogArgs.A, epilogArgs.B, epilogArgs.C, epilogArgs.D, epilogArgs.E) + }) + + t.Run("", func(t *testing.T) { + type epilogType = func(int, bool, string, float64, map[string]bool) + type prologType = func(int, bool, string, float64, map[string]bool) (epilogType, error) + + var prologArgs, epilogArgs struct { + A int + B bool + C string + D float64 + E map[string]bool + } + + prolog, ok := makePrologCallback(&Hook{ + prologFuncType: reflect.TypeOf(prologType(nil)), + }, func(params []reflect.Value) (epilog ReflectedEpilogCallback, err error) { + argValues := reflect.ValueOf(prologArgs) + argTypes := argValues.Type() + require.Len(t, params, argTypes.NumField()) + for i := range params { + require.Equal(t, argTypes.Field(i).Type, params[i].Type()) + require.Equal(t, argValues.Field(i).Interface(), params[i].Interface()) + } + return func(results []reflect.Value) { + argValues := reflect.ValueOf(epilogArgs) + argTypes := argValues.Type() + require.Len(t, results, argTypes.NumField()) + for i := range results { + require.Equal(t, argTypes.Field(i).Type, results[i].Type()) + require.Equal(t, argValues.Field(i).Interface(), results[i].Interface()) + } + }, nil + }).(prologType) + + require.True(t, ok) + require.NotNil(t, prolog) + + f := fuzz.New() + + f.Fuzz(&prologArgs) + epilog, err := prolog(prologArgs.A, prologArgs.B, prologArgs.C, prologArgs.D, prologArgs.E) + require.NoError(t, err) + + require.NotNil(t, epilog) + f.Fuzz(&epilogArgs) + epilog(epilogArgs.A, epilogArgs.B, epilogArgs.C, epilogArgs.D, epilogArgs.E) + }) +} + +func TestMultiCallback(t *testing.T) { + type epilogType = func(byte, rune, []string) + type prologType = func(int, bool, string, float64, map[string]bool) (epilogType, error) + + var ( + prologArgs struct { + A int + B bool + C string + D float64 + E map[string]bool + } + epilogArgs struct { + A byte + B rune + C []string + } + hook = &Hook{prologFuncType: reflect.TypeOf(prologType(nil))} + f = fuzz.New() + order []int + ) + + makePrologFunc := func(t *testing.T, expectedOrder int, epilog epilogType, prologErr error) prologType { + return func(a int, b bool, c string, d float64, e map[string]bool) (epilogType, error) { + require.Equal(t, prologArgs.A, a) + require.Equal(t, prologArgs.B, b) + require.Equal(t, prologArgs.C, c) + require.Equal(t, prologArgs.D, d) + require.Equal(t, prologArgs.E, e) + order = append(order, expectedOrder) + return epilog, prologErr + } + } + + makeEpilogFunc := func(t *testing.T, expectedOrder int) epilogType { + return func(a byte, b rune, c []string) { + require.Equal(t, epilogArgs.A, a) + require.Equal(t, epilogArgs.B, b) + require.Equal(t, epilogArgs.C, c) + order = append(order, expectedOrder) + } + } + + for _, tc := range []struct { + Prologs []PrologCallback + ExpectedOrder []int + ExpectedError error + }{ + { + Prologs: []PrologCallback{ + makePrologFunc(t, 0, makeEpilogFunc(t, 11), nil), + makePrologFunc(t, 1, makeEpilogFunc(t, 12), nil), + makePrologFunc(t, 2, makeEpilogFunc(t, 13), nil), + makePrologFunc(t, 3, makeEpilogFunc(t, 14), nil), + makePrologFunc(t, 4, makeEpilogFunc(t, 15), nil), + makePrologFunc(t, 5, makeEpilogFunc(t, 16), nil), + makePrologFunc(t, 6, makeEpilogFunc(t, 17), nil), + makePrologFunc(t, 7, makeEpilogFunc(t, 18), nil), + makePrologFunc(t, 8, makeEpilogFunc(t, 19), nil), + makePrologFunc(t, 9, makeEpilogFunc(t, 20), nil), + makePrologFunc(t, 10, makeEpilogFunc(t, 21), nil), + }, + ExpectedOrder: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}, + }, + + { + Prologs: []PrologCallback{ + makePrologFunc(t, 0, makeEpilogFunc(t, 11), nil), + makePrologFunc(t, 1, makeEpilogFunc(t, 12), nil), + makePrologFunc(t, 2, makeEpilogFunc(t, 13), nil), + makePrologFunc(t, 3, nil, nil), + makePrologFunc(t, 4, makeEpilogFunc(t, 15), nil), + makePrologFunc(t, 5, makeEpilogFunc(t, 16), nil), + makePrologFunc(t, 6, nil, nil), + makePrologFunc(t, 7, makeEpilogFunc(t, 18), nil), + makePrologFunc(t, 8, nil, nil), + makePrologFunc(t, 9, makeEpilogFunc(t, 20), nil), + makePrologFunc(t, 10, makeEpilogFunc(t, 21), nil), + }, + ExpectedOrder: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 18, 20, 21}, + }, + + { + Prologs: []PrologCallback{ + makePrologFunc(t, 0, makeEpilogFunc(t, 11), nil), + makePrologFunc(t, 1, makeEpilogFunc(t, 12), nil), + makePrologFunc(t, 2, makeEpilogFunc(t, 13), nil), + makePrologFunc(t, 3, nil, nil), + makePrologFunc(t, 4, makeEpilogFunc(t, 15), nil), + makePrologFunc(t, 5, makeEpilogFunc(t, 16), errors.New("my error")), + makePrologFunc(t, 6, nil, nil), + makePrologFunc(t, 7, makeEpilogFunc(t, 18), nil), + makePrologFunc(t, 8, nil, nil), + makePrologFunc(t, 9, makeEpilogFunc(t, 20), nil), + makePrologFunc(t, 10, makeEpilogFunc(t, 21), nil), + }, + ExpectedOrder: []int{0, 1, 2, 3, 4, 5, 11, 12, 13, 15, 16}, + ExpectedError: errors.New("my error"), + }, + + { + Prologs: []PrologCallback{ + makePrologFunc(t, 0, makeEpilogFunc(t, 11), errors.New("my error 1")), + makePrologFunc(t, 1, makeEpilogFunc(t, 12), nil), + makePrologFunc(t, 2, makeEpilogFunc(t, 13), nil), + makePrologFunc(t, 3, nil, nil), + makePrologFunc(t, 4, makeEpilogFunc(t, 15), nil), + makePrologFunc(t, 5, makeEpilogFunc(t, 16), errors.New("my error 2")), + makePrologFunc(t, 6, nil, nil), + makePrologFunc(t, 7, makeEpilogFunc(t, 18), nil), + makePrologFunc(t, 8, nil, nil), + makePrologFunc(t, 9, makeEpilogFunc(t, 20), nil), + makePrologFunc(t, 10, makeEpilogFunc(t, 21), nil), + }, + ExpectedOrder: []int{0, 11}, + ExpectedError: errors.New("my error 1"), + }, + + { + Prologs: []PrologCallback{ + makePrologFunc(t, 0, makeEpilogFunc(t, 11), nil), + makePrologFunc(t, 1, makeEpilogFunc(t, 12), nil), + makePrologFunc(t, 2, makeEpilogFunc(t, 13), nil), + makePrologFunc(t, 3, nil, nil), + makePrologFunc(t, 4, makeEpilogFunc(t, 15), nil), + makePrologFunc(t, 5, makeEpilogFunc(t, 16), nil), + makePrologFunc(t, 6, nil, nil), + makePrologFunc(t, 7, makeEpilogFunc(t, 18), nil), + makePrologFunc(t, 8, nil, nil), + makePrologFunc(t, 9, makeEpilogFunc(t, 20), nil), + makePrologFunc(t, 10, makeEpilogFunc(t, 21), errors.New("my error")), + }, + ExpectedOrder: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 18, 20, 21}, + ExpectedError: errors.New("my error"), + }, + } { + tc := tc + t.Run("", func(t *testing.T) { + prolog, ok := makeMultiPrologCallback(hook, tc.Prologs).(prologType) + require.True(t, ok) + require.NotNil(t, prolog) + + f.Fuzz(&prologArgs) + epilog, err := prolog(prologArgs.A, prologArgs.B, prologArgs.C, prologArgs.D, prologArgs.E) + + if tc.ExpectedError != nil { + require.Error(t, err) + require.Equal(t, tc.ExpectedError, err) + } else { + require.NoError(t, err) + } + + require.NotNil(t, epilog) + + epilog(epilogArgs.A, epilogArgs.B, epilogArgs.C) + + require.Equal(t, tc.ExpectedOrder, order) + order = nil // TODO: avoid this test side-effect... + }) + } +} diff --git a/internal/sqlib/sqhook/validation_test.go b/internal/sqlib/sqhook/validation_test.go index 33a319c2..1f493362 100644 --- a/internal/sqlib/sqhook/validation_test.go +++ b/internal/sqlib/sqhook/validation_test.go @@ -105,33 +105,33 @@ func TestPrologVarValidation(t *testing.T) { }, { // wrong prolog type: wrong arg count - fn: (func(int) (chan struct{}, error))(nil), + fn: (func(int) (chan struct{}, error))(nil), prolog: (func(*int) (func(*chan struct{}, *error, *int), error))(nil), }, { // wrong prolog type: wrong arg count - fn: (func(int) (chan struct{}, error))(nil), + fn: (func(int) (chan struct{}, error))(nil), prolog: (func(*int) (func(*chan struct{}), error))(nil), }, { // wrong prolog type: wrong arg count - fn: (func(int) (chan struct{}, error))(nil), + fn: (func(int) (chan struct{}, error))(nil), prolog: (func(*int) (func(), error))(nil), }, { // wrong prolog type: wrong arg types - fn: (func(int) (chan struct{}, error))(nil), + fn: (func(int) (chan struct{}, error))(nil), prolog: (func(*int) (func(interface{}, interface{}, interface{}), error))(nil), }, { // wrong prolog type: wrong arg types - fn: (func(int) (chan struct{}, error))(nil), + fn: (func(int) (chan struct{}, error))(nil), prolog: (func(*int) (func(chan struct{}, *error), error))(nil), }, { - fn: (func(int) (chan struct{}, error))(nil), - prolog: (func(*int) (func(*chan struct{}, *error), error))(nil), + fn: (func(int) (chan struct{}, error))(nil), + prolog: (func(*int) (func(*chan struct{}, *error), error))(nil), shouldSucceed: true, }, From 7ab7f1b3ef57c394b75570e3481111a530b677c0 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Fri, 17 Jul 2020 22:04:33 +0200 Subject: [PATCH 03/11] rule: add support for multiple rules per hookpoint Multiple rules per hookpoint are attached and executed in ascending priority order. --- internal/backend/api/api.go | 1 + internal/rule/rule.go | 100 ++++++++++----- internal/rule/rule_unit_test.go | 76 +++++++++++ internal/sqlib/sqerrors/errors.go | 24 ++++ internal/sqlib/sqerrors/errors_test.go | 9 ++ internal/sqlib/sqhook/hook.go | 25 ++-- internal/sqlib/sqhook/hook_unit_test.go | 144 +++++++++++++++++++++ internal/sqlib/sqhook/validation_test.go | 156 ----------------------- 8 files changed, 334 insertions(+), 201 deletions(-) create mode 100644 internal/rule/rule_unit_test.go delete mode 100644 internal/sqlib/sqhook/validation_test.go diff --git a/internal/backend/api/api.go b/internal/backend/api/api.go index 3c53f684..2ffe18dd 100644 --- a/internal/backend/api/api.go +++ b/internal/backend/api/api.go @@ -182,6 +182,7 @@ type Rule struct { Test bool `json:"test"` Block bool `json:"block"` AttackType string `json:"attack_type"` + Priority int `json:"priority"` } type RuleConditions struct{} diff --git a/internal/rule/rule.go b/internal/rule/rule.go index 66cf88b1..efb9c043 100644 --- a/internal/rule/rule.go +++ b/internal/rule/rule.go @@ -21,6 +21,7 @@ package rule import ( "crypto/ecdsa" "io" + "sort" "github.com/sqreen/go-agent/internal/backend/api" "github.com/sqreen/go-agent/internal/metrics" @@ -35,7 +36,7 @@ type Engine struct { // at run time by atomically replacing a running rule. // TODO: write a test to check two HookFaces are correctly comparable // to find back a hook - hooks hookDescriptors + hooks hookDescriptorMap packID string enabled bool metricsEngine *metrics.Engine @@ -79,7 +80,7 @@ func (e *Engine) PackID() string { // them by atomically modifying the hooks, and removing what is left. func (e *Engine) SetRules(packID string, rules []api.Rule) { // Create the new rule descriptors and replace the existing ones - var ruleDescriptors hookDescriptors + var ruleDescriptors hookDescriptorMap if len(rules) > 0 { e.logger.Debugf("security rules: loading rules from pack `%s`", packID) ruleDescriptors = newHookDescriptors(e, rules) @@ -87,7 +88,7 @@ func (e *Engine) SetRules(packID string, rules []api.Rule) { e.setRules(packID, ruleDescriptors) } -func (e *Engine) setRules(packID string, descriptors hookDescriptors) { +func (e *Engine) setRules(packID string, descriptors hookDescriptorMap) { // Firstly update already enabled hookpoints with their new callbacks in order // to avoid having a blank moment without any callback set. This case happens // when a rule is updated. @@ -96,7 +97,7 @@ func (e *Engine) setRules(packID string, descriptors hookDescriptors) { if e.enabled { // Attach the callback to the hook, possibly overwriting the previous one. e.logger.Debugf("security rules: attaching callback to `%s`", hook) - err := hook.Attach(descr.callback) + err := hook.Attach(descr.callbacks...) if err != nil { e.logger.Error(sqerrors.Wrapf(err, "security rules: could not attach the prolog callback to `%s`", hook)) continue @@ -135,11 +136,11 @@ func (e *Engine) setRules(packID string, descriptors hookDescriptors) { // newHookDescriptors walks the list of received rules and creates the map of // hook descriptors indexed by their hook pointer. A hook descriptor contains // all it takes to enable and disable rules at run time. -func newHookDescriptors(e *Engine, rules []api.Rule) hookDescriptors { +func newHookDescriptors(e *Engine, rules []api.Rule) hookDescriptorMap { logger := e.logger // Create and configure the list of callbacks according to the given rules - var hookDescriptors = make(hookDescriptors) + var hookDescriptors = make(hookDescriptorMap) for i := len(rules) - 1; i >= 0; i-- { r := rules[i] // Verify the signature @@ -168,6 +169,8 @@ func newHookDescriptors(e *Engine, rules []api.Rule) hookDescriptors { continue } + // Create the prolog callback + var prolog sqhook.PrologCallback switch hookpoint.Strategy { case "", "native": cfg, err := newNativeCallbackConfig(&r) @@ -176,26 +179,23 @@ func newHookDescriptors(e *Engine, rules []api.Rule) hookDescriptors { continue } - prolog, err := NewNativeCallback(hookpoint.Callback, callbackContext, cfg) + prolog, err = NewNativeCallback(hookpoint.Callback, callbackContext, cfg) if err != nil { logger.Error(sqerrors.Wrapf(err, "security rules: rule `%s`: callback constructor", r.Name)) continue } - // Create the descriptor with everything required to be able to enable or - // disable it afterwards. - hookDescriptors.Set(hook, prolog) case "reflected": - prolog, err := NewReflectedCallback(hookpoint.Callback, callbackContext, &r) + prolog, err = NewReflectedCallback(hookpoint.Callback, callbackContext, &r) if err != nil { logger.Error(sqerrors.Wrapf(err, "security rules: rule `%s`: callback constructor", r.Name)) continue } - // Create the descriptor with everything required to be able to enable or - // disable it afterwards. - hookDescriptors.Set(hook, prolog) } + // Create the descriptor with everything required to be able to enable or + // disable it afterwards. + hookDescriptors.Add(hook, prolog, r.Priority) } // Nothing in the end if len(hookDescriptors) == 0 { @@ -207,9 +207,8 @@ func newHookDescriptors(e *Engine, rules []api.Rule) hookDescriptors { // Enable the hooks of the ongoing configured rules. func (e *Engine) Enable() { for hook, descr := range e.hooks { - prolog := descr.callback e.logger.Debugf("security rules: attaching callback to hook `%s`", hook) - if err := hook.Attach(prolog); err != nil { + if err := hook.Attach(descr.callbacks...); err != nil { e.logger.Error(sqerrors.Wrapf(err, "security rules: could not attach the callback to hook `%v`", hook)) } } @@ -235,23 +234,66 @@ func (e *Engine) Count() int { return len(e.hooks) } -type callbackWrapper struct { - callback sqhook.PrologCallback -} +type ( + hookDescriptorMap map[HookFace]hookDescriptor -func (c callbackWrapper) Close() error { - if closer, ok := c.callback.(io.Closer); ok { - return closer.Close() + hookDescriptor struct { + priorities []int + callbacks []sqhook.PrologCallback + closers []io.Closer + } +) + +func (m hookDescriptorMap) Add(hook HookFace, callback sqhook.PrologCallback, priority int) { + d, exists := m[hook] + closer, _ := callback.(io.Closer) + + if !exists { + // First insertion + var closers []io.Closer + if closer != nil { + closers = []io.Closer{closer} + } + m[hook] = hookDescriptor{ + priorities: []int{priority}, + callbacks: []sqhook.PrologCallback{callback}, + closers: closers, + } + return } - return nil -} -type hookDescriptors map[HookFace]callbackWrapper + // Not the first insertion. + // Look for the callback position i per ascending priority order + i := sort.Search(len(d.priorities), func(i int) bool { + return d.priorities[i] > priority + }) -func (m hookDescriptors) Set(hook HookFace, prolog sqhook.PrologCallback) { - m[hook] = callbackWrapper{prolog} + // Update the list of priorities + d.priorities = append(d.priorities, 0) + copy(d.priorities[i+1:], d.priorities[i:]) + d.priorities[i] = priority + + // Update the list of closers + if closer != nil { + d.closers = append(d.closers, closer) + } + + // Update the list of callbacks + d.callbacks = append(d.callbacks, nil) + copy(d.callbacks[i+1:], d.callbacks[i:]) + d.callbacks[i] = callback + + // Update the hook descriptor map entry with the new value + m[hook] = d } -func (m hookDescriptors) Get(hook HookFace) callbackWrapper { - return m[hook] +func (d hookDescriptor) Close() error { + var errs sqerrors.ErrorCollection + for _, c := range d.closers { + err := c.Close() + if err != nil { + errs.Add(err) + } + } + return errs.ToError() } diff --git a/internal/rule/rule_unit_test.go b/internal/rule/rule_unit_test.go new file mode 100644 index 00000000..a74f1f2f --- /dev/null +++ b/internal/rule/rule_unit_test.go @@ -0,0 +1,76 @@ +// Copyright (c) 2016 - 2020 Sqreen. All Rights Reserved. +// Please refer to our terms for more information: +// https://www.sqreen.io/terms.html + +package rule + +import ( + "io" + "testing" + + "github.com/sqreen/go-agent/internal/sqlib/sqhook" + "github.com/stretchr/testify/require" +) + +type hookMockup struct{} + +func (h hookMockup) Attach(...sqhook.PrologCallback) error { + panic("should not be called") + // TODO: better API to avoid that? the map only needs a "comparable" key and + // doesn't matter about the hook interface. +} + +func TestHookDescriptors(t *testing.T) { + // Not actual callbacks but enough for this unit test. + // We need to use distinct types to correctly check the ordering. + + t.Run("multiple callbacks having the same priority", func(t *testing.T) { + var m = hookDescriptorMap{} + key := hookMockup{} + m.Add(key, 1, 1) + m.Add(key, 2, 1) + m.Add(key, 3, 1) + m.Add(key, 4, 1) + d := m[key] + require.Equal(t, []int{1, 1, 1, 1}, d.priorities) + require.Equal(t, []sqhook.PrologCallback{1, 2, 3, 4}, d.callbacks) + require.Nil(t, d.closers) + }) + + t.Run("multiple callbacks having distinct priorities", func(t *testing.T) { + var m = hookDescriptorMap{} + key := hookMockup{} + + m.Add(key, 3, 2) + m.Add(key, 5, 3) + m.Add(key, 4, 2) + m.Add(key, 1, 1) + m.Add(key, 6, 3) + m.Add(key, 2, 1) + d := m[key] + require.Equal(t, []int{1, 1, 2, 2, 3, 3}, d.priorities) + require.Equal(t, []sqhook.PrologCallback{1, 2, 3, 4, 5, 6}, d.callbacks) + require.Nil(t, d.closers) + }) + + t.Run("multiple callbacks with close methods", func(t *testing.T) { + var m = hookDescriptorMap{} + key := hookMockup{} + m.Add(key, myFakeCallback(7), 10) + m.Add(key, 3, 2) + m.Add(key, myFakeCallback(1), 1) + m.Add(key, 2, 1) + m.Add(key, myFakeCallback(5), 3) + m.Add(key, 4, 2) + m.Add(key, 6, 3) + + d := m[key] + require.Equal(t, []int{1, 1, 2, 2, 3, 3, 10}, d.priorities) + require.Equal(t, []sqhook.PrologCallback{myFakeCallback(1), 2, 3, 4, myFakeCallback(5), 6, myFakeCallback(7)}, d.callbacks) + require.Equal(t, []io.Closer{myFakeCallback(7), myFakeCallback(1), myFakeCallback(5)}, d.closers) + }) +} + +type myFakeCallback int + +func (m myFakeCallback) Close() error { return nil } diff --git a/internal/sqlib/sqerrors/errors.go b/internal/sqlib/sqerrors/errors.go index 68d2e5c1..a1d06dad 100644 --- a/internal/sqlib/sqerrors/errors.go +++ b/internal/sqlib/sqerrors/errors.go @@ -6,6 +6,7 @@ package sqerrors import ( "fmt" + "strings" "time" "github.com/pkg/errors" @@ -171,3 +172,26 @@ func Timestamp(err error) (t time.Time, ok bool) { } return time.Time{}, false } + +type ErrorCollection []error + +func (c ErrorCollection) Error() string { + var s strings.Builder + s.WriteString("multiple errors occurred:") + for i, e := range c { + fmt.Fprintf(&s, " (error %d) %s;", i+1, e.Error()) + } + // Return the build string without the trailing `;` + return s.String()[:s.Len()-1] +} + +func (c *ErrorCollection) Add(e error) { + *c = append(*c, e) +} + +func (c ErrorCollection) ToError() error { + if len(c) == 0 { + return nil + } + return c +} diff --git a/internal/sqlib/sqerrors/errors_test.go b/internal/sqlib/sqerrors/errors_test.go index f29268d2..a626afd9 100644 --- a/internal/sqlib/sqerrors/errors_test.go +++ b/internal/sqlib/sqerrors/errors_test.go @@ -46,3 +46,12 @@ func TestWithInfo(t *testing.T) { require.Equal(t, info, got) }) } + +func TestErrorCollection(t *testing.T) { + var errs sqerrors.ErrorCollection + errs.Add(errors.New("error 1")) + errs.Add(errors.New("error 2")) + errs.Add(errors.New("error 3")) + errs.Add(errors.New("error 4")) + require.Equal(t, "multiple errors occurred: (error 1) error 1; (error 2) error 2; (error 3) error 3; (error 4) error 4", errs.Error()) +} diff --git a/internal/sqlib/sqhook/hook.go b/internal/sqlib/sqhook/hook.go index ea8d7792..9083c9c9 100644 --- a/internal/sqlib/sqhook/hook.go +++ b/internal/sqlib/sqhook/hook.go @@ -301,6 +301,15 @@ func makeMultiPrologCallback(h *Hook, prologs []PrologCallback) PrologCallback { return makePrologCallback(h, func(params []reflect.Value) (epilog ReflectedEpilogCallback, err error) { safeCallErr := sqsafe.Call(func() error { epilogs := make([]reflect.Value, 0, len(prologs)) + defer func() { + if len(epilogs) > 0 { + epilog = func(results []reflect.Value) { + for _, epilog := range epilogs { + epilog.Call(results) + } + } + } + }() for _, prolog := range prologs { prologValue := reflect.ValueOf(prolog) results := prologValue.Call(params) @@ -308,26 +317,10 @@ func makeMultiPrologCallback(h *Hook, prologs []PrologCallback) PrologCallback { epilogs = append(epilogs, r0) } if r1 := results[1]; !r1.IsNil() { - if len(epilogs) > 0 { - epilog = func(results []reflect.Value) { - for _, epilog := range epilogs { - epilog.Call(results) - } - } - } err = r1.Interface().(error) return nil } } - - if len(epilogs) > 0 { - epilog = func(results []reflect.Value) { - for _, epilog := range epilogs { - epilog.Call(results) - } - } - } - return nil }) if safeCallErr != nil { diff --git a/internal/sqlib/sqhook/hook_unit_test.go b/internal/sqlib/sqhook/hook_unit_test.go index b46eae1a..37ab114a 100644 --- a/internal/sqlib/sqhook/hook_unit_test.go +++ b/internal/sqlib/sqhook/hook_unit_test.go @@ -10,6 +10,8 @@ import ( "testing" fuzz "github.com/google/gofuzz" + "github.com/sqreen/go-agent/internal/sqlib/sqhook/internal" + "github.com/sqreen/go-agent/tools/testlib" "github.com/stretchr/testify/require" ) @@ -407,3 +409,145 @@ func TestMultiCallback(t *testing.T) { }) } } + +func TestHookTableLookup(t *testing.T) { + t.Run("nil", func(t *testing.T) { + myIndex := symbolIndexType{} + found, err := hookTableLookup(nil, testlib.RandUTF8String(), myIndex) + require.NoError(t, err) + require.Nil(t, found) + }) + + t.Run("empty", func(t *testing.T) { + myIndex := symbolIndexType{} + myTable := internal.HookTableType{} + found, err := hookTableLookup(myTable, testlib.RandUTF8String(), myIndex) + require.NoError(t, err) + require.Nil(t, found) + }) + + t.Run("having instrumentation errors", func(t *testing.T) { + myIndex := symbolIndexType{} + for _, tc := range []internal.HookTableType{ + { + func(d *internal.HookDescriptorType) { + // Nil values + *d = internal.HookDescriptorType{Func: nil, PrologVar: nil} + }, + }, + + { + func(d *internal.HookDescriptorType) { + // Nil Func value - Non-nil prolog var + var prologVar *func() + *d = internal.HookDescriptorType{Func: nil, PrologVar: &prologVar} + }, + }, + } { + tc := tc + t.Run("", func(t *testing.T) { + found, err := hookTableLookup(tc, testlib.RandUTF8String(), myIndex) + require.Error(t, err) + require.Nil(t, found) + }) + } + }) +} + +func TestPrologVarValidation(t *testing.T) { + for _, tc := range []struct { + fn, prolog interface{} + shouldSucceed bool + }{ + { + fn: (func())(nil), + prolog: (func() (func(), error))(nil), + shouldSucceed: true, + }, + + { // wrong arg count + fn: (func())(nil), + prolog: (func(*int) (func(), error))(nil), + }, + + { // wrong prolog arg type: should be *int + fn: (func(int))(nil), + prolog: (func(int) (func(), error))(nil), + }, + + { // wrong prolog arg type: should be *int + fn: (func(int))(nil), + prolog: (func(*int) (func(), error))(nil), + shouldSucceed: true, + }, + + { // wrong return count + fn: (func(int))(nil), + prolog: (func(*int) error)(nil), + }, + + { // wrong return type: wrong prolog type + fn: (func(int))(nil), + prolog: (func(*int) (func(string), error))(nil), + }, + + { // wrong return type: wrong error type + fn: (func(int))(nil), + prolog: (func(*int) (func(), bool))(nil), + }, + + { // wrong return count + fn: (func(int))(nil), + prolog: (func(*int) (func(), error, bool))(nil), + }, + + { // wrong prolog type: wrong arg count + fn: (func(int) (chan struct{}, error))(nil), + prolog: (func(*int) (func(*chan struct{}, *error, *int), error))(nil), + }, + + { // wrong prolog type: wrong arg count + fn: (func(int) (chan struct{}, error))(nil), + prolog: (func(*int) (func(*chan struct{}), error))(nil), + }, + + { // wrong prolog type: wrong arg count + fn: (func(int) (chan struct{}, error))(nil), + prolog: (func(*int) (func(), error))(nil), + }, + + { // wrong prolog type: wrong arg types + fn: (func(int) (chan struct{}, error))(nil), + prolog: (func(*int) (func(interface{}, interface{}, interface{}), error))(nil), + }, + + { // wrong prolog type: wrong arg types + fn: (func(int) (chan struct{}, error))(nil), + prolog: (func(*int) (func(chan struct{}, *error), error))(nil), + }, + + { + fn: (func(int) (chan struct{}, error))(nil), + prolog: (func(*int) (func(*chan struct{}, *error), error))(nil), + shouldSucceed: true, + }, + + { // variadic func + fn: (func(...int))(nil), + prolog: (func(*[]int) (func(), error))(nil), + shouldSucceed: true, + }, + } { + tc := tc + t.Run("unexpected signatures", func(t *testing.T) { + fnType := reflect.TypeOf(tc.fn) + prologVarType := reflect.PtrTo(reflect.PtrTo(reflect.TypeOf(tc.prolog))) + err := validatePrologVar(fnType, prologVarType) + if tc.shouldSucceed { + require.NoError(t, err) + } else { + require.Error(t, err) + } + }) + } +} diff --git a/internal/sqlib/sqhook/validation_test.go b/internal/sqlib/sqhook/validation_test.go deleted file mode 100644 index 1f493362..00000000 --- a/internal/sqlib/sqhook/validation_test.go +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright (c) 2016 - 2019 Sqreen. All Rights Reserved. -// Please refer to our terms for more information: -// https://www.sqreen.io/terms.html - -package sqhook - -import ( - "reflect" - "testing" - - "github.com/sqreen/go-agent/internal/sqlib/sqhook/internal" - "github.com/sqreen/go-agent/tools/testlib" - "github.com/stretchr/testify/require" -) - -func TestHookTableLookup(t *testing.T) { - t.Run("nil", func(t *testing.T) { - myIndex := symbolIndexType{} - found, err := hookTableLookup(nil, testlib.RandUTF8String(), myIndex) - require.NoError(t, err) - require.Nil(t, found) - }) - - t.Run("empty", func(t *testing.T) { - myIndex := symbolIndexType{} - myTable := internal.HookTableType{} - found, err := hookTableLookup(myTable, testlib.RandUTF8String(), myIndex) - require.NoError(t, err) - require.Nil(t, found) - }) - - t.Run("having instrumentation errors", func(t *testing.T) { - myIndex := symbolIndexType{} - for _, tc := range []internal.HookTableType{ - { - func(d *internal.HookDescriptorType) { - // Nil values - *d = internal.HookDescriptorType{Func: nil, PrologVar: nil} - }, - }, - - { - func(d *internal.HookDescriptorType) { - // Nil Func value - Non-nil prolog var - var prologVar *func() - *d = internal.HookDescriptorType{Func: nil, PrologVar: &prologVar} - }, - }, - } { - tc := tc - t.Run("", func(t *testing.T) { - found, err := hookTableLookup(tc, testlib.RandUTF8String(), myIndex) - require.Error(t, err) - require.Nil(t, found) - }) - } - }) -} - -func TestPrologVarValidation(t *testing.T) { - for _, tc := range []struct { - fn, prolog interface{} - shouldSucceed bool - }{ - { - fn: (func())(nil), - prolog: (func() (func(), error))(nil), - shouldSucceed: true, - }, - - { // wrong arg count - fn: (func())(nil), - prolog: (func(*int) (func(), error))(nil), - }, - - { // wrong prolog arg type: should be *int - fn: (func(int))(nil), - prolog: (func(int) (func(), error))(nil), - }, - - { // wrong prolog arg type: should be *int - fn: (func(int))(nil), - prolog: (func(*int) (func(), error))(nil), - shouldSucceed: true, - }, - - { // wrong return count - fn: (func(int))(nil), - prolog: (func(*int) error)(nil), - }, - - { // wrong return type: wrong prolog type - fn: (func(int))(nil), - prolog: (func(*int) (func(string), error))(nil), - }, - - { // wrong return type: wrong error type - fn: (func(int))(nil), - prolog: (func(*int) (func(), bool))(nil), - }, - - { // wrong return count - fn: (func(int))(nil), - prolog: (func(*int) (func(), error, bool))(nil), - }, - - { // wrong prolog type: wrong arg count - fn: (func(int) (chan struct{}, error))(nil), - prolog: (func(*int) (func(*chan struct{}, *error, *int), error))(nil), - }, - - { // wrong prolog type: wrong arg count - fn: (func(int) (chan struct{}, error))(nil), - prolog: (func(*int) (func(*chan struct{}), error))(nil), - }, - - { // wrong prolog type: wrong arg count - fn: (func(int) (chan struct{}, error))(nil), - prolog: (func(*int) (func(), error))(nil), - }, - - { // wrong prolog type: wrong arg types - fn: (func(int) (chan struct{}, error))(nil), - prolog: (func(*int) (func(interface{}, interface{}, interface{}), error))(nil), - }, - - { // wrong prolog type: wrong arg types - fn: (func(int) (chan struct{}, error))(nil), - prolog: (func(*int) (func(chan struct{}, *error), error))(nil), - }, - - { - fn: (func(int) (chan struct{}, error))(nil), - prolog: (func(*int) (func(*chan struct{}, *error), error))(nil), - shouldSucceed: true, - }, - - { // variadic func - fn: (func(...int))(nil), - prolog: (func(*[]int) (func(), error))(nil), - shouldSucceed: true, - }, - } { - tc := tc - t.Run("unexpected signatures", func(t *testing.T) { - fnType := reflect.TypeOf(tc.fn) - prologVarType := reflect.PtrTo(reflect.PtrTo(reflect.TypeOf(tc.prolog))) - err := validatePrologVar(fnType, prologVarType) - if tc.shouldSucceed { - require.NoError(t, err) - } else { - require.Error(t, err) - } - }) - } -} From ab939cd485067af7b9272b73127efefe99c9acdb Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Tue, 21 Jul 2020 11:14:03 +0200 Subject: [PATCH 04/11] sdk/sqreen-instrumentation-tool: add limited instrumentation of go.mongodb.org/mongo-driver/mongo In order to attach noSQL-injection protection to the driver. --- sdk/sqreen-instrumentation-tool/instrumentation.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sdk/sqreen-instrumentation-tool/instrumentation.go b/sdk/sqreen-instrumentation-tool/instrumentation.go index 6bdf33b3..ec34bb3b 100644 --- a/sdk/sqreen-instrumentation-tool/instrumentation.go +++ b/sdk/sqreen-instrumentation-tool/instrumentation.go @@ -202,6 +202,7 @@ var ( "os", "net/http", "github.com/gin-gonic/gin", + "go.mongodb.org/mongo-driver/mongo", } // Optional list packages, files and functions we want to only instrument for @@ -221,6 +222,10 @@ var ( // Same comment as net/http "context.go", // context.go contains the body parsers }, + "go.mongodb.org/mongo-driver/mongo": { + // Limited for performance reasons to: + "mongo.go", // mongo.go contains the bson transformation function + }, } ) From 65424ff8b040c2e2eb043a6d1fb91fc2066d7e2e Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Wed, 22 Jul 2020 17:07:41 +0200 Subject: [PATCH 05/11] agent/backend: add sqreen domains health checking The new security signal API is at `ingestion.sqreen.com`. The agent now checks if it can access it, and otherwise uses the usual `back.sqreen.com`. The agent reports the results in the login payload for debugging purposes. --- go.mod | 5 +- go.sum | 13 ++-- internal/adapter.go | 9 +++ internal/agent.go | 12 ++-- internal/backend/api/api.go | 37 ++++++++---- internal/backend/api/json_test.go | 3 +- internal/backend/client.go | 69 ++++++++++++++++++---- internal/backend/client_test.go | 11 +++- internal/client.go | 8 ++- internal/config/config.go | 3 +- internal/config/config_test.go | 3 +- internal/sqlib/sqsanitize/sanitize.go | 4 +- internal/sqlib/sqsanitize/sanitize_test.go | 28 +++------ 13 files changed, 141 insertions(+), 64 deletions(-) diff --git a/go.mod b/go.mod index 3cf33964..2b65c18f 100644 --- a/go.mod +++ b/go.mod @@ -29,8 +29,8 @@ require ( github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/viper v1.3.2 github.com/sqreen/go-libsqreen v0.7.0 - github.com/sqreen/go-sdk/signal v1.0.0 - github.com/stretchr/testify v1.5.1 + github.com/sqreen/go-sdk/signal v1.1.0 + github.com/stretchr/testify v1.6.1 golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37 // indirect golang.org/x/net v0.0.0-20200513185701-a91f0712d120 golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9 // indirect @@ -40,4 +40,5 @@ require ( gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect gopkg.in/go-playground/assert.v1 v1.2.1 gopkg.in/go-playground/validator.v8 v8.18.2 // indirect + gopkg.in/yaml.v2 v2.3.0 // indirect ) diff --git a/go.sum b/go.sum index e2cdb96d..942e906c 100644 --- a/go.sum +++ b/go.sum @@ -101,12 +101,10 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.3.2 h1:VUFqw5KcqRf7i70GOzW7N+Q7+gxVBkSSqiXB12+JQ4M= github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= -github.com/sqreen/go-libsqreen v0.6.1 h1:+SHH3h8qHhINEzgRVqTZ40YxqwDjSVxU5r4isUeg+C8= -github.com/sqreen/go-libsqreen v0.6.1/go.mod h1:D324eoKlZGfW+TF3WGg+2fUtpdrI+cEK5UYwpxfaeUc= github.com/sqreen/go-libsqreen v0.7.0 h1:MRX/KB5lX3O6ucvmTUap6iSDt27bM+76MQpuDNjL+1o= github.com/sqreen/go-libsqreen v0.7.0/go.mod h1:D324eoKlZGfW+TF3WGg+2fUtpdrI+cEK5UYwpxfaeUc= -github.com/sqreen/go-sdk/signal v1.0.0 h1:WNjufvcjKYOgSZHPCwqG0Od5eVAD8wxwmiIe6ZCqoNE= -github.com/sqreen/go-sdk/signal v1.0.0/go.mod h1:UksuO4mxxDMFw3el+R9mW9tmCgdc94WiDcGuCXU/pwU= +github.com/sqreen/go-sdk/signal v1.1.0 h1:l22lqlUNDlEaqsNjpgVelGteBCwGodZqUDPUMBOLzhE= +github.com/sqreen/go-sdk/signal v1.1.0/go.mod h1:XWJV0TzuoN6PotzRn4YSe6fhTxyw67yRpVYr9NJTzto= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0 h1:Hbg2NidpLE8veEBkEZTL3CvlkUIVzuU9jDplZO54c48= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= @@ -115,8 +113,8 @@ github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0 github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8 h1:3SVOIvH7Ae1KRYyQWRjXWJEA9sS/c/pjvH++55Gr648= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= @@ -190,3 +188,6 @@ gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 h1:tQIYjPdBoyREyB9XMu+nnTclpTYkz2zFM+lzLJFO4gQ= +gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/adapter.go b/internal/adapter.go index 964d4ecb..4ad08b3e 100644 --- a/internal/adapter.go +++ b/internal/adapter.go @@ -334,6 +334,15 @@ func newMetricsAPIAdapter(logger plog.ErrorLogger, expiredMetrics map[string]*me return metricsArray } +type variousInfoAPIAdapter struct { + *appInfoAPIAdapter + sqreenDomains api.SqreenDomainStatusMap +} + +func (v variousInfoAPIAdapter) GetSqreenDomains() api.SqreenDomainStatusMap { + return v.sqreenDomains +} + type appInfoAPIAdapter app.Info func (a *appInfoAPIAdapter) unwrap() *app.Info { return (*app.Info)(a) } diff --git a/internal/agent.go b/internal/agent.go index 1aec2315..bf547f51 100644 --- a/internal/agent.go +++ b/internal/agent.go @@ -228,9 +228,9 @@ func New(cfg *config.Config) *AgentType { logger.Info(message) return nil } - // TODO: agent.Health() + waf.Health() + if waf.Version() == nil { - message := fmt.Sprintf("in-app waf disabled: cgo was disabled during the program compilation while required by the in-app waf") + message := "in-app waf disabled: cgo was disabled during the program compilation while required by the in-app waf" backend.SendAgentMessage(logger, cfg, message) logger.Info("agent: ", message) } @@ -240,9 +240,11 @@ func New(cfg *config.Config) *AgentType { sdkMetricsPeriod := time.Duration(cfg.SDKMetricsPeriod()) * time.Second logger.Debugf("agent: using sdk metrics store time period of %s", sdkMetricsPeriod) - piiScrubber, err := sqsanitize.NewScrubber(cfg.StripSensitiveKeyRegexp(), cfg.StripSensitiveValueRegexp(), config.ScrubberRedactedString) + piiScrubber := sqsanitize.NewScrubber(cfg.StripSensitiveKeyRegexp(), cfg.StripSensitiveValueRegexp(), config.ScrubberRedactedString) + + client, err := backend.NewClient(cfg.BackendHTTPAPIBaseURL(), cfg.BackendHTTPAPIProxy(), logger) if err != nil { - logger.Error(sqerrors.Wrap(err, "ecdsa public key")) + logger.Error(sqerrors.Wrap(err, "agent: could not create the backend client")) return nil } @@ -264,7 +266,7 @@ func New(cfg *config.Config) *AgentType { cancel: cancel, config: cfg, appInfo: app.NewInfo(logger), - client: backend.NewClient(cfg.BackendHTTPAPIBaseURL(), cfg.BackendHTTPAPIProxy(), logger), + client: client, actors: actor.NewStore(logger), rules: rulesEngine, piiScrubber: piiScrubber, diff --git a/internal/backend/api/api.go b/internal/backend/api/api.go index 2ffe18dd..3ecb3f3e 100644 --- a/internal/backend/api/api.go +++ b/internal/backend/api/api.go @@ -11,6 +11,10 @@ import ( "github.com/sqreen/go-agent/internal/sqlib/sqsanitize" ) +type PingResponse struct { + Status bool `json:"status"` +} + type AppLoginRequest struct { BundleSignature string `json:"bundle_signature"` VariousInfos AppLoginRequest_VariousInfos `json:"various_infos"` @@ -26,19 +30,28 @@ type AppLoginRequest struct { } type AppLoginRequest_VariousInfos struct { - Time time.Time `json:"time"` - Pid uint32 `json:"pid"` - Ppid uint32 `json:"ppid"` - Euid uint32 `json:"euid"` - Egid uint32 `json:"egid"` - Uid uint32 `json:"uid"` - Gid uint32 `json:"gid"` - Name string `json:"name"` - LibSqreenVersion *string `json:"libsqreen_version"` - HasDependencies bool `json:"has_dependencies"` - HasLibsqreen bool `json:"has_libsqreen"` + Time time.Time `json:"time"` + Pid uint32 `json:"pid"` + Ppid uint32 `json:"ppid"` + Euid uint32 `json:"euid"` + Egid uint32 `json:"egid"` + Uid uint32 `json:"uid"` + Gid uint32 `json:"gid"` + Name string `json:"name"` + LibSqreenVersion *string `json:"libsqreen_version"` + HasDependencies bool `json:"has_dependencies"` + HasLibsqreen bool `json:"has_libsqreen"` + SqreenDomains SqreenDomainStatusMap `json:"sqreen_domains"` } +type ( + SqreenDomainStatusMap map[string]SqreenDomainStatus + SqreenDomainStatus struct { + Status bool `json:"status,omitempty"` + Error string `json:"error,omitempty"` + } +) + type AppLoginResponse struct { Error string `json:"error"` SessionId string `json:"session_id"` @@ -565,6 +578,7 @@ type AppLoginRequest_VariousInfosFace interface { GetLibSqreenVersion() *string GetHasDependencies() bool GetHasLibsqreen() bool + GetSqreenDomains() SqreenDomainStatusMap } func NewAppLoginRequest_VariousInfosFromFace(that AppLoginRequest_VariousInfosFace) *AppLoginRequest_VariousInfos { @@ -580,6 +594,7 @@ func NewAppLoginRequest_VariousInfosFromFace(that AppLoginRequest_VariousInfosFa this.LibSqreenVersion = that.GetLibSqreenVersion() this.HasDependencies = that.GetHasDependencies() this.HasLibsqreen = that.GetHasLibsqreen() + this.SqreenDomains = that.GetSqreenDomains() return this } diff --git a/internal/backend/api/json_test.go b/internal/backend/api/json_test.go index 828db956..216c3445 100644 --- a/internal/backend/api/json_test.go +++ b/internal/backend/api/json_test.go @@ -108,8 +108,7 @@ func TestJSON(t *testing.T) { func TestCustomScrubber(t *testing.T) { expectedMask := "scrubbed" - scrubber, err := sqsanitize.NewScrubber(regexp.MustCompile("password"), regexp.MustCompile("forbidden"), expectedMask) - require.NoError(t, err) + scrubber := sqsanitize.NewScrubber(regexp.MustCompile("password"), regexp.MustCompile("forbidden"), expectedMask) t.Run("without attack", func(t *testing.T) { rr := &api.RequestRecord{ diff --git a/internal/backend/client.go b/internal/backend/client.go index f32be8ce..f3c51fb1 100644 --- a/internal/backend/client.go +++ b/internal/backend/client.go @@ -10,7 +10,6 @@ import ( "crypto/sha1" "encoding/hex" "encoding/json" - "fmt" "io" "io/ioutil" "net" @@ -31,19 +30,20 @@ import ( type Client struct { client *http.Client - backendURL string + backendURL *url.URL logger *plog.Logger session string signalClient *client.Client infra *signal.AgentInfra + health *HealthStatus } -func NewClient(backendURL string, proxy string, logger *plog.Logger) *Client { +func NewClient(baseURL string, proxy string, logger *plog.Logger) (*Client, error) { var transport *http.Transport if proxy == "" { // No user settings. The default transport uses standard global proxy // settings *_PROXY environment variables. - dummyReq, _ := http.NewRequest("GET", backendURL, nil) + dummyReq, _ := http.NewRequest("GET", baseURL, nil) if proxyURL, _ := http.ProxyFromEnvironment(dummyReq); proxyURL != nil { logger.Infof("client: using system http proxy `%s` as indicated by the system environment variables http_proxy, https_proxy and no_proxy (or their uppercase alternatives)", proxyURL) } @@ -64,6 +64,11 @@ func NewClient(backendURL string, proxy string, logger *plog.Logger) *Client { transport.Proxy = proxy } + backendURL, err := url.Parse(baseURL) + if err != nil { + return nil, sqerrors.Wrapf(err, "could not parse the URL `%s`", backendURL) + } + client := &Client{ client: &http.Client{ Timeout: config.BackendHTTPAPIRequestTimeout, @@ -73,7 +78,41 @@ func NewClient(backendURL string, proxy string, logger *plog.Logger) *Client { logger: logger, } - return client + return client, nil +} + +type HealthStatus struct { + DomainStatus api.SqreenDomainStatusMap +} + +func (c *Client) Health() HealthStatus { + if c.health != nil { + return *c.health + } + + health := HealthStatus{ + DomainStatus: api.SqreenDomainStatusMap{}, + } + + var ( + domain = client.DefaultBaseURL + res api.PingResponse + err error + ) + req, err := http.NewRequest(config.BackendHTTPAPIEndpoint.Ping.Method, domain+config.BackendHTTPAPIEndpoint.Ping.URL, nil) + if err == nil { + err = c.Do(req, nil, &res) + } + status := api.SqreenDomainStatus{ + Status: res.Status, + } + if err != nil { + status.Error = err.Error() + } + health.DomainStatus[domain] = status + + c.health = &health + return health } func (c *Client) AppLogin(req *api.AppLoginRequest, token string, appName string, useSignalBackend bool) (*api.AppLoginResponse, error) { @@ -99,6 +138,13 @@ func (c *Client) AppLogin(req *api.AppLoginRequest, token string, appName string if useSignalBackend || res.Features.UseSignals { c.signalClient = client.NewClient(c.client, c.session) + + // If the default signal URL is not healthy, fallback to the general + // backend URL. + if !c.Health().DomainStatus[client.DefaultBaseURL].Status { + c.signalClient.BaseURL = c.backendURL + } + c.signalClient.Logger = c.logger c.infra = signal.NewAgentInfra(req.AgentVersion, req.OsType, req.Hostname, req.RuntimeVersion) } @@ -221,9 +267,9 @@ func (c *Client) Do(req *http.Request, pbs ...interface{}) error { // involved ip addresses if urlErr, ok := err.(*url.Error); ok { if netErr, ok := urlErr.Err.(*net.OpError); ok { - // TODO: update the api to pass these dropped extra details (involved - // ip addresses) as error metadata - err = sqerrors.Wrap(netErr.Err, fmt.Sprintf("%s %s", urlErr.Op, urlErr.URL)) + err = sqerrors.WithInfo(err, netErr) + } else { + err = sqerrors.WithInfo(err, urlErr) } } return err @@ -271,16 +317,19 @@ func (r *HTTPResponseStringer) String() string { // Helper method to build an API endpoint request structure. func (c *Client) newRequest(descriptor *config.HTTPAPIEndpoint) (*http.Request, error) { + url, err := c.backendURL.Parse(descriptor.URL) + if err != nil { + return nil, sqerrors.Wrap(err, "could not parse the request url") + } req, err := http.NewRequest( descriptor.Method, - c.backendURL+descriptor.URL, + url.String(), nil) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") - return req, nil } diff --git a/internal/backend/client_test.go b/internal/backend/client_test.go index 09b04459..7af15fe1 100644 --- a/internal/backend/client_test.go +++ b/internal/backend/client_test.go @@ -18,6 +18,7 @@ import ( "github.com/sqreen/go-agent/internal/config" "github.com/sqreen/go-agent/internal/plog" "github.com/sqreen/go-agent/tools/testlib" + "github.com/stretchr/testify/require" ) var ( @@ -49,7 +50,8 @@ func TestClient(t *testing.T) { server := initFakeServer(endpointCfg, request, response, statusCode, headers) defer server.Close() - client := backend.NewClient(server.URL(), "", logger) + client, err := backend.NewClient(server.URL(), "", logger) + require.NoError(t, err) res, err := client.AppLogin(request, token, appName, false) g.Expect(err).NotTo(HaveOccurred()) @@ -163,11 +165,14 @@ func initFakeServerSession(endpointCfg *config.HTTPAPIEndpoint, request, respons loginRes.Status = true server.AppendHandlers(ghttp.RespondWithJSONEncoded(http.StatusOK, loginRes)) - client = backend.NewClient(server.URL(), "", logger) + client, err := backend.NewClient(server.URL(), "", logger) + if err != nil { + panic(err) + } token := testlib.RandHTTPHeaderValue(2, 50) appName := testlib.RandHTTPHeaderValue(2, 50) - _, err := client.AppLogin(loginReq, token, appName, false) + _, err = client.AppLogin(loginReq, token, appName, false) if err != nil { panic(err) } diff --git a/internal/client.go b/internal/client.go index f4d8878e..718c40b6 100644 --- a/internal/client.go +++ b/internal/client.go @@ -59,8 +59,14 @@ func appLogin(ctx context.Context, logger *plog.Logger, client *backend.Client, logger.Error(withNotificationError{sqerrors.Wrap(err, "could not retrieve the program dependencies")}) } + backendHealth := client.Health() + variousInfoAPIAdapter := variousInfoAPIAdapter{ + appInfoAPIAdapter: (*appInfoAPIAdapter)(appInfo), + sqreenDomains: backendHealth.DomainStatus, + } + appLoginReq := api.AppLoginRequest{ - VariousInfos: *api.NewAppLoginRequest_VariousInfosFromFace((*appInfoAPIAdapter)(appInfo)), + VariousInfos: *api.NewAppLoginRequest_VariousInfosFromFace(variousInfoAPIAdapter), BundleSignature: bundleSignature, AgentType: "golang", AgentVersion: version.Version(), diff --git a/internal/config/config.go b/internal/config/config.go index 8b5cbe97..987d864d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -57,7 +57,7 @@ var ( // List of endpoint addresses, relative to the base URL. BackendHTTPAPIEndpoint = struct { AppLogin, AppLogout, AppBeat, AppException, Batch, ActionsPack, RulesPack, - Bundle, AgentMessage, AppAgentMessage HTTPAPIEndpoint + Bundle, AgentMessage, AppAgentMessage, Ping HTTPAPIEndpoint }{ AppLogin: HTTPAPIEndpoint{http.MethodPost, "/sqreen/v1/app-login"}, AppLogout: HTTPAPIEndpoint{http.MethodGet, "/sqreen/v0/app-logout"}, @@ -69,6 +69,7 @@ var ( Bundle: HTTPAPIEndpoint{http.MethodPost, "/sqreen/v0/bundle"}, AgentMessage: HTTPAPIEndpoint{http.MethodPost, "/sqreen/v0/agent_message"}, AppAgentMessage: HTTPAPIEndpoint{http.MethodPost, "/sqreen/v0/app_agent_message"}, + Ping: HTTPAPIEndpoint{http.MethodGet, "/ping"}, } // Header name of the API token. diff --git a/internal/config/config_test.go b/internal/config/config_test.go index ebc9b8a2..c73e9327 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -258,8 +258,7 @@ func TestStripRegexp(t *testing.T) { require.NotNil(t, cfg.StripSensitiveKeyRegexp()) require.NotNil(t, cfg.StripSensitiveValueRegexp()) - scrubber, err := sqsanitize.NewScrubber(cfg.StripSensitiveKeyRegexp(), cfg.StripSensitiveValueRegexp(), ScrubberRedactedString) - require.NoError(t, err) + scrubber := sqsanitize.NewScrubber(cfg.StripSensitiveKeyRegexp(), cfg.StripSensitiveValueRegexp(), ScrubberRedactedString) t.Run("the key regexp should match", func(t *testing.T) { for _, key := range []string{ diff --git a/internal/sqlib/sqsanitize/sanitize.go b/internal/sqlib/sqsanitize/sanitize.go index d7d41922..0c6af248 100644 --- a/internal/sqlib/sqsanitize/sanitize.go +++ b/internal/sqlib/sqsanitize/sanitize.go @@ -48,12 +48,12 @@ type CustomScrubber interface { // scrubbed regardless of `valueRegexp` - any string in the associated // value is replaced by `redactedValue`. // An error can be returned if the regular expressions cannot be compiled. -func NewScrubber(keyRegexp, valueRegexp *regexp.Regexp, redactedValueMask string) (*Scrubber, error) { +func NewScrubber(keyRegexp, valueRegexp *regexp.Regexp, redactedValueMask string) *Scrubber { return &Scrubber{ keyRegexp: keyRegexp, valueRegexp: valueRegexp, redactedValueMask: redactedValueMask, - }, nil + } } // RedactedValueMask returns the configured redactedValueMask diff --git a/internal/sqlib/sqsanitize/sanitize_test.go b/internal/sqlib/sqsanitize/sanitize_test.go index dd56ff6f..8531f5e8 100644 --- a/internal/sqlib/sqsanitize/sanitize_test.go +++ b/internal/sqlib/sqsanitize/sanitize_test.go @@ -72,10 +72,9 @@ func TestScrubber(t *testing.T) { redactedValueMask string } tests := []struct { - name string - args args - want *sqsanitize.Scrubber - wantErr bool + name string + args args + want *sqsanitize.Scrubber }{ { name: "no regexps", @@ -105,13 +104,8 @@ func TestScrubber(t *testing.T) { for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { - got, err := sqsanitize.NewScrubber(tc.args.keyRegexp, tc.args.valueRegexp, tc.args.redactedValueMask) - if tc.wantErr { - require.Error(t, err) - } else { - require.NoError(t, err) - require.NotNil(t, got) - } + got := sqsanitize.NewScrubber(tc.args.keyRegexp, tc.args.valueRegexp, tc.args.redactedValueMask) + require.NotNil(t, got) }) } }) @@ -168,8 +162,7 @@ func TestScrubber(t *testing.T) { valueRE = regexp.MustCompile(tc.valueRegexp) } - s, err := sqsanitize.NewScrubber(regexp.MustCompile(testlib.RandUTF8String()), valueRE, expectedMask) - require.NoError(t, err) + s := sqsanitize.NewScrubber(regexp.MustCompile(testlib.RandUTF8String()), valueRE, expectedMask) info := sqsanitize.Info{} scrubbed, err := s.Scrub(&tc.value, info) require.NoError(t, err) @@ -1086,8 +1079,7 @@ func TestScrubber(t *testing.T) { valRegex = regexp.MustCompile(valueRE) } - s, err := sqsanitize.NewScrubber(keyRegex, valRegex, expectedMask) - require.NoError(t, err) + s := sqsanitize.NewScrubber(keyRegex, valRegex, expectedMask) for _, tc := range tests { tc := tc @@ -1150,8 +1142,7 @@ func TestScrubber(t *testing.T) { }) t.Run("Usage", func(t *testing.T) { - s, err := sqsanitize.NewScrubber(regexp.MustCompile("(?i)password"), regexp.MustCompile("forbidden"), expectedMask) - require.NoError(t, err) + s := sqsanitize.NewScrubber(regexp.MustCompile("(?i)password"), regexp.MustCompile("forbidden"), expectedMask) t.Run("URL Values", func(t *testing.T) { values := url.Values{ @@ -1184,8 +1175,7 @@ func TestScrubber(t *testing.T) { }) t.Run("HTTP Request", func(t *testing.T) { - s, err := sqsanitize.NewScrubber(regexp.MustCompile(`(?i)(passw(or)?d)|(secret)|(authorization)|(api_?key)|(access_?token)`), regexp.MustCompile(`(?:\d[ -]*?){13,16}`), expectedMask) - require.NoError(t, err) + s := sqsanitize.NewScrubber(regexp.MustCompile(`(?i)(passw(or)?d)|(secret)|(authorization)|(api_?key)|(access_?token)`), regexp.MustCompile(`(?:\d[ -]*?){13,16}`), expectedMask) t.Run("zero value", func(t *testing.T) { var req http.Request From a3b6f3eea1d7296c5076ab3f336be225c6cf7898 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Thu, 23 Jul 2020 10:13:03 +0200 Subject: [PATCH 06/11] batch: avoid a possible empty batch case being sent --- internal/agent.go | 9 ++++++--- internal/backend/api/signal/signal.go | 6 +++--- internal/backend/client.go | 5 +---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/internal/agent.go b/internal/agent.go index bf547f51..09afffdc 100644 --- a/internal/agent.go +++ b/internal/agent.go @@ -588,11 +588,14 @@ func stopTimer(t *time.Timer) { func (m *eventManager) Loop(ctx context.Context, client *backend.Client) { var ( - stalenessTimer = time.NewTimer(m.maxStaleness) + // We can't create a stopped timer so we initializae it with a large value + // of 24 hours and stop it immediately. Calls to Reset() will correctly + // set the configured timer value. + stalenessTimer = time.NewTimer(24 * time.Hour) stalenessChan <-chan time.Time ) - defer stopTimer(stalenessTimer) stopTimer(stalenessTimer) + defer stopTimer(stalenessTimer) batch := make([]Event, 0, m.count) for { @@ -649,7 +652,7 @@ func (m *eventManager) sendBatch(ctx context.Context, client *backend.Client, ba if _, err := m.agent.piiScrubber.Scrub(event, nil); err != nil { // Only log this unexpected error and keep the event that may have been // partially scrubbed. - m.agent.logger.Error(errors.Wrap(err, "could not send the event batch")) + m.agent.logger.Error(errors.Wrap(err, "could not scrub the event")) } req.Batch = append(req.Batch, *api.NewBatchRequest_EventFromFace(event)) } diff --git a/internal/backend/api/signal/signal.go b/internal/backend/api/signal/signal.go index ea521da6..b76a80ee 100644 --- a/internal/backend/api/signal/signal.go +++ b/internal/backend/api/signal/signal.go @@ -60,12 +60,12 @@ func NewAgentInfra(agentVersion, osType, hostname, runtimeVersion string) *Agent func fromLegacyRequestRecord(record *legacy_api.RequestRecord, infra *AgentInfra) (*http_trace.Trace, error) { port, err := strconv.ParseUint(record.Request.Port, 10, 64) if err != nil { - return nil, sqerrors.Wrap(err, "could not parse the request port number as an int64 value") + return nil, sqerrors.Wrap(err, "could not parse the request port number as an uint64 value") } remotePort, err := strconv.ParseUint(record.Request.RemotePort, 10, 64) if err != nil { - return nil, sqerrors.Wrap(err, "could not parse the request remote port number as an int64 value") + return nil, sqerrors.Wrap(err, "could not parse the request remote port number as an uint64 value") } headers := make([][]string, len(record.Request.Headers)) @@ -182,7 +182,7 @@ func FromLegacyBatch(b []legacy_api.BatchRequest_Event, infra *AgentInfra, logge case legacy_api.RequestRecordEvent: trace, err := fromLegacyRequestRecord(evt.RequestRecord, infra) if err != nil { - logger.Error(sqerrors.Wrap(err, "could not create the HTTP trace")) + logger.Error(sqerrors.WithInfo(sqerrors.Wrap(err, "could not create the HTTP trace"), evt)) continue } signal = trace diff --git a/internal/backend/client.go b/internal/backend/client.go index f3c51fb1..a773b65b 100644 --- a/internal/backend/client.go +++ b/internal/backend/client.go @@ -199,10 +199,7 @@ func (c *Client) Batch(ctx context.Context, req *api.BatchRequest) error { return err } httpReq.Header.Set(config.BackendHTTPAPIHeaderSession, c.session) - if err := c.Do(httpReq, req); err != nil { - return err - } - return nil + return c.Do(httpReq, req) } batch := signal.FromLegacyBatch(req.Batch, c.infra, c.logger) From b99af5c0a314c139a631919a1df287b27f60b70f Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Thu, 23 Jul 2020 10:59:32 +0200 Subject: [PATCH 07/11] agent: update the version number to v0.13.0 --- internal/version/version.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/version/version.go b/internal/version/version.go index 0ebfb17c..fba69777 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -4,6 +4,6 @@ package version -const version = "0.12.1" +const version = "0.13.0" func Version() string { return version } From 38545296f214015f9650b9f7bb76b8798fe96caf Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Thu, 23 Jul 2020 20:33:49 +0200 Subject: [PATCH 08/11] config: change the use_signal_backend config logic to disable_signal_backend --- internal/agent.go | 2 +- internal/backend/client.go | 4 ++-- internal/client.go | 4 ++-- internal/config/config.go | 8 ++++---- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/internal/agent.go b/internal/agent.go index 09afffdc..d0e80703 100644 --- a/internal/agent.go +++ b/internal/agent.go @@ -343,7 +343,7 @@ func (a *AgentType) Serve() error { token := a.config.BackendHTTPAPIToken() appName := a.config.AppName() - appLoginRes, err := appLogin(a.ctx, a.logger, a.client, token, appName, a.appInfo, a.config.UseSignalBackend()) + appLoginRes, err := appLogin(a.ctx, a.logger, a.client, token, appName, a.appInfo, a.config.DisableSignalBackend()) if err != nil { if xerrors.Is(err, context.Canceled) { a.logger.Debug(err) diff --git a/internal/backend/client.go b/internal/backend/client.go index a773b65b..8be7e4c5 100644 --- a/internal/backend/client.go +++ b/internal/backend/client.go @@ -115,7 +115,7 @@ func (c *Client) Health() HealthStatus { return health } -func (c *Client) AppLogin(req *api.AppLoginRequest, token string, appName string, useSignalBackend bool) (*api.AppLoginResponse, error) { +func (c *Client) AppLogin(req *api.AppLoginRequest, token string, appName string, disableSignalBackend bool) (*api.AppLoginResponse, error) { httpReq, err := c.newRequest(&config.BackendHTTPAPIEndpoint.AppLogin) if err != nil { return nil, err @@ -136,7 +136,7 @@ func (c *Client) AppLogin(req *api.AppLoginRequest, token string, appName string c.session = res.SessionId - if useSignalBackend || res.Features.UseSignals { + if !disableSignalBackend && res.Features.UseSignals { c.signalClient = client.NewClient(c.client, c.session) // If the default signal URL is not healthy, fallback to the general diff --git a/internal/client.go b/internal/client.go index 718c40b6..f5ae051f 100644 --- a/internal/client.go +++ b/internal/client.go @@ -53,7 +53,7 @@ func (e LoginError) Unwrap() error { // Login to the backend. When the API request fails, retry for ever and after // sleeping some time. -func appLogin(ctx context.Context, logger *plog.Logger, client *backend.Client, token string, appName string, appInfo *app.Info, useSignalBackend bool) (*api.AppLoginResponse, error) { +func appLogin(ctx context.Context, logger *plog.Logger, client *backend.Client, token string, appName string, appInfo *app.Info, disableSignalBackend bool) (*api.AppLoginResponse, error) { _, bundleSignature, err := appInfo.Dependencies() if err != nil { logger.Error(withNotificationError{sqerrors.Wrap(err, "could not retrieve the program dependencies")}) @@ -83,7 +83,7 @@ func appLogin(ctx context.Context, logger *plog.Logger, client *backend.Client, case <-ctx.Done(): return nil, ctx.Err() default: - appLoginRes, err = client.AppLogin(&appLoginReq, token, appName, useSignalBackend) + appLoginRes, err = client.AppLogin(&appLoginReq, token, appName, disableSignalBackend) if err == nil && appLoginRes.Status { return appLoginRes, nil } diff --git a/internal/config/config.go b/internal/config/config.go index 987d864d..62a2d9da 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -222,7 +222,7 @@ const ( configKeyRules = `rules` configKeySDKMetricsPeriod = `sdk_metrics_period` configKeyMaxMetricsStoreLength = `max_metrics_store_length` - configKeyUseSignalBackend = `use_signal_backend` + configKeyDisableSignalBackend = `disable_signal_backend` configKeyStripSensitiveKeyRegexp = `strip_sensitive_key_regexp` configKeyStripSensitiveValueRegexp = `strip_sensitive_value_regexp` ) @@ -271,7 +271,7 @@ func New(logger *plog.Logger) (*Config, error) { {key: configKeyRules, defaultValue: "", hidden: true}, {key: configKeySDKMetricsPeriod, defaultValue: configDefaultSDKMetricsPeriod, hidden: true}, {key: configKeyMaxMetricsStoreLength, defaultValue: configDefaultMaxMetricsStoreLength, hidden: true}, - {key: configKeyUseSignalBackend, defaultValue: "", hidden: true}, + {key: configKeyDisableSignalBackend, defaultValue: "", hidden: true}, {key: configKeyStripSensitiveKeyRegexp, defaultValue: configDefaultStripSensitiveKeyRegexp}, {key: configKeyStripSensitiveValueRegexp, defaultValue: configDefaultStripSensitiveValueRegexp}, } @@ -413,8 +413,8 @@ func (c *Config) MaxMetricsStoreLength() uint { // UseSignalBackend returns true to force the agent to use the signal backend // no matter the feature flag. When false, the app-login feature flag tells // whether or not to use the signal backend. -func (c *Config) UseSignalBackend() bool { - strip := sanitizeString(c.GetString(configKeyUseSignalBackend)) +func (c *Config) DisableSignalBackend() bool { + strip := sanitizeString(c.GetString(configKeyDisableSignalBackend)) return strip != "" } From 4f48c2e086a7247f9c440bc90438a3b9c1530e15 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Thu, 23 Jul 2020 20:34:11 +0200 Subject: [PATCH 09/11] agent: always report the domain status --- internal/backend/api/api.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/backend/api/api.go b/internal/backend/api/api.go index 3ecb3f3e..aa8df972 100644 --- a/internal/backend/api/api.go +++ b/internal/backend/api/api.go @@ -47,7 +47,7 @@ type AppLoginRequest_VariousInfos struct { type ( SqreenDomainStatusMap map[string]SqreenDomainStatus SqreenDomainStatus struct { - Status bool `json:"status,omitempty"` + Status bool `json:"status"` Error string `json:"error,omitempty"` } ) From 693342ece8fb0090988b18a1fd9321dcb0fc9c68 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Fri, 24 Jul 2020 12:16:36 +0200 Subject: [PATCH 10/11] scubbing: fix the waf metadata scrubbing Change the scrubbing function of the WAF metadata to sanitize substrings and not only the full string value. --- internal/backend/api/api.go | 16 ++++++--- internal/backend/api/json_test.go | 2 +- internal/sqlib/sqsanitize/sanitize.go | 11 ++---- internal/sqlib/sqsanitize/sanitize_test.go | 40 ++++++++++++++++++++++ 4 files changed, 55 insertions(+), 14 deletions(-) diff --git a/internal/backend/api/api.go b/internal/backend/api/api.go index aa8df972..4ad7c8da 100644 --- a/internal/backend/api/api.go +++ b/internal/backend/api/api.go @@ -6,6 +6,7 @@ package api import ( "encoding/json" + "strings" "time" "github.com/sqreen/go-agent/internal/sqlib/sqsanitize" @@ -451,12 +452,17 @@ func (i *WAFAttackInfo) Scrub(scrubber *sqsanitize.Scrubber, info sqsanitize.Inf redactedString := scrubber.RedactedValueMask() for e := range wafInfo { for f := range wafInfo[e].Filter { - if info.Contains(wafInfo[e].Filter[f].ResolvedValue) { - wafInfo[e].Filter[f].ResolvedValue = redactedString - if wafInfo[e].Filter[f].MatchStatus != "" { - wafInfo[e].Filter[f].MatchStatus = redactedString + for v := range info { + resolvedValue := wafInfo[e].Filter[f].ResolvedValue + newStr := strings.ReplaceAll(resolvedValue, v, redactedString) + if newStr != resolvedValue { + // The string was changed + wafInfo[e].Filter[f].ResolvedValue = newStr + if wafInfo[e].Filter[f].MatchStatus != "" { + wafInfo[e].Filter[f].MatchStatus = strings.ReplaceAll(wafInfo[e].Filter[f].MatchStatus, v, redactedString) + } + scrubbed = true } - scrubbed = true } } } diff --git a/internal/backend/api/json_test.go b/internal/backend/api/json_test.go index 216c3445..6074035f 100644 --- a/internal/backend/api/json_test.go +++ b/internal/backend/api/json_test.go @@ -220,7 +220,7 @@ func TestCustomScrubber(t *testing.T) { OperatorValue: "trigger", BindingAccessor: "#.request_params", ResolvedValue: expectedMask, - MatchStatus: expectedMask, + MatchStatus: "forbidden", }, }, }, diff --git a/internal/sqlib/sqsanitize/sanitize.go b/internal/sqlib/sqsanitize/sanitize.go index 0c6af248..619cb946 100644 --- a/internal/sqlib/sqsanitize/sanitize.go +++ b/internal/sqlib/sqsanitize/sanitize.go @@ -236,6 +236,9 @@ func (s *Scrubber) scrubMap(v reflect.Value, info Info) (scrubbed bool) { // When the current value is an interface value, we scrub its underlying // value. if hasInterfaceValueType { + if val.IsNil() { + continue + } val = val.Elem() valT = val.Type() } @@ -337,14 +340,6 @@ func (i Info) Add(value string) { i[value] = struct{}{} } -func (i Info) Contains(value string) bool { - if len(i) == 0 { - return false - } - _, exists := i[value] - return exists -} - func (i Info) Append(info Info) { if i == nil || len(info) == 0 { return diff --git a/internal/sqlib/sqsanitize/sanitize_test.go b/internal/sqlib/sqsanitize/sanitize_test.go index 8531f5e8..fe5541a9 100644 --- a/internal/sqlib/sqsanitize/sanitize_test.go +++ b/internal/sqlib/sqsanitize/sanitize_test.go @@ -935,6 +935,14 @@ func TestScrubber(t *testing.T) { }, "e": 33, "passwd": []interface{}{"everything", randString}, + "g": nil, + "h": (*string)(nil), + "i": []interface{}{nil}, + "j": []interface{}{(*string)(nil)}, + "password": map[string]interface{}{ + "a": nil, + "b": []interface{}{nil}, + }, } }, expected: expectedValues{ @@ -948,6 +956,14 @@ func TestScrubber(t *testing.T) { }, "e": 33, "passwd": []interface{}{expectedMask, randString}, + "g": nil, + "h": (*string)(nil), + "i": []interface{}{nil}, + "j": []interface{}{(*string)(nil)}, + "password": map[string]interface{}{ + "a": nil, + "b": []interface{}{nil}, + }, }, withKeyRE: map[string]interface{}{ "apikey": expectedMask, @@ -959,6 +975,14 @@ func TestScrubber(t *testing.T) { }, "e": 33, "passwd": []interface{}{expectedMask, expectedMask}, + "g": nil, + "h": (*string)(nil), + "i": []interface{}{nil}, + "j": []interface{}{(*string)(nil)}, + "password": map[string]interface{}{ + "a": nil, + "b": []interface{}{nil}, + }, }, withBothRE: map[string]interface{}{ "apikey": expectedMask, @@ -970,6 +994,14 @@ func TestScrubber(t *testing.T) { }, "e": 33, "passwd": []interface{}{expectedMask, expectedMask}, + "g": nil, + "h": (*string)(nil), + "i": []interface{}{nil}, + "j": []interface{}{(*string)(nil)}, + "password": map[string]interface{}{ + "a": nil, + "b": []interface{}{nil}, + }, }, withBothDisabled: map[string]interface{}{ "apikey": randString, @@ -981,6 +1013,14 @@ func TestScrubber(t *testing.T) { }, "e": 33, "passwd": []interface{}{"everything", randString}, + "g": nil, + "h": (*string)(nil), + "i": []interface{}{nil}, + "j": []interface{}{(*string)(nil)}, + "password": map[string]interface{}{ + "a": nil, + "b": []interface{}{nil}, + }, }, }, }, From 53e2d0532305f7ffb4bd3d16574a528a3c0855f9 Mon Sep 17 00:00:00 2001 From: Julio Guerra Date: Fri, 24 Jul 2020 14:56:51 +0200 Subject: [PATCH 11/11] repo: update the changelog --- CHANGELOG.md | 57 ++++++++++++++++++++++++++++++++++++---------------- README.md | 1 - 2 files changed, 40 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4973b03b..fe4d0f72 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,27 @@ -# v0.12.1 +# v0.13.0 - 24 July 2020 + +## New Feature + +- (#137) RASP: add noSQL Injection protection support for the Go MongoDB driver + `go.mongodb.org/mongo-driver/mongo`. This protection can be configured at + . + +## Internal Changes + +- (#138) Health-check the HTTPS connectivity to the new backend API + `ingestion.sqreen.com` before using it. Fallback to the usual + `back.sqreen.com` in case of a connection issue. Therefore, the agent can take + up to 30 seconds to connect to Sqreen if the health-check timeouts. Please + make sure to add this new firewall and proxy configurations. + +- (#136) Add support to attach multiple security protections per hook point. + +## Fixes + +- (#140) Fix the In-App WAF metadata PII scrubbing to also match substrings. + + +# v0.12.1 - 13 July 2020 ## Fixes @@ -19,7 +42,7 @@ - (eeb1dca) Avoid copying the metadata returned by the In-App WAF. -# v0.12.0 +# v0.12.0 - 6 July 2020 ## New Features @@ -53,7 +76,7 @@ - (794d6e2) Allow port numbers in the `X-Forwarded-For` header. -# v0.11.0 +# v0.11.0 - 19 June 2020 ## New Features @@ -90,14 +113,14 @@ - (#114) Add Goroutine Local Storage (GLS) support through static instrumentation of the Go runtime. -# v0.10.1 +# v0.10.1 - 5 June 2020 ## Fix - (#116) Fix the instrumentation tool ignoring vendored packages, leading to missing hook points in the agent. -# v0.10.0 +# v0.10.0 - 20 May 2020 ## New Features @@ -136,7 +159,7 @@ - Document PII scrubbing configuration at . -# v0.9.1 +# v0.9.1 - 31 March 2020 ## Fixes @@ -150,7 +173,7 @@ - (#101) Prevent starting the agent when the instrumentation tool and agent versions are not the same. -# v0.9.0 +# v0.9.0 - 19 February 2020 This new major version says farewell to the `beta` and adds SQL-injection run time protection thanks the first building blocks of [RASP][RASP-Wikipedia] @@ -233,7 +256,7 @@ Because we now want a stable public API, find below the breaking changes: compiled as a Go module. This is also shown by the dashboard when the list of dependencies is empty. -# v0.1.0-beta.10 +# v0.1.0-beta.10 - 24 January 2020 ## Breaking Change @@ -264,7 +287,7 @@ Because we now want a stable public API, find below the breaking changes: - (#92) Vendoring using `go mod vendor` could lead to compilation errors due to missing files. -# v0.1.0-beta.9 +# v0.1.0-beta.9 - 19 December 2019 ## New Features @@ -283,7 +306,7 @@ Because we now want a stable public API, find below the breaking changes: - The In-App WAF has been intensively optimized so that large requests can no longer impact its execution time. (#83) -# v0.1.0-beta.8 +# v0.1.0-beta.8 - 15 October 2019 ## Internal Changes @@ -292,7 +315,7 @@ Because we now want a stable public API, find below the breaking changes: - Ignore WAF timeout errors and add more context when reporting an error (#80). - Update the libsqreen to v0.4.0 to add support for the `@pm` operator. -# v0.1.0-beta.7 +# v0.1.0-beta.7 - 26 September 2019 ## Breaking Changes @@ -319,7 +342,7 @@ Because we now want a stable public API, find below the breaking changes: - Fix a compilation error on 32-bit target architectures. -# v0.1.0-beta.6 +# v0.1.0-beta.6 - 25 July 2019 ## New Features @@ -354,7 +377,7 @@ Because we now want a stable public API, find below the breaking changes: log-level. -# v0.1.0-beta.5 +# v0.1.0-beta.5 - 23 May 2019 ## New Features @@ -380,7 +403,7 @@ Because we now want a stable public API, find below the breaking changes: processing loop. -# v0.1.0-beta.4 +# v0.1.0-beta.4 - 16 April 2019 This release adds the ability to block IP addresses or users into your Go web services by adding support for [Security Automation] according to your @@ -440,7 +463,7 @@ Note that redirecting users or IP addresses is not supported yet. - Avoid performing multiple times commands within the same command batch. (51) -# v0.1.0-beta.3 +# v0.1.0-beta.3 - 22 March 2019 ## New Features @@ -477,7 +500,7 @@ Note that redirecting users or IP addresses is not supported yet. self-managing the initializations. (#28) -# v0.1.0-beta.2 +# v0.1.0-beta.2 - 14 February 2019 ## New feature @@ -485,7 +508,7 @@ Note that redirecting users or IP addresses is not supported yet. current request. As soon as we add the support for the security reponses, it will allow to block users (#26). -# v0.1.0-beta.1 +# v0.1.0-beta.1 - 7 February 2019 This version is a new major version towards the v0.1.0 as it proposes a new and stable SDK API, that now will only be updated upon user feedback. So please, diff --git a/README.md b/README.md index 9c282455..cac850b2 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,6 @@ [![GoDoc](https://godoc.org/github.com/sqreen/go-agent?status.svg)](https://godoc.org/github.com/sqreen/go-agent) [![Go Report Card](https://goreportcard.com/badge/github.com/sqreen/go-agent)](https://goreportcard.com/report/github.com/sqreen/go-agent) [![Build Status](https://dev.azure.com/sqreenci/Go%20Agent/_apis/build/status/sqreen.go-agent?branchName=master)](https://dev.azure.com/sqreenci/Go%20Agent/_build/latest?definitionId=8&branchName=master) -[![Sourcegraph](https://sourcegraph.com/github.com/sqreen/go-agent/-/badge.svg)](https://sourcegraph.com/github.com/sqreen/go-agent?badge) After performance monitoring (APM), error and log monitoring it’s time to add a security component into your app. Sqreen’s microagent automatically monitors