Skip to content

Commit 4ad66c0

Browse files
authored
Merge pull request #15 from go-rs/develop
Fixed context related issue
2 parents 7108e20 + e5865fd commit 4ad66c0

File tree

5 files changed

+30
-18
lines changed

5 files changed

+30
-18
lines changed

README.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,6 @@ user.Get("/:uid", func(ctx *rest.Context) {
7777
})
7878
```
7979

80-
###Pending
81-
- Stop execution on timeout/abort
82-
8380
## Documentation
8481
https://godoc.org/github.com/go-rs/rest-api-framework
8582

api.go

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ func (api *API) ServeHTTP(res http.ResponseWriter, req *http.Request) {
171171
ctx.init()
172172
defer ctx.destroy()
173173

174+
// recovery/handle any runtime error
174175
defer func() {
175176
err := recover()
176177
if err != nil {
@@ -183,9 +184,18 @@ func (api *API) ServeHTTP(res http.ResponseWriter, req *http.Request) {
183184
}
184185
}()
185186

187+
// On context done, stop execution
188+
go func() {
189+
c := req.Context()
190+
select {
191+
case <-c.Done():
192+
ctx.End()
193+
}
194+
}()
195+
186196
// STEP 2: execute all interceptors
187197
for _, task := range api.interceptors {
188-
if ctx.end || ctx.code != "" {
198+
if ctx.shouldBreak() {
189199
break
190200
}
191201

@@ -195,34 +205,35 @@ func (api *API) ServeHTTP(res http.ResponseWriter, req *http.Request) {
195205
// STEP 3: check routes
196206
urlPath := []byte(req.URL.Path)
197207
for _, route := range api.routes {
198-
if ctx.end || ctx.code != "" {
208+
if ctx.shouldBreak() {
199209
break
200210
}
201211

202212
if (route.method == "" || strings.EqualFold(route.method, req.Method)) && route.regex.Match(urlPath) {
203-
ctx.found = route.method != "" //?
204213
ctx.Params = utils.Exec(route.regex, route.params, urlPath)
205214
route.handle(ctx)
206215
}
207216
}
208217

209218
// STEP 4: check handled exceptions
210219
for _, exp := range api.exceptions {
211-
if ctx.end || ctx.code == "" {
220+
if ctx.shouldBreak() {
212221
break
213222
}
214223

215-
if exp.code == ctx.code {
224+
if strings.EqualFold(exp.code, ctx.code) {
216225
exp.handle(ctx)
217226
}
218227
}
219228

220229
// STEP 5: unhandled exceptions
221230
if !ctx.end {
222-
if ctx.code == "" && !ctx.found {
231+
// if no error and still not ended that means it NOT FOUND
232+
if ctx.code == "" {
223233
ctx.Throw(ErrCodeNotFound)
224234
}
225235

236+
// if user has custom unhandled function, then execute it
226237
if api.unhandled != nil {
227238
api.unhandled(ctx)
228239
}

context.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package rest
66

77
import (
8+
"fmt"
89
"log"
910
"net/http"
1011
"net/url"
@@ -32,7 +33,6 @@ type Context struct {
3233
code string
3334
err error
3435
status int
35-
found bool
3636
end bool
3737
requestSent bool
3838
preTasksCalled bool
@@ -62,7 +62,6 @@ func (ctx *Context) destroy() {
6262
ctx.code = ""
6363
ctx.err = nil
6464
ctx.status = 0
65-
ctx.found = false
6665
ctx.end = false
6766
ctx.requestSent = false
6867
ctx.preTasksCalled = false
@@ -154,6 +153,10 @@ func (ctx *Context) PostSend(task Task) {
154153
}
155154

156155
//////////////////////////////////////////////////
156+
func (ctx *Context) shouldBreak() (flag bool) {
157+
return ctx.end || ctx.code != ""
158+
}
159+
157160
// Send data, which uses bytes or error if any
158161
// Also, it calls pre-send and post-send registered hooks
159162
func (ctx *Context) send(data []byte, err error) {
@@ -191,7 +194,7 @@ func (ctx *Context) send(data []byte, err error) {
191194

192195
if err != nil {
193196
//TODO: debugger mode
194-
log.Println("Response Error: ", err)
197+
log.Printf("response error: %v", err)
195198
}
196199
}
197200

@@ -220,9 +223,9 @@ func (ctx *Context) unhandledException() {
220223
}
221224

222225
if ctx.code != "" || ctx.err != nil {
223-
msg := "Error Code: " + ctx.code
226+
msg := fmt.Sprintf("error code: %v", ctx.code)
224227
if ctx.err != nil {
225-
msg += "\nError Message: " + ctx.err.Error()
228+
msg += fmt.Sprintf("\nerror message: %v", ctx.err)
226229
}
227230
ctx.SetHeader("Content-Type", "text/plain;charset=UTF-8")
228231
if ctx.status < 400 {
@@ -237,9 +240,9 @@ func (ctx *Context) recover() {
237240
err := recover()
238241
if err != nil {
239242
//TODO: debugger mode
240-
log.Println("Runtime Error: ", err)
243+
log.Printf("runtime error: %v", err)
241244
if !ctx.requestSent {
242-
http.Error(ctx.Response, "Internal Server Error", http.StatusInternalServerError)
245+
http.Error(ctx.Response, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
243246
}
244247
}
245248
}

examples/server.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@ func main() {
5454

5555
fmt.Println("Starting server.")
5656

57-
tout := http.TimeoutHandler(api, 100*time.Millisecond, "timeout")
57+
//tout := http.TimeoutHandler(api, 100*time.Millisecond, "timeout")
5858

5959
server := http.Server{
6060
Addr: ":8080",
61-
Handler: tout,
61+
Handler: api, //tout,
6262
}
6363

6464
if err := server.ListenAndServe(); err != nil {

examples/user/user.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"github.com/go-rs/rest-api-framework"
55
)
66

7+
// just a simple method, can use in better way
78
func Load(api *rest.API) {
89

910
var user = rest.Extend("/user", api)

0 commit comments

Comments
 (0)