Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tweak sortcache.SortCache API #52414

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions lib/services/access_request_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,9 @@ func (c *AccessRequestCache) ListMatchingAccessRequests(ctx context.Context, req
return nil, trace.Errorf("access request cache was not configured with sort index %q (this is a bug)", index)
}

traverse := cache.Ascend
accessRequests := cache.Ascend
if req.Descending {
traverse = cache.Descend
accessRequests = cache.Descend
}

limit := int(req.Limit)
Expand All @@ -219,34 +219,30 @@ func (c *AccessRequestCache) ListMatchingAccessRequests(ctx context.Context, req
var rsp proto.ListAccessRequestsResponse
now := time.Now()
var expired int
traverse(index, req.StartKey, "", func(r *types.AccessRequestV3) (continueTraversal bool) {
for r := range accessRequests(index, req.StartKey, "") {
if len(rsp.AccessRequests) == limit {
rsp.NextKey = cache.KeyOf(index, r)
break
}

if !r.Expiry().IsZero() && now.After(r.Expiry()) {
expired++
// skip requests that appear expired. some backends can take up to 48 hours to expired items
// and access requests showing up past their expiry time is particularly confusing.
return true
continue
}
if !req.Filter.Match(r) || !match(r) {
return true
continue
}

c := r.Copy()
cr, ok := c.(*types.AccessRequestV3)
if !ok {
slog.WarnContext(ctx, "clone returned unexpected type (this is a bug)", "expected", logutils.TypeAttr(r), "got", logutils.TypeAttr(c))
return true
continue
}

rsp.AccessRequests = append(rsp.AccessRequests, cr)

// halt when we have Limit+1 items so that we can create a
// correct 'NextKey'.
return len(rsp.AccessRequests) <= limit
})

if len(rsp.AccessRequests) > limit {
rsp.NextKey = cache.KeyOf(index, rsp.AccessRequests[limit])
rsp.AccessRequests = rsp.AccessRequests[:limit]
}

if expired > 0 {
Expand Down
40 changes: 24 additions & 16 deletions lib/services/notifications_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,21 +210,25 @@ func (c *UserNotificationCache) StreamUserNotifications(ctx context.Context, use
startKey = fmt.Sprintf("%s/%s", username, startKey)
}

const limit = 50
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this was hardcoded to 50 before, but there is a PageSize int in ListNotificationsRequest. I wonder why that isn't taken into account anywhere.

Nothing to do in this PR of course, just talking out loud

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that it is being taken into consideration when collecting the notifications from the stream.

var done bool
return stream.PageFunc(func() ([]*notificationsv1.Notification, error) {
if done {
return nil, io.EOF
}
notifications, nextKey := c.primaryCache.DescendPaginated(notificationKey, startKey, endKey, 50)
startKey = nextKey
done = nextKey == ""

// Return copies of the notification to prevent mutating the original.
clonedNotifications := make([]*notificationsv1.Notification, 0, len(notifications))
for _, notification := range notifications {
clonedNotifications = append(clonedNotifications, apiutils.CloneProtoMsg(notification))
notifications := make([]*notificationsv1.Notification, 0, limit)
for n := range c.primaryCache.Descend(notificationKey, startKey, endKey) {
if len(notifications) == limit {
startKey = c.primaryCache.KeyOf(notificationKey, n)
return notifications, nil
}

notifications = append(notifications, apiutils.CloneProtoMsg(n))
}
return clonedNotifications, nil

done = true
return notifications, nil
})
}

Expand Down Expand Up @@ -315,21 +319,25 @@ func (c *GlobalNotificationCache) StreamGlobalNotifications(ctx context.Context,
return stream.Fail[*notificationsv1.GlobalNotification](trace.Errorf("global notifications cache was not configured with index %q (this is a bug)", notificationID))
}

const limit = 50
var done bool
return stream.PageFunc(func() ([]*notificationsv1.GlobalNotification, error) {
if done {
return nil, io.EOF
}
notifications, nextKey := c.primaryCache.DescendPaginated(notificationID, startKey, "", 50)
startKey = nextKey
done = nextKey == ""

// Return copies of the notification to prevent mutating the original.
clonedNotifications := make([]*notificationsv1.GlobalNotification, 0, len(notifications))
for _, notification := range notifications {
clonedNotifications = append(clonedNotifications, apiutils.CloneProtoMsg(notification))
notifications := make([]*notificationsv1.GlobalNotification, 0, limit)
for n := range c.primaryCache.Descend(notificationID, startKey, "") {
if len(notifications) == limit {
startKey = c.primaryCache.KeyOf(notificationID, n)
return notifications, nil
}

notifications = append(notifications, apiutils.CloneProtoMsg(n))
}
return clonedNotifications, nil

done = true
return notifications, nil
})
}

Expand Down
158 changes: 70 additions & 88 deletions lib/utils/sortcache/sortcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package sortcache

import (
"iter"
"sync"

"github.com/google/btree"
Expand Down Expand Up @@ -156,6 +157,20 @@ func (c *SortCache[T]) Put(value T) (evicted int) {
return
}

// Clear wipes all items from the cache and returns
// the cache to its initial empty state.
func (c *SortCache[T]) Clear() {
c.rw.Lock()
defer c.rw.Unlock()

for _, tree := range c.trees {
tree.Clear(true)
}

clear(c.values)
c.counter = 0
}

// Delete deletes the value associated with the specified index/key if one exists.
func (c *SortCache[T]) Delete(index, key string) {
c.rw.Lock()
Expand Down Expand Up @@ -183,111 +198,78 @@ func (c *SortCache[T]) deleteValue(ref uint64) {
}
}

// Ascend iterates the specified range from least to greatest. iteration is terminated early if the
// supplied closure returns false. if this method is being used to read a range, it is strongly recommended
// that all values retained be cloned. any mutation that results in changing a value's index keys will put
// the sort cache into a permanently bad state. empty strings are treated as "open" bounds. passing an empty
// string for both the start and stop bounds iterates all values.
// Ascend iterates the specified range from least to greatest. if this method is being used to read a range,
// it is strongly recommended that all values retained be cloned. any mutation that results in changing a
// value's index keys will put the sort cache into a permanently bad state. empty strings are treated as
// "open" bounds. passing an empty string for both the start and stop bounds iterates all values.
//
// NOTE: ascending ranges are equivalent to the default range logic used across most of teleport, so
// common helpers like `backend.RangeEnd` will function as expected with this method.
func (c *SortCache[T]) Ascend(index, start, stop string, iterator func(T) bool) {
c.rw.RLock()
defer c.rw.RUnlock()

tree, ok := c.trees[index]
if !ok {
return
}

fn := func(ent entry) bool {
return iterator(c.values[ent.ref])
}

// select the appropriate ascend variant based on wether or not
// start/stop points were specified.
switch {
case start == "" && stop == "":
tree.Ascend(fn)
case start == "":
tree.AscendLessThan(entry{key: stop}, fn)
case stop == "":
tree.AscendGreaterOrEqual(entry{key: start}, fn)
default:
tree.AscendRange(entry{key: start}, entry{key: stop}, fn)
}
}

// AscendPaginated returns a page from a range of items in the sortcache in ascending order, and the nextKey.
func (c *SortCache[T]) AscendPaginated(index, startKey string, endKey string, pageSize int) ([]T, string) {
page := make([]T, 0, pageSize+1)
func (c *SortCache[T]) Ascend(index, start, stop string) iter.Seq[T] {
return func(yield func(T) bool) {
c.rw.RLock()
defer c.rw.RUnlock()

tree, ok := c.trees[index]
if !ok {
return
}

c.Ascend(index, startKey, endKey, func(r T) bool {
page = append(page, r)
return len(page) <= pageSize
})
fn := func(ent entry) bool {
return yield(c.values[ent.ref])
}

var nextKey string
if len(page) > pageSize {
nextKey = c.KeyOf(index, page[pageSize])
page = page[:pageSize]
// select the appropriate ascend variant based on wether or not
// start/stop points were specified.
switch {
case start == "" && stop == "":
tree.Ascend(fn)
case start == "":
tree.AscendLessThan(entry{key: stop}, fn)
case stop == "":
tree.AscendGreaterOrEqual(entry{key: start}, fn)
default:
tree.AscendRange(entry{key: start}, entry{key: stop}, fn)
}
}

return page, nextKey
}

// Descend iterates the specified range from greatest to least. iteration is terminated early if the
// supplied closure returns false. if this method is being used to read a range, it is strongly recommended
// that all values retained be cloned. any mutation that results in changing a value's index keys will put
// the sort cache into a permanently bad state. empty strings are treated as "open" bounds. passing an empty
// string for both the start and stop bounds iterates all values.
// Descend iterates the specified range from greatest to least. if this method is being used to read a range,
// it is strongly recommended that all values retained be cloned. any mutation that results in changing a
// value's index keys will put the sort cache into a permanently bad state. empty strings are treated as
// "open" bounds. passing an empty string for both the start and stop bounds iterates all values.
//
// NOTE: descending sort order is the *opposite* of what most teleport range-based logic uses, meaning that
// many common patterns need to be inverted when using this method (e.g. `backend.RangeEnd` actually gives
// you the start position for descending ranges).
func (c *SortCache[T]) Descend(index, start, stop string, iterator func(T) bool) {
c.rw.RLock()
defer c.rw.RUnlock()

tree, ok := c.trees[index]
if !ok {
return
}

fn := func(ent entry) bool {
return iterator(c.values[ent.ref])
}
func (c *SortCache[T]) Descend(index, start, stop string) iter.Seq[T] {
return func(yield func(T) bool) {

// select the appropriate descend variant based on wether or not
// start/stop points were specified.
switch {
case start == "" && stop == "":
tree.Descend(fn)
case start == "":
tree.DescendGreaterThan(entry{key: stop}, fn)
case stop == "":
tree.DescendLessOrEqual(entry{key: start}, fn)
default:
tree.DescendRange(entry{key: start}, entry{key: stop}, fn)
}
}
c.rw.RLock()
defer c.rw.RUnlock()

// DescendPaginated returns a page from a range of items in the sortcache in descending order, and the nextKey.
func (c *SortCache[T]) DescendPaginated(index, startKey string, endKey string, pageSize int) ([]T, string) {
page := make([]T, 0, pageSize+1)
tree, ok := c.trees[index]
if !ok {
return
}

c.Descend(index, startKey, endKey, func(r T) bool {
page = append(page, r)
return len(page) <= pageSize
})
fn := func(ent entry) bool {
return yield(c.values[ent.ref])
}

var nextKey string
if len(page) > pageSize {
nextKey = c.KeyOf(index, page[pageSize])
page = page[:pageSize]
// select the appropriate descend variant based on wether or not
// start/stop points were specified.
switch {
case start == "" && stop == "":
tree.Descend(fn)
case start == "":
tree.DescendGreaterThan(entry{key: stop}, fn)
case stop == "":
tree.DescendLessOrEqual(entry{key: start}, fn)
default:
tree.DescendRange(entry{key: start}, entry{key: stop}, fn)
}
}

return page, nextKey
}

// Len returns the number of values currently stored.
Expand Down
Loading
Loading