From 4419c23f7639d515481e183a5c4c0f78511dfc72 Mon Sep 17 00:00:00 2001 From: Karl Fischer Date: Sun, 21 Feb 2021 19:11:36 +0100 Subject: [PATCH] Concurrency for traversal --- cli/command.go | 31 +++++++++++++++++++++++-------- cli/rm.go | 4 +++- client/client.go | 30 ++++++++++++++++++------------ client/traverse.go | 15 ++++++--------- 4 files changed, 50 insertions(+), 30 deletions(-) diff --git a/cli/command.go b/cli/command.go index 33a89778..cc90c888 100644 --- a/cli/command.go +++ b/cli/command.go @@ -3,6 +3,7 @@ package cli import ( "path/filepath" "strings" + "sync" "github.com/fatih/structs" "github.com/fishi0x01/vsh/client" @@ -74,17 +75,31 @@ func cmdPath(pwd string, arg string) (result string) { return result } +var numWorkers = 5 + func runCommandWithTraverseTwoPaths(client *client.Client, source string, target string, f func(string, string) error) { + c := make(chan string, numWorkers) source = filepath.Clean(source) // remove potential trailing '/' - for _, path := range client.Traverse(source) { - target := strings.Replace(path, source, target, 1) - err := f(path, target) - if err != nil { - return - } + go client.Traverse(source, c) + var wg sync.WaitGroup + for t := 0; t < numWorkers; t++ { + wg.Add(1) + go func() { + defer wg.Done() + for { + path, ok := <-c + if !ok { + return + } + target := strings.Replace(path, source, target, 1) + err := f(path, target) + if err != nil { + return + } + } + }() } - - return + wg.Wait() } func transportSecrets(c *client.Client, source string, target string, transport func(string, string) error) int { diff --git a/cli/rm.go b/cli/rm.go index b2873a75..3a0cade0 100644 --- a/cli/rm.go +++ b/cli/rm.go @@ -72,7 +72,9 @@ func (cmd *RemoveCommand) Run() int { case client.LEAF: cmd.removeSecret(newPwd) case client.NODE: - for _, path := range cmd.client.Traverse(newPwd) { + c := make(chan string, 10) + go cmd.client.Traverse(newPwd, c) + for path := range c { err := cmd.removeSecret(path) if err != nil { return 1 diff --git a/client/client.go b/client/client.go index 17b77923..d23046d4 100644 --- a/client/client.go +++ b/client/client.go @@ -153,30 +153,36 @@ func (client *Client) GetType(absolutePath string) (kind PathKind) { return kind } -// Traverse traverses given absolutePath via DFS and returns sub-paths in array -func (client *Client) Traverse(absolutePath string) (paths []string) { +// Traverse traverses given absolutePath via DFS and pushes paths to given channel +func (client *Client) Traverse(absolutePath string, c chan<- string) { + defer close(c) if client.isTopLevelPath(absolutePath) { - paths = client.topLevelTraverse() + client.topLevelTraverse(c) } else { - paths = client.lowLevelTraverse(normalizedVaultPath(absolutePath)) + client.lowLevelTraverse(normalizedVaultPath(absolutePath), c) } - - return paths } // SubpathsForPath will return an array of absolute paths at or below path -func (client *Client) SubpathsForPath(path string) (filePaths []string, err error) { +func (client *Client) SubpathsForPath(path string) (result []string, err error) { switch t := client.GetType(path); t { case LEAF: - filePaths = append(filePaths, filepath.Clean(path)) + result = []string{filepath.Clean(path)} case NODE: - for _, traversedPath := range client.Traverse(path) { - filePaths = append(filePaths, traversedPath) + c := make(chan string, 10) + go client.Traverse(path, c) + // TODO: this is currently fully sequential to keep old behavior + for { + p, ok := <-c + if !ok { + break + } + result = append(result, p) } default: - return filePaths, fmt.Errorf("Not a valid path for operation: %s", path) + err = fmt.Errorf("Not a valid path for operation: %s", path) } - return filePaths, nil + return result, err } // ClearCache clears the list cache diff --git a/client/traverse.go b/client/traverse.go index 556f0d94..08aa4779 100644 --- a/client/traverse.go +++ b/client/traverse.go @@ -5,15 +5,13 @@ import ( "strings" ) -func (client *Client) topLevelTraverse() (result []string) { +func (client *Client) topLevelTraverse(c chan<- string) { for k := range client.KVBackends { - result = append(result, k) + c <- k } - - return result } -func (client *Client) lowLevelTraverse(path string) (result []string) { +func (client *Client) lowLevelTraverse(path string, c chan<- string) { s, err := client.cache.List(client.getKVMetaDataPath(path)) if err != nil { log.AppTrace("%+v", err) @@ -27,17 +25,16 @@ func (client *Client) lowLevelTraverse(path string) (result []string) { // prevent ambiguous dir/file to be added twice if strings.HasSuffix(val, "/") { // dir - result = append(result, client.lowLevelTraverse(path+"/"+val)...) + client.lowLevelTraverse(path+"/"+val, c) } else { // file leaf := strings.ReplaceAll("/"+path+"/"+val, "//", "/") - result = append(result, leaf) + c <- leaf } } } } else { leaf := strings.ReplaceAll("/"+path, "//", "/") - result = append(result, leaf) + c <- leaf } - return result }