From 9afd8b4a9b79c72c120fbc35c9286e4123c333fb Mon Sep 17 00:00:00 2001 From: "kayos@tcp.direct" Date: Wed, 26 Jul 2023 13:08:00 -0700 Subject: [PATCH] Fix: make ishell context values safe to use concurrently --- context.go | 36 +++++++++++++++++++++++++----------- ishell.go | 8 ++++---- 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/context.go b/context.go index 94df6b0..edef852 100644 --- a/context.go +++ b/context.go @@ -1,8 +1,10 @@ package ishell +import "sync" + // Context is an ishell context. It embeds ishell.Actions. type Context struct { - contextValues + *contextValues progressBar ProgressBar err error @@ -30,32 +32,44 @@ func (c *Context) ProgressBar() ProgressBar { } // contextValues is the map for values in the context. -type contextValues map[string]interface{} +type contextValues struct { + vals map[string]interface{} + *sync.RWMutex +} // Get returns the value associated with this context for key, or nil // if no value is associated with key. Successive calls to Get with // the same key returns the same result. -func (c contextValues) Get(key string) interface{} { - return c[key] +func (c *contextValues) Get(key string) interface{} { + c.RLock() + defer c.RUnlock() + return c.vals[key] } // Set sets the key in this context to value. func (c *contextValues) Set(key string, value interface{}) { - if *c == nil { - *c = make(map[string]interface{}) + if c.vals == nil { + c.vals = make(map[string]interface{}) + c.RWMutex = &sync.RWMutex{} } - (*c)[key] = value + c.Lock() + c.vals[key] = value + c.Unlock() } // Del deletes key and its value in this context. -func (c contextValues) Del(key string) { - delete(c, key) +func (c *contextValues) Del(key string) { + c.Lock() + delete(c.vals, key) + c.Unlock() } // Keys returns all keys in the context. -func (c contextValues) Keys() (keys []string) { - for key := range c { +func (c *contextValues) Keys() (keys []string) { + c.RLock() + for key := range c.vals { keys = append(keys, key) } + c.RUnlock() return } diff --git a/ishell.go b/ishell.go index f440095..2aea014 100644 --- a/ishell.go +++ b/ishell.go @@ -669,10 +669,10 @@ func newContext(s *Shell, cmd *Cmd, args []string) *Context { Args: args, RawArgs: s.rawArgs, Cmd: *cmd, - contextValues: func() contextValues { - values := contextValues{} - for k := range s.contextValues { - values[k] = s.contextValues[k] + contextValues: func() *contextValues { + values := &contextValues{} + for k := range s.contextValues.vals { + values.Set(k, s.contextValues.vals[k]) } return values }(),