Skip to content

Commit

Permalink
Chore: Refactor tool indexing (#1633)
Browse files Browse the repository at this point in the history
* Chore: Refactor tool indexing

Signed-off-by: Daishan Peng <[email protected]>
  • Loading branch information
StrongMonkey authored Feb 11, 2025
1 parent 351d29e commit d8faf4d
Show file tree
Hide file tree
Showing 17 changed files with 539 additions and 327 deletions.
15 changes: 9 additions & 6 deletions apiclient/types/toolreference.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,22 @@ const (
type ToolReferenceManifest struct {
Name string `json:"name"`
ToolType ToolReferenceType `json:"toolType"`
Commit string `json:"commit,omitempty"`
Reference string `json:"reference,omitempty"`
Active bool `json:"active,omitempty"`
}

type ToolReference struct {
Metadata
ToolReferenceManifest
Resolved bool `json:"resolved,omitempty"`
Error string `json:"error,omitempty"`
Builtin bool `json:"builtin,omitempty"`
Description string `json:"description,omitempty"`
Credentials []string `json:"credentials,omitempty"`
Params map[string]string `json:"params,omitempty"`
Resolved bool `json:"resolved,omitempty"`
Error string `json:"error,omitempty"`
Builtin bool `json:"builtin,omitempty"`
Description string `json:"description,omitempty"`
Credentials []string `json:"credentials,omitempty"`
Params map[string]string `json:"params,omitempty"`
Bundle bool `json:"bundle,omitempty"`
BundleToolName string `json:"bundleToolName,omitempty"`
}

type ToolReferenceList List[ToolReference]
38 changes: 23 additions & 15 deletions pkg/api/handlers/toolreferences.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/obot-platform/obot/apiclient/types"
"github.com/obot-platform/obot/pkg/api"
v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1"
"github.com/obot-platform/obot/pkg/tools"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
Expand All @@ -33,10 +34,13 @@ func convertToolReference(toolRef v1.ToolReference) types.ToolReference {
Name: toolRef.Name,
ToolType: toolRef.Spec.Type,
Reference: toolRef.Spec.Reference,
Commit: toolRef.Status.Commit,
},
Builtin: toolRef.Spec.Builtin,
Error: toolRef.Status.Error,
Resolved: toolRef.Generation == toolRef.Status.ObservedGeneration,
Builtin: toolRef.Spec.Builtin,
Bundle: toolRef.Spec.Bundle,
BundleToolName: toolRef.Spec.BundleToolName,
Error: toolRef.Status.Error,
Resolved: toolRef.Generation == toolRef.Status.ObservedGeneration,
}
if toolRef.Spec.Active == nil {
tf.Active = true
Expand Down Expand Up @@ -123,22 +127,22 @@ func (a *ToolReferenceHandler) Create(req api.Context) (err error) {
return apierrors.NewBadRequest(fmt.Sprintf("invalid tool type %s", newToolReference.ToolType))
}

toolRef := &v1.ToolReference{
ObjectMeta: metav1.ObjectMeta{
Name: newToolReference.Name,
Namespace: req.Namespace(),
},
Spec: v1.ToolReferenceSpec{
Type: newToolReference.ToolType,
Reference: newToolReference.Reference,
},
toolRefs, err := tools.ResolveToolReferences(req.Context(), a.gptscript, newToolReference.Name, newToolReference.Reference, false, newToolReference.ToolType)
if err != nil {
return apierrors.NewBadRequest(fmt.Sprintf("failed to resolve tool references for %s: %v", newToolReference.Reference, err))
}

if err = req.Create(toolRef); err != nil {
return err
if len(toolRefs) == 0 {
return apierrors.NewBadRequest(fmt.Sprintf("no tool references found for %s", newToolReference.Reference))
}

return req.Write(convertToolReference(*toolRef))
for _, toolRef := range toolRefs {
if err := req.Create(toolRef); err != nil && !apierrors.IsAlreadyExists(err) {
return apierrors.NewInternalError(fmt.Errorf("failed to create tool reference %s: %w", toolRef.GetName(), err))
}
}

return req.Write(convertToolReference(*toolRefs[0]))
}

func (a *ToolReferenceHandler) Delete(req api.Context) error {
Expand Down Expand Up @@ -166,6 +170,10 @@ func (a *ToolReferenceHandler) Delete(req api.Context) error {
return types.NewErrBadRequest("cannot delete builtin tool reference %s", id)
}

if !toolRef.Spec.Bundle && toolRef.Spec.BundleToolName != "" {
return types.NewErrBadRequest("cannot delete child tool that belongs to a bundle tool")
}

return req.Delete(&v1.ToolReference{
ObjectMeta: metav1.ObjectMeta{
Name: id,
Expand Down
3 changes: 3 additions & 0 deletions pkg/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ func (c *Controller) PreStart(ctx context.Context) error {
if err := data.Data(ctx, c.services.StorageClient, c.services.AgentsDir); err != nil {
return fmt.Errorf("failed to apply data: %w", err)
}
if err := toolreference.MigrateToolNames(ctx, c.services.StorageClient); err != nil {
return fmt.Errorf("failed to migrate tool names: %w", err)
}
return nil
}

Expand Down
83 changes: 83 additions & 0 deletions pkg/controller/handlers/toolreference/migrate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package toolreference

import (
"context"
"errors"

v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1"
kclient "sigs.k8s.io/controller-runtime/pkg/client"
)

var toolMigrations = map[string]string{
"file-summarizer-file-summarizer": "file-summarizer",
}

func MigrateToolNames(ctx context.Context, client kclient.Client) error {
if len(toolMigrations) == 0 {
return nil
}

var agents v1.AgentList
if err := client.List(ctx, &agents); err != nil {
return err
}

var workflows v1.WorkflowList
if err := client.List(ctx, &workflows); err != nil {
return err
}

var threads v1.ThreadList
if err := client.List(ctx, &threads); err != nil {
return err
}

var workflowSteps v1.WorkflowStepList
if err := client.List(ctx, &workflowSteps); err != nil {
return err
}

var objs []kclient.Object
for _, agent := range agents.Items {
objs = append(objs, &agent)
}
for _, workflow := range workflows.Items {
objs = append(objs, &workflow)
}
for _, thread := range threads.Items {
objs = append(objs, &thread)
}
for _, step := range workflowSteps.Items {
objs = append(objs, &step)
}

var tools []string
var errs []error
for _, obj := range objs {
switch o := obj.(type) {
case *v1.Agent:
tools = o.Spec.Manifest.Tools
case *v1.Workflow:
tools = o.Spec.Manifest.Tools
case *v1.Thread:
tools = o.Spec.Manifest.Tools
case *v1.WorkflowStep:
tools = o.Spec.Step.Tools
}
modified := false
for i, tool := range tools {
if newName, shouldMigrate := toolMigrations[tool]; shouldMigrate {
tools[i] = newName
modified = true
}
}

if !modified {
continue
}

errs = append(errs, client.Update(ctx, obj))
}

return errors.Join(errs...)
}
100 changes: 15 additions & 85 deletions pkg/controller/handlers/toolreference/toolreference.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/obot-platform/obot/pkg/gateway/server/dispatcher"
v1 "github.com/obot-platform/obot/pkg/storage/apis/obot.obot.ai/v1"
"github.com/obot-platform/obot/pkg/system"
"github.com/obot-platform/obot/pkg/tools"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/fields"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand All @@ -35,7 +36,6 @@ var jsonErrRegexp = regexp.MustCompile(`\{.*"error":.*}`)

type indexEntry struct {
Reference string `json:"reference,omitempty"`
All bool `json:"all,omitempty"`
}

type index struct {
Expand Down Expand Up @@ -68,93 +68,23 @@ func New(gptClient *gptscript.GPTScript,
}
}

func isValidTool(tool gptscript.Tool) bool {
if tool.MetaData["index"] == "false" {
return false
}
return tool.Name != "" && (tool.Type == "" || tool.Type == "tool")
}

func (h *Handler) toolsToToolReferences(ctx context.Context, toolType types.ToolReferenceType, registryURL string, entries map[string]indexEntry) (result []client.Object) {
annotations := map[string]string{
"obot.obot.ai/timestamp": time.Now().String(),
}
for name, entry := range entries {
if ref, ok := strings.CutPrefix(entry.Reference, "./"); ok {
entry.Reference = registryURL + "/" + ref
}
if !h.supportDocker && name == system.ShellTool {
continue
}

if entry.All {
prg, err := h.gptClient.LoadFile(ctx, "* from "+entry.Reference)
if err != nil {
log.Errorf("Failed to load tool %s: %v", entry.Reference, err)
continue
}
toolRefs, err := tools.ResolveToolReferences(ctx, h.gptClient, name, entry.Reference, true, toolType)
if err != nil {
log.Errorf("Failed to resolve tool references for %s: %v", entry.Reference, err)
continue
}

tool := prg.ToolSet[prg.EntryToolID]
if isValidTool(tool) {
toolName := tool.Name
if tool.MetaData["bundle"] == "true" {
toolName = "bundle"
}
result = append(result, &v1.ToolReference{
ObjectMeta: metav1.ObjectMeta{
Name: normalize(name, toolName),
Namespace: system.DefaultNamespace,
Finalizers: []string{v1.ToolReferenceFinalizer},
Annotations: annotations,
},
Spec: v1.ToolReferenceSpec{
Type: toolType,
Reference: entry.Reference,
Builtin: true,
},
})
}
for _, peerToolID := range tool.LocalTools {
// If this is the entry tool, then we already added it or skipped it above.
if peerToolID == prg.EntryToolID {
continue
}

peerTool := prg.ToolSet[peerToolID]
if isValidTool(peerTool) {
toolName := peerTool.Name
if peerTool.MetaData["bundle"] == "true" {
toolName += "-bundle"
}
result = append(result, &v1.ToolReference{
ObjectMeta: metav1.ObjectMeta{
Name: normalize(name, toolName),
Namespace: system.DefaultNamespace,
Finalizers: []string{v1.ToolReferenceFinalizer},
Annotations: annotations,
},
Spec: v1.ToolReferenceSpec{
Type: toolType,
Reference: fmt.Sprintf("%s from %s", peerTool.Name, entry.Reference),
Builtin: true,
},
})
}
}
} else {
if !h.supportDocker && name == system.ShellTool {
continue
}
result = append(result, &v1.ToolReference{
ObjectMeta: metav1.ObjectMeta{
Name: name,
Namespace: system.DefaultNamespace,
Finalizers: []string{v1.ToolReferenceFinalizer},
Annotations: annotations,
},
Spec: v1.ToolReferenceSpec{
Type: toolType,
Reference: entry.Reference,
Builtin: true,
},
})
for _, toolRef := range toolRefs {
result = append(result, toolRef)
}
}

Expand Down Expand Up @@ -215,10 +145,6 @@ func (h *Handler) readFromRegistry(ctx context.Context, c client.Client) error {
return apply.New(c).WithOwnerSubContext("toolreferences").Apply(ctx, nil, toAdd...)
}

func normalize(names ...string) string {
return strings.ToLower(strings.ReplaceAll(strings.ReplaceAll(strings.Join(names, "-"), " ", "-"), "_", "-"))
}

func (h *Handler) PollRegistries(ctx context.Context, c client.Client) {
if len(h.registryURLs) < 1 {
return
Expand Down Expand Up @@ -251,6 +177,7 @@ func (h *Handler) Populate(req router.Request, resp router.Response) error {
toolRef.Status.LastReferenceCheck = metav1.Now()
toolRef.Status.ObservedGeneration = toolRef.Generation
toolRef.Status.Reference = toolRef.Spec.Reference
toolRef.Status.Commit = ""
toolRef.Status.Tool = nil
toolRef.Status.Error = ""

Expand All @@ -269,6 +196,9 @@ func (h *Handler) Populate(req router.Request, resp router.Response) error {
Metadata: tool.MetaData,
Params: map[string]string{},
}
if tool.Source.Repo != nil {
toolRef.Status.Commit = tool.Source.Repo.Revision
}
if tool.Arguments != nil {
for name, param := range tool.Arguments.Properties {
if param.Value != nil {
Expand Down
1 change: 1 addition & 0 deletions pkg/controller/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ func (c *Controller) setupRoutes() error {
root.Type(&v1.KnowledgeSource{}).HandlerFunc(knowledgesource.Sync)

// ToolReferences
root.Type(&v1.ToolReference{}).HandlerFunc(cleanup.Cleanup)
root.Type(&v1.ToolReference{}).HandlerFunc(toolRef.Populate)
root.Type(&v1.ToolReference{}).HandlerFunc(toolRef.BackPopulateModels)
root.Type(&v1.ToolReference{}).IncludeFinalizing().HandlerFunc(removeOldFinalizers)
Expand Down
19 changes: 14 additions & 5 deletions pkg/storage/apis/obot.obot.ai/v1/toolreference.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,20 @@ func (in *ToolReference) GetColumns() [][]string {
}
}

func (in *ToolReference) DeleteRefs() []Ref {
return []Ref{
{ObjType: new(ToolReference), Name: in.Spec.BundleToolName},
}
}

type ToolReferenceSpec struct {
Type types.ToolReferenceType `json:"type,omitempty"`
Builtin bool `json:"builtin,omitempty"`
Reference string `json:"reference,omitempty"`
Active *bool `json:"active,omitempty"`
ForceRefresh metav1.Time `json:"forceRefresh,omitempty"`
Type types.ToolReferenceType `json:"type,omitempty"`
Builtin bool `json:"builtin,omitempty"`
Reference string `json:"reference,omitempty"`
Active *bool `json:"active,omitempty"`
Bundle bool `json:"bundle,omitempty"`
BundleToolName string `json:"bundleToolName,omitempty"`
ForceRefresh metav1.Time `json:"forceRefresh,omitempty"`
}

type ToolShortDescription struct {
Expand All @@ -70,6 +78,7 @@ type ToolShortDescription struct {

type ToolReferenceStatus struct {
Reference string `json:"reference,omitempty"`
Commit string `json:"commit,omitempty"`
ObservedGeneration int64 `json:"observedGeneration,omitempty"`
LastReferenceCheck metav1.Time `json:"lastReferenceCheck,omitempty"`
Tool *ToolShortDescription `json:"tool,omitempty"`
Expand Down
Loading

0 comments on commit d8faf4d

Please sign in to comment.