Skip to content

Commit 4e5b3fc

Browse files
committed
allow hash.Hash for OAEP and MGF1 to be specified independently
1 parent 7f78fc5 commit 4e5b3fc

File tree

4 files changed

+63
-21
lines changed

4 files changed

+63
-21
lines changed

openssl/evp.go

+22-9
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ type cryptFunc func(C.GO_EVP_PKEY_CTX_PTR, *C.uchar, *C.size_t, *C.uchar, C.size
9494
type verifyFunc func(C.GO_EVP_PKEY_CTX_PTR, *C.uchar, C.size_t, *C.uchar, C.size_t) error
9595

9696
func setupEVP(withKey withKeyFunc, padding C.int,
97-
h hash.Hash, label []byte, saltLen C.int, ch crypto.Hash,
97+
h, mgfHash hash.Hash, label []byte, saltLen C.int, ch crypto.Hash,
9898
init initFunc) (ctx C.GO_EVP_PKEY_CTX_PTR, err error) {
9999
defer func() {
100100
if err != nil {
@@ -132,13 +132,26 @@ func setupEVP(withKey withKeyFunc, padding C.int,
132132
if md == nil {
133133
return nil, errors.New("crypto/rsa: unsupported hash function")
134134
}
135+
var mgfMD C.GO_EVP_MD_PTR
136+
if mgfHash != nil {
137+
// mgfHash is optional, but if it is set it must match a supported hash function.
138+
mgfMD = hashToMD(mgfHash)
139+
if mgfMD == nil {
140+
return nil, errors.New("crypto/rsa: unsupported hash function")
141+
}
142+
}
135143
// setPadding must happen before setting EVP_PKEY_CTRL_RSA_OAEP_MD.
136144
if err := setPadding(); err != nil {
137145
return nil, err
138146
}
139147
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, C.GO_EVP_PKEY_RSA, -1, C.GO_EVP_PKEY_CTRL_RSA_OAEP_MD, 0, unsafe.Pointer(md)) != 1 {
140148
return nil, newOpenSSLError("EVP_PKEY_CTX_ctrl failed")
141149
}
150+
if mgfHash != nil {
151+
if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, C.GO_EVP_PKEY_RSA, -1, C.GO_EVP_PKEY_CTRL_RSA_MGF1_MD, 0, unsafe.Pointer(mgfMD)) != 1 {
152+
return nil, newOpenSSLError("EVP_PKEY_CTX_ctrl failed")
153+
}
154+
}
142155
// ctx takes ownership of label, so malloc a copy for OpenSSL to free.
143156
// OpenSSL 1.1.1 and higher does not take ownership of the label if the length is zero,
144157
// so better avoid the allocation.
@@ -199,10 +212,10 @@ func setupEVP(withKey withKeyFunc, padding C.int,
199212
}
200213

201214
func cryptEVP(withKey withKeyFunc, padding C.int,
202-
h hash.Hash, label []byte, saltLen C.int, ch crypto.Hash,
215+
h, mgfHash hash.Hash, label []byte, saltLen C.int, ch crypto.Hash,
203216
init initFunc, crypt cryptFunc, in []byte) ([]byte, error) {
204217

205-
ctx, err := setupEVP(withKey, padding, h, label, saltLen, ch, init)
218+
ctx, err := setupEVP(withKey, padding, h, mgfHash, label, saltLen, ch, init)
206219
if err != nil {
207220
return nil, err
208221
}
@@ -225,15 +238,15 @@ func verifyEVP(withKey withKeyFunc, padding C.int,
225238
init initFunc, verify verifyFunc,
226239
sig, in []byte) error {
227240

228-
ctx, err := setupEVP(withKey, padding, h, label, saltLen, ch, init)
241+
ctx, err := setupEVP(withKey, padding, h, nil, label, saltLen, ch, init)
229242
if err != nil {
230243
return err
231244
}
232245
defer C.go_openssl_EVP_PKEY_CTX_free(ctx)
233246
return verify(ctx, base(sig), C.size_t(len(sig)), base(in), C.size_t(len(in)))
234247
}
235248

236-
func evpEncrypt(withKey withKeyFunc, padding C.int, h hash.Hash, label, msg []byte) ([]byte, error) {
249+
func evpEncrypt(withKey withKeyFunc, padding C.int, h, mgfHash hash.Hash, label, msg []byte) ([]byte, error) {
237250
encryptInit := func(ctx C.GO_EVP_PKEY_CTX_PTR) error {
238251
if ret := C.go_openssl_EVP_PKEY_encrypt_init(ctx); ret != 1 {
239252
return newOpenSSLError("EVP_PKEY_encrypt_init failed")
@@ -246,10 +259,10 @@ func evpEncrypt(withKey withKeyFunc, padding C.int, h hash.Hash, label, msg []by
246259
}
247260
return nil
248261
}
249-
return cryptEVP(withKey, padding, h, label, 0, 0, encryptInit, encrypt, msg)
262+
return cryptEVP(withKey, padding, h, mgfHash, label, 0, 0, encryptInit, encrypt, msg)
250263
}
251264

252-
func evpDecrypt(withKey withKeyFunc, padding C.int, h hash.Hash, label, msg []byte) ([]byte, error) {
265+
func evpDecrypt(withKey withKeyFunc, padding C.int, h, mgfHash hash.Hash, label, msg []byte) ([]byte, error) {
253266
decryptInit := func(ctx C.GO_EVP_PKEY_CTX_PTR) error {
254267
if ret := C.go_openssl_EVP_PKEY_decrypt_init(ctx); ret != 1 {
255268
return newOpenSSLError("EVP_PKEY_decrypt_init failed")
@@ -262,7 +275,7 @@ func evpDecrypt(withKey withKeyFunc, padding C.int, h hash.Hash, label, msg []by
262275
}
263276
return nil
264277
}
265-
return cryptEVP(withKey, padding, h, label, 0, 0, decryptInit, decrypt, msg)
278+
return cryptEVP(withKey, padding, h, mgfHash, label, 0, 0, decryptInit, decrypt, msg)
266279
}
267280

268281
func evpSign(withKey withKeyFunc, padding C.int, saltLen C.int, h crypto.Hash, hashed []byte) ([]byte, error) {
@@ -278,7 +291,7 @@ func evpSign(withKey withKeyFunc, padding C.int, saltLen C.int, h crypto.Hash, h
278291
}
279292
return nil
280293
}
281-
return cryptEVP(withKey, padding, nil, nil, saltLen, h, signtInit, sign, hashed)
294+
return cryptEVP(withKey, padding, nil, nil, nil, saltLen, h, signtInit, sign, hashed)
282295
}
283296

284297
func evpVerify(withKey withKeyFunc, padding C.int, saltLen C.int, h crypto.Hash, sig, hashed []byte) error {

openssl/rsa.go

+8-8
Original file line numberDiff line numberDiff line change
@@ -124,24 +124,24 @@ func (k *PrivateKeyRSA) withKey(f func(C.GO_EVP_PKEY_PTR) C.int) C.int {
124124
return f(k._pkey)
125125
}
126126

127-
func DecryptRSAOAEP(h hash.Hash, priv *PrivateKeyRSA, ciphertext, label []byte) ([]byte, error) {
128-
return evpDecrypt(priv.withKey, C.GO_RSA_PKCS1_OAEP_PADDING, h, label, ciphertext)
127+
func DecryptRSAOAEP(h, mgfHash hash.Hash, priv *PrivateKeyRSA, ciphertext, label []byte) ([]byte, error) {
128+
return evpDecrypt(priv.withKey, C.GO_RSA_PKCS1_OAEP_PADDING, h, mgfHash, label, ciphertext)
129129
}
130130

131-
func EncryptRSAOAEP(h hash.Hash, pub *PublicKeyRSA, msg, label []byte) ([]byte, error) {
132-
return evpEncrypt(pub.withKey, C.GO_RSA_PKCS1_OAEP_PADDING, h, label, msg)
131+
func EncryptRSAOAEP(h, mgfHash hash.Hash, pub *PublicKeyRSA, msg, label []byte) ([]byte, error) {
132+
return evpEncrypt(pub.withKey, C.GO_RSA_PKCS1_OAEP_PADDING, h, mgfHash, label, msg)
133133
}
134134

135135
func DecryptRSAPKCS1(priv *PrivateKeyRSA, ciphertext []byte) ([]byte, error) {
136-
return evpDecrypt(priv.withKey, C.GO_RSA_PKCS1_PADDING, nil, nil, ciphertext)
136+
return evpDecrypt(priv.withKey, C.GO_RSA_PKCS1_PADDING, nil, nil, nil, ciphertext)
137137
}
138138

139139
func EncryptRSAPKCS1(pub *PublicKeyRSA, msg []byte) ([]byte, error) {
140-
return evpEncrypt(pub.withKey, C.GO_RSA_PKCS1_PADDING, nil, nil, msg)
140+
return evpEncrypt(pub.withKey, C.GO_RSA_PKCS1_PADDING, nil, nil, nil, msg)
141141
}
142142

143143
func DecryptRSANoPadding(priv *PrivateKeyRSA, ciphertext []byte) ([]byte, error) {
144-
ret, err := evpDecrypt(priv.withKey, C.GO_RSA_NO_PADDING, nil, nil, ciphertext)
144+
ret, err := evpDecrypt(priv.withKey, C.GO_RSA_NO_PADDING, nil, nil, nil, ciphertext)
145145
if err != nil {
146146
return nil, err
147147
}
@@ -165,7 +165,7 @@ func DecryptRSANoPadding(priv *PrivateKeyRSA, ciphertext []byte) ([]byte, error)
165165
}
166166

167167
func EncryptRSANoPadding(pub *PublicKeyRSA, msg []byte) ([]byte, error) {
168-
return evpEncrypt(pub.withKey, C.GO_RSA_NO_PADDING, nil, nil, msg)
168+
return evpEncrypt(pub.withKey, C.GO_RSA_NO_PADDING, nil, nil, nil, msg)
169169
}
170170

171171
func saltLength(saltLen int, sign bool) (C.int, error) {

openssl/rsa_test.go

+32-4
Original file line numberDiff line numberDiff line change
@@ -41,28 +41,56 @@ func TestEncryptDecryptOAEP(t *testing.T) {
4141
msg := []byte("hi!")
4242
label := []byte("ho!")
4343
priv, pub := newRSAKey(t, 2048)
44-
enc, err := openssl.EncryptRSAOAEP(sha256, pub, msg, label)
44+
enc, err := openssl.EncryptRSAOAEP(sha256, nil, pub, msg, label)
4545
if err != nil {
4646
t.Fatal(err)
4747
}
48-
dec, err := openssl.DecryptRSAOAEP(sha256, priv, enc, label)
48+
dec, err := openssl.DecryptRSAOAEP(sha256, nil, priv, enc, label)
4949
if err != nil {
5050
t.Fatal(err)
5151
}
5252
if !bytes.Equal(dec, msg) {
5353
t.Errorf("got:%x want:%x", dec, msg)
5454
}
55+
sha1 := openssl.NewSHA1()
56+
_, err = openssl.DecryptRSAOAEP(sha1, nil, priv, enc, label)
57+
if err == nil {
58+
t.Error("decrypt failure expected due to hash mismatch")
59+
}
60+
}
61+
62+
func TestEncryptDecryptOAEP_WithMGF1Hash(t *testing.T) {
63+
sha1 := openssl.NewSHA1()
64+
sha256 := openssl.NewSHA256()
65+
msg := []byte("hi!")
66+
label := []byte("ho!")
67+
priv, pub := newRSAKey(t, 2048)
68+
enc, err := openssl.EncryptRSAOAEP(sha256, sha1, pub, msg, label)
69+
if err != nil {
70+
t.Fatal(err)
71+
}
72+
dec, err := openssl.DecryptRSAOAEP(sha256, sha1, priv, enc, label)
73+
if err != nil {
74+
t.Fatal(err)
75+
}
76+
if !bytes.Equal(dec, msg) {
77+
t.Errorf("got:%x want:%x", dec, msg)
78+
}
79+
_, err = openssl.DecryptRSAOAEP(sha256, sha256, priv, enc, label)
80+
if err == nil {
81+
t.Error("decrypt failure expected due to mgf1 hash mismatch")
82+
}
5583
}
5684

5785
func TestEncryptDecryptOAEP_WrongLabel(t *testing.T) {
5886
sha256 := openssl.NewSHA256()
5987
msg := []byte("hi!")
6088
priv, pub := newRSAKey(t, 2048)
61-
enc, err := openssl.EncryptRSAOAEP(sha256, pub, msg, []byte("ho!"))
89+
enc, err := openssl.EncryptRSAOAEP(sha256, nil, pub, msg, []byte("ho!"))
6290
if err != nil {
6391
t.Fatal(err)
6492
}
65-
dec, err := openssl.DecryptRSAOAEP(sha256, priv, enc, []byte("wrong!"))
93+
dec, err := openssl.DecryptRSAOAEP(sha256, nil, priv, enc, []byte("wrong!"))
6694
if err == nil {
6795
t.Errorf("error expected")
6896
}

openssl/shims.h

+1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ enum {
7575
GO_EVP_PKEY_CTRL_RSA_PADDING = 0x1001,
7676
GO_EVP_PKEY_CTRL_RSA_PSS_SALTLEN = 0x1002,
7777
GO_EVP_PKEY_CTRL_RSA_KEYGEN_BITS = 0x1003,
78+
GO_EVP_PKEY_CTRL_RSA_MGF1_MD = 0x1005,
7879
GO_EVP_PKEY_CTRL_RSA_OAEP_MD = 0x1009,
7980
GO_EVP_PKEY_CTRL_RSA_OAEP_LABEL = 0x100A
8081
};

0 commit comments

Comments
 (0)