Skip to content

Commit 5f5cc09

Browse files
committed
Added before hooks for plugins
1 parent 9046078 commit 5f5cc09

File tree

4 files changed

+47
-10
lines changed

4 files changed

+47
-10
lines changed

api.go

+8
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package main
22

33
import (
44
"github.com/gophergala2016/dbserver/plugins"
5+
"net/http"
56
"os"
67
)
78

@@ -11,6 +12,7 @@ type Api struct {
1112
MinVersion int
1213
Routes []*Route
1314
Plugins map[string]Plugin
15+
PluginsList []string
1416
}
1517

1618
func (self *Api) IsDeprecated(version int) bool {
@@ -28,13 +30,19 @@ func (self *Api) RegisterPlugin(name string, plugin Plugin) {
2830
}
2931
plugin.ParseConfig("plugins/" + name + ".toml")
3032
self.Plugins[name] = plugin
33+
self.PluginsList = append(self.PluginsList, name)
3134
}
3235

3336
func (self *Api) GetPlugin(name string) Plugin {
3437
return self.Plugins[name]
3538
}
3639

40+
func (self *Api) GetPlugins() []string {
41+
return self.PluginsList
42+
}
43+
3744
type Plugin interface {
3845
ParseConfig(path string) error
3946
Process(data map[string]interface{}, arg map[string]interface{}) *plugins.Response
47+
ProcessBeforeHook(data map[string]interface{}, r *http.Request)
4048
}

main.go

+17-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,15 @@ func handler(api *Api, route *Route, version int) func(http.ResponseWriter, *htt
7575
urlParams[urlParam.Key] = urlParam.Value
7676
}
7777
params, err := getRequestParams(r, urlParams)
78-
sql, err := route.Sql(params, apiVersion)
78+
if err != nil {
79+
w.WriteHeader(http.StatusBadRequest)
80+
return
81+
}
82+
data := make(map[string]interface{})
83+
data["params"] = params
84+
85+
runBeforeHooks(api, data, r)
86+
sql, err := route.Sql(data, apiVersion)
7987
if err != nil && sql != "" {
8088
w.WriteHeader(http.StatusBadRequest)
8189
fmt.Fprint(w, sql)
@@ -209,3 +217,11 @@ func goThroughPipelines(api *Api,
209217
}
210218
return nil
211219
}
220+
221+
func runBeforeHooks(api *Api, data map[string]interface{}, r *http.Request) {
222+
plugins := api.GetPlugins()
223+
for _, name := range plugins {
224+
plugin := api.GetPlugin(name)
225+
plugin.ProcessBeforeHook(data, r)
226+
}
227+
}

plugins/jwt/jwt.go

+18-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
"github.com/SermoDigital/jose/jws"
99
"github.com/gophergala2016/dbserver/plugins"
1010
"io/ioutil"
11+
"net/http"
12+
"strings"
1113
"time"
1214
)
1315

@@ -83,7 +85,19 @@ func (self *JWT) GenerateToken(payload map[string]interface{}) ([]byte, error) {
8385
return serializedToken, nil
8486
}
8587

86-
// app.Register(&JWT{}, "jwt") || app.Register("jwt", JWT)
87-
88-
// Hooks: 1. Before request
89-
// 2. Process - when called in pipeline
88+
func (self *JWT) ProcessBeforeHook(data map[string]interface{}, r *http.Request) {
89+
headerValue := r.Header.Get("Authorization")
90+
if headerValue == "" {
91+
return
92+
}
93+
if !strings.HasPrefix(headerValue, "Bearer ") {
94+
return
95+
}
96+
headerValue = strings.Replace(headerValue, "Bearer ", "", 1)
97+
//TODO: Verify secret
98+
token, err := jws.ParseJWT([]byte(headerValue))
99+
if err != nil {
100+
return
101+
}
102+
data["jwt"] = token.Claims()
103+
}

route.go

+4-5
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ type RouteVersion struct {
3232
SqlTemplate *template.Template
3333
}
3434

35-
func (self *Route) validate(params map[string]interface{}, version int) (string, error) {
35+
func (self *Route) validate(params interface{}, version int) (string, error) {
3636
route := self.Versions[version]
3737
if route == nil {
3838
return "", fmt.Errorf("Route version %v missing from %v route", version, self.Name)
@@ -59,14 +59,14 @@ func (self *Route) validate(params map[string]interface{}, version int) (string,
5959
return "", nil
6060
}
6161

62-
func (self *Route) Sql(params map[string]interface{}, version int) (string, error) {
62+
func (self *Route) Sql(data map[string]interface{}, version int) (string, error) {
6363
version = self.GetAvailableVersion(version)
6464
route := self.Versions[version]
6565
if route == nil {
6666
return "", fmt.Errorf("Route version %v missing from %v route", version, self.Name)
6767
}
6868
var out bytes.Buffer
69-
response, err := self.validate(params, version)
69+
response, err := self.validate(data["params"], version)
7070
if err != nil {
7171
return "", err
7272
}
@@ -76,7 +76,7 @@ func (self *Route) Sql(params map[string]interface{}, version int) (string, erro
7676
if !self.Custom {
7777
out.Write([]byte("with response_table as ("))
7878
}
79-
err = route.SqlTemplate.Execute(&out, params)
79+
err = route.SqlTemplate.Execute(&out, data)
8080
if err != nil {
8181
return "", err
8282
}
@@ -124,5 +124,4 @@ func makeTemplate(t string) (*template.Template, error) {
124124
"quote": quoteString,
125125
}
126126
return template.New("").Funcs(funcMap).Parse(t)
127-
128127
}

0 commit comments

Comments
 (0)