Skip to content

Commit

Permalink
Only omit RSA primes if precomputed values are missing in OpenSSL 3.0…
Browse files Browse the repository at this point in the history
… and 3.1 (#163)

* only allow omission of precomputed RSA values in OpenSSL 3.0 and 3.1

* check if rsa primes and precomputed values exists

* make logic more clear

* Update rsa.go

Co-authored-by: Davis Goodin <[email protected]>

* add tests

* exhaustively tests messing precomputed values

---------

Co-authored-by: Davis Goodin <[email protected]>
  • Loading branch information
qmuntal and dagood authored Sep 10, 2024
1 parent 9dbdc19 commit 6c842d7
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 20 deletions.
25 changes: 17 additions & 8 deletions rsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -396,15 +396,24 @@ func newRSAKey3(isPriv bool, n, e, d, p, q, dp, dq, qinv BigInt) (C.GO_EVP_PKEY_
}
comps = append(comps, required[:]...)

// OpenSSL 3.0 and 3.1 required all the precomputed values if
// P and Q are present. See:
// https://github.com/openssl/openssl/pull/22334
if p != nil && q != nil && dp != nil && dq != nil && qinv != nil {
precomputed := [...]bigIntParam{
{OSSL_PKEY_PARAM_RSA_FACTOR1, p}, {OSSL_PKEY_PARAM_RSA_FACTOR2, q},
{OSSL_PKEY_PARAM_RSA_EXPONENT1, dp}, {OSSL_PKEY_PARAM_RSA_EXPONENT2, dq}, {OSSL_PKEY_PARAM_RSA_COEFFICIENT1, qinv},
if p != nil && q != nil {
allPrecomputedExists := dp != nil && dq != nil && qinv != nil
// The precomputed values should only be passed if P and Q are present
// and every precomputed value is present. (If any precomputed value is
// missing, don't pass any of them.)
//
// In OpenSSL 3.0 and 3.1, we must also omit P and Q if any precomputed
// value is missing. See https://github.com/openssl/openssl/pull/22334
if vMinor >= 2 || allPrecomputedExists {
comps = append(comps, bigIntParam{OSSL_PKEY_PARAM_RSA_FACTOR1, p}, bigIntParam{OSSL_PKEY_PARAM_RSA_FACTOR2, q})
}
if allPrecomputedExists {
comps = append(comps,
bigIntParam{OSSL_PKEY_PARAM_RSA_EXPONENT1, dp},
bigIntParam{OSSL_PKEY_PARAM_RSA_EXPONENT2, dq},
bigIntParam{OSSL_PKEY_PARAM_RSA_COEFFICIENT1, qinv},
)
}
comps = append(comps, precomputed[:]...)
}

for _, comp := range comps {
Expand Down
97 changes: 85 additions & 12 deletions rsa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"crypto"
"crypto/rsa"
"fmt"
"math/big"
"strconv"
"testing"
Expand All @@ -14,9 +15,78 @@ import (

func TestRSAKeyGeneration(t *testing.T) {
for _, size := range []int{2048, 3072} {
t.Run(strconv.Itoa(size), func(t *testing.T) {
t.Parallel()
_, _, _, _, _, _, _, _, err := openssl.GenerateKeyRSA(size)
if err != nil {
t.Fatal(err)
}
})
}
}

func testRSAEncryptDecryptPKCS1(t *testing.T, priv *openssl.PrivateKeyRSA, pub *openssl.PublicKeyRSA) {
msg := []byte("hi!")
enc, err := openssl.EncryptRSAPKCS1(pub, msg)
if err != nil {
t.Fatalf("EncryptPKCS1v15: %v", err)
}
dec, err := openssl.DecryptRSAPKCS1(priv, enc)
if err != nil {
t.Fatalf("DecryptPKCS1v15: %v", err)
}
if !bytes.Equal(dec, msg) {
t.Fatalf("got:%x want:%x", dec, msg)
}
}

func TestRSAEncryptDecryptPKCS1(t *testing.T) {
for _, size := range []int{2048, 3072} {
size := size
t.Run(strconv.Itoa(size), func(t *testing.T) {
t.Parallel()
priv, pub := newRSAKey(t, size)
testRSAEncryptDecryptPKCS1(t, priv, pub)
})
}
}

func TestRSAEncryptDecryptPKCS1_MissingPrecomputedValues(t *testing.T) {
n, e, d, p, q, dp, dq, qinv, err := openssl.GenerateKeyRSA(2048)
if err != nil {
t.Fatalf("GenerateKeyRSA: %v", err)
}
tt := []struct {
withDp bool
withDq bool
withQinv bool
}{
{true, true, false},
{true, false, true},
{false, true, true},
{false, false, false},
{false, false, true},
{false, true, false},
{true, false, false},
{true, true, true},
}
for _, tt := range tt {
tt := tt
t.Run(fmt.Sprintf("dp=%v,dq=%v,qinv=%v", tt.withDp, tt.withDq, tt.withQinv), func(t *testing.T) {
t.Parallel()
dp1, dq1, qinv1 := dp, dq, qinv
if !tt.withDp {
dp1 = nil
}
if !tt.withDq {
dq1 = nil
}
if !tt.withQinv {
qinv1 = nil
}

priv, pub := newRSAKeyFromParams(t, n, e, d, p, q, dp1, dq1, qinv1)
testRSAEncryptDecryptPKCS1(t, priv, pub)
msg := []byte("hi!")
enc, err := openssl.EncryptRSAPKCS1(pub, msg)
if err != nil {
Expand All @@ -33,7 +103,7 @@ func TestRSAKeyGeneration(t *testing.T) {
}
}

func TestEncryptDecryptOAEP(t *testing.T) {
func TestRSAEncryptDecryptOAEP(t *testing.T) {
sha256 := openssl.NewSHA256()
msg := []byte("hi!")
label := []byte("ho!")
Expand All @@ -56,7 +126,7 @@ func TestEncryptDecryptOAEP(t *testing.T) {
}
}

func TestEncryptDecryptOAEP_EmptyLabel(t *testing.T) {
func TestRSAEncryptDecryptOAEP_EmptyLabel(t *testing.T) {
sha256 := openssl.NewSHA256()
msg := []byte("hi!")
label := []byte("")
Expand All @@ -79,7 +149,7 @@ func TestEncryptDecryptOAEP_EmptyLabel(t *testing.T) {
}
}

func TestEncryptDecryptOAEP_WithMGF1Hash(t *testing.T) {
func TestRSAEncryptDecryptOAEP_WithMGF1Hash(t *testing.T) {
if openssl.SymCryptProviderAvailable() {
t.Skip("SymCrypt provider does not support MGF1 hash")
}
Expand All @@ -106,7 +176,7 @@ func TestEncryptDecryptOAEP_WithMGF1Hash(t *testing.T) {
}
}

func TestEncryptDecryptOAEP_WrongLabel(t *testing.T) {
func TestRSAEncryptDecryptOAEP_WrongLabel(t *testing.T) {
sha256 := openssl.NewSHA256()
msg := []byte("hi!")
priv, pub := newRSAKey(t, 2048)
Expand All @@ -123,7 +193,7 @@ func TestEncryptDecryptOAEP_WrongLabel(t *testing.T) {
}
}

func TestSignVerifyPKCS1v15(t *testing.T) {
func TestRSASignVerifyPKCS1v15(t *testing.T) {
sha256 := openssl.NewSHA256()
priv, pub := newRSAKey(t, 2048)
msg := []byte("hi!")
Expand All @@ -150,7 +220,7 @@ func TestSignVerifyPKCS1v15(t *testing.T) {
}
}

func TestSignVerifyPKCS1v15_Unhashed(t *testing.T) {
func TestRSASignVerifyPKCS1v15_Unhashed(t *testing.T) {
if openssl.SymCryptProviderAvailable() {
t.Skip("SymCrypt provider does not support unhashed PKCS1v15")
}
Expand All @@ -167,7 +237,7 @@ func TestSignVerifyPKCS1v15_Unhashed(t *testing.T) {
}
}

func TestSignVerifyPKCS1v15_Invalid(t *testing.T) {
func TestRSASignVerifyPKCS1v15_Invalid(t *testing.T) {
sha256 := openssl.NewSHA256()
msg := []byte("hi!")
priv, pub := newRSAKey(t, 2048)
Expand All @@ -183,7 +253,7 @@ func TestSignVerifyPKCS1v15_Invalid(t *testing.T) {
}
}

func TestSignVerifyRSAPSS(t *testing.T) {
func TestRSASignVerifyRSAPSS(t *testing.T) {
// Test cases taken from
// https://github.com/golang/go/blob/54182ff54a687272dd7632c3a963e036ce03cb7c/src/crypto/rsa/pss_test.go#L200.
const keyBits = 2048
Expand Down Expand Up @@ -224,15 +294,18 @@ func newRSAKey(t *testing.T, size int) (*openssl.PrivateKeyRSA, *openssl.PublicK
if err != nil {
t.Fatalf("GenerateKeyRSA(%d): %v", size, err)
}
// Exercise omission of precomputed value
Dp = nil
return newRSAKeyFromParams(t, N, E, D, P, Q, Dp, Dq, Qinv)
}

func newRSAKeyFromParams(t *testing.T, N, E, D, P, Q, Dp, Dq, Qinv openssl.BigInt) (*openssl.PrivateKeyRSA, *openssl.PublicKeyRSA) {
t.Helper()
priv, err := openssl.NewPrivateKeyRSA(N, E, D, P, Q, Dp, Dq, Qinv)
if err != nil {
t.Fatalf("NewPrivateKeyRSA(%d): %v", size, err)
t.Fatalf("NewPrivateKeyRSA: %v", err)
}
pub, err := openssl.NewPublicKeyRSA(N, E)
if err != nil {
t.Fatalf("NewPublicKeyRSA(%d): %v", size, err)
t.Fatalf("NewPublicKeyRSA: %v", err)
}
return priv, pub
}
Expand Down

0 comments on commit 6c842d7

Please sign in to comment.