Skip to content

Commit 4fb8ffc

Browse files
qmuntaldagood
andauthored
Implement TLS1 PRF using the EVP_KDF API in OpenSSL 3 (#191)
* implement TLS1 PRF using the EVP_KDF API in OpenSSL 3 * use uppercase for addUTF8String * finalize param builders earlier * remove redundant condition in SupportsTLS1PRF * Apply suggestions from code review Co-authored-by: Davis Goodin <[email protected]> * code review suggestions --------- Co-authored-by: Davis Goodin <[email protected]>
1 parent 731002f commit 4fb8ffc

File tree

4 files changed

+209
-48
lines changed

4 files changed

+209
-48
lines changed

Diff for: params.go

+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
//go:build !cmd_go_bootstrap
2+
3+
package openssl
4+
5+
// #include "goopenssl.h"
6+
import "C"
7+
import (
8+
"runtime"
9+
"unsafe"
10+
)
11+
12+
var (
13+
// KDF parameters
14+
_OSSL_KDF_PARAM_DIGEST = C.CString("digest")
15+
_OSSL_KDF_PARAM_SECRET = C.CString("secret")
16+
_OSSL_KDF_PARAM_SEED = C.CString("seed")
17+
)
18+
19+
// paramBuilder is a helper for building OSSL_PARAMs.
20+
// If an error occurs when adding a new parameter,
21+
// subsequent calls to add parameters are ignored
22+
// and build() will return the error.
23+
type paramBuilder struct {
24+
bld C.GO_OSSL_PARAM_BLD_PTR
25+
pinner runtime.Pinner
26+
27+
err error
28+
}
29+
30+
// newParamBuilder creates a new paramBuilder.
31+
func newParamBuilder() (*paramBuilder, error) {
32+
bld := C.go_openssl_OSSL_PARAM_BLD_new()
33+
if bld == nil {
34+
return nil, newOpenSSLError("OSSL_PARAM_BLD_new")
35+
}
36+
pb := &paramBuilder{bld: bld}
37+
runtime.SetFinalizer(pb, (*paramBuilder).finalize)
38+
return pb, nil
39+
}
40+
41+
// finalize frees the builder.
42+
func (b *paramBuilder) finalize() {
43+
if b.bld != nil {
44+
b.pinner.Unpin()
45+
C.go_openssl_OSSL_PARAM_BLD_free(b.bld)
46+
b.bld = nil
47+
}
48+
}
49+
50+
// check is used internally to enforce invariants and should not be called by users of paramBuilder.
51+
// Returns true if it's ok to add parameters to the builder or build it.
52+
// Returns false if there has been an error while adding a parameter.
53+
// Panics if the paramBuilder has been freed, e.g. if it has already been built.
54+
func (b *paramBuilder) check() bool {
55+
if b.err != nil {
56+
return false
57+
}
58+
if b.bld == nil {
59+
panic("openssl: paramBuilder has been freed")
60+
}
61+
return true
62+
}
63+
64+
// build creates an OSSL_PARAM from the builder.
65+
// The returned OSSL_PARAM must be freed with OSSL_PARAM_free.
66+
// If an error occurred while adding parameters, the error is returned
67+
// and the OSSL_PARAM is nil. Once build() is called, the builder is finalized
68+
// and cannot be reused.
69+
func (b *paramBuilder) build() (C.GO_OSSL_PARAM_PTR, error) {
70+
defer b.finalize()
71+
if !b.check() {
72+
return nil, b.err
73+
}
74+
param := C.go_openssl_OSSL_PARAM_BLD_to_param(b.bld)
75+
if param == nil {
76+
return nil, newOpenSSLError("OSSL_PARAM_BLD_build")
77+
}
78+
return param, nil
79+
}
80+
81+
// addUTF8String adds a NUL-terminated UTF-8 string to the builder.
82+
// size should not include the terminating NUL byte. If size is zero, then it will be calculated.
83+
func (b *paramBuilder) addUTF8String(name *C.char, value *C.char, size C.size_t) {
84+
if !b.check() {
85+
return
86+
}
87+
// OSSL_PARAM_BLD_push_utf8_string calculates the size if it is zero.
88+
if C.go_openssl_OSSL_PARAM_BLD_push_utf8_string(b.bld, name, value, size) != 1 {
89+
b.err = newOpenSSLError("OSSL_PARAM_BLD_push_utf8_string(" + C.GoString(name) + ")")
90+
}
91+
}
92+
93+
// addOctetString adds an octet string to the builder.
94+
// The value is pinned and will be unpinned when the builder is freed.
95+
func (b *paramBuilder) addOctetString(name *C.char, value []byte) {
96+
if !b.check() {
97+
return
98+
}
99+
if len(value) != 0 {
100+
b.pinner.Pin(&value[0])
101+
}
102+
if C.go_openssl_OSSL_PARAM_BLD_push_octet_string(b.bld, name, unsafe.Pointer(sbase(value)), C.size_t(len(value))) != 1 {
103+
b.err = newOpenSSLError("OSSL_PARAM_BLD_push_octet_string(" + C.GoString(name) + ")")
104+
}
105+
}

Diff for: shims.h

+7-3
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ typedef void* GO_OSSL_PARAM_PTR;
106106
typedef void* GO_CRYPTO_THREADID_PTR;
107107
typedef void* GO_EVP_SIGNATURE_PTR;
108108
typedef void* GO_DSA_PTR;
109+
typedef void* GO_EVP_KDF_PTR;
110+
typedef void* GO_EVP_KDF_CTX_PTR;
109111

110112
// #include <openssl/md5.h>
111113
typedef void* GO_MD5_CTX_PTR;
@@ -375,9 +377,6 @@ DEFINEFUNC_3_0(int, EVP_PKEY_up_ref, (GO_EVP_PKEY_PTR key), (key)) \
375377
DEFINEFUNC_LEGACY_1(int, EVP_PKEY_set1_EC_KEY, (GO_EVP_PKEY_PTR pkey, GO_EC_KEY_PTR key), (pkey, key)) \
376378
DEFINEFUNC_3_0(int, EVP_PKEY_CTX_set0_rsa_oaep_label, (GO_EVP_PKEY_CTX_PTR ctx, void *label, int len), (ctx, label, len)) \
377379
DEFINEFUNC(int, PKCS5_PBKDF2_HMAC, (const char *pass, int passlen, const unsigned char *salt, int saltlen, int iter, const GO_EVP_MD_PTR digest, int keylen, unsigned char *out), (pass, passlen, salt, saltlen, iter, digest, keylen, out)) \
378-
DEFINEFUNC_3_0(int, EVP_PKEY_CTX_set_tls1_prf_md, (GO_EVP_PKEY_CTX_PTR arg0, const GO_EVP_MD_PTR arg1), (arg0, arg1)) \
379-
DEFINEFUNC_3_0(int, EVP_PKEY_CTX_set1_tls1_prf_secret, (GO_EVP_PKEY_CTX_PTR arg0, const unsigned char *arg1, int arg2), (arg0, arg1, arg2)) \
380-
DEFINEFUNC_3_0(int, EVP_PKEY_CTX_add1_tls1_prf_seed, (GO_EVP_PKEY_CTX_PTR arg0, const unsigned char *arg1, int arg2), (arg0, arg1, arg2)) \
381380
DEFINEFUNC_1_1_1(int, EVP_PKEY_get_raw_public_key, (const GO_EVP_PKEY_PTR pkey, unsigned char *pub, size_t *len), (pkey, pub, len)) \
382381
DEFINEFUNC_1_1_1(int, EVP_PKEY_get_raw_private_key, (const GO_EVP_PKEY_PTR pkey, unsigned char *priv, size_t *len), (pkey, priv, len)) \
383382
DEFINEFUNC_3_0(GO_EVP_SIGNATURE_PTR, EVP_SIGNATURE_fetch, (GO_OSSL_LIB_CTX_PTR ctx, const char *algorithm, const char *properties), (ctx, algorithm, properties)) \
@@ -389,4 +388,9 @@ DEFINEFUNC_LEGACY_1_1(void, DSA_get0_pqg, (const GO_DSA_PTR d, const GO_BIGNUM_P
389388
DEFINEFUNC_LEGACY_1_1(int, DSA_set0_pqg, (GO_DSA_PTR d, GO_BIGNUM_PTR p, GO_BIGNUM_PTR q, GO_BIGNUM_PTR g), (d, p, q, g)) \
390389
DEFINEFUNC_LEGACY_1_1(void, DSA_get0_key, (const GO_DSA_PTR d, const GO_BIGNUM_PTR *pub_key, const GO_BIGNUM_PTR *priv_key), (d, pub_key, priv_key)) \
391390
DEFINEFUNC_LEGACY_1_1(int, DSA_set0_key, (GO_DSA_PTR d, GO_BIGNUM_PTR pub_key, GO_BIGNUM_PTR priv_key), (d, pub_key, priv_key)) \
391+
DEFINEFUNC_3_0(GO_EVP_KDF_PTR, EVP_KDF_fetch, (GO_OSSL_LIB_CTX_PTR libctx, const char *algorithm, const char *properties), (libctx, algorithm, properties)) \
392+
DEFINEFUNC_3_0(void, EVP_KDF_free, (GO_EVP_KDF_PTR kdf), (kdf)) \
393+
DEFINEFUNC_3_0(GO_EVP_KDF_CTX_PTR, EVP_KDF_CTX_new, (GO_EVP_KDF_PTR kdf), (kdf)) \
394+
DEFINEFUNC_3_0(void, EVP_KDF_CTX_free, (GO_EVP_KDF_CTX_PTR ctx), (ctx)) \
395+
DEFINEFUNC_3_0(int, EVP_KDF_derive, (GO_EVP_KDF_CTX_PTR ctx, unsigned char *key, size_t keylen, const GO_OSSL_PARAM_PTR params), (ctx, key, keylen, params)) \
392396

Diff for: tls1prf.go

+96-44
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,20 @@ import (
88
"crypto"
99
"errors"
1010
"hash"
11+
"sync"
1112
"unsafe"
1213
)
1314

1415
func SupportsTLS1PRF() bool {
15-
return vMajor > 1 ||
16-
(vMajor >= 1 && vMinor >= 1)
16+
switch vMajor {
17+
case 1:
18+
return vMinor >= 1
19+
case 3:
20+
_, err := fetchTLS1PRF3()
21+
return err == nil
22+
default:
23+
panic(errUnsupportedVersion())
24+
}
1725
}
1826

1927
// TLS1PRF implements the TLS 1.0/1.1 pseudo-random function if h is nil,
@@ -39,6 +47,20 @@ func TLS1PRF(result, secret, label, seed []byte, fh func() hash.Hash) error {
3947
return errors.New("unsupported hash function")
4048
}
4149

50+
switch vMajor {
51+
case 1:
52+
return tls1PRF1(result, secret, label, seed, md)
53+
case 3:
54+
return tls1PRF3(result, secret, label, seed, md)
55+
default:
56+
return errUnsupportedVersion()
57+
}
58+
}
59+
60+
// tls1PRF1 implements TLS1PRF for OpenSSL 1 using the EVP_PKEY API.
61+
func tls1PRF1(result, secret, label, seed []byte, md C.GO_EVP_MD_PTR) error {
62+
checkMajorVersion(1)
63+
4264
ctx := C.go_openssl_EVP_PKEY_CTX_new_id(C.GO_EVP_PKEY_TLS1_PRF, nil)
4365
if ctx == nil {
4466
return newOpenSSLError("EVP_PKEY_CTX_new_id")
@@ -50,48 +72,29 @@ func TLS1PRF(result, secret, label, seed []byte, fh func() hash.Hash) error {
5072
if C.go_openssl_EVP_PKEY_derive_init(ctx) != 1 {
5173
return newOpenSSLError("EVP_PKEY_derive_init")
5274
}
53-
switch vMajor {
54-
case 3:
55-
if C.go_openssl_EVP_PKEY_CTX_set_tls1_prf_md(ctx, md) != 1 {
56-
return newOpenSSLError("EVP_PKEY_CTX_set_tls1_prf_md")
57-
}
58-
if C.go_openssl_EVP_PKEY_CTX_set1_tls1_prf_secret(ctx,
59-
base(secret), C.int(len(secret))) != 1 {
60-
return newOpenSSLError("EVP_PKEY_CTX_set1_tls1_prf_secret")
61-
}
62-
if C.go_openssl_EVP_PKEY_CTX_add1_tls1_prf_seed(ctx,
63-
base(label), C.int(len(label))) != 1 {
64-
return newOpenSSLError("EVP_PKEY_CTX_add1_tls1_prf_seed")
65-
}
66-
if C.go_openssl_EVP_PKEY_CTX_add1_tls1_prf_seed(ctx,
67-
base(seed), C.int(len(seed))) != 1 {
68-
return newOpenSSLError("EVP_PKEY_CTX_add1_tls1_prf_seed")
69-
}
70-
case 1:
71-
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1,
72-
C.GO1_EVP_PKEY_OP_DERIVE,
73-
C.GO_EVP_PKEY_CTRL_TLS_MD,
74-
0, unsafe.Pointer(md)) != 1 {
75-
return newOpenSSLError("EVP_PKEY_CTX_set_tls1_prf_md")
76-
}
77-
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1,
78-
C.GO1_EVP_PKEY_OP_DERIVE,
79-
C.GO_EVP_PKEY_CTRL_TLS_SECRET,
80-
C.int(len(secret)), unsafe.Pointer(base(secret))) != 1 {
81-
return newOpenSSLError("EVP_PKEY_CTX_set1_tls1_prf_secret")
82-
}
83-
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1,
84-
C.GO1_EVP_PKEY_OP_DERIVE,
85-
C.GO_EVP_PKEY_CTRL_TLS_SEED,
86-
C.int(len(label)), unsafe.Pointer(base(label))) != 1 {
87-
return newOpenSSLError("EVP_PKEY_CTX_add1_tls1_prf_seed")
88-
}
89-
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1,
90-
C.GO1_EVP_PKEY_OP_DERIVE,
91-
C.GO_EVP_PKEY_CTRL_TLS_SEED,
92-
C.int(len(seed)), unsafe.Pointer(base(seed))) != 1 {
93-
return newOpenSSLError("EVP_PKEY_CTX_add1_tls1_prf_seed")
94-
}
75+
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1,
76+
C.GO1_EVP_PKEY_OP_DERIVE,
77+
C.GO_EVP_PKEY_CTRL_TLS_MD,
78+
0, unsafe.Pointer(md)) != 1 {
79+
return newOpenSSLError("EVP_PKEY_CTX_set_tls1_prf_md")
80+
}
81+
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1,
82+
C.GO1_EVP_PKEY_OP_DERIVE,
83+
C.GO_EVP_PKEY_CTRL_TLS_SECRET,
84+
C.int(len(secret)), unsafe.Pointer(base(secret))) != 1 {
85+
return newOpenSSLError("EVP_PKEY_CTX_set1_tls1_prf_secret")
86+
}
87+
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1,
88+
C.GO1_EVP_PKEY_OP_DERIVE,
89+
C.GO_EVP_PKEY_CTRL_TLS_SEED,
90+
C.int(len(label)), unsafe.Pointer(base(label))) != 1 {
91+
return newOpenSSLError("EVP_PKEY_CTX_add1_tls1_prf_seed")
92+
}
93+
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1,
94+
C.GO1_EVP_PKEY_OP_DERIVE,
95+
C.GO_EVP_PKEY_CTRL_TLS_SEED,
96+
C.int(len(seed)), unsafe.Pointer(base(seed))) != 1 {
97+
return newOpenSSLError("EVP_PKEY_CTX_add1_tls1_prf_seed")
9598
}
9699
outLen := C.size_t(len(result))
97100
if C.go_openssl_EVP_PKEY_derive_wrapper(ctx, base(result), outLen).result != 1 {
@@ -106,3 +109,52 @@ func TLS1PRF(result, secret, label, seed []byte, fh func() hash.Hash) error {
106109
}
107110
return nil
108111
}
112+
113+
// fetchTLS1PRF3 fetches the TLS1-PRF KDF algorithm.
114+
// It is safe to call this function concurrently.
115+
// The returned EVP_KDF_PTR shouldn't be freed.
116+
var fetchTLS1PRF3 = sync.OnceValues(func() (C.GO_EVP_KDF_PTR, error) {
117+
checkMajorVersion(3)
118+
119+
name := C.CString("TLS1-PRF")
120+
kdf := C.go_openssl_EVP_KDF_fetch(nil, name, nil)
121+
C.free(unsafe.Pointer(name))
122+
if kdf == nil {
123+
return nil, newOpenSSLError("EVP_KDF_fetch")
124+
}
125+
return kdf, nil
126+
})
127+
128+
// tls1PRF3 implements TLS1PRF for OpenSSL 3 using the EVP_KDF API.
129+
func tls1PRF3(result, secret, label, seed []byte, md C.GO_EVP_MD_PTR) error {
130+
checkMajorVersion(3)
131+
132+
kdf, err := fetchTLS1PRF3()
133+
if err != nil {
134+
return err
135+
}
136+
ctx := C.go_openssl_EVP_KDF_CTX_new(kdf)
137+
if ctx == nil {
138+
return newOpenSSLError("EVP_KDF_CTX_new")
139+
}
140+
defer C.go_openssl_EVP_KDF_CTX_free(ctx)
141+
142+
bld, err := newParamBuilder()
143+
if err != nil {
144+
return err
145+
}
146+
bld.addUTF8String(_OSSL_KDF_PARAM_DIGEST, C.go_openssl_EVP_MD_get0_name(md), 0)
147+
bld.addOctetString(_OSSL_KDF_PARAM_SECRET, secret)
148+
bld.addOctetString(_OSSL_KDF_PARAM_SEED, label)
149+
bld.addOctetString(_OSSL_KDF_PARAM_SEED, seed)
150+
params, err := bld.build()
151+
if err != nil {
152+
return err
153+
}
154+
defer C.go_openssl_OSSL_PARAM_free(params)
155+
156+
if C.go_openssl_EVP_KDF_derive(ctx, base(result), C.size_t(len(result)), params) != 1 {
157+
return newOpenSSLError("EVP_KDF_derive")
158+
}
159+
return nil
160+
}

Diff for: tls1prf_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ var tls1prfTests = []tls1prfTest{
147147

148148
func TestTLS1PRF(t *testing.T) {
149149
if !openssl.SupportsTLS1PRF() {
150-
t.Skip("TLS PRF is not supported")
150+
t.Skip("TLS1 PRF is not supported")
151151
}
152152
for _, tt := range tls1prfTests {
153153
t.Run(tt.hash.String(), func(t *testing.T) {

0 commit comments

Comments
 (0)