Skip to content

Commit

Permalink
Tweak sortcache.SortCache API
Browse files Browse the repository at this point in the history
- Adds (SortCache) Clear to allow resetting an existing cache
- Converts (SortCache) Ascend and (SortCache) Descend to return
  an iter.Seq[T] instead of taking an iteration function
- Removes (SortCache) AscendPaginated and (SortCache) DescendPaginated

Leveraging iterates aligns the API with the direction that Go is
heading in regards to iteration. Removing of the paginated functions
does put the onus on callers to paginate, however, in the only
cases these were used there are now fewer allocations. AscendPaginated
was unused and DescendPaginated was only used in two places - both
of which would benefit from migrating away from using streams in
favor of iter.Seq.

The new Clear API makes it easier for callers to reset an existing
cache. Without this change the only way to reset a cache would be
to replace the entire cache with a new one. This places the burden
on callers to apply locking and handle concurrent read/write/deletes
and swapping out the cache entirely. That all is eliminated by
Clear handling the internal locking to reset the cache state.
  • Loading branch information
rosstimothy committed Feb 24, 2025
1 parent eadbe2d commit dec6349
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 250 deletions.
26 changes: 11 additions & 15 deletions lib/services/access_request_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,9 @@ func (c *AccessRequestCache) ListMatchingAccessRequests(ctx context.Context, req
return nil, trace.Errorf("access request cache was not configured with sort index %q (this is a bug)", index)
}

traverse := cache.Ascend
accessRequests := cache.Ascend
if req.Descending {
traverse = cache.Descend
accessRequests = cache.Descend
}

limit := int(req.Limit)
Expand All @@ -219,34 +219,30 @@ func (c *AccessRequestCache) ListMatchingAccessRequests(ctx context.Context, req
var rsp proto.ListAccessRequestsResponse
now := time.Now()
var expired int
traverse(index, req.StartKey, "", func(r *types.AccessRequestV3) (continueTraversal bool) {
for r := range accessRequests(index, req.StartKey, "") {
if len(rsp.AccessRequests) == limit {
rsp.NextKey = cache.KeyOf(index, r)
break
}

if !r.Expiry().IsZero() && now.After(r.Expiry()) {
expired++
// skip requests that appear expired. some backends can take up to 48 hours to expired items
// and access requests showing up past their expiry time is particularly confusing.
return true
continue
}
if !req.Filter.Match(r) || !match(r) {
return true
continue
}

c := r.Copy()
cr, ok := c.(*types.AccessRequestV3)
if !ok {
slog.WarnContext(ctx, "clone returned unexpected type (this is a bug)", "expected", logutils.TypeAttr(r), "got", logutils.TypeAttr(c))
return true
continue
}

rsp.AccessRequests = append(rsp.AccessRequests, cr)

// halt when we have Limit+1 items so that we can create a
// correct 'NextKey'.
return len(rsp.AccessRequests) <= limit
})

if len(rsp.AccessRequests) > limit {
rsp.NextKey = cache.KeyOf(index, rsp.AccessRequests[limit])
rsp.AccessRequests = rsp.AccessRequests[:limit]
}

if expired > 0 {
Expand Down
40 changes: 24 additions & 16 deletions lib/services/notifications_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,21 +210,25 @@ func (c *UserNotificationCache) StreamUserNotifications(ctx context.Context, use
startKey = fmt.Sprintf("%s/%s", username, startKey)
}

const limit = 50
var done bool
return stream.PageFunc(func() ([]*notificationsv1.Notification, error) {
if done {
return nil, io.EOF
}
notifications, nextKey := c.primaryCache.DescendPaginated(notificationKey, startKey, endKey, 50)
startKey = nextKey
done = nextKey == ""

// Return copies of the notification to prevent mutating the original.
clonedNotifications := make([]*notificationsv1.Notification, 0, len(notifications))
for _, notification := range notifications {
clonedNotifications = append(clonedNotifications, apiutils.CloneProtoMsg(notification))
notifications := make([]*notificationsv1.Notification, 0, limit)
for n := range c.primaryCache.Descend(notificationKey, startKey, endKey) {
if len(notifications) == limit {
startKey = c.primaryCache.KeyOf(notificationKey, n)
return notifications, nil
}

notifications = append(notifications, apiutils.CloneProtoMsg(n))
}
return clonedNotifications, nil

done = true
return notifications, nil
})
}

Expand Down Expand Up @@ -315,21 +319,25 @@ func (c *GlobalNotificationCache) StreamGlobalNotifications(ctx context.Context,
return stream.Fail[*notificationsv1.GlobalNotification](trace.Errorf("global notifications cache was not configured with index %q (this is a bug)", notificationID))
}

const limit = 50
var done bool
return stream.PageFunc(func() ([]*notificationsv1.GlobalNotification, error) {
if done {
return nil, io.EOF
}
notifications, nextKey := c.primaryCache.DescendPaginated(notificationID, startKey, "", 50)
startKey = nextKey
done = nextKey == ""

// Return copies of the notification to prevent mutating the original.
clonedNotifications := make([]*notificationsv1.GlobalNotification, 0, len(notifications))
for _, notification := range notifications {
clonedNotifications = append(clonedNotifications, apiutils.CloneProtoMsg(notification))
notifications := make([]*notificationsv1.GlobalNotification, 0, limit)
for n := range c.primaryCache.Descend(notificationID, startKey, "") {
if len(notifications) == limit {
startKey = c.primaryCache.KeyOf(notificationID, n)
return notifications, nil
}

notifications = append(notifications, apiutils.CloneProtoMsg(n))
}
return clonedNotifications, nil

done = true
return notifications, nil
})
}

Expand Down
158 changes: 70 additions & 88 deletions lib/utils/sortcache/sortcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package sortcache

import (
"iter"
"sync"

"github.com/google/btree"
Expand Down Expand Up @@ -156,6 +157,20 @@ func (c *SortCache[T]) Put(value T) (evicted int) {
return
}

// Clear wipes all items from the cache and returns
// the cache to its initial empty state.
func (c *SortCache[T]) Clear() {
c.rw.Lock()
defer c.rw.Unlock()

for _, tree := range c.trees {
tree.Clear(true)
}

clear(c.values)
c.counter = 0
}

// Delete deletes the value associated with the specified index/key if one exists.
func (c *SortCache[T]) Delete(index, key string) {
c.rw.Lock()
Expand Down Expand Up @@ -183,111 +198,78 @@ func (c *SortCache[T]) deleteValue(ref uint64) {
}
}

// Ascend iterates the specified range from least to greatest. iteration is terminated early if the
// supplied closure returns false. if this method is being used to read a range, it is strongly recommended
// that all values retained be cloned. any mutation that results in changing a value's index keys will put
// the sort cache into a permanently bad state. empty strings are treated as "open" bounds. passing an empty
// string for both the start and stop bounds iterates all values.
// Ascend iterates the specified range from least to greatest. if this method is being used to read a range,
// it is strongly recommended that all values retained be cloned. any mutation that results in changing a
// value's index keys will put the sort cache into a permanently bad state. empty strings are treated as
// "open" bounds. passing an empty string for both the start and stop bounds iterates all values.
//
// NOTE: ascending ranges are equivalent to the default range logic used across most of teleport, so
// common helpers like `backend.RangeEnd` will function as expected with this method.
func (c *SortCache[T]) Ascend(index, start, stop string, iterator func(T) bool) {
c.rw.RLock()
defer c.rw.RUnlock()

tree, ok := c.trees[index]
if !ok {
return
}

fn := func(ent entry) bool {
return iterator(c.values[ent.ref])
}

// select the appropriate ascend variant based on wether or not
// start/stop points were specified.
switch {
case start == "" && stop == "":
tree.Ascend(fn)
case start == "":
tree.AscendLessThan(entry{key: stop}, fn)
case stop == "":
tree.AscendGreaterOrEqual(entry{key: start}, fn)
default:
tree.AscendRange(entry{key: start}, entry{key: stop}, fn)
}
}

// AscendPaginated returns a page from a range of items in the sortcache in ascending order, and the nextKey.
func (c *SortCache[T]) AscendPaginated(index, startKey string, endKey string, pageSize int) ([]T, string) {
page := make([]T, 0, pageSize+1)
func (c *SortCache[T]) Ascend(index, start, stop string) iter.Seq[T] {
return func(yield func(T) bool) {
c.rw.RLock()
defer c.rw.RUnlock()

tree, ok := c.trees[index]
if !ok {
return
}

c.Ascend(index, startKey, endKey, func(r T) bool {
page = append(page, r)
return len(page) <= pageSize
})
fn := func(ent entry) bool {
return yield(c.values[ent.ref])
}

var nextKey string
if len(page) > pageSize {
nextKey = c.KeyOf(index, page[pageSize])
page = page[:pageSize]
// select the appropriate ascend variant based on wether or not
// start/stop points were specified.
switch {
case start == "" && stop == "":
tree.Ascend(fn)
case start == "":
tree.AscendLessThan(entry{key: stop}, fn)
case stop == "":
tree.AscendGreaterOrEqual(entry{key: start}, fn)
default:
tree.AscendRange(entry{key: start}, entry{key: stop}, fn)
}
}

return page, nextKey
}

// Descend iterates the specified range from greatest to least. iteration is terminated early if the
// supplied closure returns false. if this method is being used to read a range, it is strongly recommended
// that all values retained be cloned. any mutation that results in changing a value's index keys will put
// the sort cache into a permanently bad state. empty strings are treated as "open" bounds. passing an empty
// string for both the start and stop bounds iterates all values.
// Descend iterates the specified range from greatest to least. if this method is being used to read a range,
// it is strongly recommended that all values retained be cloned. any mutation that results in changing a
// value's index keys will put the sort cache into a permanently bad state. empty strings are treated as
// "open" bounds. passing an empty string for both the start and stop bounds iterates all values.
//
// NOTE: descending sort order is the *opposite* of what most teleport range-based logic uses, meaning that
// many common patterns need to be inverted when using this method (e.g. `backend.RangeEnd` actually gives
// you the start position for descending ranges).
func (c *SortCache[T]) Descend(index, start, stop string, iterator func(T) bool) {
c.rw.RLock()
defer c.rw.RUnlock()

tree, ok := c.trees[index]
if !ok {
return
}

fn := func(ent entry) bool {
return iterator(c.values[ent.ref])
}
func (c *SortCache[T]) Descend(index, start, stop string) iter.Seq[T] {
return func(yield func(T) bool) {

// select the appropriate descend variant based on wether or not
// start/stop points were specified.
switch {
case start == "" && stop == "":
tree.Descend(fn)
case start == "":
tree.DescendGreaterThan(entry{key: stop}, fn)
case stop == "":
tree.DescendLessOrEqual(entry{key: start}, fn)
default:
tree.DescendRange(entry{key: start}, entry{key: stop}, fn)
}
}
c.rw.RLock()
defer c.rw.RUnlock()

// DescendPaginated returns a page from a range of items in the sortcache in descending order, and the nextKey.
func (c *SortCache[T]) DescendPaginated(index, startKey string, endKey string, pageSize int) ([]T, string) {
page := make([]T, 0, pageSize+1)
tree, ok := c.trees[index]
if !ok {
return
}

c.Descend(index, startKey, endKey, func(r T) bool {
page = append(page, r)
return len(page) <= pageSize
})
fn := func(ent entry) bool {
return yield(c.values[ent.ref])
}

var nextKey string
if len(page) > pageSize {
nextKey = c.KeyOf(index, page[pageSize])
page = page[:pageSize]
// select the appropriate descend variant based on wether or not
// start/stop points were specified.
switch {
case start == "" && stop == "":
tree.Descend(fn)
case start == "":
tree.DescendGreaterThan(entry{key: stop}, fn)
case stop == "":
tree.DescendLessOrEqual(entry{key: start}, fn)
default:
tree.DescendRange(entry{key: start}, entry{key: stop}, fn)
}
}

return page, nextKey
}

// Len returns the number of values currently stored.
Expand Down
Loading

0 comments on commit dec6349

Please sign in to comment.