@@ -94,7 +94,7 @@ type cryptFunc func(C.GO_EVP_PKEY_CTX_PTR, *C.uchar, *C.size_t, *C.uchar, C.size
94
94
type verifyFunc func (C.GO_EVP_PKEY_CTX_PTR , * C.uchar , C.size_t , * C.uchar , C.size_t ) error
95
95
96
96
func setupEVP (withKey withKeyFunc , padding C.int ,
97
- h hash.Hash , label []byte , saltLen C.int , ch crypto.Hash ,
97
+ h , mgfHash hash.Hash , label []byte , saltLen C.int , ch crypto.Hash ,
98
98
init initFunc ) (ctx C.GO_EVP_PKEY_CTX_PTR , err error ) {
99
99
defer func () {
100
100
if err != nil {
@@ -132,13 +132,26 @@ func setupEVP(withKey withKeyFunc, padding C.int,
132
132
if md == nil {
133
133
return nil , errors .New ("crypto/rsa: unsupported hash function" )
134
134
}
135
+ var mgfMD C.GO_EVP_MD_PTR
136
+ if mgfHash != nil {
137
+ // mgfHash is optional, but if it is set it must match a supported hash function.
138
+ mgfMD = hashToMD (mgfHash )
139
+ if mgfMD == nil {
140
+ return nil , errors .New ("crypto/rsa: unsupported hash function" )
141
+ }
142
+ }
135
143
// setPadding must happen before setting EVP_PKEY_CTRL_RSA_OAEP_MD.
136
144
if err := setPadding (); err != nil {
137
145
return nil , err
138
146
}
139
147
if C .go_openssl_EVP_PKEY_CTX_ctrl (ctx , C .GO_EVP_PKEY_RSA , - 1 , C .GO_EVP_PKEY_CTRL_RSA_OAEP_MD , 0 , unsafe .Pointer (md )) != 1 {
140
148
return nil , newOpenSSLError ("EVP_PKEY_CTX_ctrl failed" )
141
149
}
150
+ if mgfHash != nil {
151
+ if C .go_openssl_EVP_PKEY_CTX_ctrl (ctx , C .GO_EVP_PKEY_RSA , - 1 , C .GO_EVP_PKEY_CTRL_RSA_MGF1_MD , 0 , unsafe .Pointer (mgfMD )) != 1 {
152
+ return nil , newOpenSSLError ("EVP_PKEY_CTX_ctrl failed" )
153
+ }
154
+ }
142
155
// ctx takes ownership of label, so malloc a copy for OpenSSL to free.
143
156
// OpenSSL 1.1.1 and higher does not take ownership of the label if the length is zero,
144
157
// so better avoid the allocation.
@@ -199,10 +212,10 @@ func setupEVP(withKey withKeyFunc, padding C.int,
199
212
}
200
213
201
214
func cryptEVP (withKey withKeyFunc , padding C.int ,
202
- h hash.Hash , label []byte , saltLen C.int , ch crypto.Hash ,
215
+ h , mgfHash hash.Hash , label []byte , saltLen C.int , ch crypto.Hash ,
203
216
init initFunc , crypt cryptFunc , in []byte ) ([]byte , error ) {
204
217
205
- ctx , err := setupEVP (withKey , padding , h , label , saltLen , ch , init )
218
+ ctx , err := setupEVP (withKey , padding , h , mgfHash , label , saltLen , ch , init )
206
219
if err != nil {
207
220
return nil , err
208
221
}
@@ -225,15 +238,15 @@ func verifyEVP(withKey withKeyFunc, padding C.int,
225
238
init initFunc , verify verifyFunc ,
226
239
sig , in []byte ) error {
227
240
228
- ctx , err := setupEVP (withKey , padding , h , label , saltLen , ch , init )
241
+ ctx , err := setupEVP (withKey , padding , h , nil , label , saltLen , ch , init )
229
242
if err != nil {
230
243
return err
231
244
}
232
245
defer C .go_openssl_EVP_PKEY_CTX_free (ctx )
233
246
return verify (ctx , base (sig ), C .size_t (len (sig )), base (in ), C .size_t (len (in )))
234
247
}
235
248
236
- func evpEncrypt (withKey withKeyFunc , padding C.int , h hash.Hash , label , msg []byte ) ([]byte , error ) {
249
+ func evpEncrypt (withKey withKeyFunc , padding C.int , h , mgfHash hash.Hash , label , msg []byte ) ([]byte , error ) {
237
250
encryptInit := func (ctx C.GO_EVP_PKEY_CTX_PTR ) error {
238
251
if ret := C .go_openssl_EVP_PKEY_encrypt_init (ctx ); ret != 1 {
239
252
return newOpenSSLError ("EVP_PKEY_encrypt_init failed" )
@@ -246,10 +259,10 @@ func evpEncrypt(withKey withKeyFunc, padding C.int, h hash.Hash, label, msg []by
246
259
}
247
260
return nil
248
261
}
249
- return cryptEVP (withKey , padding , h , label , 0 , 0 , encryptInit , encrypt , msg )
262
+ return cryptEVP (withKey , padding , h , mgfHash , label , 0 , 0 , encryptInit , encrypt , msg )
250
263
}
251
264
252
- func evpDecrypt (withKey withKeyFunc , padding C.int , h hash.Hash , label , msg []byte ) ([]byte , error ) {
265
+ func evpDecrypt (withKey withKeyFunc , padding C.int , h , mgfHash hash.Hash , label , msg []byte ) ([]byte , error ) {
253
266
decryptInit := func (ctx C.GO_EVP_PKEY_CTX_PTR ) error {
254
267
if ret := C .go_openssl_EVP_PKEY_decrypt_init (ctx ); ret != 1 {
255
268
return newOpenSSLError ("EVP_PKEY_decrypt_init failed" )
@@ -262,7 +275,7 @@ func evpDecrypt(withKey withKeyFunc, padding C.int, h hash.Hash, label, msg []by
262
275
}
263
276
return nil
264
277
}
265
- return cryptEVP (withKey , padding , h , label , 0 , 0 , decryptInit , decrypt , msg )
278
+ return cryptEVP (withKey , padding , h , mgfHash , label , 0 , 0 , decryptInit , decrypt , msg )
266
279
}
267
280
268
281
func evpSign (withKey withKeyFunc , padding C.int , saltLen C.int , h crypto.Hash , hashed []byte ) ([]byte , error ) {
@@ -278,7 +291,7 @@ func evpSign(withKey withKeyFunc, padding C.int, saltLen C.int, h crypto.Hash, h
278
291
}
279
292
return nil
280
293
}
281
- return cryptEVP (withKey , padding , nil , nil , saltLen , h , signtInit , sign , hashed )
294
+ return cryptEVP (withKey , padding , nil , nil , nil , saltLen , h , signtInit , sign , hashed )
282
295
}
283
296
284
297
func evpVerify (withKey withKeyFunc , padding C.int , saltLen C.int , h crypto.Hash , sig , hashed []byte ) error {
0 commit comments