Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions server/errors.json
Original file line number Diff line number Diff line change
Expand Up @@ -1998,5 +1998,15 @@
"help": "",
"url": "",
"deprecates": ""
},
{
"constant": "JSMessageSchedulesSourceInvalidErr",
"code": 400,
"error_code": 10202,
"description": "message schedules source is invalid",
"comment": "",
"help": "",
"url": "",
"deprecates": ""
}
]
18 changes: 14 additions & 4 deletions server/filestore.go
Original file line number Diff line number Diff line change
Expand Up @@ -6186,10 +6186,16 @@ func (fs *fileStore) runMsgScheduling() {
}
fs.scheduling.running = true

scheduledMsgs := fs.scheduling.getScheduledMessages(func(seq uint64, smv *StoreMsg) *StoreMsg {
sm, _ := fs.msgForSeqLocked(seq, smv, false)
return sm
})
scheduledMsgs := fs.scheduling.getScheduledMessages(
func(seq uint64, smv *StoreMsg) *StoreMsg {
sm, _ := fs.msgForSeqLocked(seq, smv, false)
return sm
},
func(subj string, smv *StoreMsg) *StoreMsg {
sm, _ := fs.loadLastLocked(subj, smv)
return sm
},
)
if len(scheduledMsgs) > 0 {
fs.mu.Unlock()
for _, msg := range scheduledMsgs {
Expand Down Expand Up @@ -7908,7 +7914,11 @@ func (fs *fileStore) LoadMsg(seq uint64, sm *StoreMsg) (*StoreMsg, error) {
func (fs *fileStore) loadLast(subj string, sm *StoreMsg) (lsm *StoreMsg, err error) {
fs.mu.RLock()
defer fs.mu.RUnlock()
return fs.loadLastLocked(subj, sm)
}

// Lock should be held.
func (fs *fileStore) loadLastLocked(subj string, sm *StoreMsg) (lsm *StoreMsg, err error) {
if fs.closed || fs.lmb == nil {
return nil, ErrStoreClosed
}
Expand Down
4 changes: 4 additions & 0 deletions server/jetstream_batching.go
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,10 @@ func checkMsgHeadersPreClusteredProposal(
!IsValidPublishSubject(scheduleTarget) || SubjectsCollide(scheduleTarget, subject) {
apiErr := NewJSMessageSchedulesTargetInvalidError()
return hdr, msg, 0, apiErr, apiErr
} else if scheduleSource := getMessageScheduleSource(hdr); scheduleSource != _EMPTY_ &&
(scheduleSource == scheduleTarget || scheduleSource == subject || !IsValidPublishSubject(scheduleSource)) {
apiErr := NewJSMessageSchedulesSourceInvalidError()
return hdr, msg, 0, apiErr, apiErr
} else {
mset.cfgMu.RLock()
match := slices.ContainsFunc(mset.cfg.Subjects, func(subj string) bool {
Expand Down
87 changes: 87 additions & 0 deletions server/jetstream_cluster_1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9439,6 +9439,93 @@ func TestJetStreamClusterScheduledDelayedMessage(t *testing.T) {
}
}

func TestJetStreamClusterScheduledMessageSubjectSourcing(t *testing.T) {
for _, replicas := range []int{1, 3} {
for _, storage := range []StorageType{FileStorage, MemoryStorage} {
t.Run(fmt.Sprintf("R%d/%s", replicas, storage), func(t *testing.T) {
c := createJetStreamClusterExplicit(t, "R3S", 3)
defer c.shutdown()

nc, js := jsClientConnect(t, c.randomServer())
defer nc.Close()

cfg := &StreamConfig{
Name: "SchedulesEnabled",
Subjects: []string{"foo.*"},
Storage: storage,
Replicas: replicas,
AllowMsgSchedules: true,
AllowMsgTTL: true,
}
_, err := jsStreamCreate(t, nc, cfg)
require_NoError(t, err)

m := nats.NewMsg("foo.data")
m.Header.Set("Header", "Value")
m.Data = []byte("data")

pubAck, err := js.PublishMsg(m)
require_NoError(t, err)
require_Equal(t, pubAck.Sequence, 1)

m = nats.NewMsg("foo.schedule")
m.Header.Set("Nats-Schedule", "@at 1970-01-01T00:00:00Z")
m.Header.Set("Nats-Schedule-Target", "foo.publish")

// Invalid sources include if the subject:
// - matches the schedule/target subject
// - contains wildcard/is not literal
for _, src := range []string{"foo.schedule", "foo.publish", "foo.*", "foo.>"} {
m.Header.Set("Nats-Schedule-Source", src)
_, err = js.PublishMsg(m)
require_Error(t, err, NewJSMessageSchedulesSourceInvalidError())
}

// Now publish using a correct source subject.
m.Header.Set("Nats-Schedule-Source", "foo.data")
pubAck, err = js.PublishMsg(m)
require_NoError(t, err)
require_Equal(t, pubAck.Sequence, 2)

sl := c.streamLeader(globalAccountName, "SchedulesEnabled")
mset, err := sl.globalAccount().lookupStream("SchedulesEnabled")
require_NoError(t, err)

state := mset.state()
require_Equal(t, state.LastSeq, 2)
require_Equal(t, state.Msgs, 2)

// Waiting for the delayed message to be published.
checkFor(t, 2*time.Second, 200*time.Millisecond, func() error {
state = mset.state()
if state.LastSeq != 3 {
return fmt.Errorf("expected last seq 3, got %d", state.LastSeq)
} else if state.Msgs != 2 {
// One is the scheduled message, one is the sourced message.
return fmt.Errorf("expected 2 msgs, got %d", state.Msgs)
}
return nil
})

// Confirm the scheduled message has the correct data.
rsm, err := js.GetLastMsg("SchedulesEnabled", "foo.publish")
require_NoError(t, err)
require_Equal(t, rsm.Sequence, 3)
require_True(t, bytes.Equal(rsm.Data, []byte("data")))
require_Len(t, len(rsm.Header), 3)
require_Equal(t, rsm.Header.Get("Nats-Scheduler"), "foo.schedule")
require_Equal(t, rsm.Header.Get("Nats-Schedule-Next"), "purge")
require_Equal(t, rsm.Header.Get("Header"), "Value")

// Servers should be synced.
checkFor(t, 2*time.Second, 200*time.Millisecond, func() error {
return checkState(t, c, globalAccountName, "SchedulesEnabled")
})
})
}
}
}

func TestJetStreamClusterScheduledDelayedMessageReversedHeaderOrder(t *testing.T) {
for _, replicas := range []int{1, 3} {
for _, storage := range []StorageType{FileStorage, MemoryStorage} {
Expand Down
14 changes: 14 additions & 0 deletions server/jetstream_errors_generated.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,9 @@ const (
// JSMessageSchedulesRollupInvalidErr message schedules invalid rollup
JSMessageSchedulesRollupInvalidErr ErrorIdentifier = 10192

// JSMessageSchedulesSourceInvalidErr message schedules source is invalid
JSMessageSchedulesSourceInvalidErr ErrorIdentifier = 10202

// JSMessageSchedulesTTLInvalidErr message schedules invalid per-message TTL
JSMessageSchedulesTTLInvalidErr ErrorIdentifier = 10191

Expand Down Expand Up @@ -711,6 +714,7 @@ var (
JSMessageSchedulesDisabledErr: {Code: 400, ErrCode: 10188, Description: "message schedules is disabled"},
JSMessageSchedulesPatternInvalidErr: {Code: 400, ErrCode: 10189, Description: "message schedules pattern is invalid"},
JSMessageSchedulesRollupInvalidErr: {Code: 400, ErrCode: 10192, Description: "message schedules invalid rollup"},
JSMessageSchedulesSourceInvalidErr: {Code: 400, ErrCode: 10202, Description: "message schedules source is invalid"},
JSMessageSchedulesTTLInvalidErr: {Code: 400, ErrCode: 10191, Description: "message schedules invalid per-message TTL"},
JSMessageSchedulesTargetInvalidErr: {Code: 400, ErrCode: 10190, Description: "message schedules target is invalid"},
JSMessageTTLDisabledErr: {Code: 400, ErrCode: 10166, Description: "per-message TTL is disabled"},
Expand Down Expand Up @@ -1953,6 +1957,16 @@ func NewJSMessageSchedulesRollupInvalidError(opts ...ErrorOption) *ApiError {
return ApiErrors[JSMessageSchedulesRollupInvalidErr]
}

// NewJSMessageSchedulesSourceInvalidError creates a new JSMessageSchedulesSourceInvalidErr error: "message schedules source is invalid"
func NewJSMessageSchedulesSourceInvalidError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
if ae, ok := eopts.err.(*ApiError); ok {
return ae
}

return ApiErrors[JSMessageSchedulesSourceInvalidErr]
}

// NewJSMessageSchedulesTTLInvalidError creates a new JSMessageSchedulesTTLInvalidErr error: "message schedules invalid per-message TTL"
func NewJSMessageSchedulesTTLInvalidError(opts ...ErrorOption) *ApiError {
eopts := parseOpts(opts)
Expand Down
24 changes: 17 additions & 7 deletions server/memstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -1346,10 +1346,16 @@ func (ms *memStore) runMsgScheduling() {
}
ms.scheduling.running = true

scheduledMsgs := ms.scheduling.getScheduledMessages(func(seq uint64, smv *StoreMsg) *StoreMsg {
sm, _ := ms.loadMsgLocked(seq, smv, false)
return sm
})
scheduledMsgs := ms.scheduling.getScheduledMessages(
func(seq uint64, smv *StoreMsg) *StoreMsg {
sm, _ := ms.loadMsgLocked(seq, smv, false)
return sm
},
func(subj string, smv *StoreMsg) *StoreMsg {
sm, _ := ms.loadLastLocked(subj, smv)
return sm
},
)
if len(scheduledMsgs) > 0 {
ms.mu.Unlock()
for _, msg := range scheduledMsgs {
Expand Down Expand Up @@ -1660,13 +1666,17 @@ func (ms *memStore) loadMsgLocked(seq uint64, smp *StoreMsg, needMSLock bool) (*
// LoadLastMsg will return the last message we have that matches a given subject.
// The subject can be a wildcard.
func (ms *memStore) LoadLastMsg(subject string, smp *StoreMsg) (*StoreMsg, error) {
var sm *StoreMsg
var ok bool

// This needs to be a write lock, as filteredStateLocked can
// mutate the per-subject state.
ms.mu.Lock()
defer ms.mu.Unlock()
return ms.loadLastLocked(subject, smp)
}

// Lock should be held.
func (ms *memStore) loadLastLocked(subject string, smp *StoreMsg) (*StoreMsg, error) {
var sm *StoreMsg
var ok bool

if subject == _EMPTY_ || subject == fwcs {
sm, ok = ms.msgs[ms.state.LastSeq]
Expand Down
16 changes: 12 additions & 4 deletions server/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func (ms *MsgScheduling) resetTimer() {
}
}

func (ms *MsgScheduling) getScheduledMessages(loadMsg func(seq uint64, smv *StoreMsg) *StoreMsg) []*inMsg {
func (ms *MsgScheduling) getScheduledMessages(loadMsg func(seq uint64, smv *StoreMsg) *StoreMsg, loadLast func(subj string, smv *StoreMsg) *StoreMsg) []*inMsg {
var (
smv StoreMsg
sm *StoreMsg
Expand All @@ -155,7 +155,8 @@ func (ms *MsgScheduling) getScheduledMessages(loadMsg func(seq uint64, smv *Stor
if sm != nil {
// If already inflight, don't duplicate a scheduled message. The stream could
// be replicated and the scheduled message could take some time to propagate.
if ms.isInflight(sm.subj) {
subj := sm.subj
if ms.isInflight(subj) {
return false
}
// Validate the contents are correct if not, we just remove it from THW.
Expand All @@ -169,6 +170,13 @@ func (ms *MsgScheduling) getScheduledMessages(loadMsg func(seq uint64, smv *Stor
ms.remove(seq)
return true
}
source := getMessageScheduleSource(sm.hdr)
if source != _EMPTY_ {
if sm = loadLast(source, &smv); sm == nil {
ms.remove(seq)
return true
}
}

// Copy, as this is retrieved directly from storage, and we'll need to keep hold of this for some time.
// And in the case of headers, we'll copy all of them, but make changes.
Expand All @@ -183,13 +191,13 @@ func (ms *MsgScheduling) getScheduledMessages(loadMsg func(seq uint64, smv *Stor
hdr = removeHeaderIfPresent(hdr, JSMsgRollup)

// Add headers for the scheduled message.
hdr = genHeader(hdr, JSScheduler, sm.subj)
hdr = genHeader(hdr, JSScheduler, subj)
hdr = genHeader(hdr, JSScheduleNext, JSScheduleNextPurge) // Purge the schedule message itself.
if ttl != _EMPTY_ {
hdr = genHeader(hdr, JSMessageTTL, ttl)
}
msgs = append(msgs, &inMsg{seq: seq, subj: target, hdr: hdr, msg: msg})
ms.markInflight(sm.subj)
ms.markInflight(subj)
return false
}
ms.remove(seq)
Expand Down
19 changes: 19 additions & 0 deletions server/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ const (
JSSchedulePattern = "Nats-Schedule"
JSScheduleTTL = "Nats-Schedule-TTL"
JSScheduleTarget = "Nats-Schedule-Target"
JSScheduleSource = "Nats-Schedule-Source"
)

// Headers for published KV messages.
Expand Down Expand Up @@ -4839,6 +4840,14 @@ func getMessageScheduleTarget(hdr []byte) string {
return string(getHeader(JSScheduleTarget, hdr))
}

// Fast lookup of message schedule source.
func getMessageScheduleSource(hdr []byte) string {
if len(hdr) == 0 {
return _EMPTY_
}
return string(getHeader(JSScheduleSource, hdr))
}

// Fast lookup of message scheduler.
func getMessageScheduler(hdr []byte) string {
if len(hdr) == 0 {
Expand Down Expand Up @@ -5647,6 +5656,16 @@ func (mset *stream) processJetStreamMsg(subject, reply string, hdr, msg []byte,
outq.sendMsg(reply, b)
}
return apiErr
} else if scheduleSource := getMessageScheduleSource(hdr); scheduleSource != _EMPTY_ &&
(scheduleSource == scheduleTarget || scheduleSource == subject || !IsValidPublishSubject(scheduleSource)) {
apiErr := NewJSMessageSchedulesSourceInvalidError()
if canRespond {
resp.PubAck = &PubAck{Stream: name}
resp.Error = apiErr
b, _ := json.Marshal(resp)
outq.sendMsg(reply, b)
}
return apiErr
} else {
match := slices.ContainsFunc(mset.cfg.Subjects, func(subj string) bool {
return SubjectsCollide(subj, scheduleTarget)
Expand Down