Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement HKDF using the EVP_KDF API in OpenSSL 3 #194

Merged
merged 5 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions evp.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,20 @@ func hashToMD(h hash.Hash) C.GO_EVP_MD_PTR {
return nil
}

// hashFuncToMD converts a hash.Hash function to a GO_EVP_MD_PTR.
// See [hashFuncHash] for details on error handling.
func hashFuncToMD(fn func() hash.Hash) (C.GO_EVP_MD_PTR, error) {
h, err := hashFuncHash(fn)
if err != nil {
return nil, err
}
md := hashToMD(h)
if md == nil {
return nil, errors.New("unsupported hash function")
}
return md, nil
}

// cryptoHashToMD converts a crypto.Hash to a GO_EVP_MD_PTR.
func cryptoHashToMD(ch crypto.Hash) (md C.GO_EVP_MD_PTR) {
if v, ok := cacheMD.Load(ch); ok {
Expand Down
286 changes: 189 additions & 97 deletions hkdf.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"hash"
"io"
"runtime"
"sync"
"unsafe"
)

Expand All @@ -18,88 +19,72 @@ func SupportsHKDF() bool {
case 1:
return versionAtOrAbove(1, 1, 1)
case 3:
// Some OpenSSL 3 providers don't support HKDF or don't support it via
// the EVP_PKEY API, which is the one we use.
// See https://github.com/golang-fips/openssl/issues/189.
ctx := C.go_openssl_EVP_PKEY_CTX_new_id(C.GO_EVP_PKEY_HKDF, nil)
if ctx == nil {
return false
}
C.go_openssl_EVP_PKEY_CTX_free(ctx)
return true
_, err := fetchHKDF3()
return err == nil
default:
panic(errUnsupportedVersion())
}
}

func newHKDF(fh func() hash.Hash, mode C.int) (*hkdf, error) {
if !SupportsHKDF() {
return nil, errUnsupportedVersion()
}

h, err := hashFuncHash(fh)
if err != nil {
return nil, err
}
md := hashToMD(h)
if md == nil {
return nil, errors.New("unsupported hash function")
}
func newHKDFCtx1(md C.GO_EVP_MD_PTR, mode C.int, secret, salt, pseudorandomKey, info []byte) (ctx C.GO_EVP_PKEY_CTX_PTR, err error) {
checkMajorVersion(1)

ctx := C.go_openssl_EVP_PKEY_CTX_new_id(C.GO_EVP_PKEY_HKDF, nil)
ctx = C.go_openssl_EVP_PKEY_CTX_new_id(C.GO_EVP_PKEY_HKDF, nil)
if ctx == nil {
return nil, newOpenSSLError("EVP_PKEY_CTX_new_id")
}
defer func() {
C.go_openssl_EVP_PKEY_CTX_free(ctx)
if err != nil {
C.go_openssl_EVP_PKEY_CTX_free(ctx)
}
}()

if C.go_openssl_EVP_PKEY_derive_init(ctx) != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive_init")
return ctx, newOpenSSLError("EVP_PKEY_derive_init")
}
switch vMajor {
case 3:
if C.go_openssl_EVP_PKEY_CTX_set_hkdf_mode(ctx, mode) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_set_hkdf_mode")
}
if C.go_openssl_EVP_PKEY_CTX_set_hkdf_md(ctx, md) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_set_hkdf_md")
}
case 1:
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE,
C.GO_EVP_PKEY_CTRL_HKDF_MODE,
C.int(mode), nil) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_set_hkdf_mode")
}
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE,
C.GO_EVP_PKEY_CTRL_HKDF_MD,
0, unsafe.Pointer(md)) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_set_hkdf_md")

ctrlSlice := func(ctrl int, data []byte) C.int {
if len(data) == 0 {
return 1 // No data to set.
}
return C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE, C.int(ctrl), C.int(len(data)), unsafe.Pointer(base(data)))
}

c := &hkdf{ctx: ctx, hashLen: h.Size()}
ctx = nil

runtime.SetFinalizer(c, (*hkdf).finalize)

return c, nil
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE, C.GO_EVP_PKEY_CTRL_HKDF_MODE, mode, nil) != 1 {
return ctx, newOpenSSLError("EVP_PKEY_CTX_set_hkdf_mode")
}
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE, C.GO_EVP_PKEY_CTRL_HKDF_MD, 0, unsafe.Pointer(md)) != 1 {
return ctx, newOpenSSLError("EVP_PKEY_CTX_set_hkdf_md")
}
if ctrlSlice(C.GO_EVP_PKEY_CTRL_HKDF_KEY, secret) != 1 {
return ctx, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_key")
}
if ctrlSlice(C.GO_EVP_PKEY_CTRL_HKDF_SALT, salt) != 1 {
return ctx, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_salt")
}
if ctrlSlice(C.GO_EVP_PKEY_CTRL_HKDF_KEY, pseudorandomKey) != 1 {
return ctx, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_key")
}
if ctrlSlice(C.GO_EVP_PKEY_CTRL_HKDF_INFO, info) != 1 {
return ctx, newOpenSSLError("EVP_PKEY_CTX_add1_hkdf_info")
}
return ctx, nil
}

type hkdf struct {
type hkdf1 struct {
ctx C.GO_EVP_PKEY_CTX_PTR

hashLen int
buf []byte
}

func (c *hkdf) finalize() {
func (c *hkdf1) finalize() {
if c.ctx != nil {
C.go_openssl_EVP_PKEY_CTX_free(c.ctx)
}
}

func (c *hkdf) Read(p []byte) (int, error) {
func (c *hkdf1) Read(p []byte) (int, error) {
defer runtime.KeepAlive(c)

// EVP_PKEY_derive doesn't support incremental output, each call
Expand All @@ -125,69 +110,176 @@ func (c *hkdf) Read(p []byte) (int, error) {
}

func ExtractHKDF(h func() hash.Hash, secret, salt []byte) ([]byte, error) {
c, err := newHKDF(h, C.GO_EVP_KDF_HKDF_MODE_EXTRACT_ONLY)
if !SupportsHKDF() {
return nil, errUnsupportedVersion()
}

md, err := hashFuncToMD(h)
if err != nil {
return nil, err
}

switch vMajor {
case 3:
if C.go_openssl_EVP_PKEY_CTX_set1_hkdf_key(c.ctx,
base(secret), C.int(len(secret))) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_key")
case 1:
ctx, err := newHKDFCtx1(md, C.GO_EVP_KDF_HKDF_MODE_EXTRACT_ONLY, secret, salt, nil, nil)
if err != nil {
return nil, err
}
if C.go_openssl_EVP_PKEY_CTX_set1_hkdf_salt(c.ctx,
base(salt), C.int(len(salt))) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_salt")
defer C.go_openssl_EVP_PKEY_CTX_free(ctx)
r := C.go_openssl_EVP_PKEY_derive_wrapper(ctx, nil, 0)
if r.result != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive_init")
}
case 1:
if C.go_openssl_EVP_PKEY_CTX_ctrl(c.ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE,
C.GO_EVP_PKEY_CTRL_HKDF_KEY,
C.int(len(secret)), unsafe.Pointer(base(secret))) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_key")
out := make([]byte, r.keylen)
if C.go_openssl_EVP_PKEY_derive_wrapper(ctx, base(out), r.keylen).result != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive")
}
if C.go_openssl_EVP_PKEY_CTX_ctrl(c.ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE,
C.GO_EVP_PKEY_CTRL_HKDF_SALT,
C.int(len(salt)), unsafe.Pointer(base(salt))) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_salt")
return out[:r.keylen], nil
case 3:
ctx, err := newHKDFCtx3(md, C.GO_EVP_KDF_HKDF_MODE_EXTRACT_ONLY, secret, salt, nil, nil)
if err != nil {
return nil, err
}
defer C.go_openssl_EVP_KDF_CTX_free(ctx)
out := make([]byte, C.go_openssl_EVP_KDF_CTX_get_kdf_size(ctx))
if C.go_openssl_EVP_KDF_derive(ctx, base(out), C.size_t(len(out)), nil) != 1 {
return nil, newOpenSSLError("EVP_KDF_derive")
}
return out, nil
default:
panic(errUnsupportedVersion())
}
r := C.go_openssl_EVP_PKEY_derive_wrapper(c.ctx, nil, 0)
if r.result != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive_init")
}
out := make([]byte, r.keylen)
if C.go_openssl_EVP_PKEY_derive_wrapper(c.ctx, base(out), r.keylen).result != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive")
}
return out[:r.keylen], nil
}

func ExpandHKDF(h func() hash.Hash, pseudorandomKey, info []byte) (io.Reader, error) {
c, err := newHKDF(h, C.GO_EVP_KDF_HKDF_MODE_EXPAND_ONLY)
if !SupportsHKDF() {
return nil, errUnsupportedVersion()
}

md, err := hashFuncToMD(h)
if err != nil {
return nil, err
}

switch vMajor {
case 3:
if C.go_openssl_EVP_PKEY_CTX_set1_hkdf_key(c.ctx,
base(pseudorandomKey), C.int(len(pseudorandomKey))) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_key")
}
if C.go_openssl_EVP_PKEY_CTX_add1_hkdf_info(c.ctx,
base(info), C.int(len(info))) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_add1_hkdf_info")
}
case 1:
if C.go_openssl_EVP_PKEY_CTX_ctrl(c.ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE,
C.GO_EVP_PKEY_CTRL_HKDF_KEY,
C.int(len(pseudorandomKey)), unsafe.Pointer(base(pseudorandomKey))) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_set1_hkdf_key")
ctx, err := newHKDFCtx1(md, C.GO_EVP_KDF_HKDF_MODE_EXPAND_ONLY, nil, nil, pseudorandomKey, info)
if err != nil {
return nil, err
}
if C.go_openssl_EVP_PKEY_CTX_ctrl(c.ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE,
C.GO_EVP_PKEY_CTRL_HKDF_INFO,
C.int(len(info)), unsafe.Pointer(base(info))) != 1 {
return nil, newOpenSSLError("EVP_PKEY_CTX_add1_hkdf_info")
c := &hkdf1{ctx: ctx, hashLen: int(C.go_openssl_EVP_MD_get_size(md))}
runtime.SetFinalizer(c, (*hkdf1).finalize)
return c, nil
case 3:
ctx, err := newHKDFCtx3(md, C.GO_EVP_KDF_HKDF_MODE_EXPAND_ONLY, nil, nil, pseudorandomKey, info)
if err != nil {
return nil, err
}
c := &hkdf3{ctx: ctx, hashLen: int(C.go_openssl_EVP_MD_get_size(md))}
runtime.SetFinalizer(c, (*hkdf3).finalize)
return c, nil
default:
panic(errUnsupportedVersion())
}
return c, nil
}

type hkdf3 struct {
ctx C.GO_EVP_KDF_CTX_PTR

hashLen int
buf []byte
}

func (c *hkdf3) finalize() {
if c.ctx != nil {
C.go_openssl_EVP_KDF_CTX_free(c.ctx)
}
}

// fetchHKDF3 fetches the HKDF algorithm.
// It is safe to call this function concurrently.
// The returned EVP_KDF_PTR shouldn't be freed.
var fetchHKDF3 = sync.OnceValues(func() (C.GO_EVP_KDF_PTR, error) {
checkMajorVersion(3)

name := C.CString("HKDF")
kdf := C.go_openssl_EVP_KDF_fetch(nil, name, nil)
C.free(unsafe.Pointer(name))
if kdf == nil {
return nil, newOpenSSLError("EVP_KDF_fetch")
}
return kdf, nil
})

// newHKDFCtx3 implements HKDF for OpenSSL 3 using the EVP_KDF API.
func newHKDFCtx3(md C.GO_EVP_MD_PTR, mode C.int, secret, salt, pseudorandomKey, info []byte) (_ C.GO_EVP_KDF_CTX_PTR, err error) {
checkMajorVersion(3)

kdf, err := fetchHKDF3()
if err != nil {
return nil, err
}
ctx := C.go_openssl_EVP_KDF_CTX_new(kdf)
if ctx == nil {
return nil, newOpenSSLError("EVP_KDF_CTX_new")
}
defer func() {
if err != nil {
C.go_openssl_EVP_KDF_CTX_free(ctx)
}
}()

bld, err := newParamBuilder()
if err != nil {
return ctx, err
}
bld.addUTF8String(_OSSL_KDF_PARAM_DIGEST, C.go_openssl_EVP_MD_get0_name(md), 0)
bld.addInt32(_OSSL_KDF_PARAM_MODE, int32(mode))
if len(secret) > 0 {
bld.addOctetString(_OSSL_KDF_PARAM_KEY, secret)
}
if len(salt) > 0 {
bld.addOctetString(_OSSL_KDF_PARAM_SALT, salt)
}
if len(pseudorandomKey) > 0 {
bld.addOctetString(_OSSL_KDF_PARAM_KEY, pseudorandomKey)
}
if len(info) > 0 {
bld.addOctetString(_OSSL_KDF_PARAM_INFO, info)
}
params, err := bld.build()
if err != nil {
return ctx, err
}
defer C.go_openssl_OSSL_PARAM_free(params)

if C.go_openssl_EVP_KDF_CTX_set_params(ctx, params) != 1 {
return ctx, newOpenSSLError("EVP_KDF_CTX_set_params")
}
return ctx, nil
}

func (c *hkdf3) Read(p []byte) (int, error) {
defer runtime.KeepAlive(c)

// EVP_KDF_derive doesn't support incremental output, each call
// derives the key from scratch and returns the requested bytes.
// To implement io.Reader, we need to ask for len(c.buf) + len(p)
// bytes and copy the last derived len(p) bytes to p.
// We use c.buf to know how many bytes we've already derived and
// to avoid allocating the whole output buffer on each call.
prevLen := len(c.buf)
needLen := len(p)
remains := 255*c.hashLen - prevLen
// Check whether enough data can be generated.
if remains < needLen {
return 0, errors.New("hkdf: entropy limit reached")
}
c.buf = append(c.buf, make([]byte, needLen)...)
outLen := C.size_t(prevLen + needLen)
if C.go_openssl_EVP_KDF_derive(c.ctx, base(c.buf), outLen, nil) != 1 {
return 0, newOpenSSLError("EVP_KDF_derive")
}
n := copy(p, c.buf[prevLen:outLen])
return n, nil
}
Loading
Loading