From 81ef92c077cce950199c80d0358e669725e7e858 Mon Sep 17 00:00:00 2001 From: Vigilans Date: Sun, 31 Oct 2021 19:54:35 +0800 Subject: [PATCH] Refactor strmatcher.ACAutomatonMatcherGroup --- .../strmatcher/matchergroup_ac_automation.go | 415 ++++++++++-------- .../matchergroup_ac_automation_test.go | 182 ++++++++ 2 files changed, 416 insertions(+), 181 deletions(-) diff --git a/common/strmatcher/matchergroup_ac_automation.go b/common/strmatcher/matchergroup_ac_automation.go index 22f1760aa..e1d5e52ea 100644 --- a/common/strmatcher/matchergroup_ac_automation.go +++ b/common/strmatcher/matchergroup_ac_automation.go @@ -4,234 +4,287 @@ import ( "container/list" ) -const validCharCount = 53 - -type MatchType struct { - matchType Type - exist bool -} - const ( - TrieEdge bool = true - FailEdge bool = false + acValidCharCount = 38 // aA-zZ (26), 0-9 (10), - (1), . (1) + acMatchTypeCount = 3 // Full, Domain and Substr ) -type Edge struct { - edgeType bool - nextNode int -} +type acEdge byte + +const ( + acTrieEdge acEdge = 1 + acFailEdge acEdge = 0 +) + +type acNode struct { + next [acValidCharCount]uint32 // EdgeIdx -> Next NodeIdx (Next trie node or fail node) + edge [acValidCharCount]acEdge // EdgeIdx -> Trie Edge / Fail Edge + fail uint32 // NodeIdx of *next matched* Substr Pattern on its fail path + match uint32 // MatchIdx of matchers registered on this node, 0 indicates no match +} // Sizeof acNode: (4+1)*acValidCharCount + + 4 + 4 + +type acValue [acMatchTypeCount][]uint32 // MatcherType -> Registered Matcher Values // ACAutoMationMatcherGroup is an implementation of MatcherGroup. // It uses an AC Automata to provide support for Full, Domain and Substr matcher. Trie node is char based. +// +// NOTICE: ACAutomatonMatcherGroup currently uses a restricted charset (LDH Subset), +// upstream should manually in a way to ensure all patterns and inputs passed to it to be in this charset. type ACAutomatonMatcherGroup struct { - trie [][validCharCount]Edge - fail []int - exists []MatchType - count int -} - -func newNode() [validCharCount]Edge { - var s [validCharCount]Edge - for i := range s { - s[i] = Edge{ - edgeType: FailEdge, - nextNode: 0, - } - } - return s -} - -var char2Index = []int{ - 'A': 0, - 'a': 0, - 'B': 1, - 'b': 1, - 'C': 2, - 'c': 2, - 'D': 3, - 'd': 3, - 'E': 4, - 'e': 4, - 'F': 5, - 'f': 5, - 'G': 6, - 'g': 6, - 'H': 7, - 'h': 7, - 'I': 8, - 'i': 8, - 'J': 9, - 'j': 9, - 'K': 10, - 'k': 10, - 'L': 11, - 'l': 11, - 'M': 12, - 'm': 12, - 'N': 13, - 'n': 13, - 'O': 14, - 'o': 14, - 'P': 15, - 'p': 15, - 'Q': 16, - 'q': 16, - 'R': 17, - 'r': 17, - 'S': 18, - 's': 18, - 'T': 19, - 't': 19, - 'U': 20, - 'u': 20, - 'V': 21, - 'v': 21, - 'W': 22, - 'w': 22, - 'X': 23, - 'x': 23, - 'Y': 24, - 'y': 24, - 'Z': 25, - 'z': 25, - '!': 26, - '$': 27, - '&': 28, - '\'': 29, - '(': 30, - ')': 31, - '*': 32, - '+': 33, - ',': 34, - ';': 35, - '=': 36, - ':': 37, - '%': 38, - '-': 39, - '.': 40, - '_': 41, - '~': 42, - '0': 43, - '1': 44, - '2': 45, - '3': 46, - '4': 47, - '5': 48, - '6': 49, - '7': 50, - '8': 51, - '9': 52, + nodes []acNode // NodeIdx -> acNode + values []acValue // MatchIdx -> acValue } func NewACAutomatonMatcherGroup() *ACAutomatonMatcherGroup { ac := new(ACAutomatonMatcherGroup) - ac.trie = append(ac.trie, newNode()) - ac.fail = append(ac.fail, 0) - ac.exists = append(ac.exists, MatchType{ - matchType: Full, - exist: false, - }) + ac.addNode() // Create root node (NodeIdx 0) + ac.addMatchEntry() // Create sentinel match entry (MatchIdx 0) return ac } // AddFullMatcher implements MatcherGroupForFull.AddFullMatcher. -func (ac *ACAutomatonMatcherGroup) AddFullMatcher(matcher FullMatcher, _ uint32) { - ac.addPattern(0, matcher.Pattern(), matcher.Type()) +func (ac *ACAutomatonMatcherGroup) AddFullMatcher(matcher FullMatcher, value uint32) { + ac.addPattern(0, matcher.Pattern(), matcher.Type(), value) } // AddDomainMatcher implements MatcherGroupForDomain.AddDomainMatcher. -func (ac *ACAutomatonMatcherGroup) AddDomainMatcher(matcher DomainMatcher, _ uint32) { - node := ac.addPattern(0, matcher.Pattern(), Full) - ac.addPattern(node, ".", Domain) +func (ac *ACAutomatonMatcherGroup) AddDomainMatcher(matcher DomainMatcher, value uint32) { + node := ac.addPattern(0, matcher.Pattern(), matcher.Type(), value) // For full domain match + ac.addPattern(node, ".", matcher.Type(), value) // For partial domain match } // AddSubstrMatcher implements MatcherGroupForSubstr.AddSubstrMatcher. -func (ac *ACAutomatonMatcherGroup) AddSubstrMatcher(matcher SubstrMatcher, _ uint32) { - ac.addPattern(0, matcher.Pattern(), matcher.Type()) +func (ac *ACAutomatonMatcherGroup) AddSubstrMatcher(matcher SubstrMatcher, value uint32) { + ac.addPattern(0, matcher.Pattern(), matcher.Type(), value) } -func (ac *ACAutomatonMatcherGroup) addPattern(node int, pattern string, matcherType Type) int { +func (ac *ACAutomatonMatcherGroup) addPattern(nodeIdx uint32, pattern string, matcherType Type, value uint32) uint32 { + node := &ac.nodes[nodeIdx] for i := len(pattern) - 1; i >= 0; i-- { - idx := char2Index[pattern[i]] - if ac.trie[node][idx].nextNode == 0 { - ac.count++ - if len(ac.trie) < ac.count+1 { - ac.trie = append(ac.trie, newNode()) - ac.fail = append(ac.fail, 0) - ac.exists = append(ac.exists, MatchType{ - matchType: Full, - exist: false, - }) - } - ac.trie[node][idx] = Edge{ - edgeType: TrieEdge, - nextNode: ac.count, - } + edgeIdx := acCharset[pattern[i]] + nextIdx := node.next[edgeIdx] + if nextIdx == 0 { // Add new Trie Edge + nextIdx = ac.addNode() + ac.nodes[nodeIdx].next[edgeIdx] = nextIdx + ac.nodes[nodeIdx].edge[edgeIdx] = acTrieEdge } - node = ac.trie[node][idx].nextNode + nodeIdx = nextIdx + node = &ac.nodes[nodeIdx] } - ac.exists[node] = MatchType{ - matchType: matcherType, - exist: true, + if node.match == 0 { // Add new match entry + node.match = ac.addMatchEntry() } - return node + ac.values[node.match][matcherType] = append(ac.values[node.match][matcherType], value) + return nodeIdx } -func (ac *ACAutomatonMatcherGroup) Build() { +func (ac *ACAutomatonMatcherGroup) addNode() uint32 { + ac.nodes = append(ac.nodes, acNode{}) + return uint32(len(ac.nodes) - 1) +} + +func (ac *ACAutomatonMatcherGroup) addMatchEntry() uint32 { + ac.values = append(ac.values, acValue{}) + return uint32(len(ac.values) - 1) +} + +func (ac *ACAutomatonMatcherGroup) Build() error { + fail := make([]uint32, len(ac.nodes)) queue := list.New() - for i := 0; i < validCharCount; i++ { - if ac.trie[0][i].nextNode != 0 { - queue.PushBack(ac.trie[0][i]) + for edgeIdx := 0; edgeIdx < acValidCharCount; edgeIdx++ { + if nextIdx := ac.nodes[0].next[edgeIdx]; nextIdx != 0 { + queue.PushBack(nextIdx) } } for { front := queue.Front() if front == nil { break - } else { - node := front.Value.(Edge).nextNode - queue.Remove(front) - for i := 0; i < validCharCount; i++ { - if ac.trie[node][i].nextNode != 0 { - ac.fail[ac.trie[node][i].nextNode] = ac.trie[ac.fail[node]][i].nextNode - queue.PushBack(ac.trie[node][i]) - } else { - ac.trie[node][i] = Edge{ - edgeType: FailEdge, - nextNode: ac.trie[ac.fail[node]][i].nextNode, - } + } + queue.Remove(front) + nodeIdx := front.Value.(uint32) + node := &ac.nodes[nodeIdx] // Current node + failNode := &ac.nodes[fail[nodeIdx]] // Fail node of currrent node + for edgeIdx := 0; edgeIdx < acValidCharCount; edgeIdx++ { + nodeIdx := node.next[edgeIdx] // Next node through trie edge + failIdx := failNode.next[edgeIdx] // Next node through fail edge + if nodeIdx != 0 { + queue.PushBack(nodeIdx) + fail[nodeIdx] = failIdx + if match := ac.nodes[failIdx].match; match != 0 && len(ac.values[match][Substr]) > 0 { // Fail node is a Substr match node + ac.nodes[nodeIdx].fail = failIdx + } else { // Use path compression to reduce fail path to only contain match nodes + ac.nodes[nodeIdx].fail = ac.nodes[failIdx].fail } + } else { // Add new fail edge + node.next[edgeIdx] = failIdx + node.edge[edgeIdx] = acFailEdge } } } -} - -// Match implements MatcherGroup.Match. -func (*ACAutomatonMatcherGroup) Match(_ string) []uint32 { return nil } -// MatchAny implements MatcherGroup.MatchAny. -func (ac *ACAutomatonMatcherGroup) MatchAny(s string) bool { - node := 0 - fullMatch := true +// Match implements MatcherGroup.Match. +func (ac *ACAutomatonMatcherGroup) Match(input string) []uint32 { + var suffixMatches [][]uint32 + var substrMatches [][]uint32 + fullMatch := true // fullMatch indicates no fail edge traversed so far. + node := &ac.nodes[0] // start from root node. // 1. the match string is all through trie edge. FULL MATCH or DOMAIN // 2. the match string is through a fail edge. NOT FULL MATCH // 2.1 Through a fail edge, but there exists a valid node. SUBSTR - for i := len(s) - 1; i >= 0; i-- { - idx := char2Index[s[i]] - fullMatch = fullMatch && ac.trie[node][idx].edgeType - node = ac.trie[node][idx].nextNode - switch ac.exists[node].matchType { - case Substr: - return true - case Domain: - if fullMatch { - return true + for i := len(input) - 1; i >= 0; i-- { + edge := acCharset[input[i]] + fullMatch = fullMatch && (node.edge[edge] == acTrieEdge) + node = &ac.nodes[node.next[edge]] // Advance to next node + // When entering a new node, traverse the fail path to find all possible Substr patterns: + // 1. The fail path is compressed to only contains match nodes and root node (for terminate condition). + // 2. node.fail != 0 is added here for better performance (as shown by benchmark), possibly it helps branch prediction. + if node.fail != 0 { + for failIdx, failNode := node.fail, &ac.nodes[node.fail]; failIdx != 0; failIdx, failNode = failNode.fail, &ac.nodes[failIdx] { + substrMatches = append(substrMatches, ac.values[failNode.match][Substr]) + } + } + // When entering a new node, check whether this node is a match. + // For Substr matchers: + // 1. Matched in any situation, whether a failNode edge is traversed or not. + // For Domain matchers: + // 1. Should not traverse any fail edge (fullMatch). + // 2. Only check on dot separator (input[i] == '.'). + if node.match != 0 { + values := ac.values[node.match] + if len(values[Substr]) > 0 { + substrMatches = append(substrMatches, values[Substr]) + } + if fullMatch && input[i] == '.' && len(values[Domain]) > 0 { + suffixMatches = append(suffixMatches, values[Domain]) } - default: - break } } - return fullMatch && ac.exists[node].exist + // At the end of input, check if the whole string matches a pattern. + // For Domain matchers: + // 1. Exact match on Domain Matcher works like Full Match. e.g. foo.com is a full match for domain:foo.com. + // For Full matchers: + // 1. Only when no fail edge is traversed (fullMatch). + // 2. Takes the highest priority (added at last). + if fullMatch && node.match != 0 { + values := ac.values[node.match] + if len(values[Domain]) > 0 { + suffixMatches = append(suffixMatches, values[Domain]) + } + if len(values[Full]) > 0 { + suffixMatches = append(suffixMatches, values[Full]) + } + } + switch matches := append(substrMatches, suffixMatches...); len(matches) { // nolint: gocritic + case 0: + return nil + case 1: + return matches[0] + default: + result := []uint32{} + for i := len(matches) - 1; i >= 0; i-- { + result = append(result, matches[i]...) + } + return result + } +} + +// MatchAny implements MatcherGroup.MatchAny. +func (ac *ACAutomatonMatcherGroup) MatchAny(input string) bool { + fullMatch := true + node := &ac.nodes[0] + for i := len(input) - 1; i >= 0; i-- { + edge := acCharset[input[i]] + fullMatch = fullMatch && (node.edge[edge] == acTrieEdge) + node = &ac.nodes[node.next[edge]] + if node.fail != 0 { // There is a match on this node's fail path + return true + } + if node.match != 0 { // There is a match on this node + values := ac.values[node.match] + if len(values[Substr]) > 0 { // Substr match succeeds unconditionally + return true + } + if fullMatch && input[i] == '.' && len(values[Domain]) > 0 { // Domain match only succeeds with dot separator on trie path + return true + } + } + } + return fullMatch && node.match != 0 // At the end of input, Domain and Full match will succeed if no fail edge is traversed +} + +// Letter-Digit-Hyphen (LDH) subset (https://tools.ietf.org/html/rfc952): +// * Letters A to Z (no distinction is made between uppercase and lowercase) +// * Digits 0 to 9 +// * Hyphens(-) and Periods(.) +// +// If for future the strmatcher are used for other scenarios than domain, +// we could add a new Charset interface to represent variable charsets. +var acCharset = []int{ + 'A': 0, + 'a': 0, + 'B': 1, + 'b': 1, + 'C': 2, + 'c': 2, + 'D': 3, + 'd': 3, + 'E': 4, + 'e': 4, + 'F': 5, + 'f': 5, + 'G': 6, + 'g': 6, + 'H': 7, + 'h': 7, + 'I': 8, + 'i': 8, + 'J': 9, + 'j': 9, + 'K': 10, + 'k': 10, + 'L': 11, + 'l': 11, + 'M': 12, + 'm': 12, + 'N': 13, + 'n': 13, + 'O': 14, + 'o': 14, + 'P': 15, + 'p': 15, + 'Q': 16, + 'q': 16, + 'R': 17, + 'r': 17, + 'S': 18, + 's': 18, + 'T': 19, + 't': 19, + 'U': 20, + 'u': 20, + 'V': 21, + 'v': 21, + 'W': 22, + 'w': 22, + 'X': 23, + 'x': 23, + 'Y': 24, + 'y': 24, + 'Z': 25, + 'z': 25, + '-': 26, + '.': 27, + '0': 28, + '1': 29, + '2': 30, + '3': 31, + '4': 32, + '5': 33, + '6': 34, + '7': 35, + '8': 36, + '9': 37, } diff --git a/common/strmatcher/matchergroup_ac_automation_test.go b/common/strmatcher/matchergroup_ac_automation_test.go index fa3483882..4e731f654 100644 --- a/common/strmatcher/matchergroup_ac_automation_test.go +++ b/common/strmatcher/matchergroup_ac_automation_test.go @@ -1,6 +1,7 @@ package strmatcher_test import ( + "reflect" "testing" "github.com/v2fly/v2ray-core/v5/common" @@ -180,4 +181,185 @@ func TestACAutomatonMatcherGroup(t *testing.T) { } } } + + { + cases4Input := []struct { + pattern string + mType Type + }{ + { + pattern: "apis", + mType: Substr, + }, + { + pattern: "googleapis.com", + mType: Domain, + }, + } + ac := NewACAutomatonMatcherGroup() + for _, test := range cases4Input { + matcher, err := test.mType.New(test.pattern) + common.Must(err) + common.Must(AddMatcherToGroup(ac, matcher, 0)) + } + ac.Build() + cases4Output := []struct { + pattern string + res bool + }{ + { + pattern: "gapis.com", + res: true, + }, + } + for _, test := range cases4Output { + if m := ac.MatchAny(test.pattern); m != test.res { + t.Error("unexpected output: ", m, " for test case ", test) + } + } + } +} + +func TestACAutomatonMatcherGroupSubstr(t *testing.T) { + patterns := []struct { + pattern string + mType Type + }{ + { + pattern: "apis", + mType: Substr, + }, + { + pattern: "google", + mType: Substr, + }, + { + pattern: "apis", + mType: Substr, + }, + } + cases := []struct { + input string + output []uint32 + }{ + { + input: "google.com", + output: []uint32{1}, + }, + { + input: "apis.com", + output: []uint32{0, 2}, + }, + { + input: "googleapis.com", + output: []uint32{1, 0, 2}, + }, + { + input: "fonts.googleapis.com", + output: []uint32{1, 0, 2}, + }, + { + input: "apis.googleapis.com", + output: []uint32{0, 2, 1, 0, 2}, + }, + } + matcherGroup := NewACAutomatonMatcherGroup() + for id, entry := range patterns { + matcher, err := entry.mType.New(entry.pattern) + common.Must(err) + common.Must(AddMatcherToGroup(matcherGroup, matcher, uint32(id))) + } + matcherGroup.Build() + for _, test := range cases { + if r := matcherGroup.Match(test.input); !reflect.DeepEqual(r, test.output) { + t.Error("unexpected output: ", r, " for test case ", test) + } + } +} + +// See https://github.com/v2fly/v2ray-core/issues/92#issuecomment-673238489 +func TestACAutomatonMatcherGroupAsIndexMatcher(t *testing.T) { + rules := []struct { + Type Type + Domain string + }{ + // Regex not supported by ACAutomationMatcherGroup + // { + // Type: Regex, + // Domain: "apis\\.us$", + // }, + { + Type: Substr, + Domain: "apis", + }, + { + Type: Domain, + Domain: "googleapis.com", + }, + { + Type: Domain, + Domain: "com", + }, + { + Type: Full, + Domain: "www.baidu.com", + }, + { + Type: Substr, + Domain: "apis", + }, + { + Type: Domain, + Domain: "googleapis.com", + }, + { + Type: Full, + Domain: "fonts.googleapis.com", + }, + { + Type: Full, + Domain: "www.baidu.com", + }, + { + Type: Domain, + Domain: "example.com", + }, + } + cases := []struct { + Input string + Output []uint32 + }{ + { + Input: "www.baidu.com", + Output: []uint32{5, 9, 4}, + }, + { + Input: "fonts.googleapis.com", + Output: []uint32{8, 3, 7, 4, 2, 6}, + }, + { + Input: "example.googleapis.com", + Output: []uint32{3, 7, 4, 2, 6}, + }, + { + Input: "testapis.us", + Output: []uint32{2, 6 /*, 1*/}, + }, + { + Input: "example.com", + Output: []uint32{10, 4}, + }, + } + matcherGroup := NewACAutomatonMatcherGroup() + for i, rule := range rules { + matcher, err := rule.Type.New(rule.Domain) + common.Must(err) + common.Must(AddMatcherToGroup(matcherGroup, matcher, uint32(i+2))) + } + matcherGroup.Build() + for _, test := range cases { + if m := matcherGroup.Match(test.Input); !reflect.DeepEqual(m, test.Output) { + t.Error("unexpected output: ", m, " for test case ", test) + } + } }