Skip to content

Commit

Permalink
deduplicate code
Browse files Browse the repository at this point in the history
  • Loading branch information
qmuntal committed Sep 4, 2024
1 parent 6a23e2a commit 25e7f1b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 28 deletions.
44 changes: 22 additions & 22 deletions hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,14 @@ var isMarshallableMap sync.Map

// isHashMarshallable returns true if the memory layout of cb
// is known by this library and can therefore be marshalled.
func isHashMarshallable(cb crypto.Hash) bool {
func isHashMarshallable(ch crypto.Hash) bool {
if vMajor == 1 {
return true
}
if v, ok := isMarshallableMap.Load(cb); ok {
if v, ok := isMarshallableMap.Load(ch); ok {
return v.(bool)
}
md := cryptoHashToMD(cb)
md := cryptoHashToMD(ch)
if md == nil {
return false
}
Expand All @@ -138,7 +138,7 @@ func isHashMarshallable(cb crypto.Hash) bool {
// We only know the memory layout of the built-in providers.
// See evpHash.hashState for more details.
marshallable := name == "default" || name == "fips"
isMarshallableMap.Store(cb, marshallable)
isMarshallableMap.Store(ch, marshallable)
return marshallable
}

Expand Down Expand Up @@ -276,11 +276,11 @@ func (h *md4Hash) Sum(in []byte) []byte {

// NewMD5 returns a new MD5 hash.
func NewMD5() hash.Hash {
h := newEvpHash(crypto.MD5, 16, 64)
h := md5Hash{evpHash: newEvpHash(crypto.MD5, 16, 64)}
if isHashMarshallable(crypto.MD5) {
return &md5Marshal{md5Hash{evpHash: h}}
return &md5Marshal{h}
}
return &md5Hash{evpHash: h}
return &h
}

// md5State layout is taken from
Expand Down Expand Up @@ -354,11 +354,11 @@ func (h *md5Marshal) UnmarshalBinary(b []byte) error {

// NewSHA1 returns a new SHA1 hash.
func NewSHA1() hash.Hash {
h := newEvpHash(crypto.SHA1, 20, 64)
h := sha1Hash{evpHash: newEvpHash(crypto.SHA1, 20, 64)}
if isHashMarshallable(crypto.SHA1) {
return &sha1Marshal{sha1Hash{evpHash: h}}
return &sha1Marshal{h}
}
return &sha1Hash{evpHash: h}
return &h
}

type sha1Hash struct {
Expand Down Expand Up @@ -434,11 +434,11 @@ func (h *sha1Marshal) UnmarshalBinary(b []byte) error {

// NewSHA224 returns a new SHA224 hash.
func NewSHA224() hash.Hash {
h := newEvpHash(crypto.SHA224, 224/8, 64)
h := sha224Hash{evpHash: newEvpHash(crypto.SHA224, 224/8, 64)}
if isHashMarshallable(crypto.SHA224) {
return &sha224Marshal{sha224Hash{evpHash: h}}
return &sha224Marshal{h}
}
return &sha224Hash{evpHash: h}
return &h
}

type sha224Hash struct {
Expand All @@ -453,11 +453,11 @@ func (h *sha224Hash) Sum(in []byte) []byte {

// NewSHA256 returns a new SHA256 hash.
func NewSHA256() hash.Hash {
h := newEvpHash(crypto.SHA256, 256/8, 64)
h := sha256Hash{evpHash: newEvpHash(crypto.SHA256, 256/8, 64)}
if isHashMarshallable(crypto.SHA256) {
return &sha256Marshal{sha256Hash{evpHash: h}}
return &sha256Marshal{h}
}
return &sha256Hash{evpHash: h}
return &h
}

type sha256Hash struct {
Expand Down Expand Up @@ -593,11 +593,11 @@ func (h *sha256Marshal) UnmarshalBinary(b []byte) error {

// NewSHA384 returns a new SHA384 hash.
func NewSHA384() hash.Hash {
h := newEvpHash(crypto.SHA384, 384/8, 128)
h := sha384Hash{evpHash: newEvpHash(crypto.SHA384, 384/8, 128)}
if isHashMarshallable(crypto.SHA384) {
return &sha384Marshal{sha384Hash{evpHash: h}}
return &sha384Marshal{h}
}
return &sha384Hash{evpHash: h}
return &h
}

type sha384Hash struct {
Expand All @@ -612,11 +612,11 @@ func (h *sha384Hash) Sum(in []byte) []byte {

// NewSHA512 returns a new SHA512 hash.
func NewSHA512() hash.Hash {
h := newEvpHash(crypto.SHA512, 512/8, 128)
h := sha512Hash{evpHash: newEvpHash(crypto.SHA512, 512/8, 128)}
if isHashMarshallable(crypto.SHA512) {
return &sha512Marshal{sha512Hash{evpHash: h}}
return &sha512Marshal{h}
}
return &sha512Hash{evpHash: h}
return &h
}

type sha512Hash struct {
Expand Down
12 changes: 6 additions & 6 deletions hash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ func TestHash(t *testing.T) {
crypto.SHA3_384,
crypto.SHA3_512,
}
for _, cb := range tests {
cb := cb
t.Run(cb.String(), func(t *testing.T) {
for _, ch := range tests {
ch := ch
t.Run(ch.String(), func(t *testing.T) {
t.Parallel()
if !openssl.SupportsHash(cb) {
if !openssl.SupportsHash(ch) {
t.Skip("skipping: not supported")
}
h := cryptoToHash(cb)()
h := cryptoToHash(ch)()
initSum := h.Sum(nil)
n, err := h.Write(msg)
if err != nil {
Expand All @@ -82,7 +82,7 @@ func TestHash(t *testing.T) {
if err != nil {
t.Errorf("could not marshal: %v", err)
}
h2 := cryptoToHash(cb)()
h2 := cryptoToHash(ch)()
if err := h2.(encoding.BinaryUnmarshaler).UnmarshalBinary(state); err != nil {
t.Errorf("could not unmarshal: %v", err)
}
Expand Down

0 comments on commit 25e7f1b

Please sign in to comment.