diff --git a/hash.go b/hash.go index a9bc91f6..bdc54b59 100644 --- a/hash.go +++ b/hash.go @@ -111,7 +111,7 @@ func SHA3_512(p []byte) (sum [64]byte) { return } -var isMarshallableMap sync.Map +var isMarshallableCache sync.Map // isHashMarshallable returns true if the memory layout of cb // is known by this library and can therefore be marshalled. @@ -119,7 +119,7 @@ func isHashMarshallable(ch crypto.Hash) bool { if vMajor == 1 { return true } - if v, ok := isMarshallableMap.Load(ch); ok { + if v, ok := isMarshallableCache.Load(ch); ok { return v.(bool) } md := cryptoHashToMD(ch) @@ -138,7 +138,7 @@ func isHashMarshallable(ch crypto.Hash) bool { // We only know the memory layout of the built-in providers. // See evpHash.hashState for more details. marshallable := name == "default" || name == "fips" - isMarshallableMap.Store(ch, marshallable) + isMarshallableCache.Store(ch, marshallable) return marshallable } @@ -148,12 +148,13 @@ type evpHash struct { // ctx2 is used in evpHash.sum to avoid changing // the state of ctx. Having it here allows reusing the // same allocated object multiple times. - ctx2 C.GO_EVP_MD_CTX_PTR - size int - blockSize int + ctx2 C.GO_EVP_MD_CTX_PTR + size int + blockSize int + marshallable bool } -func newEvpHash(ch crypto.Hash, size, blockSize int) *evpHash { +func newEvpHash(ch crypto.Hash) *evpHash { md := cryptoHashToMD(ch) if md == nil { panic("openssl: unsupported hash function: " + strconv.Itoa(int(ch))) @@ -164,11 +165,13 @@ func newEvpHash(ch crypto.Hash, size, blockSize int) *evpHash { panic(newOpenSSLError("EVP_DigestInit_ex")) } ctx2 := C.go_openssl_EVP_MD_CTX_new() + blockSize := int(C.go_openssl_EVP_MD_get_block_size(md)) h := &evpHash{ - ctx: ctx, - ctx2: ctx2, - size: size, - blockSize: blockSize, + ctx: ctx, + ctx2: ctx2, + size: ch.Size(), + blockSize: blockSize, + marshallable: isHashMarshallable(ch), } runtime.SetFinalizer(h, (*evpHash).finalize) return h @@ -232,6 +235,9 @@ func (h *evpHash) sum(out []byte) { // The EVP_MD_CTX memory layout has changed in OpenSSL 3 // and the property holding the internal structure is no longer md_data but algctx. func (h *evpHash) hashState() unsafe.Pointer { + if !h.marshallable { + panic("openssl: hash state is not marshallable") + } switch vMajor { case 1: // https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/crypto/evp/evp_local.h#L12. @@ -260,7 +266,7 @@ func (h *evpHash) hashState() unsafe.Pointer { // encoding.BinaryUnmarshaler. func NewMD4() hash.Hash { return &md4Hash{ - evpHash: newEvpHash(crypto.MD4, 16, 64), + evpHash: newEvpHash(crypto.MD4), } } @@ -276,8 +282,8 @@ func (h *md4Hash) Sum(in []byte) []byte { // NewMD5 returns a new MD5 hash. func NewMD5() hash.Hash { - h := md5Hash{evpHash: newEvpHash(crypto.MD5, 16, 64)} - if isHashMarshallable(crypto.MD5) { + h := md5Hash{evpHash: newEvpHash(crypto.MD5)} + if h.marshallable { return &md5Marshal{h} } return &h @@ -354,8 +360,8 @@ func (h *md5Marshal) UnmarshalBinary(b []byte) error { // NewSHA1 returns a new SHA1 hash. func NewSHA1() hash.Hash { - h := sha1Hash{evpHash: newEvpHash(crypto.SHA1, 20, 64)} - if isHashMarshallable(crypto.SHA1) { + h := sha1Hash{evpHash: newEvpHash(crypto.SHA1)} + if h.marshallable { return &sha1Marshal{h} } return &h @@ -434,8 +440,8 @@ func (h *sha1Marshal) UnmarshalBinary(b []byte) error { // NewSHA224 returns a new SHA224 hash. func NewSHA224() hash.Hash { - h := sha224Hash{evpHash: newEvpHash(crypto.SHA224, 224/8, 64)} - if isHashMarshallable(crypto.SHA224) { + h := sha224Hash{evpHash: newEvpHash(crypto.SHA224)} + if h.marshallable { return &sha224Marshal{h} } return &h @@ -453,8 +459,8 @@ func (h *sha224Hash) Sum(in []byte) []byte { // NewSHA256 returns a new SHA256 hash. func NewSHA256() hash.Hash { - h := sha256Hash{evpHash: newEvpHash(crypto.SHA256, 256/8, 64)} - if isHashMarshallable(crypto.SHA256) { + h := sha256Hash{evpHash: newEvpHash(crypto.SHA256)} + if h.marshallable { return &sha256Marshal{h} } return &h @@ -593,8 +599,8 @@ func (h *sha256Marshal) UnmarshalBinary(b []byte) error { // NewSHA384 returns a new SHA384 hash. func NewSHA384() hash.Hash { - h := sha384Hash{evpHash: newEvpHash(crypto.SHA384, 384/8, 128)} - if isHashMarshallable(crypto.SHA384) { + h := sha384Hash{evpHash: newEvpHash(crypto.SHA384)} + if h.marshallable { return &sha384Marshal{h} } return &h @@ -612,8 +618,8 @@ func (h *sha384Hash) Sum(in []byte) []byte { // NewSHA512 returns a new SHA512 hash. func NewSHA512() hash.Hash { - h := sha512Hash{evpHash: newEvpHash(crypto.SHA512, 512/8, 128)} - if isHashMarshallable(crypto.SHA512) { + h := sha512Hash{evpHash: newEvpHash(crypto.SHA512)} + if h.marshallable { return &sha512Marshal{h} } return &h @@ -761,7 +767,7 @@ func (h *sha512Marshal) UnmarshalBinary(b []byte) error { // NewSHA3_224 returns a new SHA3-224 hash. func NewSHA3_224() hash.Hash { return &sha3_224Hash{ - evpHash: newEvpHash(crypto.SHA3_224, 224/8, 64), + evpHash: newEvpHash(crypto.SHA3_224), } } @@ -778,7 +784,7 @@ func (h *sha3_224Hash) Sum(in []byte) []byte { // NewSHA3_256 returns a new SHA3-256 hash. func NewSHA3_256() hash.Hash { return &sha3_256Hash{ - evpHash: newEvpHash(crypto.SHA3_256, 256/8, 64), + evpHash: newEvpHash(crypto.SHA3_256), } } @@ -795,7 +801,7 @@ func (h *sha3_256Hash) Sum(in []byte) []byte { // NewSHA3_384 returns a new SHA3-384 hash. func NewSHA3_384() hash.Hash { return &sha3_384Hash{ - evpHash: newEvpHash(crypto.SHA3_384, 384/8, 128), + evpHash: newEvpHash(crypto.SHA3_384), } } @@ -812,7 +818,7 @@ func (h *sha3_384Hash) Sum(in []byte) []byte { // NewSHA3_512 returns a new SHA3-512 hash. func NewSHA3_512() hash.Hash { return &sha3_512Hash{ - evpHash: newEvpHash(crypto.SHA3_512, 512/8, 128), + evpHash: newEvpHash(crypto.SHA3_512), } } diff --git a/shims.h b/shims.h index 6ea5bc6d..cf3d7ee1 100644 --- a/shims.h +++ b/shims.h @@ -195,6 +195,7 @@ DEFINEFUNC_3_0(GO_EVP_MD_PTR, EVP_MD_fetch, (GO_OSSL_LIB_CTX_PTR ctx, const char DEFINEFUNC_3_0(void, EVP_MD_free, (GO_EVP_MD_PTR md), (md)) \ DEFINEFUNC_3_0(const char *, EVP_MD_get0_name, (const GO_EVP_MD_PTR md), (md)) \ DEFINEFUNC_3_0(const GO_OSSL_PROVIDER_PTR, EVP_MD_get0_provider, (const GO_EVP_MD_PTR md), (md)) \ +DEFINEFUNC(int, EVP_MD_get_block_size, (const GO_EVP_MD_PTR md), (md)) \ DEFINEFUNC(int, RAND_bytes, (unsigned char *arg0, int arg1), (arg0, arg1)) \ DEFINEFUNC_RENAMED_1_1(GO_EVP_MD_CTX_PTR, EVP_MD_CTX_new, EVP_MD_CTX_create, (void), ()) \ DEFINEFUNC_RENAMED_1_1(void, EVP_MD_CTX_free, EVP_MD_CTX_destroy, (GO_EVP_MD_CTX_PTR ctx), (ctx)) \