Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Clone for all hashers #167

Merged
merged 3 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,31 @@ func (h *evpHash) sum(out []byte) {
runtime.KeepAlive(h)
}

func (h *evpHash) clone() (*evpHash, error) {
ctx := C.go_openssl_EVP_MD_CTX_new()
if ctx == nil {
return nil, newOpenSSLError("EVP_MD_CTX_new")
}
if C.go_openssl_EVP_MD_CTX_copy_ex(ctx, h.ctx) != 1 {
C.go_openssl_EVP_MD_CTX_free(ctx)
return nil, newOpenSSLError("EVP_MD_CTX_copy")
}
ctx2 := C.go_openssl_EVP_MD_CTX_new()
if ctx2 == nil {
C.go_openssl_EVP_MD_CTX_free(ctx)
return nil, newOpenSSLError("EVP_MD_CTX_new")
}
cloned := &evpHash{
ctx: ctx,
ctx2: ctx2,
size: h.size,
blockSize: h.blockSize,
marshallable: h.marshallable,
}
runtime.SetFinalizer(cloned, (*evpHash).finalize)
return cloned, nil
}

// hashState returns a pointer to the internal hash structure.
//
// The EVP_MD_CTX memory layout has changed in OpenSSL 3
Expand Down Expand Up @@ -280,6 +305,14 @@ func (h *md4Hash) Sum(in []byte) []byte {
return append(in, h.out[:]...)
}

func (h *md4Hash) Clone() (hash.Hash, error) {
c, err := h.clone()
if err != nil {
return nil, err
}
return &md4Hash{evpHash: c}, nil
}

// NewMD5 returns a new MD5 hash.
func NewMD5() hash.Hash {
h := md5Hash{evpHash: newEvpHash(crypto.MD5)}
Expand Down Expand Up @@ -308,6 +341,14 @@ func (h *md5Hash) Sum(in []byte) []byte {
return append(in, h.out[:]...)
}

func (h *md5Hash) Clone() (hash.Hash, error) {
c, err := h.clone()
if err != nil {
return nil, err
}
return &md5Hash{evpHash: c}, nil
}

const (
md5Magic = "md5\x01"
md5MarshaledSize = len(md5Magic) + 4*4 + 64 + 8
Expand Down Expand Up @@ -377,6 +418,14 @@ func (h *sha1Hash) Sum(in []byte) []byte {
return append(in, h.out[:]...)
}

func (h *sha1Hash) Clone() (hash.Hash, error) {
c, err := h.clone()
if err != nil {
return nil, err
}
return &sha1Hash{evpHash: c}, nil
}

// sha1State layout is taken from
// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/sha.h#L34.
type sha1State struct {
Expand Down Expand Up @@ -457,6 +506,14 @@ func (h *sha224Hash) Sum(in []byte) []byte {
return append(in, h.out[:]...)
}

func (h *sha224Hash) Clone() (hash.Hash, error) {
c, err := h.clone()
if err != nil {
return nil, err
}
return &sha224Hash{evpHash: c}, nil
}

// NewSHA256 returns a new SHA256 hash.
func NewSHA256() hash.Hash {
h := sha256Hash{evpHash: newEvpHash(crypto.SHA256)}
Expand All @@ -476,6 +533,14 @@ func (h *sha256Hash) Sum(in []byte) []byte {
return append(in, h.out[:]...)
}

func (h *sha256Hash) Clone() (hash.Hash, error) {
c, err := h.clone()
if err != nil {
return nil, err
}
return &sha256Hash{evpHash: c}, nil
}

const (
magic224 = "sha\x02"
magic256 = "sha\x03"
Expand Down Expand Up @@ -616,6 +681,14 @@ func (h *sha384Hash) Sum(in []byte) []byte {
return append(in, h.out[:]...)
}

func (h *sha384Hash) Clone() (hash.Hash, error) {
c, err := h.clone()
if err != nil {
return nil, err
}
return &sha384Hash{evpHash: c}, nil
}

// NewSHA512 returns a new SHA512 hash.
func NewSHA512() hash.Hash {
h := sha512Hash{evpHash: newEvpHash(crypto.SHA512)}
Expand All @@ -635,6 +708,14 @@ func (h *sha512Hash) Sum(in []byte) []byte {
return append(in, h.out[:]...)
}

func (h *sha512Hash) Clone() (hash.Hash, error) {
c, err := h.clone()
if err != nil {
return nil, err
}
return &sha512Hash{evpHash: c}, nil
}

// sha512State layout is taken from
// https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/sha.h#L95.
type sha512State struct {
Expand Down Expand Up @@ -781,6 +862,14 @@ func (h *sha3_224Hash) Sum(in []byte) []byte {
return append(in, h.out[:]...)
}

func (h *sha3_224Hash) Clone() (hash.Hash, error) {
c, err := h.clone()
if err != nil {
return nil, err
}
return &sha3_224Hash{evpHash: c}, nil
}

// NewSHA3_256 returns a new SHA3-256 hash.
func NewSHA3_256() hash.Hash {
return &sha3_256Hash{
Expand All @@ -798,6 +887,14 @@ func (h *sha3_256Hash) Sum(in []byte) []byte {
return append(in, h.out[:]...)
}

func (h *sha3_256Hash) Clone() (hash.Hash, error) {
c, err := h.clone()
if err != nil {
return nil, err
}
return &sha3_256Hash{evpHash: c}, nil
}

// NewSHA3_384 returns a new SHA3-384 hash.
func NewSHA3_384() hash.Hash {
return &sha3_384Hash{
Expand All @@ -815,6 +912,14 @@ func (h *sha3_384Hash) Sum(in []byte) []byte {
return append(in, h.out[:]...)
}

func (h *sha3_384Hash) Clone() (hash.Hash, error) {
c, err := h.clone()
if err != nil {
return nil, err
}
return &sha3_384Hash{evpHash: c}, nil
}

// NewSHA3_512 returns a new SHA3-512 hash.
func NewSHA3_512() hash.Hash {
return &sha3_512Hash{
Expand All @@ -832,6 +937,14 @@ func (h *sha3_512Hash) Sum(in []byte) []byte {
return append(in, h.out[:]...)
}

func (h *sha3_512Hash) Clone() (hash.Hash, error) {
c, err := h.clone()
if err != nil {
return nil, err
}
return &sha3_512Hash{evpHash: c}, nil
}

// appendUint64 appends x into b as a big endian byte sequence.
func appendUint64(b []byte, x uint64) []byte {
return append(b,
Expand Down
136 changes: 107 additions & 29 deletions hash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,23 @@ func cryptoToHash(h crypto.Hash) func() hash.Hash {
return nil
}

var hashes = [...]crypto.Hash{
crypto.MD4,
crypto.MD5,
crypto.SHA1,
crypto.SHA224,
crypto.SHA256,
crypto.SHA384,
crypto.SHA512,
crypto.SHA3_224,
crypto.SHA3_256,
crypto.SHA3_384,
crypto.SHA3_512,
}

func TestHash(t *testing.T) {
msg := []byte("testing")
var tests = []crypto.Hash{
crypto.MD4,
crypto.MD5,
crypto.SHA1,
crypto.SHA224,
crypto.SHA256,
crypto.SHA384,
crypto.SHA512,
crypto.SHA3_224,
crypto.SHA3_256,
crypto.SHA3_384,
crypto.SHA3_512,
}
for _, ch := range tests {
for _, ch := range hashes {
ch := ch
t.Run(ch.String(), func(t *testing.T) {
t.Parallel()
Expand All @@ -77,38 +78,115 @@ func TestHash(t *testing.T) {
if bytes.Equal(sum, initSum) {
t.Error("Write didn't change internal hash state")
}
if _, ok := h.(encoding.BinaryMarshaler); ok {
state, err := h.(encoding.BinaryMarshaler).MarshalBinary()
if err != nil {
t.Errorf("could not marshal: %v", err)
}
h2 := cryptoToHash(ch)()
if err := h2.(encoding.BinaryUnmarshaler).UnmarshalBinary(state); err != nil {
t.Errorf("could not unmarshal: %v", err)
}
if actual, actual2 := h.Sum(nil), h2.Sum(nil); !bytes.Equal(actual, actual2) {
t.Errorf("0x%x != marshaled 0x%x", actual, actual2)
}
}
h.Reset()
sum = h.Sum(nil)
if !bytes.Equal(sum, initSum) {
t.Errorf("got:%x want:%x", sum, initSum)
}
})
}
}

func TestHash_BinaryMarshaler(t *testing.T) {
msg := []byte("testing")
for _, ch := range hashes {
ch := ch
t.Run(ch.String(), func(t *testing.T) {
t.Parallel()
if !openssl.SupportsHash(ch) {
t.Skip("skipping: not supported")
}
h := cryptoToHash(ch)()
if _, ok := h.(encoding.BinaryMarshaler); !ok {
t.Skip("skipping: not supported")
}
_, err := h.Write(msg)
if err != nil {
t.Fatal(err)
}
state, err := h.(encoding.BinaryMarshaler).MarshalBinary()
if err != nil {
t.Errorf("could not marshal: %v", err)
}
h2 := cryptoToHash(ch)()
if err := h2.(encoding.BinaryUnmarshaler).UnmarshalBinary(state); err != nil {
t.Errorf("could not unmarshal: %v", err)
}
if actual, actual2 := h.Sum(nil), h2.Sum(nil); !bytes.Equal(actual, actual2) {
t.Errorf("0x%x != marshaled 0x%x", actual, actual2)
}
})
}
}

func TestHash_Clone(t *testing.T) {
msg := []byte("testing")
for _, ch := range hashes {
ch := ch
t.Run(ch.String(), func(t *testing.T) {
t.Parallel()
if !openssl.SupportsHash(ch) {
t.Skip("skipping: not supported")
}
h := cryptoToHash(ch)()
if _, ok := h.(encoding.BinaryMarshaler); !ok {
t.Skip("skipping: not supported")
}
_, err := h.Write(msg)
if err != nil {
t.Fatal(err)
}
h2, err := h.(interface{ Clone() (hash.Hash, error) }).Clone()
if err != nil {
t.Fatal(err)
}
h.Write(msg)
h2.Write(msg)
if actual, actual2 := h.Sum(nil), h2.Sum(nil); !bytes.Equal(actual, actual2) {
t.Errorf("%s(%q) = 0x%x != cloned 0x%x", ch.String(), msg, actual, actual2)
}
})
}
}

func TestHash_ByteWriter(t *testing.T) {
msg := []byte("testing")
for _, ch := range hashes {
ch := ch
t.Run(ch.String(), func(t *testing.T) {
t.Parallel()
if !openssl.SupportsHash(ch) {
t.Skip("skipping: not supported")
}
h := cryptoToHash(ch)()
initSum := h.Sum(nil)
bw := h.(io.ByteWriter)
for i := 0; i < len(msg); i++ {
bw.WriteByte(msg[i])
}
h.Reset()
sum = h.Sum(nil)
sum := h.Sum(nil)
if !bytes.Equal(sum, initSum) {
t.Errorf("got:%x want:%x", sum, initSum)
}
})
}
}

func TestHash_StringWriter(t *testing.T) {
msg := []byte("testing")
for _, ch := range hashes {
ch := ch
t.Run(ch.String(), func(t *testing.T) {
t.Parallel()
if !openssl.SupportsHash(ch) {
t.Skip("skipping: not supported")
}
h := cryptoToHash(ch)()
initSum := h.Sum(nil)
h.(io.StringWriter).WriteString(string(msg))
h.Reset()
sum = h.Sum(nil)
sum := h.Sum(nil)
if !bytes.Equal(sum, initSum) {
t.Errorf("got:%x want:%x", sum, initSum)
}
Expand Down
1 change: 1 addition & 0 deletions shims.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ DEFINEFUNC(int, RAND_bytes, (unsigned char *arg0, int arg1), (arg0, arg1)) \
DEFINEFUNC_RENAMED_1_1(GO_EVP_MD_CTX_PTR, EVP_MD_CTX_new, EVP_MD_CTX_create, (void), ()) \
DEFINEFUNC_RENAMED_1_1(void, EVP_MD_CTX_free, EVP_MD_CTX_destroy, (GO_EVP_MD_CTX_PTR ctx), (ctx)) \
DEFINEFUNC(int, EVP_MD_CTX_copy, (GO_EVP_MD_CTX_PTR out, const GO_EVP_MD_CTX_PTR in), (out, in)) \
DEFINEFUNC(int, EVP_MD_CTX_copy_ex, (GO_EVP_MD_CTX_PTR out, const GO_EVP_MD_CTX_PTR in), (out, in)) \
DEFINEFUNC(int, EVP_Digest, (const void *data, size_t count, unsigned char *md, unsigned int *size, const GO_EVP_MD_PTR type, GO_ENGINE_PTR impl), (data, count, md, size, type, impl)) \
DEFINEFUNC(int, EVP_DigestInit_ex, (GO_EVP_MD_CTX_PTR ctx, const GO_EVP_MD_PTR type, GO_ENGINE_PTR impl), (ctx, type, impl)) \
DEFINEFUNC(int, EVP_DigestInit, (GO_EVP_MD_CTX_PTR ctx, const GO_EVP_MD_PTR type), (ctx, type)) \
Expand Down
Loading