Skip to content

Commit cee1a75

Browse files
authored
enhance: add method for recreating all credentials (#951)
Signed-off-by: Grant Linville <[email protected]>
1 parent 0d20be1 commit cee1a75

File tree

6 files changed

+105
-12
lines changed

6 files changed

+105
-12
lines changed

pkg/credentials/factory.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,17 @@ func (s *StoreFactory) NewStore(credCtxs []string) (CredentialStore, error) {
7272
return nil, err
7373
}
7474
if s.file {
75-
return withOverride{
76-
target: Store{
75+
return &withOverride{
76+
target: &Store{
7777
credCtxs: credCtxs,
7878
cfg: s.cfg,
7979
},
8080
overrides: s.overrides,
8181
credContext: credCtxs,
8282
}, nil
8383
}
84-
return withOverride{
85-
target: Store{
84+
return &withOverride{
85+
target: &Store{
8686
credCtxs: credCtxs,
8787
cfg: s.cfg,
8888
program: s.program,

pkg/credentials/noop.go

+4
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,7 @@ func (s NoopStore) Remove(context.Context, string) error {
2525
func (s NoopStore) List(context.Context) ([]Credential, error) {
2626
return nil, nil
2727
}
28+
29+
func (s NoopStore) RecreateAll(context.Context) error {
30+
return nil
31+
}

pkg/credentials/overrides.go

+4
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,7 @@ func (w withOverride) List(ctx context.Context) ([]Credential, error) {
147147

148148
return creds, nil
149149
}
150+
151+
func (w withOverride) RecreateAll(ctx context.Context) error {
152+
return w.target.RecreateAll(ctx)
153+
}

pkg/credentials/store.go

+75-8
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"regexp"
77
"slices"
8+
"sync"
89

910
"github.com/docker/cli/cli/config/credentials"
1011
"github.com/docker/cli/cli/config/types"
@@ -24,15 +25,20 @@ type CredentialStore interface {
2425
Refresh(ctx context.Context, cred Credential) error
2526
Remove(ctx context.Context, toolName string) error
2627
List(ctx context.Context) ([]Credential, error)
28+
RecreateAll(ctx context.Context) error
2729
}
2830

2931
type Store struct {
30-
credCtxs []string
31-
cfg *config.CLIConfig
32-
program client.ProgramFunc
32+
credCtxs []string
33+
cfg *config.CLIConfig
34+
program client.ProgramFunc
35+
recreateAllLock sync.RWMutex
3336
}
3437

35-
func (s Store) Get(_ context.Context, toolName string) (*Credential, bool, error) {
38+
func (s *Store) Get(_ context.Context, toolName string) (*Credential, bool, error) {
39+
s.recreateAllLock.RLock()
40+
defer s.recreateAllLock.RUnlock()
41+
3642
if len(s.credCtxs) > 0 && s.credCtxs[0] == AllCredentialContexts {
3743
return nil, false, fmt.Errorf("cannot get a credential with context %q", AllCredentialContexts)
3844
}
@@ -80,7 +86,10 @@ func (s Store) Get(_ context.Context, toolName string) (*Credential, bool, error
8086

8187
// Add adds a new credential to the credential store.
8288
// Any context set on the credential object will be overwritten with the first context of the credential store.
83-
func (s Store) Add(_ context.Context, cred Credential) error {
89+
func (s *Store) Add(_ context.Context, cred Credential) error {
90+
s.recreateAllLock.RLock()
91+
defer s.recreateAllLock.RUnlock()
92+
8493
first := first(s.credCtxs)
8594
if first == AllCredentialContexts {
8695
return fmt.Errorf("cannot add a credential with context %q", AllCredentialContexts)
@@ -99,7 +108,10 @@ func (s Store) Add(_ context.Context, cred Credential) error {
99108
}
100109

101110
// Refresh updates an existing credential in the credential store.
102-
func (s Store) Refresh(_ context.Context, cred Credential) error {
111+
func (s *Store) Refresh(_ context.Context, cred Credential) error {
112+
s.recreateAllLock.RLock()
113+
defer s.recreateAllLock.RUnlock()
114+
103115
if !slices.Contains(s.credCtxs, cred.Context) {
104116
return fmt.Errorf("context %q not in list of valid contexts for this credential store", cred.Context)
105117
}
@@ -115,7 +127,10 @@ func (s Store) Refresh(_ context.Context, cred Credential) error {
115127
return store.Store(auth)
116128
}
117129

118-
func (s Store) Remove(_ context.Context, toolName string) error {
130+
func (s *Store) Remove(_ context.Context, toolName string) error {
131+
s.recreateAllLock.RLock()
132+
defer s.recreateAllLock.RUnlock()
133+
119134
first := first(s.credCtxs)
120135
if len(s.credCtxs) > 1 || first == AllCredentialContexts {
121136
return fmt.Errorf("error: credential deletion is not supported when multiple credential contexts are provided")
@@ -129,7 +144,10 @@ func (s Store) Remove(_ context.Context, toolName string) error {
129144
return store.Erase(toolNameWithCtx(toolName, first))
130145
}
131146

132-
func (s Store) List(_ context.Context) ([]Credential, error) {
147+
func (s *Store) List(_ context.Context) ([]Credential, error) {
148+
s.recreateAllLock.RLock()
149+
defer s.recreateAllLock.RUnlock()
150+
133151
store, err := s.getStore()
134152
if err != nil {
135153
return nil, err
@@ -199,6 +217,55 @@ func (s Store) List(_ context.Context) ([]Credential, error) {
199217
return maps.Values(credsByName), nil
200218
}
201219

220+
func (s *Store) RecreateAll(_ context.Context) error {
221+
store, err := s.getStore()
222+
if err != nil {
223+
return err
224+
}
225+
226+
// New credentials might be created after our GetAll, but they will be created with the current encryption configuration,
227+
// so it's okay that they are skipped by this function.
228+
s.recreateAllLock.Lock()
229+
all, err := store.GetAll()
230+
s.recreateAllLock.Unlock()
231+
if err != nil {
232+
return err
233+
}
234+
235+
// Loop through and recreate each individual credential.
236+
for serverAddress := range all {
237+
if err := s.recreateCredential(store, serverAddress); err != nil {
238+
return err
239+
}
240+
}
241+
242+
return nil
243+
}
244+
245+
func (s *Store) recreateCredential(store credentials.Store, serverAddress string) error {
246+
s.recreateAllLock.Lock()
247+
defer s.recreateAllLock.Unlock()
248+
249+
authConfig, err := store.Get(serverAddress)
250+
if err != nil {
251+
if IsCredentialsNotFoundError(err) {
252+
// This can happen if the credential was deleted between the GetAll and the Get by another thread.
253+
return nil
254+
}
255+
return err
256+
}
257+
258+
if err := store.Erase(serverAddress); err != nil {
259+
return err
260+
}
261+
262+
if err := store.Store(authConfig); err != nil {
263+
return err
264+
}
265+
266+
return nil
267+
}
268+
202269
func (s *Store) getStore() (credentials.Store, error) {
203270
if s.program != nil {
204271
return &toolCredentialStore{

pkg/sdkserver/credentials.go

+17
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,23 @@ func (s *server) initializeCredentialStore(_ context.Context, credCtxs []string)
2020
return store, nil
2121
}
2222

23+
func (s *server) recreateAllCredentials(w http.ResponseWriter, r *http.Request) {
24+
logger := gcontext.GetLogger(r.Context())
25+
26+
store, err := s.initializeCredentialStore(r.Context(), []string{credentials.AllCredentialContexts})
27+
if err != nil {
28+
writeError(logger, w, http.StatusInternalServerError, err)
29+
return
30+
}
31+
32+
if err := store.RecreateAll(r.Context()); err != nil {
33+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to recreate all credentials: %w", err))
34+
return
35+
}
36+
37+
writeResponse(logger, w, map[string]any{"stdout": "All credentials recreated successfully"})
38+
}
39+
2340
func (s *server) listCredentials(w http.ResponseWriter, r *http.Request) {
2441
logger := gcontext.GetLogger(r.Context())
2542
req := new(credentialsRequest)

pkg/sdkserver/routes.go

+1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ func (s *server) addRoutes(mux *http.ServeMux) {
7070
mux.HandleFunc("POST /credentials/create", s.createCredential)
7171
mux.HandleFunc("POST /credentials/reveal", s.revealCredential)
7272
mux.HandleFunc("POST /credentials/delete", s.deleteCredential)
73+
mux.HandleFunc("POST /credentials/recreate-all", s.recreateAllCredentials)
7374

7475
mux.HandleFunc("POST /datasets", s.listDatasets)
7576
mux.HandleFunc("POST /datasets/list-elements", s.listDatasetElements)

0 commit comments

Comments
 (0)