@@ -7,6 +7,7 @@ import "C"
7
7
import (
8
8
"errors"
9
9
"runtime"
10
+ "slices"
10
11
"unsafe"
11
12
)
12
13
@@ -20,49 +21,44 @@ func (k *PublicKeyECDH) finalize() {
20
21
}
21
22
22
23
type PrivateKeyECDH struct {
23
- _pkey C.GO_EVP_PKEY_PTR
24
- curve string
25
- hasPublicKey bool
24
+ _pkey C.GO_EVP_PKEY_PTR
25
+ curve string
26
26
}
27
27
28
28
func (k * PrivateKeyECDH ) finalize () {
29
29
C .go_openssl_EVP_PKEY_free (k ._pkey )
30
30
}
31
31
32
32
func NewPublicKeyECDH (curve string , bytes []byte ) (* PublicKeyECDH , error ) {
33
- if len (bytes ) < 1 {
34
- return nil , errors .New ("NewPublicKeyECDH: missing key" )
33
+ if len (bytes ) != 1 + 2 * curveSize ( curve ) {
34
+ return nil , errors .New ("NewPublicKeyECDH: wrong key length " )
35
35
}
36
36
pkey , err := newECDHPkey (curve , bytes , false )
37
37
if err != nil {
38
38
return nil , err
39
39
}
40
- k := & PublicKeyECDH {pkey , append ([] byte ( nil ), bytes ... )}
40
+ k := & PublicKeyECDH {pkey , slices . Clone ( bytes )}
41
41
runtime .SetFinalizer (k , (* PublicKeyECDH ).finalize )
42
42
return k , nil
43
43
}
44
44
45
45
func (k * PublicKeyECDH ) Bytes () []byte { return k .bytes }
46
46
47
47
func NewPrivateKeyECDH (curve string , bytes []byte ) (* PrivateKeyECDH , error ) {
48
+ if len (bytes ) != curveSize (curve ) {
49
+ return nil , errors .New ("NewPrivateKeyECDH: wrong key length" )
50
+ }
48
51
pkey , err := newECDHPkey (curve , bytes , true )
49
52
if err != nil {
50
53
return nil , err
51
54
}
52
- k := & PrivateKeyECDH {pkey , curve , false }
55
+ k := & PrivateKeyECDH {pkey , curve }
53
56
runtime .SetFinalizer (k , (* PrivateKeyECDH ).finalize )
54
57
return k , nil
55
58
}
56
59
57
60
func (k * PrivateKeyECDH ) PublicKey () (* PublicKeyECDH , error ) {
58
61
defer runtime .KeepAlive (k )
59
- if ! k .hasPublicKey {
60
- err := deriveEcdhPublicKey (k ._pkey , k .curve )
61
- if err != nil {
62
- return nil , err
63
- }
64
- k .hasPublicKey = true
65
- }
66
62
var pkey C.GO_EVP_PKEY_PTR
67
63
defer func () {
68
64
C .go_openssl_EVP_PKEY_free (pkey )
@@ -112,10 +108,7 @@ func (k *PrivateKeyECDH) PublicKey() (*PublicKeyECDH, error) {
112
108
}
113
109
114
110
func newECDHPkey (curve string , bytes []byte , isPrivate bool ) (C.GO_EVP_PKEY_PTR , error ) {
115
- nid , err := curveNID (curve )
116
- if err != nil {
117
- return nil , err
118
- }
111
+ nid := curveNID (curve )
119
112
switch vMajor {
120
113
case 1 :
121
114
return newECDHPkey1 (nid , bytes , isPrivate )
@@ -138,6 +131,7 @@ func newECDHPkey1(nid C.int, bytes []byte, isPrivate bool) (pkey C.GO_EVP_PKEY_P
138
131
C .go_openssl_EC_KEY_free (key )
139
132
}
140
133
}()
134
+ group := C .go_openssl_EC_KEY_get0_group (key )
141
135
if isPrivate {
142
136
priv := C .go_openssl_BN_bin2bn (base (bytes ), C .int (len (bytes )), nil )
143
137
if priv == nil {
@@ -147,8 +141,15 @@ func newECDHPkey1(nid C.int, bytes []byte, isPrivate bool) (pkey C.GO_EVP_PKEY_P
147
141
if C .go_openssl_EC_KEY_set_private_key (key , priv ) != 1 {
148
142
return nil , newOpenSSLError ("EC_KEY_set_private_key" )
149
143
}
144
+ pub , err := pointMult (group , priv )
145
+ if err != nil {
146
+ return nil , err
147
+ }
148
+ defer C .go_openssl_EC_POINT_free (pub )
149
+ if C .go_openssl_EC_KEY_set_public_key (key , pub ) != 1 {
150
+ return nil , newOpenSSLError ("EC_KEY_set_public_key" )
151
+ }
150
152
} else {
151
- group := C .go_openssl_EC_KEY_get0_group (key )
152
153
pub := C .go_openssl_EC_POINT_new (group )
153
154
if pub == nil {
154
155
return nil , newOpenSSLError ("EC_POINT_new" )
@@ -161,6 +162,14 @@ func newECDHPkey1(nid C.int, bytes []byte, isPrivate bool) (pkey C.GO_EVP_PKEY_P
161
162
return nil , newOpenSSLError ("EC_KEY_set_public_key" )
162
163
}
163
164
}
165
+ if C .go_openssl_EC_KEY_check_key (key ) != 1 {
166
+ // Match upstream error message.
167
+ if isPrivate {
168
+ return nil , errors .New ("crypto/ecdh: invalid private key" )
169
+ } else {
170
+ return nil , errors .New ("crypto/ecdh: invalid public key" )
171
+ }
172
+ }
164
173
return newEVPPKEY (key )
165
174
}
166
175
@@ -175,7 +184,19 @@ func newECDHPkey3(nid C.int, bytes []byte, isPrivate bool) (C.GO_EVP_PKEY_PTR, e
175
184
bld .addUTF8String (_OSSL_PKEY_PARAM_GROUP_NAME , C .go_openssl_OBJ_nid2sn (nid ), 0 )
176
185
var selection C.int
177
186
if isPrivate {
178
- bld .addBin (_OSSL_PKEY_PARAM_PRIV_KEY , bytes , true )
187
+ priv := C .go_openssl_BN_bin2bn (base (bytes ), C .int (len (bytes )), nil )
188
+ if priv == nil {
189
+ return nil , newOpenSSLError ("BN_bin2bn" )
190
+ }
191
+ defer C .go_openssl_BN_clear_free (priv )
192
+ pubBytes , err := generateAndEncodeEcPublicKey (nid , func (group C.GO_EC_GROUP_PTR ) (C.GO_EC_POINT_PTR , error ) {
193
+ return pointMult (group , priv )
194
+ })
195
+ if err != nil {
196
+ return nil , err
197
+ }
198
+ bld .addOctetString (_OSSL_PKEY_PARAM_PUB_KEY , pubBytes )
199
+ bld .addBN (_OSSL_PKEY_PARAM_PRIV_KEY , priv )
179
200
selection = C .GO_EVP_PKEY_KEYPAIR
180
201
} else {
181
202
bld .addOctetString (_OSSL_PKEY_PARAM_PUB_KEY , bytes )
@@ -187,62 +208,31 @@ func newECDHPkey3(nid C.int, bytes []byte, isPrivate bool) (C.GO_EVP_PKEY_PTR, e
187
208
return nil , err
188
209
}
189
210
defer C .go_openssl_OSSL_PARAM_free (params )
190
- return newEvpFromParams (C .GO_EVP_PKEY_EC , selection , params )
211
+ pkey , err := newEvpFromParams (C .GO_EVP_PKEY_EC , selection , params )
212
+ if err != nil {
213
+ return nil , err
214
+ }
215
+
216
+ if err := checkPkey (pkey , isPrivate ); err != nil {
217
+ C .go_openssl_EVP_PKEY_free (pkey )
218
+ return nil , errors .New ("crypto/ecdh: " + err .Error ())
219
+ }
220
+ return pkey , nil
191
221
}
192
222
193
- // deriveEcdhPublicKey sets the raw public key of pkey by deriving it from
194
- // the raw private key.
195
- func deriveEcdhPublicKey (pkey C.GO_EVP_PKEY_PTR , curve string ) error {
196
- derive := func (group C.GO_EC_GROUP_PTR , priv C.GO_BIGNUM_PTR ) (C.GO_EC_POINT_PTR , error ) {
197
- // OpenSSL does not expose any method to generate the public
198
- // key from the private key [1], so we have to calculate it here.
199
- // [1] https://github.com/openssl/openssl/issues/18437#issuecomment-1144717206
200
- pt := C .go_openssl_EC_POINT_new (group )
201
- if pt == nil {
202
- return nil , newOpenSSLError ("EC_POINT_new" )
203
- }
204
- if C .go_openssl_EC_POINT_mul (group , pt , priv , nil , nil , nil ) == 0 {
205
- C .go_openssl_EC_POINT_free (pt )
206
- return nil , newOpenSSLError ("EC_POINT_mul" )
207
- }
208
- return pt , nil
223
+ func pointMult (group C.GO_EC_GROUP_PTR , priv C.GO_BIGNUM_PTR ) (C.GO_EC_POINT_PTR , error ) {
224
+ // OpenSSL does not expose any method to generate the public
225
+ // key from the private key [1], so we have to calculate it here.
226
+ // [1] https://github.com/openssl/openssl/issues/18437#issuecomment-1144717206
227
+ pt := C .go_openssl_EC_POINT_new (group )
228
+ if pt == nil {
229
+ return nil , newOpenSSLError ("EC_POINT_new" )
209
230
}
210
- switch vMajor {
211
- case 1 :
212
- key := getECKey (pkey )
213
- priv := C .go_openssl_EC_KEY_get0_private_key (key )
214
- if priv == nil {
215
- return newOpenSSLError ("EC_KEY_get0_private_key" )
216
- }
217
- group := C .go_openssl_EC_KEY_get0_group (key )
218
- pub , err := derive (group , priv )
219
- if err != nil {
220
- return err
221
- }
222
- defer C .go_openssl_EC_POINT_free (pub )
223
- if C .go_openssl_EC_KEY_set_public_key (key , pub ) != 1 {
224
- return newOpenSSLError ("EC_KEY_set_public_key" )
225
- }
226
- case 3 :
227
- var priv C.GO_BIGNUM_PTR
228
- if C .go_openssl_EVP_PKEY_get_bn_param (pkey , _OSSL_PKEY_PARAM_PRIV_KEY , & priv ) != 1 {
229
- return newOpenSSLError ("EVP_PKEY_get_bn_param" )
230
- }
231
- defer C .go_openssl_BN_clear_free (priv )
232
- nid , _ := curveNID (curve )
233
- pubBytes , err := generateAndEncodeEcPublicKey (nid , func (group C.GO_EC_GROUP_PTR ) (C.GO_EC_POINT_PTR , error ) {
234
- return derive (group , priv )
235
- })
236
- if err != nil {
237
- return err
238
- }
239
- if C .go_openssl_EVP_PKEY_set1_encoded_public_key (pkey , base (pubBytes ), C .size_t (len (pubBytes ))) != 1 {
240
- return newOpenSSLError ("EVP_PKEY_set1_encoded_public_key" )
241
- }
242
- default :
243
- panic (errUnsupportedVersion ())
231
+ if C .go_openssl_EC_POINT_mul (group , pt , priv , nil , nil , nil ) == 0 {
232
+ C .go_openssl_EC_POINT_free (pt )
233
+ return nil , newOpenSSLError ("EC_POINT_mul" )
244
234
}
245
- return nil
235
+ return pt , nil
246
236
}
247
237
248
238
func ECDH (priv * PrivateKeyECDH , pub * PublicKeyECDH ) ([]byte , error ) {
@@ -307,7 +297,7 @@ func GenerateKeyECDH(curve string) (*PrivateKeyECDH, []byte, error) {
307
297
if err := bnToBinPad (priv , bytes ); err != nil {
308
298
return nil , nil , err
309
299
}
310
- k = & PrivateKeyECDH {pkey , curve , true }
300
+ k = & PrivateKeyECDH {pkey , curve }
311
301
runtime .SetFinalizer (k , (* PrivateKeyECDH ).finalize )
312
302
return k , bytes , nil
313
303
}
0 commit comments