Skip to content

Commit bece620

Browse files
committed
feat: support aksk verify
1 parent 5ef7f47 commit bece620

12 files changed

+327
-153
lines changed

api/api_impl/server.go

+11-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
"net/url"
1010
"strings"
1111

12+
"github.com/jiaozifs/jiaozifs/auth/aksk"
13+
1214
"github.com/hellofresh/health-go/v5"
1315

1416
"github.com/getkin/kin-openapi/openapi3"
@@ -37,7 +39,14 @@ const (
3739
extensionValidationExcludeBody = "x-validation-exclude-body"
3840
)
3941

40-
func SetupAPI(lc fx.Lifecycle, apiConfig *config.APIConfig, secretStore crypt.SecretStore, sessionStore sessions.Store, repo models.IRepo, controller APIController) error {
42+
func SetupAPI(lc fx.Lifecycle,
43+
authenticator *auth.BasicAuthenticator,
44+
apiConfig *config.APIConfig,
45+
secretStore crypt.SecretStore,
46+
sessionStore sessions.Store,
47+
repo models.IRepo,
48+
verifier aksk.Verifier,
49+
controller APIController) error {
4150
swagger, err := api.GetSwagger()
4251
if err != nil {
4352
return err
@@ -66,7 +75,7 @@ func SetupAPI(lc fx.Lifecycle, apiConfig *config.APIConfig, secretStore crypt.Se
6675
OapiRequestValidatorWithOptions(swagger, &openapi3filter.Options{
6776
AuthenticationFunc: openapi3filter.NoopAuthenticationFunc,
6877
}),
69-
auth.Middleware(swagger, nil, secretStore, repo.UserRepo(), sessionStore),
78+
auth.Middleware(swagger, authenticator, secretStore, repo.UserRepo(), repo.AkskRepo(), sessionStore, verifier),
7079
)
7180

7281
raw, err := api.RawSpec()

api/jiaozifs.gen.go

+154-81
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

api/swagger.yml

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ security:
1717
- jwt_token: [] # Default security for the entire API
1818
- basic_auth: []
1919
- cookie_auth: []
20+
- ak_sk: []
2021
components:
2122
parameters:
2223
PaginationPrefix:

auth/aksk.go

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package auth
2+
3+
import (
4+
"context"
5+
6+
"github.com/jiaozifs/jiaozifs/auth/aksk"
7+
"github.com/jiaozifs/jiaozifs/models"
8+
)
9+
10+
var _ aksk.SkGetter = (*SkGetter)(nil)
11+
12+
type SkGetter struct {
13+
akskRepo models.IAkskRepo
14+
}
15+
16+
func (s SkGetter) Get(ak string) (string, error) {
17+
aksk, err := s.akskRepo.Get(context.Background(), models.NewGetAkSkParams().SetAccessKey(ak))
18+
if err != nil {
19+
return "", err
20+
}
21+
return aksk.SecretKey, nil
22+
}
23+
24+
func NewAkskVerifier(repo models.IRepo) aksk.Verifier {
25+
return aksk.NewV0Verier(SkGetter{repo.AkskRepo()})
26+
}

auth/aksk/sign.go

+11-5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ import (
1212
)
1313

1414
const (
15+
AccessKeykey = "JiaozifsAccessKeyId"
16+
SignatureVersionKey = "SignatureVersion"
17+
SignatureMethodKey = "SignatureMethod"
18+
TimestampKey = "Timestamp"
19+
SignatureKey = "Signature"
20+
1521
signatureVersion = "0"
1622
signatureMethod = "HmacSHA256"
1723
timeFormat = "2006-01-02T15:04:05Z"
@@ -37,10 +43,10 @@ func (voSigner V0Signer) Sign(req *http.Request) error {
3743
curTime := time.Now()
3844
// set query parameter
3945
query := req.URL.Query()
40-
query.Set("AWSAccessKeyId", voSigner.accessKey)
41-
query.Set("SignatureVersion", signatureVersion)
42-
query.Set("SignatureMethod", signatureMethod)
43-
query.Set("Timestamp", curTime.UTC().Format(timeFormat))
46+
query.Set(AccessKeykey, voSigner.accessKey)
47+
query.Set(SignatureVersionKey, signatureVersion)
48+
query.Set(SignatureVersionKey, signatureMethod)
49+
query.Set(TimestampKey, curTime.UTC().Format(timeFormat))
4450

4551
req.Header.Del("Signature")
4652

@@ -80,7 +86,7 @@ func (voSigner V0Signer) Sign(req *http.Request) error {
8086
hash := hmac.New(sha256.New, []byte(voSigner.secretKey))
8187
hash.Write([]byte(stringToSign))
8288
signature := base64.StdEncoding.EncodeToString(hash.Sum(nil))
83-
query.Set("Signature", signature)
89+
query.Set(SignatureKey, signature)
8490

8591
req.URL.RawQuery = query.Encode()
8692
return nil

auth/aksk/verifier.go

+26-18
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,18 @@ import (
1212
"time"
1313
)
1414

15-
type V0Verifier interface {
16-
Verify(req *http.Request) error
15+
type Verifier interface {
16+
// IsAkskCredential check the requess is aksk credential, just check request have AccessKeykey
17+
IsAkskCredential(req *http.Request) bool
18+
// Verify verify the request and return access key
19+
Verify(req *http.Request) (string, error)
1720
}
1821

1922
type SkGetter interface {
2023
Get(ak string) (string, error)
2124
}
2225

23-
var _ V0Verifier = (*V0Verier)(nil)
26+
var _ Verifier = (*V0Verier)(nil)
2427

2528
type V0Verier struct {
2629
skGetter SkGetter
@@ -30,38 +33,43 @@ func NewV0Verier(skGetter SkGetter) *V0Verier {
3033
return &V0Verier{skGetter: skGetter}
3134
}
3235

33-
func (v *V0Verier) Verify(req *http.Request) error {
36+
func (v *V0Verier) IsAkskCredential(req *http.Request) bool {
37+
accessKey := req.URL.Query().Get(AccessKeykey)
38+
return len(accessKey) > 0
39+
}
40+
41+
func (v *V0Verier) Verify(req *http.Request) (string, error) {
3442
query := req.URL.Query()
35-
accessKey := query.Get("AWSAccessKeyId")
43+
accessKey := query.Get(AccessKeykey)
3644
if len(accessKey) == 0 {
37-
return fmt.Errorf("ak not found")
45+
return "", fmt.Errorf("ak not found")
3846
}
3947

4048
secretKey, err := v.skGetter.Get(accessKey)
4149
if err != nil {
42-
return fmt.Errorf("access key not correct")
50+
return "", fmt.Errorf("access key not correct")
4351
}
4452

45-
sigMethod := query.Get("SignatureMethod")
53+
sigMethod := query.Get(SignatureMethodKey)
4654
if sigMethod != signatureMethod {
47-
return fmt.Errorf("invalid signature method %s", sigMethod)
55+
return "", fmt.Errorf("invalid signature method %s", sigMethod)
4856
}
4957

50-
sigVersion := query.Get("SignatureVersion")
58+
sigVersion := query.Get(SignatureVersionKey)
5159
if sigVersion != signatureVersion {
52-
return fmt.Errorf("invalid signature method %s", sigMethod)
60+
return "", fmt.Errorf("invalid signature method %s", sigMethod)
5361
}
5462

55-
reqTime := query.Get("Timestamp")
63+
reqTime := query.Get(TimestampKey)
5664
t, err := time.Parse(timeFormat, reqTime)
5765
if err != nil {
58-
return fmt.Errorf("invalid timestamp %s", reqTime)
66+
return "", fmt.Errorf("invalid timestamp %s", reqTime)
5967
}
6068
if t.Before(time.Now().Add(-5 * time.Minute)) {
61-
return fmt.Errorf("request is out of data")
69+
return "", fmt.Errorf("request is out of data")
6270
}
63-
expectSignature := query.Get("Signature")
64-
query.Del("Signature")
71+
expectSignature := query.Get(SignatureKey)
72+
query.Del(SignatureKey)
6573

6674
method := req.Method
6775
host := req.URL.Host
@@ -99,7 +107,7 @@ func (v *V0Verier) Verify(req *http.Request) error {
99107
hash.Write([]byte(stringToSign))
100108
actualSig := base64.StdEncoding.EncodeToString(hash.Sum(nil))
101109
if actualSig != expectSignature {
102-
return fmt.Errorf("signature not correct")
110+
return "", fmt.Errorf("signature not correct")
103111
}
104-
return nil
112+
return accessKey, nil
105113
}

auth/aksk/verifier_test.go

+10-9
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ func TestFull(t *testing.T) {
2424
err = signer.Sign(req)
2525
require.NoError(t, err)
2626

27-
err = verifer.Verify(req)
27+
actualAk, err := verifer.Verify(req)
2828
require.NoError(t, err)
29+
require.Equal(t, ak, actualAk)
2930
})
3031

3132
t.Run("fail verify", func(t *testing.T) {
@@ -39,7 +40,7 @@ func TestFull(t *testing.T) {
3940
query := req.URL.Query()
4041
query.Add("a", "b")
4142
req.URL.RawQuery = query.Encode()
42-
err = verifer.Verify(req)
43+
_, err = verifer.Verify(req)
4344
require.Error(t, err)
4445
})
4546
t.Run("no access id", func(t *testing.T) {
@@ -53,7 +54,7 @@ func TestFull(t *testing.T) {
5354
query := req.URL.Query()
5455
query.Del("AWSAccessKeyId")
5556
req.URL.RawQuery = query.Encode()
56-
err = verifer.Verify(req)
57+
_, err = verifer.Verify(req)
5758
require.Error(t, err)
5859
})
5960
t.Run("sig method fail", func(t *testing.T) {
@@ -67,7 +68,7 @@ func TestFull(t *testing.T) {
6768
query := req.URL.Query()
6869
query.Set("SignatureMethod", "2")
6970
req.URL.RawQuery = query.Encode()
70-
err = verifer.Verify(req)
71+
_, err = verifer.Verify(req)
7172
require.Error(t, err)
7273
})
7374
t.Run("sig method fail", func(t *testing.T) {
@@ -81,7 +82,7 @@ func TestFull(t *testing.T) {
8182
query := req.URL.Query()
8283
query.Set("SignatureMethod", "md5")
8384
req.URL.RawQuery = query.Encode()
84-
err = verifer.Verify(req)
85+
_, err = verifer.Verify(req)
8586
require.Error(t, err)
8687
})
8788
t.Run("sig version fail", func(t *testing.T) {
@@ -95,7 +96,7 @@ func TestFull(t *testing.T) {
9596
query := req.URL.Query()
9697
query.Set("SignatureVersion", "2")
9798
req.URL.RawQuery = query.Encode()
98-
err = verifer.Verify(req)
99+
_, err = verifer.Verify(req)
99100
require.Error(t, err)
100101
})
101102
t.Run("no timestamp", func(t *testing.T) {
@@ -109,7 +110,7 @@ func TestFull(t *testing.T) {
109110
query := req.URL.Query()
110111
query.Del("Timestamp")
111112
req.URL.RawQuery = query.Encode()
112-
err = verifer.Verify(req)
113+
_, err = verifer.Verify(req)
113114
require.Error(t, err)
114115
})
115116

@@ -124,7 +125,7 @@ func TestFull(t *testing.T) {
124125
query := req.URL.Query()
125126
query.Set("Timestamp", time.Now().String())
126127
req.URL.RawQuery = query.Encode()
127-
err = verifer.Verify(req)
128+
_, err = verifer.Verify(req)
128129
require.Error(t, err)
129130
})
130131
t.Run("request out of date", func(t *testing.T) {
@@ -138,7 +139,7 @@ func TestFull(t *testing.T) {
138139
query := req.URL.Query()
139140
query.Set("Timestamp", time.Now().Add(-time.Minute*10).UTC().Format(timeFormat))
140141
req.URL.RawQuery = query.Encode()
141-
err = verifer.Verify(req)
142+
_, err = verifer.Verify(req)
142143
require.Error(t, err)
143144
})
144145
}

auth/auth_middleware.go

+43-11
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import (
77
"net/http"
88
"strings"
99

10+
"github.com/jiaozifs/jiaozifs/auth/aksk"
11+
1012
"github.com/golang-jwt/jwt/v5"
1113

1214
"github.com/getkin/kin-openapi/openapi3"
@@ -53,11 +55,19 @@ type CookieAuthConfig struct {
5355
AuthSource string
5456
}
5557

56-
func Middleware(swagger *openapi3.T, authenticator Authenticator, secretStore crypt.SecretStore, userRepo models.IUserRepo, sessionStore sessions.Store) func(next http.Handler) http.Handler {
58+
func Middleware(swagger *openapi3.T,
59+
authenticator *BasicAuthenticator,
60+
secretStore crypt.SecretStore,
61+
userRepo models.IUserRepo,
62+
akskRepo models.IAkskRepo,
63+
sessionStore sessions.Store,
64+
verifier aksk.Verifier,
65+
) func(next http.Handler) http.Handler {
5766
router, err := legacy.NewRouter(swagger)
5867
if err != nil {
5968
panic(err)
6069
}
70+
6171
return func(next http.Handler) http.Handler {
6272
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
6373
// if request already authenticated
@@ -72,7 +82,7 @@ func Middleware(swagger *openapi3.T, authenticator Authenticator, secretStore cr
7282
_, _ = w.Write([]byte(err.Error()))
7383
return
7484
}
75-
user, err := checkSecurityRequirements(r, securityRequirements, authenticator, sessionStore, secretStore, userRepo)
85+
user, err := checkSecurityRequirements(r, securityRequirements, authenticator, sessionStore, secretStore, verifier, userRepo, akskRepo)
7686
if err != nil {
7787
w.WriteHeader(http.StatusUnauthorized)
7888
_, _ = w.Write([]byte(err.Error()))
@@ -90,10 +100,12 @@ func Middleware(swagger *openapi3.T, authenticator Authenticator, secretStore cr
90100
// it will return nil user and error in case of no security checks to match.
91101
func checkSecurityRequirements(r *http.Request,
92102
securityRequirements openapi3.SecurityRequirements,
93-
authenticator Authenticator,
103+
authenticator *BasicAuthenticator,
94104
sessionStore sessions.Store,
95105
secretStore crypt.SecretStore,
106+
verifier aksk.Verifier,
96107
userRepo models.IUserRepo,
108+
akskRepo models.IAkskRepo,
97109
) (*models.User, error) {
98110
ctx := r.Context()
99111
var user *models.User
@@ -120,7 +132,8 @@ func checkSecurityRequirements(r *http.Request,
120132
if !ok {
121133
continue
122134
}
123-
user, err = userByAuth(ctx, authenticator, userRepo, userName, password)
135+
136+
user, err = userByAuth(ctx, authenticator, userName, password)
124137
case "cookie_auth":
125138
var internalAuthSession *sessions.Session
126139
internalAuthSession, _ = sessionStore.Get(r, InternalAuthSessionName)
@@ -132,6 +145,12 @@ func checkSecurityRequirements(r *http.Request,
132145
continue
133146
}
134147
user, err = userByToken(ctx, userRepo, secretStore.SharedSecret(), token)
148+
case "ak_sk":
149+
isAkskRequest := verifier.IsAkskCredential(r)
150+
if !isAkskRequest {
151+
continue
152+
}
153+
user, err = userByAKSK(ctx, akskRepo, userRepo, verifier, r)
135154
default:
136155
// unknown security requirement to check
137156
log.With("provider", provider).Error("Authentication middleware unknown security requirement provider")
@@ -149,6 +168,24 @@ func checkSecurityRequirements(r *http.Request,
149168
return nil, nil
150169
}
151170

171+
func userByAKSK(ctx context.Context, akskRepo models.IAkskRepo, userRepo models.IUserRepo, verifier aksk.Verifier, r *http.Request) (*models.User, error) {
172+
ak, err := verifier.Verify(r)
173+
if err != nil {
174+
return nil, err
175+
}
176+
177+
akModel, err := akskRepo.Get(ctx, models.NewGetAkSkParams().SetAccessKey(ak))
178+
if err != nil {
179+
return nil, err
180+
}
181+
182+
userModel, err := userRepo.Get(ctx, models.NewGetUserParams().SetID(akModel.UserID))
183+
if err != nil {
184+
return nil, err
185+
}
186+
return userModel, nil
187+
}
188+
152189
func userByToken(ctx context.Context, userRepo models.IUserRepo, secret []byte, tokenString string) (*models.User, error) {
153190
claims, err := VerifyToken(secret, tokenString)
154191
if err != nil {
@@ -177,16 +214,11 @@ func userByToken(ctx context.Context, userRepo models.IUserRepo, secret []byte,
177214
return userData, nil
178215
}
179216

180-
func userByAuth(ctx context.Context, authenticator Authenticator, userRepo models.IUserRepo, accessKey string, secretKey string) (*models.User, error) {
181-
username, err := authenticator.AuthenticateUser(ctx, accessKey, secretKey)
217+
func userByAuth(ctx context.Context, authenticator *BasicAuthenticator, accessKey string, secretKey string) (*models.User, error) {
218+
user, err := authenticator.AuthenticateUser(ctx, accessKey, secretKey)
182219
if err != nil {
183220
log.With("user", accessKey).Errorf("authenticate %v", err)
184221
return nil, ErrAuthenticatingRequest
185222
}
186-
user, err := userRepo.Get(ctx, models.NewGetUserParams().SetName(username))
187-
if err != nil {
188-
log.With("user_name", username).Debugf("could not find user id by credentials %s", err)
189-
return nil, ErrAuthenticatingRequest
190-
}
191223
return user, nil
192224
}

0 commit comments

Comments
 (0)