Skip to content

Commit 6e4e578

Browse files
authored
Simplify hash implementation (#237)
* simplify hash implementations * fix openssl 3 * fix AZL 3
1 parent 7b07994 commit 6e4e578

File tree

4 files changed

+248
-651
lines changed

4 files changed

+248
-651
lines changed

evp.go

Lines changed: 72 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -48,29 +48,8 @@ func hashFuncHash(fn func() hash.Hash) (h hash.Hash, err error) {
4848

4949
// hashToMD converts a hash.Hash implementation from this package to a GO_EVP_MD_PTR.
5050
func hashToMD(h hash.Hash) C.GO_EVP_MD_PTR {
51-
var ch crypto.Hash
52-
switch h.(type) {
53-
case *sha1Hash, *sha1Marshal:
54-
ch = crypto.SHA1
55-
case *sha224Hash, *sha224Marshal:
56-
ch = crypto.SHA224
57-
case *sha256Hash, *sha256Marshal:
58-
ch = crypto.SHA256
59-
case *sha384Hash, *sha384Marshal:
60-
ch = crypto.SHA384
61-
case *sha512Hash, *sha512Marshal:
62-
ch = crypto.SHA512
63-
case *sha3_224Hash:
64-
ch = crypto.SHA3_224
65-
case *sha3_256Hash:
66-
ch = crypto.SHA3_256
67-
case *sha3_384Hash:
68-
ch = crypto.SHA3_384
69-
case *sha3_512Hash:
70-
ch = crypto.SHA3_512
71-
}
72-
if ch != 0 {
73-
return cryptoHashToMD(ch)
51+
if h, ok := h.(*evpHash); ok {
52+
return h.alg.md
7453
}
7554
return nil
7655
}
@@ -89,78 +68,109 @@ func hashFuncToMD(fn func() hash.Hash) (C.GO_EVP_MD_PTR, error) {
8968
return md, nil
9069
}
9170

92-
// cryptoHashToMD converts a crypto.Hash to a EVP_MD.
93-
func cryptoHashToMD(ch crypto.Hash) C.GO_EVP_MD_PTR {
71+
type hashAlgorithm struct {
72+
md C.GO_EVP_MD_PTR
73+
ch crypto.Hash
74+
size int
75+
blockSize int
76+
marshallable bool
77+
magic string
78+
marshalledSize int
79+
}
80+
81+
// loadHash converts a crypto.Hash to a EVP_MD.
82+
func loadHash(ch crypto.Hash) *hashAlgorithm {
9483
if v, ok := cacheMD.Load(ch); ok {
95-
return v.(C.GO_EVP_MD_PTR)
84+
return v.(*hashAlgorithm)
9685
}
97-
var md C.GO_EVP_MD_PTR
86+
87+
var hash hashAlgorithm
9888
switch ch {
9989
case crypto.RIPEMD160:
100-
md = C.go_openssl_EVP_ripemd160()
90+
hash.md = C.go_openssl_EVP_ripemd160()
10191
case crypto.MD4:
102-
md = C.go_openssl_EVP_md4()
92+
hash.md = C.go_openssl_EVP_md4()
10393
case crypto.MD5:
104-
md = C.go_openssl_EVP_md5()
94+
hash.md = C.go_openssl_EVP_md5()
95+
hash.magic = md5Magic
96+
hash.marshalledSize = md5MarshaledSize
10597
case crypto.MD5SHA1:
10698
if vMajor == 1 && vMinor == 0 {
107-
md = C.go_openssl_EVP_md5_sha1_backport()
99+
hash.md = C.go_openssl_EVP_md5_sha1_backport()
108100
} else {
109-
md = C.go_openssl_EVP_md5_sha1()
101+
hash.md = C.go_openssl_EVP_md5_sha1()
110102
}
111103
case crypto.SHA1:
112-
md = C.go_openssl_EVP_sha1()
104+
hash.md = C.go_openssl_EVP_sha1()
105+
hash.magic = sha1Magic
106+
hash.marshalledSize = sha1MarshaledSize
113107
case crypto.SHA224:
114-
md = C.go_openssl_EVP_sha224()
108+
hash.md = C.go_openssl_EVP_sha224()
109+
hash.magic = magic224
110+
hash.marshalledSize = marshaledSize256
115111
case crypto.SHA256:
116-
md = C.go_openssl_EVP_sha256()
112+
hash.md = C.go_openssl_EVP_sha256()
113+
hash.magic = magic256
114+
hash.marshalledSize = marshaledSize256
117115
case crypto.SHA384:
118-
md = C.go_openssl_EVP_sha384()
116+
hash.md = C.go_openssl_EVP_sha384()
117+
hash.magic = magic384
118+
hash.marshalledSize = marshaledSize512
119119
case crypto.SHA512:
120-
md = C.go_openssl_EVP_sha512()
120+
hash.md = C.go_openssl_EVP_sha512()
121+
hash.magic = magic512
122+
hash.marshalledSize = marshaledSize512
121123
case crypto.SHA512_224:
122124
if versionAtOrAbove(1, 1, 1) {
123-
md = C.go_openssl_EVP_sha512_224()
125+
hash.md = C.go_openssl_EVP_sha512_224()
126+
hash.magic = magic512_224
127+
hash.marshalledSize = marshaledSize512
124128
}
125129
case crypto.SHA512_256:
126130
if versionAtOrAbove(1, 1, 1) {
127-
md = C.go_openssl_EVP_sha512_256()
131+
hash.md = C.go_openssl_EVP_sha512_256()
132+
hash.magic = magic512_256
133+
hash.marshalledSize = marshaledSize512
128134
}
129135
case crypto.SHA3_224:
130136
if versionAtOrAbove(1, 1, 1) {
131-
md = C.go_openssl_EVP_sha3_224()
137+
hash.md = C.go_openssl_EVP_sha3_224()
132138
}
133139
case crypto.SHA3_256:
134140
if versionAtOrAbove(1, 1, 1) {
135-
md = C.go_openssl_EVP_sha3_256()
141+
hash.md = C.go_openssl_EVP_sha3_256()
136142
}
137143
case crypto.SHA3_384:
138144
if versionAtOrAbove(1, 1, 1) {
139-
md = C.go_openssl_EVP_sha3_384()
145+
hash.md = C.go_openssl_EVP_sha3_384()
140146
}
141147
case crypto.SHA3_512:
142148
if versionAtOrAbove(1, 1, 1) {
143-
md = C.go_openssl_EVP_sha3_512()
149+
hash.md = C.go_openssl_EVP_sha3_512()
144150
}
145151
}
146-
if md == nil {
147-
cacheMD.Store(ch, nil)
152+
if hash.md == nil {
153+
cacheMD.Store(ch, (*hashAlgorithm)(nil))
148154
return nil
149155
}
156+
hash.ch = ch
157+
hash.size = int(C.go_openssl_EVP_MD_get_size(hash.md))
158+
hash.blockSize = int(C.go_openssl_EVP_MD_get_block_size(hash.md))
150159
if vMajor == 3 {
151160
// On OpenSSL 3, directly operating on a EVP_MD object
152161
// not created by EVP_MD_fetch has negative performance
153162
// implications, as digest operations will have
154163
// to fetch it on every call. Better to just fetch it once here.
155-
md1 := C.go_openssl_EVP_MD_fetch(nil, C.go_openssl_EVP_MD_get0_name(md), nil)
164+
md := C.go_openssl_EVP_MD_fetch(nil, C.go_openssl_EVP_MD_get0_name(hash.md), nil)
156165
// Don't overwrite md in case it can't be fetched, as the md may still be used
157166
// outside of EVP_MD_CTX, for example to sign and verify RSA signatures.
158-
if md1 != nil {
159-
md = md1
167+
if md != nil {
168+
hash.md = md
160169
}
161170
}
162-
cacheMD.Store(ch, md)
163-
return md
171+
hash.marshallable = hash.magic != "" && isHashMarshallable(hash.md)
172+
cacheMD.Store(ch, &hash)
173+
return &hash
164174
}
165175

166176
// generateEVPPKey generates a new EVP_PKEY with the given id and properties.
@@ -302,11 +312,11 @@ func setupEVP(withKey withKeyFunc, padding C.int,
302312
}
303313
}
304314
case C.GO_RSA_PKCS1_PSS_PADDING:
305-
md := cryptoHashToMD(ch)
306-
if md == nil {
315+
alg := loadHash(ch)
316+
if alg == nil {
307317
return nil, errors.New("crypto/rsa: unsupported hash function")
308318
}
309-
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, C.GO_EVP_PKEY_RSA, -1, C.GO_EVP_PKEY_CTRL_MD, 0, unsafe.Pointer(md)) != 1 {
319+
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, C.GO_EVP_PKEY_RSA, -1, C.GO_EVP_PKEY_CTRL_MD, 0, unsafe.Pointer(alg.md)) != 1 {
310320
return nil, newOpenSSLError("EVP_PKEY_CTX_ctrl failed")
311321
}
312322
// setPadding must happen after setting EVP_PKEY_CTRL_MD.
@@ -322,11 +332,11 @@ func setupEVP(withKey withKeyFunc, padding C.int,
322332
case C.GO_RSA_PKCS1_PADDING:
323333
if ch != 0 {
324334
// We support unhashed messages.
325-
md := cryptoHashToMD(ch)
326-
if md == nil {
335+
alg := loadHash(ch)
336+
if alg == nil {
327337
return nil, errors.New("crypto/rsa: unsupported hash function")
328338
}
329-
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, -1, C.GO_EVP_PKEY_CTRL_MD, 0, unsafe.Pointer(md)) != 1 {
339+
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, -1, C.GO_EVP_PKEY_CTRL_MD, 0, unsafe.Pointer(alg.md)) != 1 {
330340
return nil, newOpenSSLError("EVP_PKEY_CTX_ctrl failed")
331341
}
332342
if err := setPadding(); err != nil {
@@ -441,8 +451,8 @@ func evpVerify(withKey withKeyFunc, padding C.int, saltLen C.int, h crypto.Hash,
441451
}
442452

443453
func evpHashSign(withKey withKeyFunc, h crypto.Hash, msg []byte) ([]byte, error) {
444-
md := cryptoHashToMD(h)
445-
if md == nil {
454+
alg := loadHash(h)
455+
if alg == nil {
446456
return nil, errors.New("unsupported hash function: " + strconv.Itoa(int(h)))
447457
}
448458
var out []byte
@@ -453,7 +463,7 @@ func evpHashSign(withKey withKeyFunc, h crypto.Hash, msg []byte) ([]byte, error)
453463
}
454464
defer C.go_openssl_EVP_MD_CTX_free(ctx)
455465
if withKey(func(key C.GO_EVP_PKEY_PTR) C.int {
456-
return C.go_openssl_EVP_DigestSignInit(ctx, nil, md, nil, key)
466+
return C.go_openssl_EVP_DigestSignInit(ctx, nil, alg.md, nil, key)
457467
}) != 1 {
458468
return nil, newOpenSSLError("EVP_DigestSignInit failed")
459469
}
@@ -473,8 +483,8 @@ func evpHashSign(withKey withKeyFunc, h crypto.Hash, msg []byte) ([]byte, error)
473483
}
474484

475485
func evpHashVerify(withKey withKeyFunc, h crypto.Hash, msg, sig []byte) error {
476-
md := cryptoHashToMD(h)
477-
if md == nil {
486+
alg := loadHash(h)
487+
if alg == nil {
478488
return errors.New("unsupported hash function: " + strconv.Itoa(int(h)))
479489
}
480490
ctx := C.go_openssl_EVP_MD_CTX_new()
@@ -483,7 +493,7 @@ func evpHashVerify(withKey withKeyFunc, h crypto.Hash, msg, sig []byte) error {
483493
}
484494
defer C.go_openssl_EVP_MD_CTX_free(ctx)
485495
if withKey(func(key C.GO_EVP_PKEY_PTR) C.int {
486-
return C.go_openssl_EVP_DigestVerifyInit(ctx, nil, md, nil, key)
496+
return C.go_openssl_EVP_DigestVerifyInit(ctx, nil, alg.md, nil, key)
487497
}) != 1 {
488498
return newOpenSSLError("EVP_DigestVerifyInit failed")
489499
}

0 commit comments

Comments
 (0)