9
9
"hash"
10
10
"io"
11
11
"runtime"
12
+ "sync"
12
13
"unsafe"
13
14
)
14
15
@@ -18,88 +19,72 @@ func SupportsHKDF() bool {
18
19
case 1 :
19
20
return versionAtOrAbove (1 , 1 , 1 )
20
21
case 3 :
21
- // Some OpenSSL 3 providers don't support HKDF or don't support it via
22
- // the EVP_PKEY API, which is the one we use.
23
- // See https://github.com/golang-fips/openssl/issues/189.
24
- ctx := C .go_openssl_EVP_PKEY_CTX_new_id (C .GO_EVP_PKEY_HKDF , nil )
25
- if ctx == nil {
26
- return false
27
- }
28
- C .go_openssl_EVP_PKEY_CTX_free (ctx )
29
- return true
22
+ _ , err := fetchHKDF3 ()
23
+ return err == nil
30
24
default :
31
25
panic (errUnsupportedVersion ())
32
26
}
33
27
}
34
28
35
- func newHKDF (fh func () hash.Hash , mode C.int ) (* hkdf , error ) {
36
- if ! SupportsHKDF () {
37
- return nil , errUnsupportedVersion ()
38
- }
39
-
40
- h , err := hashFuncHash (fh )
41
- if err != nil {
42
- return nil , err
43
- }
44
- md := hashToMD (h )
45
- if md == nil {
46
- return nil , errors .New ("unsupported hash function" )
47
- }
29
+ func newHKDFCtx1 (md C.GO_EVP_MD_PTR , mode C.int , secret , salt , pseudorandomKey , info []byte ) (ctx C.GO_EVP_PKEY_CTX_PTR , err error ) {
30
+ checkMajorVersion (1 )
48
31
49
- ctx : = C .go_openssl_EVP_PKEY_CTX_new_id (C .GO_EVP_PKEY_HKDF , nil )
32
+ ctx = C .go_openssl_EVP_PKEY_CTX_new_id (C .GO_EVP_PKEY_HKDF , nil )
50
33
if ctx == nil {
51
34
return nil , newOpenSSLError ("EVP_PKEY_CTX_new_id" )
52
35
}
53
36
defer func () {
54
- C .go_openssl_EVP_PKEY_CTX_free (ctx )
37
+ if err != nil {
38
+ C .go_openssl_EVP_PKEY_CTX_free (ctx )
39
+ }
55
40
}()
56
41
57
42
if C .go_openssl_EVP_PKEY_derive_init (ctx ) != 1 {
58
- return nil , newOpenSSLError ("EVP_PKEY_derive_init" )
43
+ return ctx , newOpenSSLError ("EVP_PKEY_derive_init" )
59
44
}
60
- switch vMajor {
61
- case 3 :
62
- if C .go_openssl_EVP_PKEY_CTX_set_hkdf_mode (ctx , mode ) != 1 {
63
- return nil , newOpenSSLError ("EVP_PKEY_CTX_set_hkdf_mode" )
64
- }
65
- if C .go_openssl_EVP_PKEY_CTX_set_hkdf_md (ctx , md ) != 1 {
66
- return nil , newOpenSSLError ("EVP_PKEY_CTX_set_hkdf_md" )
67
- }
68
- case 1 :
69
- if C .go_openssl_EVP_PKEY_CTX_ctrl (ctx , - 1 , C .GO1_EVP_PKEY_OP_DERIVE ,
70
- C .GO_EVP_PKEY_CTRL_HKDF_MODE ,
71
- C .int (mode ), nil ) != 1 {
72
- return nil , newOpenSSLError ("EVP_PKEY_CTX_set_hkdf_mode" )
73
- }
74
- if C .go_openssl_EVP_PKEY_CTX_ctrl (ctx , - 1 , C .GO1_EVP_PKEY_OP_DERIVE ,
75
- C .GO_EVP_PKEY_CTRL_HKDF_MD ,
76
- 0 , unsafe .Pointer (md )) != 1 {
77
- return nil , newOpenSSLError ("EVP_PKEY_CTX_set_hkdf_md" )
45
+
46
+ ctrlSlice := func (ctrl int , data []byte ) C.int {
47
+ if len (data ) == 0 {
48
+ return 1 // No data to set.
78
49
}
50
+ return C .go_openssl_EVP_PKEY_CTX_ctrl (ctx , - 1 , C .GO1_EVP_PKEY_OP_DERIVE , C .int (ctrl ), C .int (len (data )), unsafe .Pointer (base (data )))
79
51
}
80
52
81
- c := & hkdf {ctx : ctx , hashLen : h .Size ()}
82
- ctx = nil
83
-
84
- runtime .SetFinalizer (c , (* hkdf ).finalize )
85
-
86
- return c , nil
53
+ if C .go_openssl_EVP_PKEY_CTX_ctrl (ctx , - 1 , C .GO1_EVP_PKEY_OP_DERIVE , C .GO_EVP_PKEY_CTRL_HKDF_MODE , mode , nil ) != 1 {
54
+ return ctx , newOpenSSLError ("EVP_PKEY_CTX_set_hkdf_mode" )
55
+ }
56
+ if C .go_openssl_EVP_PKEY_CTX_ctrl (ctx , - 1 , C .GO1_EVP_PKEY_OP_DERIVE , C .GO_EVP_PKEY_CTRL_HKDF_MD , 0 , unsafe .Pointer (md )) != 1 {
57
+ return ctx , newOpenSSLError ("EVP_PKEY_CTX_set_hkdf_md" )
58
+ }
59
+ if ctrlSlice (C .GO_EVP_PKEY_CTRL_HKDF_KEY , secret ) != 1 {
60
+ return ctx , newOpenSSLError ("EVP_PKEY_CTX_set1_hkdf_key" )
61
+ }
62
+ if ctrlSlice (C .GO_EVP_PKEY_CTRL_HKDF_SALT , salt ) != 1 {
63
+ return ctx , newOpenSSLError ("EVP_PKEY_CTX_set1_hkdf_salt" )
64
+ }
65
+ if ctrlSlice (C .GO_EVP_PKEY_CTRL_HKDF_KEY , pseudorandomKey ) != 1 {
66
+ return ctx , newOpenSSLError ("EVP_PKEY_CTX_set1_hkdf_key" )
67
+ }
68
+ if ctrlSlice (C .GO_EVP_PKEY_CTRL_HKDF_INFO , info ) != 1 {
69
+ return ctx , newOpenSSLError ("EVP_PKEY_CTX_add1_hkdf_info" )
70
+ }
71
+ return ctx , nil
87
72
}
88
73
89
- type hkdf struct {
74
+ type hkdf1 struct {
90
75
ctx C.GO_EVP_PKEY_CTX_PTR
91
76
92
77
hashLen int
93
78
buf []byte
94
79
}
95
80
96
- func (c * hkdf ) finalize () {
81
+ func (c * hkdf1 ) finalize () {
97
82
if c .ctx != nil {
98
83
C .go_openssl_EVP_PKEY_CTX_free (c .ctx )
99
84
}
100
85
}
101
86
102
- func (c * hkdf ) Read (p []byte ) (int , error ) {
87
+ func (c * hkdf1 ) Read (p []byte ) (int , error ) {
103
88
defer runtime .KeepAlive (c )
104
89
105
90
// EVP_PKEY_derive doesn't support incremental output, each call
@@ -125,69 +110,176 @@ func (c *hkdf) Read(p []byte) (int, error) {
125
110
}
126
111
127
112
func ExtractHKDF (h func () hash.Hash , secret , salt []byte ) ([]byte , error ) {
128
- c , err := newHKDF (h , C .GO_EVP_KDF_HKDF_MODE_EXTRACT_ONLY )
113
+ if ! SupportsHKDF () {
114
+ return nil , errUnsupportedVersion ()
115
+ }
116
+
117
+ md , err := hashFuncToMD (h )
129
118
if err != nil {
130
119
return nil , err
131
120
}
121
+
132
122
switch vMajor {
133
- case 3 :
134
- if C . go_openssl_EVP_PKEY_CTX_set1_hkdf_key ( c . ctx ,
135
- base ( secret ), C . int ( len ( secret ))) != 1 {
136
- return nil , newOpenSSLError ( "EVP_PKEY_CTX_set1_hkdf_key" )
123
+ case 1 :
124
+ ctx , err := newHKDFCtx1 ( md , C . GO_EVP_KDF_HKDF_MODE_EXTRACT_ONLY , secret , salt , nil , nil )
125
+ if err != nil {
126
+ return nil , err
137
127
}
138
- if C .go_openssl_EVP_PKEY_CTX_set1_hkdf_salt (c .ctx ,
139
- base (salt ), C .int (len (salt ))) != 1 {
140
- return nil , newOpenSSLError ("EVP_PKEY_CTX_set1_hkdf_salt" )
128
+ defer C .go_openssl_EVP_PKEY_CTX_free (ctx )
129
+ r := C .go_openssl_EVP_PKEY_derive_wrapper (ctx , nil , 0 )
130
+ if r .result != 1 {
131
+ return nil , newOpenSSLError ("EVP_PKEY_derive_init" )
141
132
}
142
- case 1 :
143
- if C .go_openssl_EVP_PKEY_CTX_ctrl (c .ctx , - 1 , C .GO1_EVP_PKEY_OP_DERIVE ,
144
- C .GO_EVP_PKEY_CTRL_HKDF_KEY ,
145
- C .int (len (secret )), unsafe .Pointer (base (secret ))) != 1 {
146
- return nil , newOpenSSLError ("EVP_PKEY_CTX_set1_hkdf_key" )
133
+ out := make ([]byte , r .keylen )
134
+ if C .go_openssl_EVP_PKEY_derive_wrapper (ctx , base (out ), r .keylen ).result != 1 {
135
+ return nil , newOpenSSLError ("EVP_PKEY_derive" )
147
136
}
148
- if C .go_openssl_EVP_PKEY_CTX_ctrl (c .ctx , - 1 , C .GO1_EVP_PKEY_OP_DERIVE ,
149
- C .GO_EVP_PKEY_CTRL_HKDF_SALT ,
150
- C .int (len (salt )), unsafe .Pointer (base (salt ))) != 1 {
151
- return nil , newOpenSSLError ("EVP_PKEY_CTX_set1_hkdf_salt" )
137
+ return out [:r .keylen ], nil
138
+ case 3 :
139
+ ctx , err := newHKDFCtx3 (md , C .GO_EVP_KDF_HKDF_MODE_EXTRACT_ONLY , secret , salt , nil , nil )
140
+ if err != nil {
141
+ return nil , err
152
142
}
143
+ defer C .go_openssl_EVP_KDF_CTX_free (ctx )
144
+ out := make ([]byte , C .go_openssl_EVP_KDF_CTX_get_kdf_size (ctx ))
145
+ if C .go_openssl_EVP_KDF_derive (ctx , base (out ), C .size_t (len (out )), nil ) != 1 {
146
+ return nil , newOpenSSLError ("EVP_KDF_derive" )
147
+ }
148
+ return out , nil
149
+ default :
150
+ panic (errUnsupportedVersion ())
153
151
}
154
- r := C .go_openssl_EVP_PKEY_derive_wrapper (c .ctx , nil , 0 )
155
- if r .result != 1 {
156
- return nil , newOpenSSLError ("EVP_PKEY_derive_init" )
157
- }
158
- out := make ([]byte , r .keylen )
159
- if C .go_openssl_EVP_PKEY_derive_wrapper (c .ctx , base (out ), r .keylen ).result != 1 {
160
- return nil , newOpenSSLError ("EVP_PKEY_derive" )
161
- }
162
- return out [:r .keylen ], nil
163
152
}
164
153
165
154
func ExpandHKDF (h func () hash.Hash , pseudorandomKey , info []byte ) (io.Reader , error ) {
166
- c , err := newHKDF (h , C .GO_EVP_KDF_HKDF_MODE_EXPAND_ONLY )
155
+ if ! SupportsHKDF () {
156
+ return nil , errUnsupportedVersion ()
157
+ }
158
+
159
+ md , err := hashFuncToMD (h )
167
160
if err != nil {
168
161
return nil , err
169
162
}
163
+
170
164
switch vMajor {
171
- case 3 :
172
- if C .go_openssl_EVP_PKEY_CTX_set1_hkdf_key (c .ctx ,
173
- base (pseudorandomKey ), C .int (len (pseudorandomKey ))) != 1 {
174
- return nil , newOpenSSLError ("EVP_PKEY_CTX_set1_hkdf_key" )
175
- }
176
- if C .go_openssl_EVP_PKEY_CTX_add1_hkdf_info (c .ctx ,
177
- base (info ), C .int (len (info ))) != 1 {
178
- return nil , newOpenSSLError ("EVP_PKEY_CTX_add1_hkdf_info" )
179
- }
180
165
case 1 :
181
- if C .go_openssl_EVP_PKEY_CTX_ctrl (c .ctx , - 1 , C .GO1_EVP_PKEY_OP_DERIVE ,
182
- C .GO_EVP_PKEY_CTRL_HKDF_KEY ,
183
- C .int (len (pseudorandomKey )), unsafe .Pointer (base (pseudorandomKey ))) != 1 {
184
- return nil , newOpenSSLError ("EVP_PKEY_CTX_set1_hkdf_key" )
166
+ ctx , err := newHKDFCtx1 (md , C .GO_EVP_KDF_HKDF_MODE_EXPAND_ONLY , nil , nil , pseudorandomKey , info )
167
+ if err != nil {
168
+ return nil , err
185
169
}
186
- if C .go_openssl_EVP_PKEY_CTX_ctrl (c .ctx , - 1 , C .GO1_EVP_PKEY_OP_DERIVE ,
187
- C .GO_EVP_PKEY_CTRL_HKDF_INFO ,
188
- C .int (len (info )), unsafe .Pointer (base (info ))) != 1 {
189
- return nil , newOpenSSLError ("EVP_PKEY_CTX_add1_hkdf_info" )
170
+ c := & hkdf1 {ctx : ctx , hashLen : int (C .go_openssl_EVP_MD_get_size (md ))}
171
+ runtime .SetFinalizer (c , (* hkdf1 ).finalize )
172
+ return c , nil
173
+ case 3 :
174
+ ctx , err := newHKDFCtx3 (md , C .GO_EVP_KDF_HKDF_MODE_EXPAND_ONLY , nil , nil , pseudorandomKey , info )
175
+ if err != nil {
176
+ return nil , err
190
177
}
178
+ c := & hkdf3 {ctx : ctx , hashLen : int (C .go_openssl_EVP_MD_get_size (md ))}
179
+ runtime .SetFinalizer (c , (* hkdf3 ).finalize )
180
+ return c , nil
181
+ default :
182
+ panic (errUnsupportedVersion ())
191
183
}
192
- return c , nil
184
+ }
185
+
186
+ type hkdf3 struct {
187
+ ctx C.GO_EVP_KDF_CTX_PTR
188
+
189
+ hashLen int
190
+ buf []byte
191
+ }
192
+
193
+ func (c * hkdf3 ) finalize () {
194
+ if c .ctx != nil {
195
+ C .go_openssl_EVP_KDF_CTX_free (c .ctx )
196
+ }
197
+ }
198
+
199
+ // fetchHKDF3 fetches the HKDF algorithm.
200
+ // It is safe to call this function concurrently.
201
+ // The returned EVP_KDF_PTR shouldn't be freed.
202
+ var fetchHKDF3 = sync .OnceValues (func () (C.GO_EVP_KDF_PTR , error ) {
203
+ checkMajorVersion (3 )
204
+
205
+ name := C .CString ("HKDF" )
206
+ kdf := C .go_openssl_EVP_KDF_fetch (nil , name , nil )
207
+ C .free (unsafe .Pointer (name ))
208
+ if kdf == nil {
209
+ return nil , newOpenSSLError ("EVP_KDF_fetch" )
210
+ }
211
+ return kdf , nil
212
+ })
213
+
214
+ // newHKDFCtx3 implements HKDF for OpenSSL 3 using the EVP_KDF API.
215
+ func newHKDFCtx3 (md C.GO_EVP_MD_PTR , mode C.int , secret , salt , pseudorandomKey , info []byte ) (_ C.GO_EVP_KDF_CTX_PTR , err error ) {
216
+ checkMajorVersion (3 )
217
+
218
+ kdf , err := fetchHKDF3 ()
219
+ if err != nil {
220
+ return nil , err
221
+ }
222
+ ctx := C .go_openssl_EVP_KDF_CTX_new (kdf )
223
+ if ctx == nil {
224
+ return nil , newOpenSSLError ("EVP_KDF_CTX_new" )
225
+ }
226
+ defer func () {
227
+ if err != nil {
228
+ C .go_openssl_EVP_KDF_CTX_free (ctx )
229
+ }
230
+ }()
231
+
232
+ bld , err := newParamBuilder ()
233
+ if err != nil {
234
+ return ctx , err
235
+ }
236
+ bld .addUTF8String (_OSSL_KDF_PARAM_DIGEST , C .go_openssl_EVP_MD_get0_name (md ), 0 )
237
+ bld .addInt32 (_OSSL_KDF_PARAM_MODE , int32 (mode ))
238
+ if len (secret ) > 0 {
239
+ bld .addOctetString (_OSSL_KDF_PARAM_KEY , secret )
240
+ }
241
+ if len (salt ) > 0 {
242
+ bld .addOctetString (_OSSL_KDF_PARAM_SALT , salt )
243
+ }
244
+ if len (pseudorandomKey ) > 0 {
245
+ bld .addOctetString (_OSSL_KDF_PARAM_KEY , pseudorandomKey )
246
+ }
247
+ if len (info ) > 0 {
248
+ bld .addOctetString (_OSSL_KDF_PARAM_INFO , info )
249
+ }
250
+ params , err := bld .build ()
251
+ if err != nil {
252
+ return ctx , err
253
+ }
254
+ defer C .go_openssl_OSSL_PARAM_free (params )
255
+
256
+ if C .go_openssl_EVP_KDF_CTX_set_params (ctx , params ) != 1 {
257
+ return ctx , newOpenSSLError ("EVP_KDF_CTX_set_params" )
258
+ }
259
+ return ctx , nil
260
+ }
261
+
262
+ func (c * hkdf3 ) Read (p []byte ) (int , error ) {
263
+ defer runtime .KeepAlive (c )
264
+
265
+ // EVP_KDF_derive doesn't support incremental output, each call
266
+ // derives the key from scratch and returns the requested bytes.
267
+ // To implement io.Reader, we need to ask for len(c.buf) + len(p)
268
+ // bytes and copy the last derived len(p) bytes to p.
269
+ // We use c.buf to know how many bytes we've already derived and
270
+ // to avoid allocating the whole output buffer on each call.
271
+ prevLen := len (c .buf )
272
+ needLen := len (p )
273
+ remains := 255 * c .hashLen - prevLen
274
+ // Check whether enough data can be generated.
275
+ if remains < needLen {
276
+ return 0 , errors .New ("hkdf: entropy limit reached" )
277
+ }
278
+ c .buf = append (c .buf , make ([]byte , needLen )... )
279
+ outLen := C .size_t (prevLen + needLen )
280
+ if C .go_openssl_EVP_KDF_derive (c .ctx , base (c .buf ), outLen , nil ) != 1 {
281
+ return 0 , newOpenSSLError ("EVP_KDF_derive" )
282
+ }
283
+ n := copy (p , c .buf [prevLen :outLen ])
284
+ return n , nil
193
285
}
0 commit comments