Skip to content

Commit b7d73a8

Browse files
committed
feat: add jwks support
1 parent 0fb6962 commit b7d73a8

File tree

14 files changed

+392
-48
lines changed

14 files changed

+392
-48
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ALTER TABLE tenants ADD COLUMN IF NOT EXISTS jwks jsonb DEFAULT NULL;

package-lock.json

Lines changed: 100 additions & 18 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
"fs-extra": "^10.0.1",
5252
"fs-xattr": "^0.3.1",
5353
"ioredis": "^5.2.4",
54-
"jsonwebtoken": "^9.0.0",
54+
"jsonwebtoken": "^9.0.2",
5555
"knex": "^2.4.2",
5656
"md5-file": "^5.0.0",
5757
"pg": "^8.10.0",
@@ -69,7 +69,7 @@
6969
"@types/fs-extra": "^9.0.13",
7070
"@types/jest": "^29.2.1",
7171
"@types/js-yaml": "^4.0.5",
72-
"@types/jsonwebtoken": "^8.5.8",
72+
"@types/jsonwebtoken": "^9.0.5",
7373
"@types/mustache": "^4.2.2",
7474
"@types/node": "^18.14.6",
7575
"@types/pg": "^8.6.4",

src/auth/jwt.ts

Lines changed: 132 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
1+
import * as crypto from 'crypto'
12
import jwt from 'jsonwebtoken'
3+
4+
import { getJwtSecret as getJwtSecretForTenant } from '../database/tenant'
25
import { getConfig } from '../config'
36

4-
const { jwtAlgorithm } = getConfig()
7+
const { isMultitenant, jwtSecret, jwtAlgorithm, jwtJWKS } = getConfig()
8+
9+
const JWT_HMAC_ALGOS: jwt.Algorithm[] = ['HS256', 'HS384', 'HS512']
10+
const JWT_RSA_ALGOS: jwt.Algorithm[] = ['RS256', 'RS384', 'RS512']
11+
const JWT_ECC_ALGOS: jwt.Algorithm[] = ['ES256', 'ES384', 'ES512']
12+
const JWT_ED_ALGOS: jwt.Algorithm[] = ['EdDSA'] as unknown as jwt.Algorithm[] // types for EdDSA not yet updated
513

614
interface jwtInterface {
715
sub?: string
@@ -20,17 +28,128 @@ export type SignedUploadToken = {
2028
exp: number
2129
}
2230

31+
export function findJWKFromHeader(
32+
header: jwt.JwtHeader,
33+
secret: string,
34+
jwks: { keys: { kid?: string; kty: string }[] } | null
35+
) {
36+
if (!jwks || !jwks.keys) {
37+
return secret
38+
}
39+
40+
if (JWT_HMAC_ALGOS.indexOf(header.alg as jwt.Algorithm) > -1) {
41+
// JWT is using HS, find the proper key
42+
43+
if (!header.kid && header.alg === jwtAlgorithm) {
44+
// jwt is probably signed with the static secret
45+
return secret
46+
}
47+
48+
// find the first key without a kid or with the matching kid and the "oct" type
49+
const jwk = jwks.keys.find(
50+
(key) => (!key.kid || key.kid === header.kid) && key.kty === 'oct' && (key as any).k
51+
)
52+
53+
if (!jwk) {
54+
// jwt is probably signed with the static secret
55+
return secret
56+
}
57+
58+
return Buffer.from((jwk as any).k, 'base64')
59+
}
60+
61+
// jwt is using an asymmetric algorithm
62+
let kty = 'RSA'
63+
64+
if (JWT_ECC_ALGOS.indexOf(header.alg as jwt.Algorithm) > -1) {
65+
kty = 'EC'
66+
} else if (JWT_ED_ALGOS.indexOf(header.alg as jwt.Algorithm) > -1) {
67+
kty = 'OKP'
68+
}
69+
70+
// find the first key with a matching kid (or no kid if none is specified in the JWT header) and the correct key type
71+
const jwk = jwks.keys.find((key) => {
72+
return ((!key.kid && !header.kid) || key.kid === header.kid) && key.kty === kty
73+
})
74+
75+
if (!jwk) {
76+
// couldn't find a matching JWK, try to use the secret
77+
return secret
78+
}
79+
80+
return crypto.createPublicKey({
81+
format: 'jwk',
82+
key: jwk,
83+
})
84+
}
85+
86+
function getJWTVerificationKey(
87+
secret: string,
88+
jwks: { keys: { kid?: string; kty: string }[] } | null
89+
): jwt.GetPublicKeyOrSecret {
90+
return (header: jwt.JwtHeader, callback: jwt.SigningKeyCallback) => {
91+
let result: any = null
92+
93+
try {
94+
result = findJWKFromHeader(header, secret, jwks)
95+
} catch (e: any) {
96+
callback(e)
97+
return
98+
}
99+
100+
callback(null, result)
101+
}
102+
}
103+
104+
export function getJWTAlgorithms(
105+
secret: string,
106+
jwks: { keys: { kid?: string; kty: string }[] } | null
107+
) {
108+
let algorithms: jwt.Algorithm[]
109+
110+
if (jwks && jwks.keys && jwks.keys.length) {
111+
const hasRSA = jwks.keys.find((key) => key.kty === 'RSA')
112+
const hasECC = jwks.keys.find((key) => key.kty === 'EC')
113+
const hasED = jwks.keys.find(
114+
(key) => key.kty === 'OKP' && ((key as any).crv === 'Ed25519' || (key as any).crv === 'Ed448')
115+
)
116+
const hasHS = jwks.keys.find((key) => key.kty === 'oct' && (key as any).k)
117+
118+
algorithms = [
119+
jwtAlgorithm as jwt.Algorithm,
120+
...(hasRSA ? JWT_RSA_ALGOS : []),
121+
...(hasECC ? JWT_ECC_ALGOS : []),
122+
...(hasED ? JWT_ED_ALGOS : []),
123+
...(hasHS ? JWT_HMAC_ALGOS : []),
124+
]
125+
} else {
126+
algorithms = [jwtAlgorithm as jwt.Algorithm]
127+
}
128+
129+
return algorithms
130+
}
131+
23132
/**
24133
* Verifies if a JWT is valid
25134
* @param token
26135
* @param secret
136+
* @param jwks
27137
*/
28-
export function verifyJWT<T>(token: string, secret: string): Promise<jwt.JwtPayload & T> {
138+
export function verifyJWT<T>(
139+
token: string,
140+
secret: string,
141+
jwks?: { keys: { kid?: string; kty: string }[] } | null
142+
): Promise<jwt.JwtPayload & T> {
29143
return new Promise((resolve, reject) => {
30-
jwt.verify(token, secret, { algorithms: [jwtAlgorithm as jwt.Algorithm] }, (err, decoded) => {
31-
if (err) return reject(err)
32-
resolve(decoded as jwt.JwtPayload & T)
33-
})
144+
jwt.verify(
145+
token,
146+
getJWTVerificationKey(secret, jwks || null),
147+
{ algorithms: getJWTAlgorithms(secret, jwks || null) },
148+
(err, decoded) => {
149+
if (err) return reject(err)
150+
resolve(decoded as jwt.JwtPayload & T)
151+
}
152+
)
34153
})
35154
}
36155

@@ -62,13 +181,13 @@ export function signJWT(
62181
* Extract the owner (user) from the provided JWT
63182
* @param token
64183
* @param secret
184+
* @param jwks
65185
*/
66-
export async function getOwner(token: string, secret: string): Promise<string | undefined> {
67-
const decodedJWT = await verifyJWT(token, secret)
186+
export async function getOwner(
187+
token: string,
188+
secret: string,
189+
jwks: { keys: { kid?: string; kty: string }[] } | null
190+
): Promise<string | undefined> {
191+
const decodedJWT = await verifyJWT(token, secret, jwks)
68192
return (decodedJWT as jwtInterface)?.sub
69193
}
70-
71-
export async function getRole(token: string, secret: string): Promise<string | undefined> {
72-
const decodedJWT = await verifyJWT(token, secret)
73-
return (decodedJWT as jwtInterface)?.role
74-
}

0 commit comments

Comments
 (0)