-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathmiddleware.go
179 lines (146 loc) · 4.77 KB
/
middleware.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
package gonertia
import (
"bytes"
"io"
"net/http"
)
// Middleware returns Inertia middleware handler.
//
// All of your handlers that can be handled by
// the Inertia should be under this middleware.
func (i *Inertia) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Set header Vary to "X-Inertia".
//
// https://github.com/inertiajs/inertia-laravel/pull/404
setInertiaVaryInResponse(w)
// Resolve validation errors and clear history from the flash data provider.
{
r = i.resolveValidationErrors(r)
r = i.resolveClearHistory(r)
}
if !IsInertiaRequest(r) {
next.ServeHTTP(w, r)
return
}
// Now we know that this request was made by Inertia.
//
// But there is one problem:
// http.ResponseWriter has no methods for getting the response status code and response content.
// So, we have to create our own response writer wrapper, that will contain that info.
//
// It's not critical that we will have a byte buffer, because we
// know that Inertia response always in JSON format and actually not very big.
w2 := buildInertiaResponseWrapper(w)
// Now put our response writer wrapper to other handlers.
next.ServeHTTP(w2, r)
// Determines what to do when the Inertia asset version has changed.
// By default, we'll initiate a client-side location visit to force an update.
//
// https://inertiajs.com/asset-versioning
if r.Method == http.MethodGet && inertiaVersionFromRequest(r) != i.version {
i.Location(w, r, r.URL.RequestURI())
return
}
// Our response writer wrapper does have all needle data! Yuppy!
//
// Don't forget to copy all data to the original
// response writer before end!
defer i.copyWrapperResponse(w, w2)
// Determines what to do when an Inertia action returned empty response.
// By default, we will redirect the user back to where he came from.
if w2.StatusCode() == http.StatusOK && w2.IsEmpty() {
i.Back(w2, r)
}
// The PUT, PATCH and DELETE requests cannot have the 302 code status.
// Let's set the status code to the 303 instead.
//
// https://inertiajs.com/redirects#303-response-code
if w2.StatusCode() == http.StatusFound && isSeeOtherRedirectMethod(r.Method) {
setResponseStatus(w2, http.StatusSeeOther)
}
})
}
func (i *Inertia) resolveValidationErrors(r *http.Request) *http.Request {
if i.flash == nil {
return r
}
validationErrors, err := i.flash.GetErrors(r.Context())
if err != nil {
i.logger.Printf("get validation errors from the flash data provider error: %s", err)
return r
}
if len(validationErrors) == 0 {
return r
}
return r.WithContext(SetValidationErrors(r.Context(), validationErrors))
}
func (i *Inertia) resolveClearHistory(r *http.Request) *http.Request {
if i.flash == nil {
return r
}
clearHistory, err := i.flash.ShouldClearHistory(r.Context())
if err != nil {
i.logger.Printf("get clear history flag from the flash data provider error: %s", err)
return r
}
if clearHistory {
r = r.WithContext(ClearHistory(r.Context()))
}
return r
}
func (i *Inertia) copyWrapperResponse(dst http.ResponseWriter, src *inertiaResponseWrapper) {
i.copyWrapperHeaders(dst, src)
i.copyWrapperStatusCode(dst, src)
i.copyWrapperBuffer(dst, src)
}
func (i *Inertia) copyWrapperBuffer(dst http.ResponseWriter, src *inertiaResponseWrapper) {
if _, err := io.Copy(dst, src.buf); err != nil {
i.logger.Printf("cannot copy inertia response buffer to writer: %s", err)
}
}
func (i *Inertia) copyWrapperStatusCode(dst http.ResponseWriter, src *inertiaResponseWrapper) {
dst.WriteHeader(src.statusCode)
}
func (i *Inertia) copyWrapperHeaders(dst http.ResponseWriter, src *inertiaResponseWrapper) {
for key, headers := range src.header {
dst.Header().Del(key)
for _, header := range headers {
dst.Header().Add(key, header)
}
}
}
type inertiaResponseWrapper struct {
statusCode int
buf *bytes.Buffer
header http.Header
}
var _ http.ResponseWriter = (*inertiaResponseWrapper)(nil)
func (w *inertiaResponseWrapper) StatusCode() int {
return w.statusCode
}
func (w *inertiaResponseWrapper) IsEmpty() bool {
return w.buf.Len() == 0
}
func (w *inertiaResponseWrapper) Header() http.Header {
return w.header
}
func (w *inertiaResponseWrapper) Write(p []byte) (int, error) {
return w.buf.Write(p)
}
func (w *inertiaResponseWrapper) WriteHeader(code int) {
w.statusCode = code
}
func buildInertiaResponseWrapper(w http.ResponseWriter) *inertiaResponseWrapper {
w2 := &inertiaResponseWrapper{
statusCode: http.StatusOK,
buf: bytes.NewBuffer(nil),
header: w.Header(),
}
// In some situations, we can pass a http.ResponseWriter,
// that also implements this interface.
if val, ok := w.(interface{ StatusCode() int }); ok {
w2.statusCode = val.StatusCode()
}
return w2
}