From 950c5fd4111c53ff8aa1ca2eb6c568a70d3d080e Mon Sep 17 00:00:00 2001 From: Quim Muntal Date: Mon, 9 Sep 2024 11:06:56 +0200 Subject: [PATCH] Implement Clone for all hashers (#167) * implement Clone for all hashers * document clone methods * add TestHash_Clone comment --- hash.go | 149 +++++++++++++++++++++++++++++++++++++++++++++++++++ hash_test.go | 138 +++++++++++++++++++++++++++++++++++++---------- shims.h | 1 + 3 files changed, 259 insertions(+), 29 deletions(-) diff --git a/hash.go b/hash.go index bdc54b59..b020aa6a 100644 --- a/hash.go +++ b/hash.go @@ -230,6 +230,34 @@ func (h *evpHash) sum(out []byte) { runtime.KeepAlive(h) } +// clone returns a new evpHash object that is a deep clone of itself. +// The duplicate object contains all state and data contained in the +// original object at the point of duplication. +func (h *evpHash) clone() (*evpHash, error) { + ctx := C.go_openssl_EVP_MD_CTX_new() + if ctx == nil { + return nil, newOpenSSLError("EVP_MD_CTX_new") + } + if C.go_openssl_EVP_MD_CTX_copy_ex(ctx, h.ctx) != 1 { + C.go_openssl_EVP_MD_CTX_free(ctx) + return nil, newOpenSSLError("EVP_MD_CTX_copy") + } + ctx2 := C.go_openssl_EVP_MD_CTX_new() + if ctx2 == nil { + C.go_openssl_EVP_MD_CTX_free(ctx) + return nil, newOpenSSLError("EVP_MD_CTX_new") + } + cloned := &evpHash{ + ctx: ctx, + ctx2: ctx2, + size: h.size, + blockSize: h.blockSize, + marshallable: h.marshallable, + } + runtime.SetFinalizer(cloned, (*evpHash).finalize) + return cloned, nil +} + // hashState returns a pointer to the internal hash structure. // // The EVP_MD_CTX memory layout has changed in OpenSSL 3 @@ -280,6 +308,17 @@ func (h *md4Hash) Sum(in []byte) []byte { return append(in, h.out[:]...) } +// Clone returns a new [hash.Hash] object that is a deep clone of itself. +// The duplicate object contains all state and data contained in the +// original object at the point of duplication. +func (h *md4Hash) Clone() (hash.Hash, error) { + c, err := h.clone() + if err != nil { + return nil, err + } + return &md4Hash{evpHash: c}, nil +} + // NewMD5 returns a new MD5 hash. func NewMD5() hash.Hash { h := md5Hash{evpHash: newEvpHash(crypto.MD5)} @@ -308,6 +347,17 @@ func (h *md5Hash) Sum(in []byte) []byte { return append(in, h.out[:]...) } +// Clone returns a new [hash.Hash] object that is a deep clone of itself. +// The duplicate object contains all state and data contained in the +// original object at the point of duplication. +func (h *md5Hash) Clone() (hash.Hash, error) { + c, err := h.clone() + if err != nil { + return nil, err + } + return &md5Hash{evpHash: c}, nil +} + const ( md5Magic = "md5\x01" md5MarshaledSize = len(md5Magic) + 4*4 + 64 + 8 @@ -377,6 +427,17 @@ func (h *sha1Hash) Sum(in []byte) []byte { return append(in, h.out[:]...) } +// Clone returns a new [hash.Hash] object that is a deep clone of itself. +// The duplicate object contains all state and data contained in the +// original object at the point of duplication. +func (h *sha1Hash) Clone() (hash.Hash, error) { + c, err := h.clone() + if err != nil { + return nil, err + } + return &sha1Hash{evpHash: c}, nil +} + // sha1State layout is taken from // https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/sha.h#L34. type sha1State struct { @@ -457,6 +518,17 @@ func (h *sha224Hash) Sum(in []byte) []byte { return append(in, h.out[:]...) } +// Clone returns a new [hash.Hash] object that is a deep clone of itself. +// The duplicate object contains all state and data contained in the +// original object at the point of duplication. +func (h *sha224Hash) Clone() (hash.Hash, error) { + c, err := h.clone() + if err != nil { + return nil, err + } + return &sha224Hash{evpHash: c}, nil +} + // NewSHA256 returns a new SHA256 hash. func NewSHA256() hash.Hash { h := sha256Hash{evpHash: newEvpHash(crypto.SHA256)} @@ -476,6 +548,17 @@ func (h *sha256Hash) Sum(in []byte) []byte { return append(in, h.out[:]...) } +// Clone returns a new [hash.Hash] object that is a deep clone of itself. +// The duplicate object contains all state and data contained in the +// original object at the point of duplication. +func (h *sha256Hash) Clone() (hash.Hash, error) { + c, err := h.clone() + if err != nil { + return nil, err + } + return &sha256Hash{evpHash: c}, nil +} + const ( magic224 = "sha\x02" magic256 = "sha\x03" @@ -616,6 +699,17 @@ func (h *sha384Hash) Sum(in []byte) []byte { return append(in, h.out[:]...) } +// Clone returns a new [hash.Hash] object that is a deep clone of itself. +// The duplicate object contains all state and data contained in the +// original object at the point of duplication. +func (h *sha384Hash) Clone() (hash.Hash, error) { + c, err := h.clone() + if err != nil { + return nil, err + } + return &sha384Hash{evpHash: c}, nil +} + // NewSHA512 returns a new SHA512 hash. func NewSHA512() hash.Hash { h := sha512Hash{evpHash: newEvpHash(crypto.SHA512)} @@ -635,6 +729,17 @@ func (h *sha512Hash) Sum(in []byte) []byte { return append(in, h.out[:]...) } +// Clone returns a new [hash.Hash] object that is a deep clone of itself. +// The duplicate object contains all state and data contained in the +// original object at the point of duplication. +func (h *sha512Hash) Clone() (hash.Hash, error) { + c, err := h.clone() + if err != nil { + return nil, err + } + return &sha512Hash{evpHash: c}, nil +} + // sha512State layout is taken from // https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/include/openssl/sha.h#L95. type sha512State struct { @@ -781,6 +886,17 @@ func (h *sha3_224Hash) Sum(in []byte) []byte { return append(in, h.out[:]...) } +// Clone returns a new [hash.Hash] object that is a deep clone of itself. +// The duplicate object contains all state and data contained in the +// original object at the point of duplication. +func (h *sha3_224Hash) Clone() (hash.Hash, error) { + c, err := h.clone() + if err != nil { + return nil, err + } + return &sha3_224Hash{evpHash: c}, nil +} + // NewSHA3_256 returns a new SHA3-256 hash. func NewSHA3_256() hash.Hash { return &sha3_256Hash{ @@ -798,6 +914,17 @@ func (h *sha3_256Hash) Sum(in []byte) []byte { return append(in, h.out[:]...) } +// Clone returns a new [hash.Hash] object that is a deep clone of itself. +// The duplicate object contains all state and data contained in the +// original object at the point of duplication. +func (h *sha3_256Hash) Clone() (hash.Hash, error) { + c, err := h.clone() + if err != nil { + return nil, err + } + return &sha3_256Hash{evpHash: c}, nil +} + // NewSHA3_384 returns a new SHA3-384 hash. func NewSHA3_384() hash.Hash { return &sha3_384Hash{ @@ -815,6 +942,17 @@ func (h *sha3_384Hash) Sum(in []byte) []byte { return append(in, h.out[:]...) } +// Clone returns a new [hash.Hash] object that is a deep clone of itself. +// The duplicate object contains all state and data contained in the +// original object at the point of duplication. +func (h *sha3_384Hash) Clone() (hash.Hash, error) { + c, err := h.clone() + if err != nil { + return nil, err + } + return &sha3_384Hash{evpHash: c}, nil +} + // NewSHA3_512 returns a new SHA3-512 hash. func NewSHA3_512() hash.Hash { return &sha3_512Hash{ @@ -832,6 +970,17 @@ func (h *sha3_512Hash) Sum(in []byte) []byte { return append(in, h.out[:]...) } +// Clone returns a new [hash.Hash] object that is a deep clone of itself. +// The duplicate object contains all state and data contained in the +// original object at the point of duplication. +func (h *sha3_512Hash) Clone() (hash.Hash, error) { + c, err := h.clone() + if err != nil { + return nil, err + } + return &sha3_512Hash{evpHash: c}, nil +} + // appendUint64 appends x into b as a big endian byte sequence. func appendUint64(b []byte, x uint64) []byte { return append(b, diff --git a/hash_test.go b/hash_test.go index d8f7e78a..586960da 100644 --- a/hash_test.go +++ b/hash_test.go @@ -39,22 +39,23 @@ func cryptoToHash(h crypto.Hash) func() hash.Hash { return nil } +var hashes = [...]crypto.Hash{ + crypto.MD4, + crypto.MD5, + crypto.SHA1, + crypto.SHA224, + crypto.SHA256, + crypto.SHA384, + crypto.SHA512, + crypto.SHA3_224, + crypto.SHA3_256, + crypto.SHA3_384, + crypto.SHA3_512, +} + func TestHash(t *testing.T) { msg := []byte("testing") - var tests = []crypto.Hash{ - crypto.MD4, - crypto.MD5, - crypto.SHA1, - crypto.SHA224, - crypto.SHA256, - crypto.SHA384, - crypto.SHA512, - crypto.SHA3_224, - crypto.SHA3_256, - crypto.SHA3_384, - crypto.SHA3_512, - } - for _, ch := range tests { + for _, ch := range hashes { ch := ch t.Run(ch.String(), func(t *testing.T) { t.Parallel() @@ -77,38 +78,117 @@ func TestHash(t *testing.T) { if bytes.Equal(sum, initSum) { t.Error("Write didn't change internal hash state") } - if _, ok := h.(encoding.BinaryMarshaler); ok { - state, err := h.(encoding.BinaryMarshaler).MarshalBinary() - if err != nil { - t.Errorf("could not marshal: %v", err) - } - h2 := cryptoToHash(ch)() - if err := h2.(encoding.BinaryUnmarshaler).UnmarshalBinary(state); err != nil { - t.Errorf("could not unmarshal: %v", err) - } - if actual, actual2 := h.Sum(nil), h2.Sum(nil); !bytes.Equal(actual, actual2) { - t.Errorf("0x%x != marshaled 0x%x", actual, actual2) - } - } h.Reset() sum = h.Sum(nil) if !bytes.Equal(sum, initSum) { t.Errorf("got:%x want:%x", sum, initSum) } + }) + } +} + +func TestHash_BinaryMarshaler(t *testing.T) { + msg := []byte("testing") + for _, ch := range hashes { + ch := ch + t.Run(ch.String(), func(t *testing.T) { + t.Parallel() + if !openssl.SupportsHash(ch) { + t.Skip("skipping: not supported") + } + h := cryptoToHash(ch)() + if _, ok := h.(encoding.BinaryMarshaler); !ok { + t.Skip("skipping: not supported") + } + _, err := h.Write(msg) + if err != nil { + t.Fatal(err) + } + state, err := h.(encoding.BinaryMarshaler).MarshalBinary() + if err != nil { + t.Errorf("could not marshal: %v", err) + } + h2 := cryptoToHash(ch)() + if err := h2.(encoding.BinaryUnmarshaler).UnmarshalBinary(state); err != nil { + t.Errorf("could not unmarshal: %v", err) + } + if actual, actual2 := h.Sum(nil), h2.Sum(nil); !bytes.Equal(actual, actual2) { + t.Errorf("0x%x != marshaled 0x%x", actual, actual2) + } + }) + } +} +func TestHash_Clone(t *testing.T) { + msg := []byte("testing") + for _, ch := range hashes { + ch := ch + t.Run(ch.String(), func(t *testing.T) { + t.Parallel() + if !openssl.SupportsHash(ch) { + t.Skip("skipping: not supported") + } + h := cryptoToHash(ch)() + if _, ok := h.(encoding.BinaryMarshaler); !ok { + t.Skip("skipping: not supported") + } + _, err := h.Write(msg) + if err != nil { + t.Fatal(err) + } + // We don't define an interface for the Clone method to avoid other + // packages from depending on it. Use type assertion to call it. + h2, err := h.(interface{ Clone() (hash.Hash, error) }).Clone() + if err != nil { + t.Fatal(err) + } + h.Write(msg) + h2.Write(msg) + if actual, actual2 := h.Sum(nil), h2.Sum(nil); !bytes.Equal(actual, actual2) { + t.Errorf("%s(%q) = 0x%x != cloned 0x%x", ch.String(), msg, actual, actual2) + } + }) + } +} + +func TestHash_ByteWriter(t *testing.T) { + msg := []byte("testing") + for _, ch := range hashes { + ch := ch + t.Run(ch.String(), func(t *testing.T) { + t.Parallel() + if !openssl.SupportsHash(ch) { + t.Skip("skipping: not supported") + } + h := cryptoToHash(ch)() + initSum := h.Sum(nil) bw := h.(io.ByteWriter) for i := 0; i < len(msg); i++ { bw.WriteByte(msg[i]) } h.Reset() - sum = h.Sum(nil) + sum := h.Sum(nil) if !bytes.Equal(sum, initSum) { t.Errorf("got:%x want:%x", sum, initSum) } + }) + } +} +func TestHash_StringWriter(t *testing.T) { + msg := []byte("testing") + for _, ch := range hashes { + ch := ch + t.Run(ch.String(), func(t *testing.T) { + t.Parallel() + if !openssl.SupportsHash(ch) { + t.Skip("skipping: not supported") + } + h := cryptoToHash(ch)() + initSum := h.Sum(nil) h.(io.StringWriter).WriteString(string(msg)) h.Reset() - sum = h.Sum(nil) + sum := h.Sum(nil) if !bytes.Equal(sum, initSum) { t.Errorf("got:%x want:%x", sum, initSum) } diff --git a/shims.h b/shims.h index f3a623d1..466a925a 100644 --- a/shims.h +++ b/shims.h @@ -200,6 +200,7 @@ DEFINEFUNC(int, RAND_bytes, (unsigned char *arg0, int arg1), (arg0, arg1)) \ DEFINEFUNC_RENAMED_1_1(GO_EVP_MD_CTX_PTR, EVP_MD_CTX_new, EVP_MD_CTX_create, (void), ()) \ DEFINEFUNC_RENAMED_1_1(void, EVP_MD_CTX_free, EVP_MD_CTX_destroy, (GO_EVP_MD_CTX_PTR ctx), (ctx)) \ DEFINEFUNC(int, EVP_MD_CTX_copy, (GO_EVP_MD_CTX_PTR out, const GO_EVP_MD_CTX_PTR in), (out, in)) \ +DEFINEFUNC(int, EVP_MD_CTX_copy_ex, (GO_EVP_MD_CTX_PTR out, const GO_EVP_MD_CTX_PTR in), (out, in)) \ DEFINEFUNC(int, EVP_Digest, (const void *data, size_t count, unsigned char *md, unsigned int *size, const GO_EVP_MD_PTR type, GO_ENGINE_PTR impl), (data, count, md, size, type, impl)) \ DEFINEFUNC(int, EVP_DigestInit_ex, (GO_EVP_MD_CTX_PTR ctx, const GO_EVP_MD_PTR type, GO_ENGINE_PTR impl), (ctx, type, impl)) \ DEFINEFUNC(int, EVP_DigestInit, (GO_EVP_MD_CTX_PTR ctx, const GO_EVP_MD_PTR type), (ctx, type)) \