Skip to content

Commit 313c54f

Browse files
authored
speed up NewShaX (#238)
1 parent d9e21e3 commit 313c54f

File tree

2 files changed

+71
-36
lines changed

2 files changed

+71
-36
lines changed

hash.go

+64-36
Original file line numberDiff line numberDiff line change
@@ -251,27 +251,42 @@ func newEvpHash(ch crypto.Hash) *evpHash {
251251
if alg == nil {
252252
panic("openssl: unsupported hash function: " + strconv.Itoa(int(ch)))
253253
}
254-
ctx := C.go_openssl_EVP_MD_CTX_new()
255-
if C.go_openssl_EVP_DigestInit_ex(ctx, alg.md, nil) != 1 {
256-
C.go_openssl_EVP_MD_CTX_free(ctx)
257-
panic(newOpenSSLError("EVP_DigestInit_ex"))
258-
}
259-
ctx2 := C.go_openssl_EVP_MD_CTX_new()
260-
h := &evpHash{
261-
alg: alg,
262-
ctx: ctx,
263-
ctx2: ctx2,
264-
}
265-
runtime.SetFinalizer(h, (*evpHash).finalize)
254+
h := &evpHash{alg: alg}
255+
// Don't call init() yet, it would be wasteful
256+
// if the caller only wants to know the hash type. This
257+
// is a common pattern in this package, as some functions
258+
// accept a `func() hash.Hash` parameter and call it just
259+
// to know the hash type.
266260
return h
267261
}
268262

269263
func (h *evpHash) finalize() {
270-
C.go_openssl_EVP_MD_CTX_free(h.ctx)
271-
C.go_openssl_EVP_MD_CTX_free(h.ctx2)
264+
if h.ctx != nil {
265+
C.go_openssl_EVP_MD_CTX_free(h.ctx)
266+
}
267+
if h.ctx2 != nil {
268+
C.go_openssl_EVP_MD_CTX_free(h.ctx2)
269+
}
270+
}
271+
272+
func (h *evpHash) init() {
273+
if h.ctx != nil {
274+
return
275+
}
276+
h.ctx = C.go_openssl_EVP_MD_CTX_new()
277+
if C.go_openssl_EVP_DigestInit_ex(h.ctx, h.alg.md, nil) != 1 {
278+
C.go_openssl_EVP_MD_CTX_free(h.ctx)
279+
panic(newOpenSSLError("EVP_DigestInit_ex"))
280+
}
281+
h.ctx2 = C.go_openssl_EVP_MD_CTX_new()
282+
runtime.SetFinalizer(h, (*evpHash).finalize)
272283
}
273284

274285
func (h *evpHash) Reset() {
286+
if h.ctx == nil {
287+
// The hash is not initialized yet, no need to reset.
288+
return
289+
}
275290
// There is no need to reset h.ctx2 because it is always reset after
276291
// use in evpHash.sum.
277292
if C.go_openssl_EVP_DigestInit_ex(h.ctx, nil, nil) != 1 {
@@ -281,22 +296,31 @@ func (h *evpHash) Reset() {
281296
}
282297

283298
func (h *evpHash) Write(p []byte) (int, error) {
284-
if len(p) > 0 && C.go_openssl_EVP_DigestUpdate(h.ctx, unsafe.Pointer(&*addr(p)), C.size_t(len(p))) != 1 {
299+
if len(p) == 0 {
300+
return 0, nil
301+
}
302+
h.init()
303+
if C.go_openssl_EVP_DigestUpdate(h.ctx, unsafe.Pointer(&*addr(p)), C.size_t(len(p))) != 1 {
285304
panic(newOpenSSLError("EVP_DigestUpdate"))
286305
}
287306
runtime.KeepAlive(h)
288307
return len(p), nil
289308
}
290309

291310
func (h *evpHash) WriteString(s string) (int, error) {
292-
if len(s) > 0 && C.go_openssl_EVP_DigestUpdate(h.ctx, unsafe.Pointer(unsafe.StringData(s)), C.size_t(len(s))) == 0 {
311+
if len(s) == 0 {
312+
return 0, nil
313+
}
314+
h.init()
315+
if C.go_openssl_EVP_DigestUpdate(h.ctx, unsafe.Pointer(unsafe.StringData(s)), C.size_t(len(s))) == 0 {
293316
panic("openssl: EVP_DigestUpdate failed")
294317
}
295318
runtime.KeepAlive(h)
296319
return len(s), nil
297320
}
298321

299322
func (h *evpHash) WriteByte(c byte) error {
323+
h.init()
300324
if C.go_openssl_EVP_DigestUpdate(h.ctx, unsafe.Pointer(&c), 1) == 0 {
301325
panic("openssl: EVP_DigestUpdate failed")
302326
}
@@ -313,38 +337,38 @@ func (h *evpHash) BlockSize() int {
313337
}
314338

315339
func (h *evpHash) Sum(in []byte) []byte {
316-
defer runtime.KeepAlive(h)
340+
h.init()
317341
out := make([]byte, h.Size(), maxHashSize) // explicit cap to allow stack allocation
318342
if C.go_hash_sum(h.ctx, h.ctx2, base(out)) != 1 {
319343
panic(newOpenSSLError("go_hash_sum"))
320344
}
345+
runtime.KeepAlive(h)
321346
return append(in, out...)
322347
}
323348

324349
// Clone returns a new evpHash object that is a deep clone of itself.
325350
// The duplicate object contains all state and data contained in the
326351
// original object at the point of duplication.
327352
func (h *evpHash) Clone() (hash.Hash, error) {
328-
ctx := C.go_openssl_EVP_MD_CTX_new()
329-
if ctx == nil {
330-
return nil, newOpenSSLError("EVP_MD_CTX_new")
331-
}
332-
if C.go_openssl_EVP_MD_CTX_copy_ex(ctx, h.ctx) != 1 {
333-
C.go_openssl_EVP_MD_CTX_free(ctx)
334-
return nil, newOpenSSLError("EVP_MD_CTX_copy")
335-
}
336-
ctx2 := C.go_openssl_EVP_MD_CTX_new()
337-
if ctx2 == nil {
338-
C.go_openssl_EVP_MD_CTX_free(ctx)
339-
return nil, newOpenSSLError("EVP_MD_CTX_new")
340-
}
341-
cloned := &evpHash{
342-
alg: h.alg,
343-
ctx: ctx,
344-
ctx2: ctx2,
353+
h2 := &evpHash{alg: h.alg}
354+
if h.ctx != nil {
355+
h2.ctx = C.go_openssl_EVP_MD_CTX_new()
356+
if h2.ctx == nil {
357+
return nil, newOpenSSLError("EVP_MD_CTX_new")
358+
}
359+
if C.go_openssl_EVP_MD_CTX_copy_ex(h2.ctx, h.ctx) != 1 {
360+
C.go_openssl_EVP_MD_CTX_free(h2.ctx)
361+
return nil, newOpenSSLError("EVP_MD_CTX_copy")
362+
}
363+
h2.ctx2 = C.go_openssl_EVP_MD_CTX_new()
364+
if h2.ctx2 == nil {
365+
C.go_openssl_EVP_MD_CTX_free(h2.ctx)
366+
return nil, newOpenSSLError("EVP_MD_CTX_new")
367+
}
368+
runtime.SetFinalizer(h2, (*evpHash).finalize)
345369
}
346-
runtime.SetFinalizer(cloned, (*evpHash).finalize)
347-
return cloned, nil
370+
runtime.KeepAlive(h)
371+
return h2, nil
348372
}
349373

350374
// hashState returns a pointer to the internal hash structure.
@@ -384,6 +408,8 @@ func (d *evpHash) MarshalBinary() ([]byte, error) {
384408
}
385409

386410
func (d *evpHash) AppendBinary(buf []byte) ([]byte, error) {
411+
defer runtime.KeepAlive(d)
412+
d.init()
387413
if !d.alg.marshallable {
388414
return nil, errors.New("openssl: hash state is not marshallable")
389415
}
@@ -419,6 +445,8 @@ func (d *evpHash) AppendBinary(buf []byte) ([]byte, error) {
419445
}
420446

421447
func (d *evpHash) UnmarshalBinary(b []byte) error {
448+
defer runtime.KeepAlive(d)
449+
d.init()
422450
if !d.alg.marshallable {
423451
return errors.New("openssl: hash state is not marshallable")
424452
}

hash_test.go

+7
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,13 @@ func BenchmarkSHA256(b *testing.B) {
391391
}
392392
}
393393

394+
func BenchmarkNewSHA256(b *testing.B) {
395+
b.ReportAllocs()
396+
for i := 0; i < b.N; i++ {
397+
openssl.NewSHA256()
398+
}
399+
}
400+
394401
// stubHash is a hash.Hash implementation that does nothing.
395402
type stubHash struct{}
396403

0 commit comments

Comments
 (0)