Skip to content

Commit 493baa2

Browse files
authored
Panic if DES encryption or decryption fails (#169)
* panic if DES encryption or decryption fails * fix 3DES CBC mode support detection * test that DES conforms to the cipher.Block interface * move TestBlock to internal/cryptotest
1 parent 47676da commit 493baa2

File tree

6 files changed

+322
-37
lines changed

6 files changed

+322
-37
lines changed

aes_test.go

+9-23
Original file line numberDiff line numberDiff line change
@@ -3,35 +3,21 @@ package openssl_test
33
import (
44
"bytes"
55
"crypto/cipher"
6+
"fmt"
67
"math"
78
"testing"
89

910
"github.com/golang-fips/openssl/v2"
11+
"github.com/golang-fips/openssl/v2/internal/cryptotest"
1012
)
1113

12-
func TestAESShortBlocks(t *testing.T) {
13-
bytes := func(n int) []byte { return make([]byte, n) }
14-
15-
c, _ := openssl.NewAESCipher(bytes(16))
16-
17-
mustPanic(t, "crypto/aes: input not full block", func() { c.Encrypt(bytes(1), bytes(1)) })
18-
mustPanic(t, "crypto/aes: input not full block", func() { c.Decrypt(bytes(1), bytes(1)) })
19-
mustPanic(t, "crypto/aes: input not full block", func() { c.Encrypt(bytes(100), bytes(1)) })
20-
mustPanic(t, "crypto/aes: input not full block", func() { c.Decrypt(bytes(100), bytes(1)) })
21-
mustPanic(t, "crypto/aes: output not full block", func() { c.Encrypt(bytes(1), bytes(100)) })
22-
mustPanic(t, "crypto/aes: output not full block", func() { c.Decrypt(bytes(1), bytes(100)) })
23-
}
24-
25-
func mustPanic(t *testing.T, msg string, f func()) {
26-
defer func() {
27-
err := recover()
28-
if err == nil {
29-
t.Errorf("function did not panic, wanted %q", msg)
30-
} else if err != msg {
31-
t.Errorf("got panic %v, wanted %q", err, msg)
32-
}
33-
}()
34-
f()
14+
// Test AES against the general cipher.Block interface tester.
15+
func TestAESBlock(t *testing.T) {
16+
for _, keylen := range []int{128, 192, 256} {
17+
t.Run(fmt.Sprintf("AES-%d", keylen), func(t *testing.T) {
18+
cryptotest.TestBlock(t, keylen/8, openssl.NewAESCipher)
19+
})
20+
}
3521
}
3622

3723
func TestNewGCMNonce(t *testing.T) {

des.go

+15-14
Original file line numberDiff line numberDiff line change
@@ -33,27 +33,22 @@ func NewDESCipher(key []byte) (cipher.Block, error) {
3333
if len(key) != 8 {
3434
return nil, errors.New("crypto/des: invalid key size")
3535
}
36-
c, err := newEVPCipher(key, cipherDES)
37-
if err != nil {
38-
return nil, err
39-
}
40-
// Should always be true for stock OpenSSL.
41-
if loadCipher(cipherDES, cipherModeCBC) == nil {
42-
return &desCipherWithoutCBC{c}, nil
43-
}
44-
return &desCipher{c}, nil
36+
return newDESCipher(key, cipherDES)
4537
}
4638

4739
func NewTripleDESCipher(key []byte) (cipher.Block, error) {
4840
if len(key) != 24 {
4941
return nil, errors.New("crypto/des: invalid key size")
5042
}
51-
c, err := newEVPCipher(key, cipherDES3)
43+
return newDESCipher(key, cipherDES3)
44+
}
45+
46+
func newDESCipher(key []byte, kind cipherKind) (cipher.Block, error) {
47+
c, err := newEVPCipher(key, kind)
5248
if err != nil {
5349
return nil, err
5450
}
55-
// Should always be true for stock OpenSSL.
56-
if loadCipher(cipherDES, cipherModeCBC) != nil {
51+
if loadCipher(kind, cipherModeCBC) == nil {
5752
return &desCipherWithoutCBC{c}, nil
5853
}
5954
return &desCipher{c}, nil
@@ -105,9 +100,15 @@ func (c *desCipherWithoutCBC) BlockSize() int {
105100
}
106101

107102
func (c *desCipherWithoutCBC) Encrypt(dst, src []byte) {
108-
c.encrypt(dst, src)
103+
if err := c.encrypt(dst, src); err != nil {
104+
// crypto/des expects that the panic message starts with "crypto/des: ".
105+
panic("crypto/des: " + err.Error())
106+
}
109107
}
110108

111109
func (c *desCipherWithoutCBC) Decrypt(dst, src []byte) {
112-
c.decrypt(dst, src)
110+
if err := c.decrypt(dst, src); err != nil {
111+
// crypto/des expects that the panic message starts with "crypto/des: ".
112+
panic("crypto/des: " + err.Error())
113+
}
113114
}

des_test.go

+18
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"crypto/cipher"
66
"testing"
77

8+
"github.com/golang-fips/openssl/v2/internal/cryptotest"
89
"github.com/golang-fips/openssl/v2"
910
)
1011

@@ -1665,6 +1666,23 @@ func TestDESCBCDecryptSimple(t *testing.T) {
16651666
}
16661667
}
16671668

1669+
// Test DES against the general cipher.Block interface tester
1670+
func TestDESBlock(t *testing.T) {
1671+
t.Run("DES", func(t *testing.T) {
1672+
if !openssl.SupportsDESCipher() {
1673+
t.Skip("DES is not supported")
1674+
}
1675+
cryptotest.TestBlock(t, 8, openssl.NewDESCipher)
1676+
})
1677+
1678+
t.Run("TripleDES", func(t *testing.T) {
1679+
if !openssl.SupportsTripleDESCipher() {
1680+
t.Skip("3DES is not supported")
1681+
}
1682+
cryptotest.TestBlock(t, 24, openssl.NewTripleDESCipher)
1683+
})
1684+
}
1685+
16681686
func BenchmarkEncrypt(b *testing.B) {
16691687
if !openssl.SupportsDESCipher() {
16701688
b.Skip("DES is not supported")

internal/cryptotest/README.md

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
The `internal/cryptotest` package provides a set of tests for cryptographic primitives.
2+
3+
`internal/cryptotest` has been copied from the Go standard library's [crypo/internal/cryptotest](https://github.com/golang/go/tree/807e01db4840e25e4d98911b28a8fa54244b8dfa/src/crypto/internal/cryptotest)
4+
package and trimmed down to only include the tests that are relevant for the `openssl` package.
5+
6+
[1]: https://github.com/golang/go/tree/807e01db4840e25e4d98911b28a8fa54244b8dfa/src/crypto/internal/cryptotest

internal/cryptotest/block.go

+256
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
// Copyright 2024 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package cryptotest
6+
7+
import (
8+
"bytes"
9+
"crypto/cipher"
10+
"testing"
11+
)
12+
13+
// This file is a copy of https://github.com/golang/go/blob/9e9b1f57c26a6d13fdaebef67136718b8042cdba/src/crypto/internal/cryptotest/block.go.
14+
15+
type MakeBlock func(key []byte) (cipher.Block, error)
16+
17+
// TestBlock performs a set of tests on cipher.Block implementations, checking
18+
// the documented requirements of BlockSize, Encrypt, and Decrypt.
19+
func TestBlock(t *testing.T, keySize int, mb MakeBlock) {
20+
// Generate random key
21+
key := make([]byte, keySize)
22+
newRandReader(t).Read(key)
23+
t.Logf("Cipher key: 0x%x", key)
24+
25+
block, err := mb(key)
26+
if err != nil {
27+
t.Fatal(err)
28+
}
29+
30+
blockSize := block.BlockSize()
31+
32+
t.Run("Encryption", func(t *testing.T) {
33+
testCipher(t, block.Encrypt, blockSize)
34+
})
35+
36+
t.Run("Decryption", func(t *testing.T) {
37+
testCipher(t, block.Decrypt, blockSize)
38+
})
39+
40+
// Checks baseline Encrypt/Decrypt functionality. More thorough
41+
// implementation-specific characterization/golden tests should be done
42+
// for each block cipher implementation.
43+
t.Run("Roundtrip", func(t *testing.T) {
44+
rng := newRandReader(t)
45+
46+
// Check Decrypt inverts Encrypt
47+
before, ciphertext, after := make([]byte, blockSize), make([]byte, blockSize), make([]byte, blockSize)
48+
49+
rng.Read(before)
50+
51+
block.Encrypt(ciphertext, before)
52+
block.Decrypt(after, ciphertext)
53+
54+
if !bytes.Equal(after, before) {
55+
t.Errorf("plaintext is different after an encrypt/decrypt cycle; got %x, want %x", after, before)
56+
}
57+
58+
// Check Encrypt inverts Decrypt (assumes block ciphers are deterministic)
59+
before, plaintext, after := make([]byte, blockSize), make([]byte, blockSize), make([]byte, blockSize)
60+
61+
rng.Read(before)
62+
63+
block.Decrypt(plaintext, before)
64+
block.Encrypt(after, plaintext)
65+
66+
if !bytes.Equal(after, before) {
67+
t.Errorf("ciphertext is different after a decrypt/encrypt cycle; got %x, want %x", after, before)
68+
}
69+
})
70+
71+
}
72+
73+
func testCipher(t *testing.T, cipher func(dst, src []byte), blockSize int) {
74+
t.Run("AlterInput", func(t *testing.T) {
75+
rng := newRandReader(t)
76+
77+
// Make long src that shouldn't be modified at all, within block
78+
// size scope or beyond it
79+
src, before := make([]byte, blockSize*2), make([]byte, blockSize*2)
80+
rng.Read(src)
81+
copy(before, src)
82+
83+
dst := make([]byte, blockSize)
84+
85+
cipher(dst, src)
86+
if !bytes.Equal(src, before) {
87+
t.Errorf("block cipher modified src; got %x, want %x", src, before)
88+
}
89+
})
90+
91+
t.Run("Aliasing", func(t *testing.T) {
92+
rng := newRandReader(t)
93+
94+
buff, expectedOutput := make([]byte, blockSize), make([]byte, blockSize)
95+
96+
// Record what output is when src and dst are different
97+
rng.Read(buff)
98+
cipher(expectedOutput, buff)
99+
100+
// Check that the same output is generated when src=dst alias to the same
101+
// memory
102+
cipher(buff, buff)
103+
if !bytes.Equal(buff, expectedOutput) {
104+
t.Errorf("block cipher produced different output when dst = src; got %x, want %x", buff, expectedOutput)
105+
}
106+
})
107+
108+
t.Run("OutOfBoundsWrite", func(t *testing.T) {
109+
rng := newRandReader(t)
110+
111+
src := make([]byte, blockSize)
112+
rng.Read(src)
113+
114+
// Make a buffer with dst in the middle and data on either end
115+
buff := make([]byte, blockSize*3)
116+
endOfPrefix, startOfSuffix := blockSize, blockSize*2
117+
rng.Read(buff[:endOfPrefix])
118+
rng.Read(buff[startOfSuffix:])
119+
dst := buff[endOfPrefix:startOfSuffix]
120+
121+
// Record the prefix and suffix data to make sure they aren't written to
122+
initPrefix, initSuffix := make([]byte, blockSize), make([]byte, blockSize)
123+
copy(initPrefix, buff[:endOfPrefix])
124+
copy(initSuffix, buff[startOfSuffix:])
125+
126+
// Write to dst (the middle of the buffer) and make sure it doesn't write
127+
// beyond the dst slice
128+
cipher(dst, src)
129+
if !bytes.Equal(buff[startOfSuffix:], initSuffix) {
130+
t.Errorf("block cipher did out of bounds write after end of dst slice; got %x, want %x", buff[startOfSuffix:], initSuffix)
131+
}
132+
if !bytes.Equal(buff[:endOfPrefix], initPrefix) {
133+
t.Errorf("block cipher did out of bounds write before beginning of dst slice; got %x, want %x", buff[:endOfPrefix], initPrefix)
134+
}
135+
136+
// Check that dst isn't written to beyond BlockSize even if there is room
137+
// in the slice
138+
dst = buff[endOfPrefix:] // Extend dst to include suffix
139+
cipher(dst, src)
140+
if !bytes.Equal(buff[startOfSuffix:], initSuffix) {
141+
t.Errorf("block cipher modified dst past BlockSize bytes; got %x, want %x", buff[startOfSuffix:], initSuffix)
142+
}
143+
})
144+
145+
// Check that output of cipher isn't affected by adjacent data beyond input
146+
// slice scope
147+
// For encryption, this assumes block ciphers encrypt deterministically
148+
t.Run("OutOfBoundsRead", func(t *testing.T) {
149+
rng := newRandReader(t)
150+
151+
src := make([]byte, blockSize)
152+
rng.Read(src)
153+
expectedDst := make([]byte, blockSize)
154+
cipher(expectedDst, src)
155+
156+
// Make a buffer with src in the middle and data on either end
157+
buff := make([]byte, blockSize*3)
158+
endOfPrefix, startOfSuffix := blockSize, blockSize*2
159+
160+
copy(buff[endOfPrefix:startOfSuffix], src)
161+
rng.Read(buff[:endOfPrefix])
162+
rng.Read(buff[startOfSuffix:])
163+
164+
testDst := make([]byte, blockSize)
165+
cipher(testDst, buff[endOfPrefix:startOfSuffix])
166+
if !bytes.Equal(testDst, expectedDst) {
167+
t.Errorf("block cipher affected by data outside of src slice bounds; got %x, want %x", testDst, expectedDst)
168+
}
169+
170+
// Check that src isn't read from beyond BlockSize even if the slice is
171+
// longer and contains data in the suffix
172+
cipher(testDst, buff[endOfPrefix:]) // Input long src
173+
if !bytes.Equal(testDst, expectedDst) {
174+
t.Errorf("block cipher affected by src data beyond BlockSize bytes; got %x, want %x", buff[startOfSuffix:], expectedDst)
175+
}
176+
})
177+
178+
t.Run("NonZeroDst", func(t *testing.T) {
179+
rng := newRandReader(t)
180+
181+
// Record what the cipher writes into a destination of zeroes
182+
src := make([]byte, blockSize)
183+
rng.Read(src)
184+
expectedDst := make([]byte, blockSize)
185+
186+
cipher(expectedDst, src)
187+
188+
// Make nonzero dst
189+
dst := make([]byte, blockSize*2)
190+
rng.Read(dst)
191+
192+
// Remember the random suffix which shouldn't be written to
193+
expectedDst = append(expectedDst, dst[blockSize:]...)
194+
195+
cipher(dst, src)
196+
if !bytes.Equal(dst, expectedDst) {
197+
t.Errorf("block cipher behavior differs when given non-zero dst; got %x, want %x", dst, expectedDst)
198+
}
199+
})
200+
201+
t.Run("BufferOverlap", func(t *testing.T) {
202+
rng := newRandReader(t)
203+
204+
buff := make([]byte, blockSize*2)
205+
rng.Read((buff))
206+
207+
// Make src and dst slices point to same array with inexact overlap
208+
src := buff[:blockSize]
209+
dst := buff[1 : blockSize+1]
210+
mustPanic(t, "invalid buffer overlap", func() { cipher(dst, src) })
211+
212+
// Only overlap on one byte
213+
src = buff[:blockSize]
214+
dst = buff[blockSize-1 : 2*blockSize-1]
215+
mustPanic(t, "invalid buffer overlap", func() { cipher(dst, src) })
216+
217+
// src comes after dst with one byte overlap
218+
src = buff[blockSize-1 : 2*blockSize-1]
219+
dst = buff[:blockSize]
220+
mustPanic(t, "invalid buffer overlap", func() { cipher(dst, src) })
221+
})
222+
223+
// Test short input/output.
224+
// Assembly used to not notice.
225+
// See issue 7928.
226+
t.Run("ShortBlock", func(t *testing.T) {
227+
// Returns slice of n bytes of an n+1 length array. Lets us test that a
228+
// slice is still considered too short even if the underlying array it
229+
// points to is large enough
230+
byteSlice := func(n int) []byte { return make([]byte, n+1)[0:n] }
231+
232+
// Off by one byte
233+
mustPanic(t, "input not full block", func() { cipher(byteSlice(blockSize), byteSlice(blockSize-1)) })
234+
mustPanic(t, "output not full block", func() { cipher(byteSlice(blockSize-1), byteSlice(blockSize)) })
235+
236+
// Small slices
237+
mustPanic(t, "input not full block", func() { cipher(byteSlice(1), byteSlice(1)) })
238+
mustPanic(t, "input not full block", func() { cipher(byteSlice(100), byteSlice(1)) })
239+
mustPanic(t, "output not full block", func() { cipher(byteSlice(1), byteSlice(100)) })
240+
})
241+
}
242+
243+
func mustPanic(t *testing.T, msg string, f func()) {
244+
t.Helper()
245+
246+
defer func() {
247+
t.Helper()
248+
249+
err := recover()
250+
251+
if err == nil {
252+
t.Errorf("function did not panic for %q", msg)
253+
}
254+
}()
255+
f()
256+
}

0 commit comments

Comments
 (0)