Skip to content

Commit 4dfbfd8

Browse files
authored
Merge pull request #191 from ipfs/fix/method-handling
change HandledMethods to AllowGet and cleanup method handling
2 parents ef934e8 + 3093cad commit 4dfbfd8

File tree

7 files changed

+97
-59
lines changed

7 files changed

+97
-59
lines changed

http/config.go

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@ type ServerConfig struct {
2222
// Headers is an optional map of headers that is written out.
2323
Headers map[string][]string
2424

25-
// HandledMethods set which methods will be handled for the HTTP
26-
// requests. Other methods will return 405. This is different from CORS
27-
// AllowedMethods (the API may handle GET and POST, but only allow GETs
28-
// for CORS-enabled requests via AllowedMethods).
29-
HandledMethods []string
25+
// AllowGet indicates whether or not this server accepts GET requests.
26+
// When unset, the server only accepts POST, HEAD, and OPTIONS.
27+
//
28+
// This is different from CORS AllowedMethods. The API may allow GET
29+
// requests in general, but reject them in CORS. That will allow
30+
// websites to include resources from the API but not _read_ them.
31+
AllowGet bool
3032

3133
// corsOpts is a set of options for CORS headers.
3234
corsOpts *cors.Options
@@ -38,7 +40,6 @@ type ServerConfig struct {
3840
func NewServerConfig() *ServerConfig {
3941
cfg := new(ServerConfig)
4042
cfg.corsOpts = new(cors.Options)
41-
cfg.HandledMethods = []string{http.MethodPost}
4243
return cfg
4344
}
4445

@@ -149,16 +150,3 @@ func allowReferer(r *http.Request, cfg *ServerConfig) bool {
149150

150151
return false
151152
}
152-
153-
// handleRequestMethod returns true if the request method is among
154-
// HandledMethods.
155-
func handleRequestMethod(r *http.Request, cfg *ServerConfig) bool {
156-
// For very small slices as these, this should be faster than
157-
// a map lookup.
158-
for _, m := range cfg.HandledMethods {
159-
if r.Method == m {
160-
return true
161-
}
162-
}
163-
return false
164-
}

http/errors_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ func TestErrors(t *testing.T) {
116116

117117
mkTest := func(tc testcase) func(*testing.T) {
118118
return func(t *testing.T) {
119-
_, srv := getTestServer(t, nil, nil) // handler_test:/^func getTestServer/
119+
_, srv := getTestServer(t, nil, false) // handler_test:/^func getTestServer/
120120
c := NewClient(srv.URL)
121121
req, err := cmds.NewRequest(context.Background(), tc.path, tc.opts, nil, nil, cmdRoot)
122122
if err != nil {
@@ -161,11 +161,11 @@ func TestErrors(t *testing.T) {
161161

162162
func TestUnhandledMethod(t *testing.T) {
163163
tc := httpTestCase{
164-
Method: "GET",
165-
HandledMethods: []string{"POST"},
166-
Code: http.StatusMethodNotAllowed,
164+
Method: "GET",
165+
AllowGet: false,
166+
Code: http.StatusMethodNotAllowed,
167167
ResHeaders: map[string]string{
168-
"Allow": "POST",
168+
"Allow": "POST, HEAD, OPTIONS",
169169
},
170170
}
171171
tc.test(t)

http/handler.go

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,27 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
9797

9898
// First of all, check if we are allowed to handle the request method
9999
// or we are configured not to.
100-
if !handleRequestMethod(r, h.cfg) {
101-
setAllowedHeaders(w, h.cfg.HandledMethods)
100+
//
101+
// Always allow OPTIONS, POST
102+
switch r.Method {
103+
case http.MethodOptions:
104+
// If we get here, this is a normal (non-preflight) request.
105+
// The CORS library handles all other requests.
106+
107+
// Tell the user the allowed methods, and return.
108+
setAllowedHeaders(w, h.cfg.AllowGet)
109+
w.WriteHeader(http.StatusNoContent)
110+
return
111+
case http.MethodPost:
112+
case http.MethodGet, http.MethodHead:
113+
if h.cfg.AllowGet {
114+
break
115+
}
116+
fallthrough
117+
default:
118+
setAllowedHeaders(w, h.cfg.AllowGet)
102119
http.Error(w, "405 - Method Not Allowed", http.StatusMethodNotAllowed)
103-
log.Warnf("The IPFS API does not support %s requests. All requests must use %s", h.cfg.HandledMethods)
120+
log.Warnf("The IPFS API does not support %s requests.", r.Method)
104121
return
105122
}
106123

@@ -139,6 +156,13 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
139156
return
140157
}
141158

159+
// set user's headers first.
160+
for k, v := range h.cfg.Headers {
161+
if !skipAPIHeader(k) {
162+
w.Header()[k] = v
163+
}
164+
}
165+
142166
// Handle the timeout up front.
143167
var cancel func()
144168
if timeoutStr, ok := req.Options[cmds.TimeoutOpt]; ok {
@@ -163,13 +187,6 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
163187
defer done()
164188
}
165189

166-
// set user's headers first.
167-
for k, v := range h.cfg.Headers {
168-
if !skipAPIHeader(k) {
169-
w.Header()[k] = v
170-
}
171-
}
172-
173190
h.root.Call(req, re, h.env)
174191
}
175192

@@ -180,8 +197,11 @@ func sanitizedErrStr(err error) string {
180197
return s
181198
}
182199

183-
func setAllowedHeaders(w http.ResponseWriter, methods []string) {
184-
for _, m := range methods {
185-
w.Header().Add("Allow", m)
200+
func setAllowedHeaders(w http.ResponseWriter, allowGet bool) {
201+
w.Header().Add("Allow", http.MethodHead)
202+
w.Header().Add("Allow", http.MethodOptions)
203+
w.Header().Add("Allow", http.MethodPost)
204+
if allowGet {
205+
w.Header().Add("Allow", http.MethodGet)
186206
}
187207
}

http/handler_test.go

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ var (
292292
}
293293
)
294294

295-
func getTestServer(t *testing.T, origins []string, handledMethods []string) (cmds.Environment, *httptest.Server) {
295+
func getTestServer(t *testing.T, origins []string, allowGet bool) (cmds.Environment, *httptest.Server) {
296296
if len(origins) == 0 {
297297
origins = defaultOrigins
298298
}
@@ -306,12 +306,7 @@ func getTestServer(t *testing.T, origins []string, handledMethods []string) (cmd
306306
}
307307

308308
srvCfg := originCfg(origins)
309-
310-
if len(handledMethods) == 0 {
311-
srvCfg.HandledMethods = []string{"GET", "POST"}
312-
} else {
313-
srvCfg.HandledMethods = handledMethods
314-
}
309+
srvCfg.AllowGet = allowGet
315310

316311
return env, httptest.NewServer(NewHandler(env, cmdRoot, srvCfg))
317312
}

http/http_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ func TestHTTP(t *testing.T) {
8888

8989
mkTest := func(tc testcase) func(*testing.T) {
9090
return func(t *testing.T) {
91-
env, srv := getTestServer(t, nil, nil) // handler_test:/^func getTestServer/
91+
env, srv := getTestServer(t, nil, true) // handler_test:/^func getTestServer/
9292
c := NewClient(srv.URL)
9393
req, err := cmds.NewRequest(context.Background(), tc.path, nil, nil, nil, cmdRoot)
9494
if err != nil {

http/reforigin_test.go

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,42 @@ import (
44
"fmt"
55
"net/http"
66
"net/url"
7+
"strings"
78
"testing"
89

910
cmds "github.com/ipfs/go-ipfs-cmds"
1011
)
1112

1213
func assertHeaders(t *testing.T, resHeaders http.Header, reqHeaders map[string]string) {
14+
t.Helper()
15+
t.Logf("headers: %v", resHeaders)
1316
for name, value := range reqHeaders {
14-
if resHeaders.Get(name) != value {
15-
t.Errorf("Invalid header '%s', wanted '%s', got '%s'", name, value, resHeaders.Get(name))
17+
header := resHeaders[http.CanonicalHeaderKey(name)]
18+
switch len(header) {
19+
case 0:
20+
if value != "" {
21+
t.Errorf("expected a header for %s", name)
22+
}
23+
case 1:
24+
if header[0] != value {
25+
t.Errorf("Invalid header '%s', wanted '%s', got '%s'", name, value, header[0])
26+
}
27+
default:
28+
values := strings.Split(value, ",")
29+
set := make(map[string]bool, len(values))
30+
for _, v := range values {
31+
set[strings.Trim(v, " ")] = true
32+
}
33+
for _, got := range header {
34+
if !set[got] {
35+
t.Errorf("found unexpected value %s in header %s", got, name)
36+
continue
37+
}
38+
delete(set, got)
39+
}
40+
for missing := range set {
41+
t.Errorf("missing value %s in header %s", missing, name)
42+
}
1643
}
1744
}
1845
}
@@ -27,7 +54,7 @@ func originCfg(origins []string) *ServerConfig {
2754
cfg := NewServerConfig()
2855
cfg.SetAllowedOrigins(origins...)
2956
cfg.SetAllowedMethods("GET", "PUT", "POST")
30-
cfg.HandledMethods = []string{"GET", "POST"}
57+
cfg.AllowGet = true
3158
return cfg
3259
}
3360

@@ -39,18 +66,19 @@ var defaultOrigins = []string{
3966
}
4067

4168
type httpTestCase struct {
42-
Method string
43-
Path string
44-
Code int
45-
Origin string
46-
Referer string
47-
AllowOrigins []string
48-
HandledMethods []string
49-
ReqHeaders map[string]string
50-
ResHeaders map[string]string
69+
Method string
70+
Path string
71+
Code int
72+
Origin string
73+
Referer string
74+
AllowOrigins []string
75+
AllowGet bool
76+
ReqHeaders map[string]string
77+
ResHeaders map[string]string
5178
}
5279

5380
func (tc *httpTestCase) test(t *testing.T) {
81+
t.Helper()
5482
// defaults
5583
method := tc.Method
5684
if method == "" {
@@ -85,7 +113,7 @@ func (tc *httpTestCase) test(t *testing.T) {
85113
}
86114

87115
// server
88-
_, server := getTestServer(t, tc.AllowOrigins, tc.HandledMethods)
116+
_, server := getTestServer(t, tc.AllowOrigins, tc.AllowGet)
89117
if server == nil {
90118
return
91119
}
@@ -114,6 +142,7 @@ func TestDisallowedOrigins(t *testing.T) {
114142
return httpTestCase{
115143
Origin: origin,
116144
AllowOrigins: allowedOrigins,
145+
AllowGet: true,
117146
ResHeaders: map[string]string{
118147
ACAOrigin: "",
119148
ACAMethods: "",
@@ -144,6 +173,7 @@ func TestAllowedOrigins(t *testing.T) {
144173
return httpTestCase{
145174
Origin: origin,
146175
AllowOrigins: allowedOrigins,
176+
AllowGet: true,
147177
ResHeaders: map[string]string{
148178
ACAOrigin: origin,
149179
ACAMethods: "",
@@ -171,6 +201,7 @@ func TestWildcardOrigin(t *testing.T) {
171201
gtc := func(origin string, allowedOrigins []string) httpTestCase {
172202
return httpTestCase{
173203
Origin: origin,
204+
AllowGet: true,
174205
AllowOrigins: allowedOrigins,
175206
ResHeaders: map[string]string{
176207
ACAOrigin: "*",
@@ -204,6 +235,7 @@ func TestDisallowedReferer(t *testing.T) {
204235
return httpTestCase{
205236
Origin: "http://localhost",
206237
Referer: referer,
238+
AllowGet: true,
207239
AllowOrigins: allowedOrigins,
208240
ResHeaders: map[string]string{
209241
ACAOrigin: "http://localhost",
@@ -232,6 +264,7 @@ func TestAllowedReferer(t *testing.T) {
232264
return httpTestCase{
233265
Origin: "http://localhost",
234266
AllowOrigins: allowedOrigins,
267+
AllowGet: true,
235268
ResHeaders: map[string]string{
236269
ACAOrigin: "http://localhost",
237270
ACAMethods: "",
@@ -260,6 +293,7 @@ func TestWildcardReferer(t *testing.T) {
260293
return httpTestCase{
261294
Origin: origin,
262295
AllowOrigins: allowedOrigins,
296+
AllowGet: true,
263297
ResHeaders: map[string]string{
264298
ACAOrigin: "*",
265299
ACAMethods: "",
@@ -338,6 +372,7 @@ func TestEncoding(t *testing.T) {
338372
return httpTestCase{
339373
Method: "GET",
340374
Path: path,
375+
AllowGet: true,
341376
Origin: "http://localhost",
342377
AllowOrigins: []string{"*"},
343378
ReqHeaders: map[string]string{

http/responseemitter.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ func (re *responseEmitter) Emit(value interface{}) error {
106106
var err error
107107

108108
// return immediately if this is a head request
109-
if re.method == "HEAD" {
109+
if re.method == http.MethodHead {
110110
return nil
111111
}
112112

0 commit comments

Comments
 (0)