@@ -94,7 +94,7 @@ type cryptFunc func(C.GO_EVP_PKEY_CTX_PTR, *C.uchar, *C.size_t, *C.uchar, C.size
9494type verifyFunc func (C.GO_EVP_PKEY_CTX_PTR , * C.uchar , C.size_t , * C.uchar , C.size_t ) error
9595
9696func setupEVP (withKey withKeyFunc , padding C.int ,
97- h hash.Hash , label []byte , saltLen C.int , ch crypto.Hash ,
97+ h , mgfHash hash.Hash , label []byte , saltLen C.int , ch crypto.Hash ,
9898 init initFunc ) (ctx C.GO_EVP_PKEY_CTX_PTR , err error ) {
9999 defer func () {
100100 if err != nil {
@@ -132,13 +132,26 @@ func setupEVP(withKey withKeyFunc, padding C.int,
132132 if md == nil {
133133 return nil , errors .New ("crypto/rsa: unsupported hash function" )
134134 }
135+ var mgfMD C.GO_EVP_MD_PTR
136+ if mgfHash != nil {
137+ // mgfHash is optional, but if it is set it must match a supported hash function.
138+ mgfMD = hashToMD (mgfHash )
139+ if mgfMD == nil {
140+ return nil , errors .New ("crypto/rsa: unsupported hash function" )
141+ }
142+ }
135143 // setPadding must happen before setting EVP_PKEY_CTRL_RSA_OAEP_MD.
136144 if err := setPadding (); err != nil {
137145 return nil , err
138146 }
139147 if C .go_openssl_EVP_PKEY_CTX_ctrl (ctx , C .GO_EVP_PKEY_RSA , - 1 , C .GO_EVP_PKEY_CTRL_RSA_OAEP_MD , 0 , unsafe .Pointer (md )) != 1 {
140148 return nil , newOpenSSLError ("EVP_PKEY_CTX_ctrl failed" )
141149 }
150+ if mgfHash != nil {
151+ if C .go_openssl_EVP_PKEY_CTX_ctrl (ctx , C .GO_EVP_PKEY_RSA , - 1 , C .GO_EVP_PKEY_CTRL_RSA_MGF1_MD , 0 , unsafe .Pointer (mgfMD )) != 1 {
152+ return nil , newOpenSSLError ("EVP_PKEY_CTX_ctrl failed" )
153+ }
154+ }
142155 // ctx takes ownership of label, so malloc a copy for OpenSSL to free.
143156 // OpenSSL 1.1.1 and higher does not take ownership of the label if the length is zero,
144157 // so better avoid the allocation.
@@ -199,10 +212,10 @@ func setupEVP(withKey withKeyFunc, padding C.int,
199212}
200213
201214func cryptEVP (withKey withKeyFunc , padding C.int ,
202- h hash.Hash , label []byte , saltLen C.int , ch crypto.Hash ,
215+ h , mgfHash hash.Hash , label []byte , saltLen C.int , ch crypto.Hash ,
203216 init initFunc , crypt cryptFunc , in []byte ) ([]byte , error ) {
204217
205- ctx , err := setupEVP (withKey , padding , h , label , saltLen , ch , init )
218+ ctx , err := setupEVP (withKey , padding , h , mgfHash , label , saltLen , ch , init )
206219 if err != nil {
207220 return nil , err
208221 }
@@ -225,15 +238,15 @@ func verifyEVP(withKey withKeyFunc, padding C.int,
225238 init initFunc , verify verifyFunc ,
226239 sig , in []byte ) error {
227240
228- ctx , err := setupEVP (withKey , padding , h , label , saltLen , ch , init )
241+ ctx , err := setupEVP (withKey , padding , h , nil , label , saltLen , ch , init )
229242 if err != nil {
230243 return err
231244 }
232245 defer C .go_openssl_EVP_PKEY_CTX_free (ctx )
233246 return verify (ctx , base (sig ), C .size_t (len (sig )), base (in ), C .size_t (len (in )))
234247}
235248
236- func evpEncrypt (withKey withKeyFunc , padding C.int , h hash.Hash , label , msg []byte ) ([]byte , error ) {
249+ func evpEncrypt (withKey withKeyFunc , padding C.int , h , mgfHash hash.Hash , label , msg []byte ) ([]byte , error ) {
237250 encryptInit := func (ctx C.GO_EVP_PKEY_CTX_PTR ) error {
238251 if ret := C .go_openssl_EVP_PKEY_encrypt_init (ctx ); ret != 1 {
239252 return newOpenSSLError ("EVP_PKEY_encrypt_init failed" )
@@ -246,10 +259,10 @@ func evpEncrypt(withKey withKeyFunc, padding C.int, h hash.Hash, label, msg []by
246259 }
247260 return nil
248261 }
249- return cryptEVP (withKey , padding , h , label , 0 , 0 , encryptInit , encrypt , msg )
262+ return cryptEVP (withKey , padding , h , mgfHash , label , 0 , 0 , encryptInit , encrypt , msg )
250263}
251264
252- func evpDecrypt (withKey withKeyFunc , padding C.int , h hash.Hash , label , msg []byte ) ([]byte , error ) {
265+ func evpDecrypt (withKey withKeyFunc , padding C.int , h , mgfHash hash.Hash , label , msg []byte ) ([]byte , error ) {
253266 decryptInit := func (ctx C.GO_EVP_PKEY_CTX_PTR ) error {
254267 if ret := C .go_openssl_EVP_PKEY_decrypt_init (ctx ); ret != 1 {
255268 return newOpenSSLError ("EVP_PKEY_decrypt_init failed" )
@@ -262,7 +275,7 @@ func evpDecrypt(withKey withKeyFunc, padding C.int, h hash.Hash, label, msg []by
262275 }
263276 return nil
264277 }
265- return cryptEVP (withKey , padding , h , label , 0 , 0 , decryptInit , decrypt , msg )
278+ return cryptEVP (withKey , padding , h , mgfHash , label , 0 , 0 , decryptInit , decrypt , msg )
266279}
267280
268281func evpSign (withKey withKeyFunc , padding C.int , saltLen C.int , h crypto.Hash , hashed []byte ) ([]byte , error ) {
@@ -278,7 +291,7 @@ func evpSign(withKey withKeyFunc, padding C.int, saltLen C.int, h crypto.Hash, h
278291 }
279292 return nil
280293 }
281- return cryptEVP (withKey , padding , nil , nil , saltLen , h , signtInit , sign , hashed )
294+ return cryptEVP (withKey , padding , nil , nil , nil , saltLen , h , signtInit , sign , hashed )
282295}
283296
284297func evpVerify (withKey withKeyFunc , padding C.int , saltLen C.int , h crypto.Hash , sig , hashed []byte ) error {
0 commit comments