diff --git a/ociregistry/ociserver/error_test.go b/ociregistry/ociserver/error_test.go index b6f0589b..52d2f0d8 100644 --- a/ociregistry/ociserver/error_test.go +++ b/ociregistry/ociserver/error_test.go @@ -31,9 +31,9 @@ func TestCustomErrorWriter(t *testing.T) { // HTTP status code is derived from the OCI error code in preference // to the HTTPError status code. r := New(&ociregistry.Funcs{}, &Options{ - WriteError: func(w http.ResponseWriter, err error) error { + WriteError: func(w http.ResponseWriter, _ *http.Request, err error) { w.Header().Set("Some-Header", "a value") - return ociregistry.WriteError(w, err) + ociregistry.WriteError(w, err) }, }) s := httptest.NewServer(r) diff --git a/ociregistry/ociserver/registry.go b/ociregistry/ociserver/registry.go index b16d1192..9ce9a886 100644 --- a/ociregistry/ociserver/registry.go +++ b/ociregistry/ociserver/registry.go @@ -42,10 +42,12 @@ var v2 = ocispecroot.Versioned{ // Options holds options for the server. type Options struct { // WriteError is used to write error responses. It is passed the - // error an API call has returned and is responsible for writing - // it to w. If WriteError is nil, [ociregistry.WriteError] will - // be used. - WriteError func(w http.ResponseWriter, err error) error + // writer to write the error response to, the request that + // the error is in response to, and the error itself. + // + // If WriteError is nil, [ociregistry.WriteError] will + // be used and any error discarded. + WriteError func(w http.ResponseWriter, req *http.Request, err error) // DisableReferrersAPI, when true, causes the registry to behave as if // it does not understand the referrers API. @@ -140,7 +142,9 @@ func New(backend ociregistry.Interface, opts *Options) http.Handler { r.opts.DebugID = fmt.Sprintf("ociserver%d", atomic.AddInt32(&debugID, 1)) } if r.opts.WriteError == nil { - r.opts.WriteError = ociregistry.WriteError + r.opts.WriteError = func(w http.ResponseWriter, _ *http.Request, err error) { + ociregistry.WriteError(w, err) + } } return r } @@ -176,7 +180,7 @@ var handlers = []func(r *registry, ctx context.Context, w http.ResponseWriter, r func (r *registry) ServeHTTP(resp http.ResponseWriter, req *http.Request) { if rerr := r.v2(resp, req); rerr != nil { - r.opts.WriteError(resp, rerr) + r.opts.WriteError(resp, req, rerr) return } }