diff --git a/README.md b/README.md index 1080433..ced1ed1 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ All the API endpoints are hand tested, no unit or integration CI yet. Probably n ## Key Points * Needs a backing clamd setup with a tcp socket. -* Auth using an `Authorize` header if you supply a users.yml (a yaml dict of `token: username`). +* Auth using an `Authorization` header if you supply a users.yml (a yaml dict of `token: username`). * HTTPS if either of the supplied `certfile` or `keyfile` resolve to a file. * POST /scan performs an instream scan using the post body (will correctly chunk for instream, just send your files as straight binary in the body). * GET /healthz performs a health check and calls ping on the underlying antivirus. diff --git a/go.mod b/go.mod index f5de47e..4643a32 100644 --- a/go.mod +++ b/go.mod @@ -6,5 +6,6 @@ require ( github.com/julienschmidt/httprouter v1.2.0 github.com/prometheus/client_golang v0.9.2 github.com/rs/zerolog v1.13.0 + github.com/stretchr/testify v1.3.0 gopkg.in/yaml.v2 v2.2.2 ) diff --git a/go.sum b/go.sum index 8eff60d..fde002e 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,15 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/julienschmidt/httprouter v1.2.0 h1:TDTW5Yz1mjftljbcKqRcrYhd4XeOoI98t+9HbQbYf7g= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v0.9.2 h1:awm861/B8OKDd2I/6o1dy3ra4BamzKhYOiGItCeZ740= github.com/prometheus/client_golang v0.9.2/go.mod h1:OsXs2jCmiKlQ1lTBmv21f2mNfw4xf/QclQDMrYNZzcM= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910 h1:idejC8f05m9MGOsuEi1ATq9shN03HrxNkD/luQvxCv8= @@ -16,6 +20,10 @@ github.com/prometheus/procfs v0.0.0-20181204211112-1dc9a6cbc91a h1:9a8MnZMP0X2nL github.com/prometheus/procfs v0.0.0-20181204211112-1dc9a6cbc91a/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/rs/zerolog v1.13.0 h1:hSNcYHyxDWycfePW7pUI8swuFkcSMPKh3E63Pokg1Hk= github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= +github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/main.go b/main.go index 4e6e772..e5a6697 100644 --- a/main.go +++ b/main.go @@ -67,7 +67,7 @@ func main() { r.POST("/scan", proxy.Scan) r.GET("/healthz", proxy.Ok) r.GET("/metrics", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { promhttp.Handler().ServeHTTP(w, r) }) - api := chowder.LogRequests(chowder.HeaderAuth(users, r)) + api := chowder.LogRequests(log.With().Logger(), chowder.HeaderAuth(users, r)) l.Fatal().Err(listenAndServe(l, *bind, *certFile, *keyFile, api)).Msg("closed") } diff --git a/pkg/clamav.go b/pkg/clamav.go index 23fd811..c066d07 100644 --- a/pkg/clamav.go +++ b/pkg/clamav.go @@ -32,8 +32,8 @@ var ( // VirusScanner is the interface for a virus scanning service type VirusScanner interface { - Scan(stream io.Reader) (bool, string, error) - Ok() (bool, string, error) + Scan(stream io.Reader) (infected bool, msg string, err error) + Ok() (ok bool, msg string, err error) } // ClamAV is a virus scanning service backed by a ClamAV tcp connection diff --git a/pkg/middleware.go b/pkg/middleware.go index 596fa88..6c9c1da 100644 --- a/pkg/middleware.go +++ b/pkg/middleware.go @@ -77,11 +77,11 @@ func (w *StatusWriter) Write(b []byte) (int, error) { } // LogRequests logs all requests that pass through with loglevel dependant on status code -func LogRequests(handler http.Handler) http.HandlerFunc { +func LogRequests(l zerolog.Logger, handler http.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { start := time.Now() connCount.Inc() - l := log.With(). + l := l.With(). Time("start", start). Str("host", r.Host). Str("remote-address", r.RemoteAddr). @@ -90,8 +90,19 @@ func LogRequests(handler http.Handler) http.HandlerFunc { Str("proto", r.Proto). Str("user-agent", r.Header.Get("User-Agent")). Logger() - sw := &StatusWriter{ResponseWriter: w} - handler.ServeHTTP(sw, r.WithContext(setLog(r.Context(), &l))) + sw := StatusWriter{ResponseWriter: w} + defer func() { + if r := recover(); r != nil { + d := time.Now().Sub(start) + durations.Observe(d.Seconds()) + logPanic(r, logLevelFromStatus(l, sw.status). + Int("status", sw.status). + Int("content-length", sw.length). + Dur("duration", d)). + Msg("response returned") + } + }() + handler.ServeHTTP(&sw, r.WithContext(setLog(r.Context(), &l))) statusCodes.Observe(float64(sw.status)) d := time.Now().Sub(start) durations.Observe(d.Seconds()) @@ -103,6 +114,21 @@ func LogRequests(handler http.Handler) http.HandlerFunc { } } +func logPanic(p interface{}, e *zerolog.Event) *zerolog.Event { + switch p.(type) { + case error: + return e.Err(p.(error)) + case string: + return e.Str("panic-logger", p.(string)) + case fmt.Stringer: + return e.Str("panic-logger", p.(fmt.Stringer).String()) + case fmt.GoStringer: + return e.Str("panic-logger", p.(fmt.GoStringer).GoString()) + default: + return e.Interface("panic-logger", p) + } +} + func logLevelFromStatus(l zerolog.Logger, status int) *zerolog.Event { switch { case status < 200: // 100 -> 199 diff --git a/pkg/middleware_test.go b/pkg/middleware_test.go new file mode 100644 index 0000000..fe86cfb --- /dev/null +++ b/pkg/middleware_test.go @@ -0,0 +1,155 @@ +package chowder + +import ( + "errors" + "fmt" + "net/http" + "strings" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var ( + _ http.Handler = &mockHandler{} + _ http.ResponseWriter = &mockResponseWriter{} +) + +func TestLoggingMiddlewareLogsPanics(t *testing.T) { + w := &strings.Builder{} + l := zerolog.New(w) + m := &mockHandler{} + m.On("ServeHTTP", mock.Anything, mock.Anything).Once().Run(func(mock.Arguments) { + panic(errors.New("big badda boom")) + }) + r := &http.Request{} + + assert.NotPanics(t, func() { + LogRequests(l, m).ServeHTTP(&mockResponseWriter{}, r) + }) + assert.Contains(t, w.String(), "big badda boom") +} + +func TestLoggingMiddlewareLogs(t *testing.T) { + w := &strings.Builder{} + l := zerolog.New(w) + m := &mockHandler{} + m.On("ServeHTTP", mock.Anything, mock.Anything).Once() + r := &http.Request{} + + sut := LogRequests(l, m) + sut.ServeHTTP(&mockResponseWriter{}, r) + line := w.String() + + assert.Regexp(t, `{"level":"debug","start":".{20}","host":"","remote-address":"","method":"","request-uri":"","proto":"","user-agent":"","status":0,"content-length":0,"duration":.+,"message":"response returned"}`, line) +} + +func TestAuthMiddlewareAllowsValidAuth(t *testing.T) { + rw := &mockResponseWriter{} + r := &http.Request{ + Header: http.Header{ + "Authorization": []string{"password"}, + }, + } + m := &mockHandler{} + m.On("ServeHTTP", rw, r).Once() + u := map[string]string{ + "password": "user", + } + + sut := HeaderAuth(u, m) + sut.ServeHTTP(rw, r) + + rw.AssertExpectations(t) + m.AssertExpectations(t) +} + +func TestAuthMiddlewareBlocksInvalidAuth(t *testing.T) { + rw := &mockResponseWriter{} + rw.Mock.On("WriteHeader", 401).Once() + h := http.Header{} + rw.Mock.On("Header").Once().Return(h) + resp := "" + rw.Mock.On("Write", mock.Anything).Once().Run(func(args mock.Arguments) { + asByte, ok := args.Get(0).([]byte) + if !ok { + panic("wasn't a []byte") + } + resp = string(asByte) + }).Return(0, nil) + r := &http.Request{ + Header: http.Header{ + "Authorization": []string{"notpassword"}, + }} + m := &mockHandler{} + u := map[string]string{ + "password": "user", + } + + sut := HeaderAuth(u, m) + sut.ServeHTTP(rw, r) + + rw.AssertExpectations(t) + m.AssertExpectations(t) + assert.Equal(t, `{"message":"token 'notpassword' not recognised","error":"Unauthorized"}`, resp) +} + +func TestAuthMiddlewareBlocksNoAuthSupplied(t *testing.T) { + rw := &mockResponseWriter{} + rw.Mock.On("WriteHeader", 401).Once() + h := http.Header{} + rw.Mock.On("Header").Once().Return(h) + resp := "" + rw.Mock.On("Write", mock.Anything).Once().Run(func(args mock.Arguments) { + asByte, ok := args.Get(0).([]byte) + if !ok { + panic("wasn't a []byte") + } + resp = string(asByte) + }).Return(0, nil) + r := &http.Request{} + m := &mockHandler{} + u := map[string]string{ + "password": "user", + } + + sut := HeaderAuth(u, m) + sut.ServeHTTP(rw, r) + + rw.AssertExpectations(t) + m.AssertExpectations(t) + assert.Equal(t, `{"message":"no authorisation token supplied","error":"Unauthorized"}`, resp) +} + +type mockHandler struct { + mock.Mock +} + +func (m *mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + m.Called(w, r) +} + +type mockResponseWriter struct { + mock.Mock +} + +func (m *mockResponseWriter) Header() http.Header { + args := m.Called() + var h http.Header + var ok bool + if h, ok = args.Get(0).(http.Header); !ok { + panic(fmt.Errorf("assert: arguments: Failed because object wasn't correct type: %v", args.Get(0))) + } + return h +} + +func (m *mockResponseWriter) Write(b []byte) (int, error) { + args := m.Called(b) + return args.Int(0), args.Error(1) +} + +func (m *mockResponseWriter) WriteHeader(statusCode int) { + m.Called(statusCode) +} diff --git a/pkg/proxy.go b/pkg/proxy.go index b42634d..8792a52 100644 --- a/pkg/proxy.go +++ b/pkg/proxy.go @@ -11,7 +11,7 @@ import ( // ScanResponse is a response with the result of a scan type ScanResponse struct { - Infected bool + Infected bool `json:"infected"` Response `json:",omitempty"` } @@ -69,6 +69,6 @@ func (p *Proxy) Ok(w http.ResponseWriter, r *http.Request, _ httprouter.Params) return } writeResponse(w, r, &Response{ - Message: "Ok", + Message: "Up", }, http.StatusOK) } diff --git a/pkg/proxy_test.go b/pkg/proxy_test.go new file mode 100644 index 0000000..ed7da6e --- /dev/null +++ b/pkg/proxy_test.go @@ -0,0 +1,109 @@ +package chowder + +import ( + "errors" + "io" + "net/http" + "testing" + + "github.com/julienschmidt/httprouter" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func setupProxyTest(header int) (rw *mockResponseWriter, r *http.Request, mav *mockAntiVirus, w *string) { + rw = &mockResponseWriter{} + rw.On("WriteHeader", header).Once() + rw.On("Header").Once().Return(http.Header{}) + resp := "" + w = &resp + rw.On("Write", mock.Anything).Once().Run(func(args mock.Arguments) { + asByte, ok := args.Get(0).([]byte) + if !ok { + panic("wasn't a []byte") + } + *w = string(asByte) + }).Return(0, nil) + r = &http.Request{} + mav = &mockAntiVirus{} + return +} + +func TestScanValidCreatesCorrectResponse(t *testing.T) { + rw, r, mav, resp := setupProxyTest(200) + mav.On("Scan", nil).Return(false, "ok", nil) + + sut := &Proxy{mav} + + sut.Scan(rw, r, httprouter.Params{}) + + rw.AssertExpectations(t) + mav.AssertExpectations(t) + assert.Equal(t, `{"infected":false,"message":"ok"}`, *resp) +} + +func TestScanErrCreatesCorrectResponse(t *testing.T) { + rw, r, mav, resp := setupProxyTest(500) + mav.On("Scan", nil).Return(false, "", errors.New("big badda boom")) + + sut := &Proxy{mav} + + sut.Scan(rw, r, httprouter.Params{}) + + rw.AssertExpectations(t) + mav.AssertExpectations(t) + assert.Equal(t, `{"error":"big badda boom"}`, *resp) +} + +func TestOkValidCreatesCorrectResponse(t *testing.T) { + rw, r, mav, resp := setupProxyTest(200) + mav.On("Ok").Return(true, "ok", nil) + + sut := &Proxy{mav} + + sut.Ok(rw, r, httprouter.Params{}) + + rw.AssertExpectations(t) + mav.AssertExpectations(t) + assert.Equal(t, `{"message":"Up"}`, *resp) +} + +func TestOkAntivirusDownCreatesCorrectResponse(t *testing.T) { + rw, r, mav, resp := setupProxyTest(500) + mav.On("Ok").Return(false, "", nil) + + sut := &Proxy{mav} + + sut.Ok(rw, r, httprouter.Params{}) + + rw.AssertExpectations(t) + mav.AssertExpectations(t) + assert.Equal(t, `{"message":"Down"}`, *resp) +} + +func TestOkErrCreatesCorrectResponse(t *testing.T) { + rw, r, mav, resp := setupProxyTest(500) + mav.On("Ok").Return(false, "", errors.New("big badda boom")) + + sut := &Proxy{mav} + + sut.Ok(rw, r, httprouter.Params{}) + + rw.AssertExpectations(t) + mav.AssertExpectations(t) + assert.Equal(t, `{"message":"Down","error":"big badda boom - daemon response: "}`, *resp) +} + +type mockAntiVirus struct { + mock.Mock +} + +func (m *mockAntiVirus) Scan(stream io.Reader) (ok bool, msg string, err error) { + args := m.Called(stream) + return args.Bool(0), args.String(1), args.Error(2) +} + +func (m *mockAntiVirus) Ok() (ok bool, msg string, err error) { + args := m.Called() + return args.Bool(0), args.String(1), args.Error(2) +} diff --git a/pkg/response.go b/pkg/response.go index 8914787..99ce98e 100644 --- a/pkg/response.go +++ b/pkg/response.go @@ -13,11 +13,14 @@ type Response struct { func writeResponse(w http.ResponseWriter, r *http.Request, resp interface{}, code int) { bytes, err := json.Marshal(resp) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return + if err == nil { + w.WriteHeader(code) + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(bytes) + if err == nil { + return + } } - w.WriteHeader(code) - w.Header().Set("Content-Type", "application/json") - w.Write(bytes) + http.Error(w, err.Error(), http.StatusInternalServerError) + return }