Skip to content

Commit 731002f

Browse files
qmuntaldagood
andauthored
Make algorithms more robust to unsupported hashes (#185)
* make algorithms more robust to unsupported hashes * fix TestTLS1PRFUnsupportedHash * Apply suggestions from code review Co-authored-by: Davis Goodin <[email protected]> * preserve hashFuncHash error * add continue in TestHKDF --------- Co-authored-by: Davis Goodin <[email protected]>
1 parent 14fd570 commit 731002f

10 files changed

+131
-32
lines changed

evp.go

+24
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,30 @@ import (
1616
// cacheMD is a cache of crypto.Hash to GO_EVP_MD_PTR.
1717
var cacheMD sync.Map
1818

19+
// hashFuncHash calls fn() and returns its result.
20+
// If fn() panics, the panic is recovered and returned as an error.
21+
// This is used to avoid aborting the program when calling
22+
// an unsupported hash function. It is the caller's responsibility
23+
// to check the returned value.
24+
func hashFuncHash(fn func() hash.Hash) (h hash.Hash, err error) {
25+
defer func() {
26+
r := recover()
27+
if r == nil {
28+
return
29+
}
30+
h = nil
31+
switch e := r.(type) {
32+
case error:
33+
err = e
34+
case string:
35+
err = errors.New(e)
36+
default:
37+
err = errors.New("unsupported panic")
38+
}
39+
}()
40+
return fn(), nil
41+
}
42+
1943
// hashToMD converts a hash.Hash implementation from this package to a GO_EVP_MD_PTR.
2044
func hashToMD(h hash.Hash) C.GO_EVP_MD_PTR {
2145
var ch crypto.Hash

hash_test.go

+13
Original file line numberDiff line numberDiff line change
@@ -369,3 +369,16 @@ func BenchmarkSHA256(b *testing.B) {
369369
openssl.SHA256(buf)
370370
}
371371
}
372+
373+
// stubHash is a hash.Hash implementation that does nothing.
374+
type stubHash struct{}
375+
376+
func newStubHash() hash.Hash {
377+
return new(stubHash)
378+
}
379+
380+
func (h *stubHash) Write(p []byte) (int, error) { return 0, nil }
381+
func (h *stubHash) Sum(in []byte) []byte { return in }
382+
func (h *stubHash) Reset() {}
383+
func (h *stubHash) Size() int { return 0 }
384+
func (h *stubHash) BlockSize() int { return 0 }

hkdf.go

+7-4
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,16 @@ func SupportsHKDF() bool {
3232
}
3333
}
3434

35-
func newHKDF(h func() hash.Hash, mode C.int) (*hkdf, error) {
35+
func newHKDF(fh func() hash.Hash, mode C.int) (*hkdf, error) {
3636
if !SupportsHKDF() {
3737
return nil, errUnsupportedVersion()
3838
}
3939

40-
ch := h()
41-
md := hashToMD(ch)
40+
h, err := hashFuncHash(fh)
41+
if err != nil {
42+
return nil, err
43+
}
44+
md := hashToMD(h)
4245
if md == nil {
4346
return nil, errors.New("unsupported hash function")
4447
}
@@ -75,7 +78,7 @@ func newHKDF(h func() hash.Hash, mode C.int) (*hkdf, error) {
7578
}
7679
}
7780

78-
c := &hkdf{ctx: ctx, hashLen: ch.Size()}
81+
c := &hkdf{ctx: ctx, hashLen: h.Size()}
7982
ctx = nil
8083

8184
runtime.SetFinalizer(c, (*hkdf).finalize)

hkdf_test.go

+25-7
Original file line numberDiff line numberDiff line change
@@ -289,16 +289,16 @@ var hkdfTests = []hkdfTest{
289289
},
290290
}
291291

292-
func newHKDF(hash func() hash.Hash, secret, salt, info []byte) io.Reader {
292+
func newHKDF(hash func() hash.Hash, secret, salt, info []byte) (io.Reader, error) {
293293
prk, err := openssl.ExtractHKDF(hash, secret, salt)
294294
if err != nil {
295-
panic(err)
295+
return nil, err
296296
}
297297
r, err := openssl.ExpandHKDF(hash, prk, info)
298298
if err != nil {
299-
panic(err)
299+
return nil, err
300300
}
301-
return r
301+
return r, nil
302302
}
303303

304304
func TestHKDF(t *testing.T) {
@@ -314,7 +314,11 @@ func TestHKDF(t *testing.T) {
314314
t.Errorf("test %d: incorrect PRK: have %v, need %v.", i, prk, tt.prk)
315315
}
316316

317-
hkdf := newHKDF(tt.hash, tt.master, tt.salt, tt.info)
317+
hkdf, err := newHKDF(tt.hash, tt.master, tt.salt, tt.info)
318+
if err != nil {
319+
t.Errorf("test %d: error creating HKDF: %v.", i, err)
320+
continue
321+
}
318322
out := make([]byte, len(tt.out))
319323

320324
n, err := io.ReadFull(hkdf, out)
@@ -347,7 +351,10 @@ func TestHKDFMultiRead(t *testing.T) {
347351
t.Skip("HKDF is not supported")
348352
}
349353
for i, tt := range hkdfTests {
350-
hkdf := newHKDF(tt.hash, tt.master, tt.salt, tt.info)
354+
hkdf, err := newHKDF(tt.hash, tt.master, tt.salt, tt.info)
355+
if err != nil {
356+
t.Errorf("test %d: error creating HKDF: %v.", i, err)
357+
}
351358
out := make([]byte, len(tt.out))
352359

353360
for b := range len(tt.out) {
@@ -371,7 +378,10 @@ func TestHKDFLimit(t *testing.T) {
371378
master := []byte{0x00, 0x01, 0x02, 0x03}
372379
info := []byte{}
373380

374-
hkdf := newHKDF(hash, master, nil, info)
381+
hkdf, err := newHKDF(hash, master, nil, info)
382+
if err != nil {
383+
t.Fatalf("error creating HKDF: %v.", err)
384+
}
375385
limit := hash().Size() * 255
376386
out := make([]byte, limit)
377387

@@ -387,3 +397,11 @@ func TestHKDFLimit(t *testing.T) {
387397
t.Errorf("key expansion overflowed: n = %d, err = %v", n, err)
388398
}
389399
}
400+
401+
func TestHKDFUnsupportedHash(t *testing.T) {
402+
// Test that newHKDF returns an error instead of panicking.
403+
_, err := newHKDF(newStubHash, []byte{0x00, 0x01, 0x02, 0x03}, nil, []byte{})
404+
if err == nil {
405+
t.Error("expected error for unsupported hash")
406+
}
407+
}

hmac.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ var OSSL_MAC_PARAM_DIGEST = C.CString("digest")
1717
// The function h must return a hash implemented by
1818
// OpenSSL (for example, h could be openssl.NewSHA256).
1919
// If h is not recognized, NewHMAC returns nil.
20-
func NewHMAC(h func() hash.Hash, key []byte) hash.Hash {
21-
ch := h()
22-
md := hashToMD(ch)
20+
func NewHMAC(fh func() hash.Hash, key []byte) hash.Hash {
21+
h, _ := hashFuncHash(fh)
22+
md := hashToMD(h)
2323
if md == nil {
2424
return nil
2525
}
@@ -34,8 +34,8 @@ func NewHMAC(h func() hash.Hash, key []byte) hash.Hash {
3434
}
3535

3636
hmac := &opensslHMAC{
37-
size: ch.Size(),
38-
blockSize: ch.BlockSize(),
37+
size: h.Size(),
38+
blockSize: h.BlockSize(),
3939
}
4040

4141
switch vMajor {

hmac_test.go

+22-11
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,42 @@
1-
package openssl
1+
package openssl_test
22

33
import (
44
"bytes"
55
"hash"
66
"testing"
7+
8+
"github.com/golang-fips/openssl/v2"
79
)
810

911
func TestHMAC(t *testing.T) {
1012
var tests = []struct {
1113
name string
1214
fn func() hash.Hash
1315
}{
14-
{"sha1", NewSHA1},
15-
{"sha224", NewSHA224},
16-
{"sha256", NewSHA256},
17-
{"sha384", NewSHA384},
18-
{"sha512", NewSHA512},
16+
{"sha1", openssl.NewSHA1},
17+
{"sha224", openssl.NewSHA224},
18+
{"sha256", openssl.NewSHA256},
19+
{"sha384", openssl.NewSHA384},
20+
{"sha512", openssl.NewSHA512},
1921
}
2022
for _, tt := range tests {
2123
t.Run(tt.name, func(t *testing.T) {
2224
t.Parallel()
23-
h := NewHMAC(tt.fn, nil)
25+
h := openssl.NewHMAC(tt.fn, nil)
2426
if h == nil {
2527
t.Skip("digest not supported")
2628
}
2729
h.Write([]byte("hello"))
2830
sumHello := h.Sum(nil)
2931

30-
h = NewHMAC(tt.fn, nil)
32+
h = openssl.NewHMAC(tt.fn, nil)
3133
h.Write([]byte("hello world"))
3234
sumHelloWorld := h.Sum(nil)
3335

3436
// Test that Sum has no effect on future Sum or Write operations.
3537
// This is a bit unusual as far as usage, but it's allowed
3638
// by the definition of Go hash.Hash, and some clients expect it to work.
37-
h = NewHMAC(tt.fn, nil)
39+
h = openssl.NewHMAC(tt.fn, nil)
3840
h.Write([]byte("hello"))
3941
if sum := h.Sum(nil); !bytes.Equal(sum, sumHello) {
4042
t.Fatalf("1st Sum after hello = %x, want %x", sum, sumHello)
@@ -60,11 +62,20 @@ func TestHMAC(t *testing.T) {
6062
}
6163
}
6264

65+
func TestHMACUnsupportedHash(t *testing.T) {
66+
// Test that NewHMAC returns nil for unsupported hashes
67+
// instead of panicking.
68+
h := openssl.NewHMAC(newStubHash, nil)
69+
if h != nil {
70+
t.Errorf("returned non-nil for unsupported hash")
71+
}
72+
}
73+
6374
func BenchmarkHMACSHA256_32(b *testing.B) {
6475
b.StopTimer()
6576
key := make([]byte, 32)
6677
buf := make([]byte, 32)
67-
h := NewHMAC(NewSHA256, key)
78+
h := openssl.NewHMAC(openssl.NewSHA256, key)
6879
b.SetBytes(int64(len(buf)))
6980
b.StartTimer()
7081
b.ReportAllocs()
@@ -83,7 +94,7 @@ func BenchmarkHMACNewWriteSum(b *testing.B) {
8394
b.StartTimer()
8495
b.ReportAllocs()
8596
for i := 0; i < b.N; i++ {
86-
h := NewHMAC(NewSHA256, make([]byte, 32))
97+
h := openssl.NewHMAC(openssl.NewSHA256, make([]byte, 32))
8798
h.Write(buf)
8899
mac := h.Sum(nil)
89100
buf[0] = mac[0]

pbkdf2.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,12 @@ import (
99
"hash"
1010
)
1111

12-
func PBKDF2(password, salt []byte, iter, keyLen int, h func() hash.Hash) ([]byte, error) {
13-
md := hashToMD(h())
12+
func PBKDF2(password, salt []byte, iter, keyLen int, fh func() hash.Hash) ([]byte, error) {
13+
h, err := hashFuncHash(fh)
14+
if err != nil {
15+
return nil, err
16+
}
17+
md := hashToMD(h)
1418
if md == nil {
1519
return nil, errors.New("unsupported hash function")
1620
}

pbkdf2_test.go

+8
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,14 @@ func TestWithHMACSHA256(t *testing.T) {
158158
testHash(t, openssl.NewSHA256, "SHA256", sha256TestVectors)
159159
}
160160

161+
func TestWithUnsupportedHash(t *testing.T) {
162+
// Test that PBKDF2 returns an error for unsupported hashes instead of panicking.
163+
_, err := openssl.PBKDF2([]byte{1, 2}, []byte{3, 4}, 0, 2, newStubHash)
164+
if err == nil {
165+
t.Fatal("expected error")
166+
}
167+
}
168+
161169
var sink uint8
162170

163171
func benchmark(b *testing.B, h func() hash.Hash) {

tls1prf.go

+7-3
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,21 @@ func SupportsTLS1PRF() bool {
1919
// TLS1PRF implements the TLS 1.0/1.1 pseudo-random function if h is nil,
2020
// else it implements the TLS 1.2 pseudo-random function.
2121
// The pseudo-random number will be written to result and will be of length len(result).
22-
func TLS1PRF(result, secret, label, seed []byte, h func() hash.Hash) error {
22+
func TLS1PRF(result, secret, label, seed []byte, fh func() hash.Hash) error {
2323
var md C.GO_EVP_MD_PTR
24-
if h == nil {
24+
if fh == nil {
2525
// TLS 1.0/1.1 PRF doesn't allow to specify the hash function,
2626
// it always uses MD5SHA1. If h is nil, then assume
2727
// that the caller wants to use TLS 1.0/1.1 PRF.
2828
// OpenSSL detects this case by checking if the hash
2929
// function is MD5SHA1.
3030
md = cryptoHashToMD(crypto.MD5SHA1)
3131
} else {
32-
md = hashToMD(h())
32+
h, err := hashFuncHash(fh)
33+
if err != nil {
34+
return err
35+
}
36+
md = hashToMD(h)
3337
}
3438
if md == nil {
3539
return errors.New("unsupported hash function")

tls1prf_test.go

+14
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,17 @@ func TestTLS1PRF(t *testing.T) {
165165
})
166166
}
167167
}
168+
169+
func TestTLS1PRFUnsupportedHash(t *testing.T) {
170+
if !openssl.SupportsTLS1PRF() {
171+
t.Skip("TLS PRF is not supported")
172+
}
173+
174+
tt := tls1prfTests[0]
175+
result := make([]byte, len(tt.out))
176+
// Test that TLS1PRF returns an error for unsupported hashes instead of panicking.
177+
err := openssl.TLS1PRF(result, tt.secret, tt.label, tt.seed, newStubHash)
178+
if err == nil {
179+
t.Errorf("expected an error for unsupported hash")
180+
}
181+
}

0 commit comments

Comments
 (0)