Skip to content

Commit

Permalink
Cache accounts.
Browse files Browse the repository at this point in the history
  • Loading branch information
mcdee committed Mar 30, 2024
1 parent c709c7a commit a68e875
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 56 deletions.
185 changes: 133 additions & 52 deletions grpc.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright © 2020 - 2022 Weald Technology Trading
// Copyright © 2020 - 2024 Weald Technology Trading
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
Expand Down Expand Up @@ -134,61 +134,44 @@ func (w *wallet) List(ctx context.Context, accountPath string) ([]e2wtypes.Accou
if resp.GetState() != pb.ResponseState_SUCCEEDED {
return nil, fmt.Errorf("request to list wallet accounts returned state %v", resp.GetState())
}
span.AddEvent("Obtained accounts")

walletPrefixLen := len(w.Name()) + 1
// sem := semaphore.NewWeighted(int64(runtime.GOMAXPROCS(0)))
var wg sync.WaitGroup
accounts := make([]e2wtypes.Account, 0)
for _, account := range resp.GetAccounts() {
pubKey, err := e2types.BLSPublicKeyFromBytes(account.GetPublicKey())
if err != nil {
return nil, errors.Wrap(err, "received invalid public key")
}
var accountsMu sync.Mutex
for _, respAccount := range resp.GetAccounts() {
wg.Add(1)
go func(respAccount *pb.Account, wg *sync.WaitGroup, mu *sync.Mutex) {
defer wg.Done()

var uuid uuid.UUID
err = uuid.UnmarshalBinary(account.GetUuid())
if err != nil {
return nil, errors.Wrap(err, "received invalid uuid")
}
var name string
if strings.Contains(account.GetName(), "/") {
name = account.GetName()[walletPrefixLen:]
} else {
name = account.GetName()
}
account := newAccount(w, uuid, name, pubKey, 1)
accounts = append(accounts, account)
}
for _, account := range resp.GetDistributedAccounts() {
pubKey, err := e2types.BLSPublicKeyFromBytes(account.GetPublicKey())
if err != nil {
return nil, errors.Wrap(err, "received invalid public key")
}
account, err := w.obtainAccount(respAccount)
if err != nil {
w.log.Error().Err(err).Msg("Failed to obtain account")
}

compositePubKey, err := e2types.BLSPublicKeyFromBytes(account.GetCompositePublicKey())
if err != nil {
return nil, errors.Wrap(err, "received invalid composite public key")
}
mu.Lock()
accounts = append(accounts, account)
mu.Unlock()
}(respAccount, &wg, &accountsMu)
}
for _, respAccount := range resp.GetDistributedAccounts() {
wg.Add(1)
go func(respAccount *pb.DistributedAccount, wg *sync.WaitGroup, mu *sync.Mutex) {
defer wg.Done()

var uuid uuid.UUID
err = uuid.UnmarshalBinary(account.GetUuid())
if err != nil {
return nil, errors.Wrap(err, "received invalid uuid")
}
var name string
if strings.Contains(account.GetName(), "/") {
name = account.GetName()[walletPrefixLen:]
} else {
name = account.GetName()
}
participants := make(map[uint64]*Endpoint, len(account.GetParticipants()))
for _, participant := range account.GetParticipants() {
participants[participant.GetId()] = &Endpoint{
host: participant.GetName(),
port: participant.GetPort(),
account, err := w.obtainDistributedAccount(respAccount)
if err != nil {
w.log.Error().Err(err).Msg("Failed to obtain distributed account")
}
}
account := newDistributedAccount(w, uuid, name, pubKey, compositePubKey, account.GetSigningThreshold(), participants, 1)
accounts = append(accounts, account)

mu.Lock()
accounts = append(accounts, account)
mu.Unlock()
}(respAccount, &wg, &accountsMu)
}
wg.Wait()
span.AddEvent("Processed accounts")

return accounts, nil
}
Expand Down Expand Up @@ -801,33 +784,38 @@ func (a *distributedAccount) thresholdSign(ctx context.Context, req *pb.SignRequ
return nil, errors.Wrap(err, fmt.Sprintf("failed to connect to endpoint %v", endpoint))
}
defer release()
span.AddEvent("Obtained connection")

clients[id] = pb.NewSignerClient(conn)
if clients[id] == nil {
return nil, fmt.Errorf("failed to set up signing client for %v", endpoint)
}
}
span.AddEvent("Obtained connections")

type multiSignResponse struct {
type thresholdSignResponse struct {
id uint64
resp *pb.SignResponse
}
respChannel := make(chan *multiSignResponse, len(clients))
respChannel := make(chan *thresholdSignResponse, len(clients))
errChannel := make(chan error, len(clients))

span.AddEvent("Ready to contact servers")
ctx, cancelFunc := context.WithTimeout(ctx, a.wallet.timeout)
defer cancelFunc()
for id, client := range clients {
go func(client pb.SignerClient, id uint64, req *pb.SignRequest) {
resp, err := client.Sign(ctx, req)
span.AddEvent("Received response")
if err != nil {
errChannel <- err
} else {
respChannel <- &multiSignResponse{
respChannel <- &thresholdSignResponse{
id: id,
resp: resp,
}
}
span.AddEvent("Processed response")
}(client, id, req)
}
span.AddEvent("Contacted all servers")
Expand Down Expand Up @@ -1214,3 +1202,96 @@ func blsID(id uint64) *bls.ID {

return &res
}

func (w *wallet) obtainAccount(respAccount *pb.Account) (
e2wtypes.Account,
error,
) {
var key [48]byte
copy(key[:], respAccount.GetPublicKey())
w.accountMapMu.RLock()
account, exists := w.accountMap[key]
w.accountMapMu.RUnlock()
if exists {
return account, nil
}

pubKey, err := e2types.BLSPublicKeyFromBytes(respAccount.GetPublicKey())
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("public key %#x invalid", respAccount.GetPublicKey()))
}

var uuid uuid.UUID
err = uuid.UnmarshalBinary(respAccount.GetUuid())
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("uuid %x invalid", respAccount.GetUuid()))
}

walletPrefixLen := len(w.Name()) + 1
var name string
if strings.Contains(respAccount.GetName(), "/") {
name = respAccount.GetName()[walletPrefixLen:]
} else {
name = respAccount.GetName()
}

account = newAccount(w, uuid, name, pubKey, 1)

w.accountMapMu.Lock()
w.accountMap[key] = account
w.accountMapMu.Unlock()

return account, nil
}

func (w *wallet) obtainDistributedAccount(respAccount *pb.DistributedAccount) (
e2wtypes.Account,
error,
) {
var key [48]byte
copy(key[:], respAccount.GetPublicKey())
w.accountMapMu.RLock()
account, exists := w.accountMap[key]
w.accountMapMu.RUnlock()
if exists {
return account, nil
}

pubKey, err := e2types.BLSPublicKeyFromBytes(respAccount.GetPublicKey())
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("public key %#x invalid", respAccount.GetPublicKey()))
}

compositePubKey, err := e2types.BLSPublicKeyFromBytes(respAccount.GetCompositePublicKey())
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("composite public key %#x invalid", respAccount.GetCompositePublicKey()))
}

var uuid uuid.UUID
err = uuid.UnmarshalBinary(respAccount.GetUuid())
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("uuid %x invalid", respAccount.GetUuid()))
}
var name string
walletPrefixLen := len(w.Name()) + 1
if strings.Contains(respAccount.GetName(), "/") {
name = respAccount.GetName()[walletPrefixLen:]
} else {
name = respAccount.GetName()
}
participants := make(map[uint64]*Endpoint, len(respAccount.GetParticipants()))
for _, participant := range respAccount.GetParticipants() {
participants[participant.GetId()] = &Endpoint{
host: participant.GetName(),
port: participant.GetPort(),
}
}

account = newDistributedAccount(w, uuid, name, pubKey, compositePubKey, respAccount.GetSigningThreshold(), participants, 1)

w.accountMapMu.Lock()
w.accountMap[key] = account
w.accountMapMu.Unlock()

return account, nil
}
10 changes: 10 additions & 0 deletions parameters.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ import (
"time"

"github.com/pkg/errors"
"github.com/rs/zerolog"
"google.golang.org/grpc/credentials"
)

type parameters struct {
logLevel zerolog.Level
monitor Metrics
timeout time.Duration
name string
Expand All @@ -40,6 +42,13 @@ func (f parameterFunc) apply(p *parameters) {
f(p)
}

// WithLogLevel sets the log level for the module.
func WithLogLevel(logLevel zerolog.Level) Parameter {
return parameterFunc(func(p *parameters) {
p.logLevel = logLevel
})
}

// WithMonitor sets the monitor for the wallet.
func WithMonitor(monitor Metrics) Parameter {
return parameterFunc(func(p *parameters) {
Expand Down Expand Up @@ -85,6 +94,7 @@ func WithPoolConnections(connections int32) Parameter {
// parseAndCheckParameters parses and checks parameters to ensure that mandatory parameters are present and correct.
func parseAndCheckParameters(params ...Parameter) (*parameters, error) {
parameters := parameters{
logLevel: zerolog.GlobalLevel(),
timeout: 30 * time.Second,
poolConnections: 128,
monitor: &nullMetrics{},
Expand Down
24 changes: 20 additions & 4 deletions wallet.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright © 2020, 2022 Weald Technology Trading.
// Copyright © 2020 - 2024 Weald Technology Trading.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
Expand All @@ -15,10 +15,13 @@ package dirk

import (
"context"
"sync"
"time"

"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/rs/zerolog"
zerologger "github.com/rs/zerolog/log"
e2wtypes "github.com/wealdtech/go-eth2-wallet-types/v2"
"google.golang.org/grpc/credentials"
)
Expand All @@ -29,20 +32,25 @@ const (

// wallet contains the details of a remote dirk wallet.
type wallet struct {
log zerolog.Logger
id uuid.UUID
name string
version uint
endpoints []*Endpoint
timeout time.Duration
connectionProvider ConnectionProvider

accountMap map[[48]byte]e2wtypes.Account
accountMapMu sync.RWMutex
}

// newWallet creates a new wallet.
func newWallet() *wallet {
return &wallet{
id: uuid.MustParse("00000000-0000-0000-0000-000000000000"),
timeout: 30 * time.Second,
version: 1,
id: uuid.MustParse("00000000-0000-0000-0000-000000000000"),
timeout: 30 * time.Second,
version: 1,
accountMap: make(map[[48]byte]e2wtypes.Account),
}
}

Expand All @@ -58,11 +66,18 @@ func Open(ctx context.Context,
return nil, errors.Wrap(err, "problem with parameters")
}

// Set logging.
log := zerologger.With().Str("service", "wallet").Str("impl", "dirk").Logger()
if parameters.logLevel != log.GetLevel() {
log = log.Level(parameters.logLevel)
}

if err := registerMetrics(ctx, parameters.monitor); err != nil {
return nil, errors.Wrap(err, "failed to register metrics")
}

wallet := newWallet()
wallet.log = log
wallet.name = parameters.name
wallet.timeout = parameters.timeout
wallet.endpoints = make([]*Endpoint, len(parameters.endpoints))
Expand All @@ -77,6 +92,7 @@ func Open(ctx context.Context,
port: parameters.endpoints[i].port,
}
}
wallet.log.Trace().Str("name", wallet.name).Msg("Opened wallet")

return wallet, nil
}
Expand Down

0 comments on commit a68e875

Please sign in to comment.