diff --git a/util.go b/util.go index 28a37b0..4539c46 100644 --- a/util.go +++ b/util.go @@ -129,6 +129,6 @@ func ForEach(m Multiaddr, cb func(c Component) bool) { } func (m Multiaddr) Match(p ...meg.Pattern) (bool, error) { - s := meg.PatternToMatchState(p...) - return meg.Match(s, m) + matcher := meg.PatternToMatcher(p...) + return meg.Match(matcher, m) } diff --git a/x/meg/bench_test.go b/x/meg/bench_test.go new file mode 100644 index 0000000..15bf386 --- /dev/null +++ b/x/meg/bench_test.go @@ -0,0 +1,416 @@ +package meg_test + +import ( + "testing" + + "github.com/multiformats/go-multiaddr" + "github.com/multiformats/go-multiaddr/x/meg" +) + +type preallocatedCapture struct { + certHashes []string + matcher meg.Matcher +} + +func preallocateCapture() *preallocatedCapture { + p := &preallocatedCapture{} + p.matcher = meg.PatternToMatcher( + meg.Or( + meg.Val(multiaddr.P_IP4), + meg.Val(multiaddr.P_IP6), + meg.Val(multiaddr.P_DNS), + ), + meg.Val(multiaddr.P_UDP), + meg.Val(multiaddr.P_WEBRTC_DIRECT), + meg.CaptureZeroOrMore(multiaddr.P_CERTHASH, &p.certHashes), + ) + return p +} + +var webrtcMatchPrealloc *preallocatedCapture + +func (p *preallocatedCapture) IsWebRTCDirectMultiaddr(addr multiaddr.Multiaddr) (bool, int) { + found, _ := meg.Match(p.matcher, addr) + return found, len(p.certHashes) +} + +// IsWebRTCDirectMultiaddr returns whether addr is a /webrtc-direct multiaddr with the count of certhashes +// in addr +func IsWebRTCDirectMultiaddr(addr multiaddr.Multiaddr) (bool, int) { + if webrtcMatchPrealloc == nil { + webrtcMatchPrealloc = preallocateCapture() + } + return webrtcMatchPrealloc.IsWebRTCDirectMultiaddr(addr) +} + +// IsWebRTCDirectMultiaddrLoop returns whether addr is a /webrtc-direct multiaddr with the count of certhashes +// in addr +func IsWebRTCDirectMultiaddrLoop(addr multiaddr.Multiaddr) (bool, int) { + protos := [...]int{multiaddr.P_IP4, multiaddr.P_IP6, multiaddr.P_DNS, multiaddr.P_UDP, multiaddr.P_WEBRTC_DIRECT} + matchProtos := [...][]int{protos[:3], {protos[3]}, {protos[4]}} + certHashCount := 0 + for i, c := range addr { + if i >= len(matchProtos) { + if c.Code() == multiaddr.P_CERTHASH { + certHashCount++ + } else { + return false, 0 + } + } else { + found := false + for _, proto := range matchProtos[i] { + if c.Code() == proto { + found = true + break + } + } + if !found { + return false, 0 + } + } + } + return true, certHashCount +} + +var wtPrealloc *preallocatedCapture + +func isWebTransportMultiaddrPrealloc() *preallocatedCapture { + if wtPrealloc != nil { + return wtPrealloc + } + + p := &preallocatedCapture{} + var dnsName string + var ip4Addr string + var ip6Addr string + var udpPort string + var sni string + p.matcher = meg.PatternToMatcher( + meg.Or( + meg.CaptureVal(multiaddr.P_IP4, &ip4Addr), + meg.CaptureVal(multiaddr.P_IP6, &ip6Addr), + meg.CaptureVal(multiaddr.P_DNS4, &dnsName), + meg.CaptureVal(multiaddr.P_DNS6, &dnsName), + meg.CaptureVal(multiaddr.P_DNS, &dnsName), + ), + meg.CaptureVal(multiaddr.P_UDP, &udpPort), + meg.Val(multiaddr.P_QUIC_V1), + meg.Optional( + meg.CaptureVal(multiaddr.P_SNI, &sni), + ), + meg.Val(multiaddr.P_WEBTRANSPORT), + meg.CaptureZeroOrMore(multiaddr.P_CERTHASH, &p.certHashes), + ) + wtPrealloc = p + return p +} + +func IsWebTransportMultiaddrPrealloc(m multiaddr.Multiaddr) (bool, int) { + p := isWebTransportMultiaddrPrealloc() + found, _ := meg.Match(p.matcher, m) + return found, len(p.certHashes) +} + +func IsWebTransportMultiaddr(m multiaddr.Multiaddr) (bool, int) { + var dnsName string + var ip4Addr string + var ip6Addr string + var udpPort string + var sni string + var certHashesStr []string + matched, _ := m.Match( + meg.Or( + meg.CaptureVal(multiaddr.P_IP4, &ip4Addr), + meg.CaptureVal(multiaddr.P_IP6, &ip6Addr), + meg.CaptureVal(multiaddr.P_DNS4, &dnsName), + meg.CaptureVal(multiaddr.P_DNS6, &dnsName), + meg.CaptureVal(multiaddr.P_DNS, &dnsName), + ), + meg.CaptureVal(multiaddr.P_UDP, &udpPort), + meg.Val(multiaddr.P_QUIC_V1), + meg.Optional( + meg.CaptureVal(multiaddr.P_SNI, &sni), + ), + meg.Val(multiaddr.P_WEBTRANSPORT), + meg.CaptureZeroOrMore(multiaddr.P_CERTHASH, &certHashesStr), + ) + if !matched { + return false, 0 + } + return true, len(certHashesStr) +} + +func IsWebTransportMultiaddrNoCapture(m multiaddr.Multiaddr) (bool, int) { + matched, _ := m.Match( + meg.Or( + meg.Val(multiaddr.P_IP4), + meg.Val(multiaddr.P_IP6), + meg.Val(multiaddr.P_DNS4), + meg.Val(multiaddr.P_DNS6), + meg.Val(multiaddr.P_DNS), + ), + meg.Val(multiaddr.P_UDP), + meg.Val(multiaddr.P_QUIC_V1), + meg.Optional( + meg.Val(multiaddr.P_SNI), + ), + meg.Val(multiaddr.P_WEBTRANSPORT), + meg.ZeroOrMore(multiaddr.P_CERTHASH), + ) + if !matched { + return false, 0 + } + return true, 0 +} + +func IsWebTransportMultiaddrLoop(m multiaddr.Multiaddr) (bool, int) { + var ip4Addr string + var ip6Addr string + var dnsName string + var udpPort string + var sni string + + // Expected pattern: + // 0: one of: P_IP4, P_IP6, P_DNS4, P_DNS6, P_DNS + // 1: P_UDP + // 2: P_QUIC_V1 + // 3: optional P_SNI (if present) + // Next: P_WEBTRANSPORT + // Trailing: zero or more P_CERTHASH + + // Check minimum length (at least without SNI: 4 components) + if len(m) < 4 { + return false, 0 + } + + idx := 0 + + // Component 0: Must be one of IP or DNS protocols. + switch m[idx].Code() { + case multiaddr.P_IP4: + ip4Addr = m[idx].String() + case multiaddr.P_IP6: + ip6Addr = m[idx].String() + case multiaddr.P_DNS4, multiaddr.P_DNS6, multiaddr.P_DNS: + dnsName = m[idx].String() + default: + return false, 0 + } + idx++ + + // Component 1: Must be UDP. + if idx >= len(m) || m[idx].Code() != multiaddr.P_UDP { + return false, 0 + } + udpPort = m[idx].String() + idx++ + + // Component 2: Must be QUIC_V1. + if idx >= len(m) || m[idx].Code() != multiaddr.P_QUIC_V1 { + return false, 0 + } + idx++ + + // Optional component: SNI. + if idx < len(m) && m[idx].Code() == multiaddr.P_SNI { + sni = m[idx].String() + idx++ + } + + // Next component: Must be WEBTRANSPORT. + if idx >= len(m) || m[idx].Code() != multiaddr.P_WEBTRANSPORT { + return false, 0 + } + idx++ + + // All remaining components must be CERTHASH. + certHashCount := 0 + for ; idx < len(m); idx++ { + if m[idx].Code() != multiaddr.P_CERTHASH { + return false, 0 + } + _ = m[idx].String() + certHashCount++ + } + + _ = ip4Addr + _ = ip6Addr + _ = dnsName + _ = udpPort + _ = sni + + return true, certHashCount +} + +func IsWebTransportMultiaddrLoopNoCapture(m multiaddr.Multiaddr) (bool, int) { + // Expected pattern: + // 0: one of: P_IP4, P_IP6, P_DNS4, P_DNS6, P_DNS + // 1: P_UDP + // 2: P_QUIC_V1 + // 3: optional P_SNI (if present) + // Next: P_WEBTRANSPORT + // Trailing: zero or more P_CERTHASH + + // Check minimum length (at least without SNI: 4 components) + if len(m) < 4 { + return false, 0 + } + + idx := 0 + + // Component 0: Must be one of IP or DNS protocols. + switch m[idx].Code() { + case multiaddr.P_IP4: + case multiaddr.P_IP6: + case multiaddr.P_DNS4, multiaddr.P_DNS6, multiaddr.P_DNS: + default: + return false, 0 + } + idx++ + + // Component 1: Must be UDP. + if idx >= len(m) || m[idx].Code() != multiaddr.P_UDP { + return false, 0 + } + idx++ + + // Component 2: Must be QUIC_V1. + if idx >= len(m) || m[idx].Code() != multiaddr.P_QUIC_V1 { + return false, 0 + } + idx++ + + // Optional component: SNI. + if idx < len(m) && m[idx].Code() == multiaddr.P_SNI { + idx++ + } + + // Next component: Must be WEBTRANSPORT. + if idx >= len(m) || m[idx].Code() != multiaddr.P_WEBTRANSPORT { + return false, 0 + } + idx++ + + // All remaining components must be CERTHASH. + for ; idx < len(m); idx++ { + if m[idx].Code() != multiaddr.P_CERTHASH { + return false, 0 + } + _ = m[idx].String() + } + + return true, 0 +} + +func BenchmarkIsWebTransportMultiaddrPrealloc(b *testing.B) { + addr := multiaddr.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1/sni/example.com/webtransport") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + isWT, count := IsWebTransportMultiaddrPrealloc(addr) + if !isWT || count != 0 { + b.Fatal("unexpected result") + } + } +} + +func BenchmarkIsWebTransportMultiaddrNoCapturePrealloc(b *testing.B) { + addr := multiaddr.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1/sni/example.com/webtransport") + + wtPreallocNoCapture := meg.PatternToMatcher( + meg.Or( + meg.Val(multiaddr.P_IP4), + meg.Val(multiaddr.P_IP6), + meg.Val(multiaddr.P_DNS4), + meg.Val(multiaddr.P_DNS6), + meg.Val(multiaddr.P_DNS), + ), + meg.Val(multiaddr.P_UDP), + meg.Val(multiaddr.P_QUIC_V1), + meg.Optional( + meg.Val(multiaddr.P_SNI), + ), + meg.Val(multiaddr.P_WEBTRANSPORT), + meg.ZeroOrMore(multiaddr.P_CERTHASH), + ) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + isWT, _ := meg.Match(wtPreallocNoCapture, addr) + if !isWT { + b.Fatal("unexpected result") + } + } +} + +func BenchmarkIsWebTransportMultiaddrNoCapture(b *testing.B) { + addr := multiaddr.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1/sni/example.com/webtransport") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + isWT, count := IsWebTransportMultiaddrNoCapture(addr) + if !isWT || count != 0 { + b.Fatal("unexpected result") + } + } +} + +func BenchmarkIsWebTransportMultiaddr(b *testing.B) { + addr := multiaddr.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1/sni/example.com/webtransport") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + isWT, count := IsWebTransportMultiaddr(addr) + if !isWT || count != 0 { + b.Fatal("unexpected result") + } + } +} + +func BenchmarkIsWebTransportMultiaddrLoop(b *testing.B) { + addr := multiaddr.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1/sni/example.com/webtransport") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + isWT, count := IsWebTransportMultiaddrLoop(addr) + if !isWT || count != 0 { + b.Fatal("unexpected result") + } + } +} + +func BenchmarkIsWebTransportMultiaddrLoopNoCapture(b *testing.B) { + addr := multiaddr.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1/sni/example.com/webtransport") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + isWT, count := IsWebTransportMultiaddrLoopNoCapture(addr) + if !isWT || count != 0 { + b.Fatal("unexpected result") + } + } +} + +func BenchmarkIsWebRTCDirectMultiaddr(b *testing.B) { + addr := multiaddr.StringCast("/ip4/1.2.3.4/udp/1234/webrtc-direct/") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + isWebRTC, count := IsWebRTCDirectMultiaddr(addr) + if !isWebRTC || count != 0 { + b.Fatal("unexpected result") + } + } +} + +func BenchmarkIsWebRTCDirectMultiaddrLoop(b *testing.B) { + addr := multiaddr.StringCast("/ip4/1.2.3.4/udp/1234/webrtc-direct/") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + isWebRTC, count := IsWebRTCDirectMultiaddrLoop(addr) + if !isWebRTC || count != 0 { + b.Fatal("unexpected result") + } + } +} diff --git a/x/meg/meg.go b/x/meg/meg.go index 952111f..ccc2275 100644 --- a/x/meg/meg.go +++ b/x/meg/meg.go @@ -7,49 +7,48 @@ package meg import ( "fmt" - "slices" ) -type stateKind uint8 +type stateKind = int const ( - matchCode stateKind = iota - split - done + done stateKind = (iota * -1) - 1 + // split anything else that is negative ) // MatchState is the Thompson NFA for a regular expression. type MatchState struct { - capture captureFunc - next *MatchState - nextSplit *MatchState - - kind stateKind - generation int - code int + capture captureFunc + // next is is the index of the next state. in the MatchState array. + next int + // If codeOrKind is negative, it is a kind. + // If it is negative, but not a `done`, then it is the index to the next split. + // This is done to keep the `MatchState` struct small and cache friendly. + codeOrKind int } -type captureFunc *func(string) error -type captureMap map[captureFunc][]string +type captureFunc func(string) error -func (cm captureMap) clone() captureMap { - if cm == nil { - return nil - } - out := make(captureMap, len(cm)) - for k, v := range cm { - out[k] = slices.Clone(v) - } - return out +// capture is a linked list of capture funcs with values. +type capture struct { + f captureFunc + v string + prev *capture } type statesAndCaptures struct { - states []*MatchState - captures []captureMap + states []int + captures []*capture } -func (s *MatchState) String() string { - return fmt.Sprintf("state{kind: %d, generation: %d, code: %d}", s.kind, s.generation, s.code) +func (s MatchState) String() string { + if s.codeOrKind == done { + return "done" + } + if s.codeOrKind < done { + return fmt.Sprintf("split{left: %d, right: %d}", s.next, restoreSplitIdx(s.codeOrKind)) + } + return fmt.Sprintf("match{code: %d, next: %d}", s.codeOrKind, s.next) } type Matchable interface { @@ -60,52 +59,80 @@ type Matchable interface { // Match returns whether the given Components match the Pattern defined in MatchState. // Errors are used to communicate capture errors. // If the error is non-nil the returned bool will be false. -func Match[S ~[]T, T Matchable](s *MatchState, components S) (bool, error) { - listGeneration := s.generation + 1 // Start at the last generation + 1 - defer func() { s.generation = listGeneration }() // In case we reuse this state, store our highest generation number +func Match[S ~[]T, T Matchable](matcher Matcher, components S) (bool, error) { + states := matcher.states + startStateIdx := matcher.startIdx + + // Fast case for a small number of states (<128) + // Avoids allocation of a slice for the visitedBitSet. + stackBitSet := [2]uint64{} + visitedBitSet := stackBitSet[:] + if len(states) > 128 { + visitedBitSet = make([]uint64, (len(states)+63)/64) + } currentStates := statesAndCaptures{ - states: make([]*MatchState, 0, 16), - captures: make([]captureMap, 0, 16), + states: make([]int, 0, 16), + captures: make([]*capture, 0, 16), } nextStates := statesAndCaptures{ - states: make([]*MatchState, 0, 16), - captures: make([]captureMap, 0, 16), + states: make([]int, 0, 16), + captures: make([]*capture, 0, 16), } - currentStates = appendState(currentStates, s, nil, listGeneration) + currentStates = appendState(currentStates, states, startStateIdx, nil, visitedBitSet) for _, c := range components { + clear(visitedBitSet) if len(currentStates.states) == 0 { return false, nil } - for i, s := range currentStates.states { - if s.kind == matchCode && s.code == c.Code() { + for i, stateIndex := range currentStates.states { + s := states[stateIndex] + if s.codeOrKind >= 0 && s.codeOrKind == c.Code() { cm := currentStates.captures[i] if s.capture != nil { + next := &capture{ + f: s.capture, + v: c.Value(), + } if cm == nil { - cm = make(captureMap) - currentStates.captures[i] = cm + cm = next + } else { + next.prev = cm + cm = next } - cm[s.capture] = append(cm[s.capture], c.Value()) + currentStates.captures[i] = cm } - nextStates = appendState(nextStates, s.next, currentStates.captures[i], listGeneration) + nextStates = appendState(nextStates, states, s.next, cm, visitedBitSet) } } currentStates, nextStates = nextStates, currentStates nextStates.states = nextStates.states[:0] nextStates.captures = nextStates.captures[:0] - listGeneration++ } - for i, s := range currentStates.states { - if s.kind == done { + for i, stateIndex := range currentStates.states { + s := states[stateIndex] + if s.codeOrKind == done { + // We found a complete path. Run the captures now - for f, v := range currentStates.captures[i] { - for _, s := range v { - if err := (*f)(s); err != nil { - return false, err - } + c := currentStates.captures[i] + + // Flip the order of the captures because we see captures from right + // to left, but users expect them left to right. + type captureWithVal struct { + f captureFunc + v string + } + reversedCaptures := make([]captureWithVal, 0, 16) + for c != nil { + reversedCaptures = append(reversedCaptures, captureWithVal{c.f, c.v}) + c = c.prev + } + for i := len(reversedCaptures) - 1; i >= 0; i-- { + if err := reversedCaptures[i].f(reversedCaptures[i].v); err != nil { + return false, err } } return true, nil @@ -114,17 +141,59 @@ func Match[S ~[]T, T Matchable](s *MatchState, components S) (bool, error) { return false, nil } -func appendState(arr statesAndCaptures, s *MatchState, c captureMap, listGeneration int) statesAndCaptures { - if s == nil || s.generation == listGeneration { - return arr +// appendState is a non-recursive way of appending states to statesAndCaptures. +// If a state is a split, both branches are appended to statesAndCaptures. +func appendState(arr statesAndCaptures, states []MatchState, stateIndex int, c *capture, visitedBitSet []uint64) statesAndCaptures { + // Local struct to hold state index and the associated capture pointer. + type task struct { + idx int + cap *capture } - s.generation = listGeneration - if s.kind == split { - arr = appendState(arr, s.next, c, listGeneration) - arr = appendState(arr, s.nextSplit, c.clone(), listGeneration) - } else { - arr.states = append(arr.states, s) - arr.captures = append(arr.captures, c) + + // Initialize the stack with the starting task. + stack := make([]task, 0, 16) + stack = append(stack, task{stateIndex, c}) + + // Process the stack until empty. + for len(stack) > 0 { + // Pop the last element (LIFO order). + n := len(stack) - 1 + t := stack[n] + stack = stack[:n] + + // If the state index is out of bounds, skip. + if t.idx >= len(states) { + continue + } + s := states[t.idx] + + // Check if this state has already been visited. + if visitedBitSet[t.idx/64]&(1<<(t.idx%64)) != 0 { + continue + } + // Mark the state as visited. + visitedBitSet[t.idx/64] |= 1 << (t.idx % 64) + + // If it's a split state (the value is less than done) then push both branches. + if s.codeOrKind < done { + // Get the second branch from the split. + splitIdx := restoreSplitIdx(s.codeOrKind) + // To preserve order (s.next processed first), push the split branch first. + stack = append(stack, task{splitIdx, t.cap}) + stack = append(stack, task{s.next, t.cap}) + } else { + // Otherwise, it's a valid final state -- append it. + arr.states = append(arr.states, t.idx) + arr.captures = append(arr.captures, t.cap) + } } return arr } + +func storeSplitIdx(codeOrKind int) int { + return (codeOrKind + 2) * -1 +} + +func restoreSplitIdx(splitIdx int) int { + return (splitIdx * -1) - 2 +} diff --git a/x/meg/meg_test.go b/x/meg/meg_test.go index b47ba02..e0265c4 100644 --- a/x/meg/meg_test.go +++ b/x/meg/meg_test.go @@ -26,7 +26,7 @@ var _ Matchable = codeAndValue{} func TestSimple(t *testing.T) { type testCase struct { - pattern *MatchState + pattern Matcher skipQuickCheck bool shouldMatch [][]int shouldNotMatch [][]int @@ -34,14 +34,24 @@ func TestSimple(t *testing.T) { testCases := []testCase{ { - pattern: PatternToMatchState(Val(0), Val(1)), + pattern: PatternToMatcher(Val(0), Val(1)), shouldMatch: [][]int{{0, 1}}, shouldNotMatch: [][]int{ {0}, {0, 0}, {0, 1, 0}, - }}, { - pattern: PatternToMatchState(Val(0), Val(1), Optional(Val(2))), + }, + }, + { + pattern: PatternToMatcher(Optional(Val(1))), + shouldMatch: [][]int{ + {1}, + {}, + }, + shouldNotMatch: [][]int{{0}}, + }, + { + pattern: PatternToMatcher(Val(0), Val(1), Optional(Val(2))), shouldMatch: [][]int{ {0, 1, 2}, {0, 1}, @@ -52,7 +62,7 @@ func TestSimple(t *testing.T) { {0, 1, 0}, {0, 1, 2, 0}, }}, { - pattern: PatternToMatchState(Val(0), Val(1), OneOrMore(2)), + pattern: PatternToMatcher(Val(0), Val(1), OneOrMore(2)), skipQuickCheck: true, shouldMatch: [][]int{ {0, 1, 2, 2, 2, 2}, @@ -70,13 +80,13 @@ func TestSimple(t *testing.T) { for i, tc := range testCases { for _, m := range tc.shouldMatch { - if matches, _ := Match(tc.pattern, codesToCodeAndValue(m)); !matches { - t.Fatalf("failed to match %v with %s. idx=%d", m, tc.pattern, i) + if matches, err := Match(tc.pattern, codesToCodeAndValue(m)); !matches { + t.Fatalf("failed to match %v with %v. idx=%d. err=%v", m, tc.pattern, i, err) } } for _, m := range tc.shouldNotMatch { if matches, _ := Match(tc.pattern, codesToCodeAndValue(m)); matches { - t.Fatalf("failed to not match %v with %s. idx=%d", m, tc.pattern, i) + t.Fatalf("failed to not match %v with %v. idx=%d", m, tc.pattern, i) } } if tc.skipQuickCheck { @@ -98,7 +108,7 @@ func TestSimple(t *testing.T) { } func TestCapture(t *testing.T) { - type setupStateAndAssert func() (*MatchState, func()) + type setupStateAndAssert func() (Matcher, func()) type testCase struct { setup setupStateAndAssert parts []codeAndValue @@ -107,9 +117,9 @@ func TestCapture(t *testing.T) { testCases := []testCase{ { - setup: func() (*MatchState, func()) { + setup: func() (Matcher, func()) { var code0str string - return PatternToMatchState(CaptureVal(0, &code0str), Val(1)), func() { + return PatternToMatcher(CaptureVal(0, &code0str), Val(1)), func() { if code0str != "hello" { panic("unexpected value") } @@ -118,9 +128,9 @@ func TestCapture(t *testing.T) { parts: []codeAndValue{{0, "hello"}, {1, "world"}}, }, { - setup: func() (*MatchState, func()) { + setup: func() (Matcher, func()) { var code0strs []string - return PatternToMatchState(CaptureOneOrMore(0, &code0strs), Val(1)), func() { + return PatternToMatcher(CaptureOneOrMore(0, &code0strs), Val(1)), func() { if code0strs[0] != "hello" { panic("unexpected value") } @@ -137,7 +147,7 @@ func TestCapture(t *testing.T) { for _, tc := range testCases { state, assert := tc.setup() if matches, _ := Match(state, tc.parts); !matches { - t.Fatalf("failed to match %v with %s", tc.parts, state) + t.Fatalf("failed to match %v with %v", tc.parts, state) } assert() } @@ -161,7 +171,7 @@ func bytesToCodeAndValue(codes []byte) []codeAndValue { // FuzzMatchesRegexpBehavior fuzz tests the expression matcher by comparing it to the behavior of the regexp package. func FuzzMatchesRegexpBehavior(f *testing.F) { - bytesToRegexpAndPattern := func(exp []byte) ([]byte, []Pattern) { + bytesToRegexpAndPattern := func(exp []byte) (string, []Pattern) { if len(exp) < 3 { panic("regexp too short") } @@ -197,7 +207,7 @@ func FuzzMatchesRegexpBehavior(f *testing.F) { } } - return exp, pattern + return string(exp), pattern } simplifyB := func(buf []byte) []byte { @@ -218,7 +228,7 @@ func FuzzMatchesRegexpBehavior(f *testing.F) { // Malformed regex. Ignore return } - p := PatternToMatchState(pattern...) + p := PatternToMatcher(pattern...) otherMatched, _ := Match(p, bytesToCodeAndValue(corpus)) if otherMatched != matched { t.Log("regexp", string(regexpPattern)) diff --git a/x/meg/sugar.go b/x/meg/sugar.go index 369a315..ee961cc 100644 --- a/x/meg/sugar.go +++ b/x/meg/sugar.go @@ -2,38 +2,69 @@ package meg import ( "errors" + "fmt" + "strconv" + "strings" ) -type Pattern = func(next *MatchState) *MatchState +// Pattern is essentially a curried MatchState. +// Given the slice of current MatchStates and a handle (int index) to the next +// MatchState, it returns a handle to the inserted MatchState. +type Pattern = func(states *[]MatchState, nextIdx int) int -func PatternToMatchState(states ...Pattern) *MatchState { - nextState := &MatchState{kind: done} - for i := len(states) - 1; i >= 0; i-- { - nextState = states[i](nextState) +type Matcher struct { + states []MatchState + startIdx int +} + +func (s Matcher) String() string { + states := make([]string, len(s.states)) + for i, state := range s.states { + states[i] = state.String() + "@" + strconv.Itoa(i) + } + return fmt.Sprintf("RootMatchState{states: [%s], startIdx: %d}", strings.Join(states, ", "), s.startIdx) +} + +func PatternToMatcher(patterns ...Pattern) Matcher { + // Preallocate a slice to hold the MatchStates. + // Avoids small allocations for each pattern. + // The number is chosen experimentally. It is subject to change. + states := make([]MatchState, 0, len(patterns)*3) + // Append the done state. + states = append(states, MatchState{codeOrKind: done}) + nextIdx := len(states) - 1 + // Build the chain by composing patterns from right to left. + for i := len(patterns) - 1; i >= 0; i-- { + nextIdx = patterns[i](&states, nextIdx) } - return nextState + return Matcher{states: states, startIdx: nextIdx} } func Cat(left, right Pattern) Pattern { - return func(next *MatchState) *MatchState { - return left(right(next)) + return func(states *[]MatchState, nextIdx int) int { + // First run the right pattern, then feed the result into left. + return left(states, right(states, nextIdx)) } } func Or(p ...Pattern) Pattern { - return func(next *MatchState) *MatchState { + return func(states *[]MatchState, nextIdx int) int { if len(p) == 0 { - return next + return nextIdx } - if len(p) == 1 { - return p[0](next) - } - - return &MatchState{ - kind: split, - next: p[0](next), - nextSplit: Or(p[1:]...)(next), + // Evaluate the last pattern and use its result as the initial accumulator. + accum := p[len(p)-1](states, nextIdx) + // Iterate backwards from the second-to-last pattern to the first. + for i := len(p) - 2; i >= 0; i-- { + leftIdx := p[i](states, nextIdx) + newState := MatchState{ + next: leftIdx, + codeOrKind: storeSplitIdx(accum), + } + *states = append(*states, newState) + accum = len(*states) - 1 } + return accum } } @@ -52,7 +83,7 @@ func captureOneValueOrErr(val *string) captureFunc { *val = s return nil } - return &f + return f } func captureMany(vals *[]string) captureFunc { @@ -63,17 +94,18 @@ func captureMany(vals *[]string) captureFunc { *vals = append(*vals, s) return nil } - return &f + return f } func captureValWithF(code int, f captureFunc) Pattern { - return func(next *MatchState) *MatchState { - return &MatchState{ - kind: matchCode, - capture: f, - code: code, - next: next, + return func(states *[]MatchState, nextIdx int) int { + newState := MatchState{ + capture: f, + codeOrKind: code, + next: nextIdx, } + *states = append(*states, newState) + return len(*states) - 1 } } @@ -90,18 +122,27 @@ func ZeroOrMore(code int) Pattern { } func captureZeroOrMoreWithF(code int, f captureFunc) Pattern { - return func(next *MatchState) *MatchState { - match := &MatchState{ - code: code, - capture: f, + return func(states *[]MatchState, nextIdx int) int { + // Create the match state. + matchState := MatchState{ + codeOrKind: code, + capture: f, } - s := &MatchState{ - kind: split, - next: match, - nextSplit: next, + *states = append(*states, matchState) + matchIdx := len(*states) - 1 + + // Create the split state that branches to the match state and to the next state. + s := MatchState{ + next: matchIdx, + codeOrKind: storeSplitIdx(nextIdx), } - match.next = s // Loop back to the split. - return s + *states = append(*states, s) + splitIdx := len(*states) - 1 + + // Close the loop: update the match state's next field. + (*states)[matchIdx].next = splitIdx + + return splitIdx } } @@ -112,19 +153,24 @@ func CaptureZeroOrMore(code int, vals *[]string) Pattern { func OneOrMore(code int) Pattern { return CaptureOneOrMore(code, nil) } + func CaptureOneOrMore(code int, vals *[]string) Pattern { f := captureMany(vals) - return func(next *MatchState) *MatchState { - return captureValWithF(code, f)(captureZeroOrMoreWithF(code, f)(next)) + return func(states *[]MatchState, nextIdx int) int { + // First attach the zero-or-more loop. + zeroOrMoreIdx := captureZeroOrMoreWithF(code, f)(states, nextIdx) + // Then put the capture state before the loop. + return captureValWithF(code, f)(states, zeroOrMoreIdx) } } func Optional(s Pattern) Pattern { - return func(next *MatchState) *MatchState { - return &MatchState{ - kind: split, - next: s(next), - nextSplit: next, + return func(states *[]MatchState, nextIdx int) int { + newState := MatchState{ + next: s(states, nextIdx), + codeOrKind: storeSplitIdx(nextIdx), } + *states = append(*states, newState) + return len(*states) - 1 } }