Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 95 additions & 12 deletions integrations/slack-gateway/gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"net/url"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -343,6 +344,10 @@ func TestSlackEventRoutesToConversationAndReplies(t *testing.T) {
sync.Mutex
values []string
}
var promptPayload struct {
sync.Mutex
value map[string]any
}
var channelConversationCall struct {
sync.Mutex
authHeaders []string
Expand Down Expand Up @@ -453,6 +458,9 @@ func TestSlackEventRoutesToConversationAndReplies(t *testing.T) {
case "session/load":
_ = conn.WriteJSON(map[string]any{"jsonrpc": "2.0", "id": message["id"], "result": map[string]any{}})
case "session/prompt":
promptPayload.Lock()
promptPayload.value = message
promptPayload.Unlock()
_ = conn.WriteJSON(map[string]any{
"jsonrpc": "2.0",
"method": "session/update",
Expand Down Expand Up @@ -560,6 +568,31 @@ func TestSlackEventRoutesToConversationAndReplies(t *testing.T) {
if channelConversationCall.payloads[1]["externalConversationId"] != "1711387376.000100" {
t.Fatalf("expected alias upsert to persist the bot reply ts, got %#v", channelConversationCall.payloads[1]["externalConversationId"])
}
promptPayload.Lock()
capturedPromptPayload := promptPayload.value
promptPayload.Unlock()
params, ok := capturedPromptPayload["params"].(map[string]any)
if !ok {
t.Fatalf("expected prompt params payload, got %#v", capturedPromptPayload)
}
promptItems, ok := params["prompt"].([]any)
if !ok || len(promptItems) != 1 {
t.Fatalf("expected a single prompt item, got %#v", params["prompt"])
}
item, ok := promptItems[0].(map[string]any)
if !ok {
t.Fatalf("expected prompt item object, got %#v", promptItems[0])
}
text := fmt.Sprint(item["text"])
if !strings.Contains(text, "<spritz-channel-context>") {
t.Fatalf("expected trusted channel context in prompt text, got %q", text)
}
if !strings.Contains(text, "\"actor_user_id\":\"U_1\"") {
t.Fatalf("expected actor metadata in prompt text, got %q", text)
}
if !strings.HasSuffix(text, "\n\nhello") {
t.Fatalf("expected normalized prompt body after metadata block, got %q", text)
}
return
}
time.Sleep(20 * time.Millisecond)
Expand Down Expand Up @@ -1330,6 +1363,56 @@ func TestNormalizeSlackPromptTextPreservesNonGatewayMentions(t *testing.T) {
}
}

func TestBuildSlackPromptTextPrependsTrustedContext(t *testing.T) {
prompt := buildSlackPromptText(
"T_workspace_1",
slackEventInner{
Type: "app_mention",
User: "U_requester",
Text: "<@U_BOT> create a zeno for me",
Channel: "C_channel_1",
ChannelType: "channel",
TS: "1711387375.000100",
},
"U_BOT",
)

const prefix = "<spritz-channel-context>"
if !strings.HasPrefix(prompt, prefix) {
t.Fatalf("expected trusted context prefix, got %q", prompt)
}
endIndex := strings.Index(prompt, "</spritz-channel-context>")
if endIndex < 0 {
t.Fatalf("expected trusted context suffix, got %q", prompt)
}

var payload map[string]any
if err := json.Unmarshal([]byte(prompt[len(prefix):endIndex]), &payload); err != nil {
t.Fatalf("decode prompt context: %v", err)
}
if payload["source"] != "spritz-slack-gateway" {
t.Fatalf("expected source metadata, got %#v", payload["source"])
}
if payload["provider"] != "slack" {
t.Fatalf("expected slack provider, got %#v", payload["provider"])
}
if payload["workspace_id"] != "T_workspace_1" {
t.Fatalf("expected workspace metadata, got %#v", payload["workspace_id"])
}
if payload["actor_user_id"] != "U_requester" {
t.Fatalf("expected actor metadata, got %#v", payload["actor_user_id"])
}
if payload["conversation_id"] != "1711387375.000100" {
t.Fatalf("expected top-level conversation identity, got %#v", payload["conversation_id"])
}
if payload["direct_message"] != false {
t.Fatalf("expected non-DM metadata, got %#v", payload["direct_message"])
}
if !strings.HasSuffix(prompt, "\n\ncreate a zeno for me") {
t.Fatalf("expected normalized user text after metadata block, got %q", prompt)
}
}

func TestShouldIgnoreSlackMessageEventRejectsSystemSubtypes(t *testing.T) {
if !shouldIgnoreSlackMessageEvent(
slackEventInner{Type: "message", Subtype: "channel_join"},
Expand Down Expand Up @@ -1932,7 +2015,7 @@ func TestProcessMessageEventSuppressesRetryAfterSlackReplyFailure(t *testing.T)
}))
defer backend.Close()

var promptCalls int
var promptCalls atomic.Int32
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
spritz := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
Expand Down Expand Up @@ -1979,7 +2062,7 @@ func TestProcessMessageEventSuppressesRetryAfterSlackReplyFailure(t *testing.T)
case "session/load":
_ = conn.WriteJSON(map[string]any{"jsonrpc": "2.0", "id": message["id"], "result": map[string]any{}})
case "session/prompt":
promptCalls++
promptCalls.Add(1)
_ = conn.WriteJSON(map[string]any{
"jsonrpc": "2.0",
"method": "session/update",
Expand Down Expand Up @@ -2040,8 +2123,8 @@ func TestProcessMessageEventSuppressesRetryAfterSlackReplyFailure(t *testing.T)
if err := gateway.processMessageEvent(t.Context(), envelope); err != nil {
t.Fatalf("expected duplicate slack delivery to be suppressed after prompt side effects, got %v", err)
}
if promptCalls != 1 {
t.Fatalf("expected ACP prompt to run once, got %d", promptCalls)
if promptCalls.Load() != 1 {
t.Fatalf("expected ACP prompt to run once, got %d", promptCalls.Load())
}
if postCalls != 1 {
t.Fatalf("expected one slack post attempt before dedupe suppression, got %d", postCalls)
Expand Down Expand Up @@ -2069,17 +2152,17 @@ func TestProcessMessageEventAllowsRetryWhenPromptWasNotDelivered(t *testing.T) {
}))
defer backend.Close()

postCalls := 0
var postCalls atomic.Int32
slackAPI := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/chat.postMessage" {
t.Fatalf("unexpected slack path %s", r.URL.Path)
}
postCalls++
postCalls.Add(1)
writeJSON(w, http.StatusOK, map[string]any{"ok": true})
}))
defer slackAPI.Close()

var promptCalls int
var promptCalls atomic.Int32
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
spritz := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
Expand Down Expand Up @@ -2126,7 +2209,7 @@ func TestProcessMessageEventAllowsRetryWhenPromptWasNotDelivered(t *testing.T) {
case "session/load":
_ = conn.WriteJSON(map[string]any{"jsonrpc": "2.0", "id": message["id"], "result": map[string]any{}})
case "session/prompt":
promptCalls++
promptCalls.Add(1)
return
default:
t.Fatalf("unexpected ACP method %#v", message["method"])
Expand Down Expand Up @@ -2173,11 +2256,11 @@ func TestProcessMessageEventAllowsRetryWhenPromptWasNotDelivered(t *testing.T) {
if err := gateway.processMessageEvent(t.Context(), envelope); err == nil {
t.Fatalf("expected retry to re-attempt prompt delivery")
}
if promptCalls != 2 {
t.Fatalf("expected ACP prompt to run twice after retryable failures, got %d", promptCalls)
if promptCalls.Load() != 2 {
t.Fatalf("expected ACP prompt to run twice after retryable failures, got %d", promptCalls.Load())
}
if postCalls != 0 {
t.Fatalf("expected no slack reply on undelivered prompt failure, got %d posts", postCalls)
if postCalls.Load() != 0 {
t.Fatalf("expected no slack reply on undelivered prompt failure, got %d posts", postCalls.Load())
}
}

Expand Down
45 changes: 44 additions & 1 deletion integrations/slack-gateway/slack_events.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,11 @@ func (g *slackGateway) processMessageEventWithDelivery(
return fmt.Errorf("slack api_app_id mismatch for team %s", envelope.TeamID)
}

promptText := normalizeSlackPromptText(event.Type, event.Text, session.ProviderAuth.BotUserID)
promptText := buildSlackPromptText(
envelope.TeamID,
event,
session.ProviderAuth.BotUserID,
)
if promptText == "" {
return nil
}
Expand Down Expand Up @@ -300,6 +304,45 @@ func normalizeSlackPromptText(eventType, text, botUserID string) string {
return normalized
}

type slackPromptContext struct {
Source string `json:"source"`
Provider string `json:"provider"`
WorkspaceID string `json:"workspace_id"`
ActorUserID string `json:"actor_user_id"`
ChannelID string `json:"channel_id"`
ChannelType string `json:"channel_type,omitempty"`
MessageTS string `json:"message_ts"`
ThreadTS string `json:"thread_ts,omitempty"`
ConversationID string `json:"conversation_id"`
DirectMessage bool `json:"direct_message"`
}

func buildSlackPromptText(teamID string, event slackEventInner, botUserID string) string {
normalized := normalizeSlackPromptText(event.Type, event.Text, botUserID)
if normalized == "" {
return ""
}

payload, err := json.Marshal(
slackPromptContext{
Source: "spritz-slack-gateway",
Provider: slackProvider,
WorkspaceID: strings.TrimSpace(teamID),
ActorUserID: strings.TrimSpace(event.User),
ChannelID: strings.TrimSpace(event.Channel),
ChannelType: strings.TrimSpace(event.ChannelType),
MessageTS: strings.TrimSpace(event.TS),
ThreadTS: strings.TrimSpace(event.ThreadTS),
ConversationID: slackExternalConversationID(event),
DirectMessage: isSlackDirectMessageEvent(event),
},
)
if err != nil {
return normalized
}
return "<spritz-channel-context>" + string(payload) + "</spritz-channel-context>\n\n" + normalized
}

func slackReplyThreadTS(event slackEventInner) string {
if strings.TrimSpace(event.ThreadTS) != "" {
return strings.TrimSpace(event.ThreadTS)
Expand Down
Loading