From de618121adb1bd6a4915bdf7a84474b1f51f2a64 Mon Sep 17 00:00:00 2001 From: DarthVader <61409963+darsvador@users.noreply.github.com> Date: Wed, 20 Jan 2021 15:53:07 +0800 Subject: [PATCH] Refactor: A faster DomainMatcher implementation (#587) * a faster DomainMatcher implementation * rename benchmark name * fix linting errors --- app/router/condition.go | 18 ++ app/router/condition_test.go | 95 ++++++++- app/router/config.go | 2 +- common/strmatcher/ac_automaton_matcher.go | 243 ++++++++++++++++++++++ common/strmatcher/benchmark_test.go | 13 ++ common/strmatcher/matchers_test.go | 168 +++++++++++++++ common/strmatcher/strmatcher.go | 51 +++++ 7 files changed, 586 insertions(+), 4 deletions(-) create mode 100644 common/strmatcher/ac_automaton_matcher.go diff --git a/app/router/condition.go b/app/router/condition.go index ac6984feb..9da16378b 100644 --- a/app/router/condition.go +++ b/app/router/condition.go @@ -67,6 +67,24 @@ type DomainMatcher struct { matchers strmatcher.IndexMatcher } +func NewACAutomatonDomainMatcher(domains []*Domain) (*DomainMatcher, error) { + g := strmatcher.NewACAutomatonMatcherGroup() + for _, d := range domains { + matcherType, f := matcherTypeMap[d.Type] + if !f { + return nil, newError("unsupported domain type", d.Type) + } + _, err := g.AddPattern(d.Value, matcherType) + if err != nil { + return nil, err + } + } + g.Build() + return &DomainMatcher{ + matchers: g, + }, nil +} + func NewDomainMatcher(domains []*Domain) (*DomainMatcher, error) { g := new(strmatcher.MatcherGroup) for _, d := range domains { diff --git a/app/router/condition_test.go b/app/router/condition_test.go index 05d734669..468970de3 100644 --- a/app/router/condition_test.go +++ b/app/router/condition_test.go @@ -358,6 +358,8 @@ func TestChinaSites(t *testing.T) { matcher, err := NewDomainMatcher(domains) common.Must(err) + acMatcher, err := NewACAutomatonDomainMatcher(domains) + common.Must(err) type TestCase struct { Domain string @@ -387,9 +389,96 @@ func TestChinaSites(t *testing.T) { } for _, testCase := range testCases { - r := matcher.ApplyDomain(testCase.Domain) - if r != testCase.Output { - t.Error("expected output ", testCase.Output, " for domain ", testCase.Domain, " but got ", r) + r1 := matcher.ApplyDomain(testCase.Domain) + r2 := acMatcher.ApplyDomain(testCase.Domain) + if r1 != testCase.Output { + t.Error("DomainMatcher expected output ", testCase.Output, " for domain ", testCase.Domain, " but got ", r1) + } else if r2 != testCase.Output { + t.Error("ACDomainMatcher expected output ", testCase.Output, " for domain ", testCase.Domain, " but got ", r2) + } + } +} + +func BenchmarkACDomainMatcher(b *testing.B) { + domains, err := loadGeoSite("CN") + common.Must(err) + + matcher, err := NewACAutomatonDomainMatcher(domains) + common.Must(err) + + type TestCase struct { + Domain string + Output bool + } + testCases := []TestCase{ + { + Domain: "163.com", + Output: true, + }, + { + Domain: "163.com", + Output: true, + }, + { + Domain: "164.com", + Output: false, + }, + { + Domain: "164.com", + Output: false, + }, + } + + for i := 0; i < 1024; i++ { + testCases = append(testCases, TestCase{Domain: strconv.Itoa(i) + ".not-exists.com", Output: false}) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, testCase := range testCases { + _ = matcher.ApplyDomain(testCase.Domain) + } + } +} + +func BenchmarkDomainMatcher(b *testing.B) { + domains, err := loadGeoSite("CN") + common.Must(err) + + matcher, err := NewDomainMatcher(domains) + common.Must(err) + + type TestCase struct { + Domain string + Output bool + } + testCases := []TestCase{ + { + Domain: "163.com", + Output: true, + }, + { + Domain: "163.com", + Output: true, + }, + { + Domain: "164.com", + Output: false, + }, + { + Domain: "164.com", + Output: false, + }, + } + + for i := 0; i < 1024; i++ { + testCases = append(testCases, TestCase{Domain: strconv.Itoa(i) + ".not-exists.com", Output: false}) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, testCase := range testCases { + _ = matcher.ApplyDomain(testCase.Domain) } } } diff --git a/app/router/config.go b/app/router/config.go index 8eb9d5aa1..62b0ada7c 100644 --- a/app/router/config.go +++ b/app/router/config.go @@ -69,7 +69,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) { conds := NewConditionChan() if len(rr.Domain) > 0 { - matcher, err := NewDomainMatcher(rr.Domain) + matcher, err := NewACAutomatonDomainMatcher(rr.Domain) if err != nil { return nil, newError("failed to build domain condition").Base(err) } diff --git a/common/strmatcher/ac_automaton_matcher.go b/common/strmatcher/ac_automaton_matcher.go new file mode 100644 index 000000000..e21364ecc --- /dev/null +++ b/common/strmatcher/ac_automaton_matcher.go @@ -0,0 +1,243 @@ +package strmatcher + +import ( + "container/list" +) + +const validCharCount = 53 + +type MatchType struct { + matchType Type + exist bool +} + +const ( + TrieEdge bool = true + FailEdge bool = false +) + +type Edge struct { + edgeType bool + nextNode int +} + +type ACAutomaton struct { + trie [][validCharCount]Edge + fail []int + exists []MatchType + count int +} + +func newNode() [validCharCount]Edge { + var s [validCharCount]Edge + for i := range s { + s[i] = Edge{ + edgeType: FailEdge, + nextNode: 0, + } + } + return s +} + +var char2Index = []int{ + 'A': 0, + 'a': 0, + 'B': 1, + 'b': 1, + 'C': 2, + 'c': 2, + 'D': 3, + 'd': 3, + 'E': 4, + 'e': 4, + 'F': 5, + 'f': 5, + 'G': 6, + 'g': 6, + 'H': 7, + 'h': 7, + 'I': 8, + 'i': 8, + 'J': 9, + 'j': 9, + 'K': 10, + 'k': 10, + 'L': 11, + 'l': 11, + 'M': 12, + 'm': 12, + 'N': 13, + 'n': 13, + 'O': 14, + 'o': 14, + 'P': 15, + 'p': 15, + 'Q': 16, + 'q': 16, + 'R': 17, + 'r': 17, + 'S': 18, + 's': 18, + 'T': 19, + 't': 19, + 'U': 20, + 'u': 20, + 'V': 21, + 'v': 21, + 'W': 22, + 'w': 22, + 'X': 23, + 'x': 23, + 'Y': 24, + 'y': 24, + 'Z': 25, + 'z': 25, + '!': 26, + '$': 27, + '&': 28, + '\'': 29, + '(': 30, + ')': 31, + '*': 32, + '+': 33, + ',': 34, + ';': 35, + '=': 36, + ':': 37, + '%': 38, + '-': 39, + '.': 40, + '_': 41, + '~': 42, + '0': 43, + '1': 44, + '2': 45, + '3': 46, + '4': 47, + '5': 48, + '6': 49, + '7': 50, + '8': 51, + '9': 52, +} + +func NewACAutomaton() *ACAutomaton { + var ac = new(ACAutomaton) + ac.trie = append(ac.trie, newNode()) + ac.fail = append(ac.fail, 0) + ac.exists = append(ac.exists, MatchType{ + matchType: Full, + exist: false, + }) + return ac +} + +func (ac *ACAutomaton) Add(domain string, t Type) { + var node = 0 + for i := len(domain) - 1; i >= 0; i-- { + var idx = char2Index[domain[i]] + if ac.trie[node][idx].nextNode == 0 { + ac.count++ + if len(ac.trie) < ac.count+1 { + ac.trie = append(ac.trie, newNode()) + ac.fail = append(ac.fail, 0) + ac.exists = append(ac.exists, MatchType{ + matchType: Full, + exist: false, + }) + } + ac.trie[node][idx] = Edge{ + edgeType: TrieEdge, + nextNode: ac.count, + } + } + node = ac.trie[node][idx].nextNode + } + ac.exists[node] = MatchType{ + matchType: t, + exist: true, + } + switch t { + case Domain: + ac.exists[node] = MatchType{ + matchType: Full, + exist: true, + } + var idx = char2Index['.'] + if ac.trie[node][idx].nextNode == 0 { + ac.count++ + if len(ac.trie) < ac.count+1 { + ac.trie = append(ac.trie, newNode()) + ac.fail = append(ac.fail, 0) + ac.exists = append(ac.exists, MatchType{ + matchType: Full, + exist: false, + }) + } + ac.trie[node][idx] = Edge{ + edgeType: TrieEdge, + nextNode: ac.count, + } + } + node = ac.trie[node][idx].nextNode + ac.exists[node] = MatchType{ + matchType: t, + exist: true, + } + default: + break + } +} + +func (ac *ACAutomaton) Build() { + var queue = list.New() + for i := 0; i < validCharCount; i++ { + if ac.trie[0][i].nextNode != 0 { + queue.PushBack(ac.trie[0][i]) + } + } + for { + var front = queue.Front() + if front == nil { + break + } else { + var node = front.Value.(Edge).nextNode + queue.Remove(front) + for i := 0; i < validCharCount; i++ { + if ac.trie[node][i].nextNode != 0 { + ac.fail[ac.trie[node][i].nextNode] = ac.trie[ac.fail[node]][i].nextNode + queue.PushBack(ac.trie[node][i]) + } else { + ac.trie[node][i] = Edge{ + edgeType: FailEdge, + nextNode: ac.trie[ac.fail[node]][i].nextNode, + } + } + } + } + } +} + +func (ac *ACAutomaton) Match(s string) bool { + var node = 0 + var fullMatch = true + // 1. the match string is all through trie edge. FULL MATCH or DOMAIN + // 2. the match string is through a fail edge. NOT FULL MATCH + // 2.1 Through a fail edge, but there exists a valid node. SUBSTR + for i := len(s) - 1; i >= 0; i-- { + var idx = char2Index[s[i]] + fullMatch = fullMatch && ac.trie[node][idx].edgeType + node = ac.trie[node][idx].nextNode + switch ac.exists[node].matchType { + case Substr: + return true + case Domain: + if fullMatch { + return true + } + default: + break + } + } + return fullMatch && ac.exists[node].exist +} diff --git a/common/strmatcher/benchmark_test.go b/common/strmatcher/benchmark_test.go index de5f5f626..5a0673a30 100644 --- a/common/strmatcher/benchmark_test.go +++ b/common/strmatcher/benchmark_test.go @@ -8,6 +8,19 @@ import ( . "v2ray.com/core/common/strmatcher" ) +func BenchmarkACAutomaton(b *testing.B) { + ac := NewACAutomaton() + for i := 1; i <= 1024; i++ { + ac.Add(strconv.Itoa(i)+".v2ray.com", Domain) + } + ac.Build() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ac.Match("0.v2ray.com") + } +} + func BenchmarkDomainMatcherGroup(b *testing.B) { g := new(DomainMatcherGroup) diff --git a/common/strmatcher/matchers_test.go b/common/strmatcher/matchers_test.go index 4d615b143..83317fa38 100644 --- a/common/strmatcher/matchers_test.go +++ b/common/strmatcher/matchers_test.go @@ -71,3 +71,171 @@ func TestMatcher(t *testing.T) { } } } +func TestACAutomaton(t *testing.T) { + cases1 := []struct { + pattern string + mType Type + input string + output bool + }{ + { + pattern: "v2ray.com", + mType: Domain, + input: "www.v2ray.com", + output: true, + }, + { + pattern: "v2ray.com", + mType: Domain, + input: "v2ray.com", + output: true, + }, + { + pattern: "v2ray.com", + mType: Domain, + input: "www.v3ray.com", + output: false, + }, + { + pattern: "v2ray.com", + mType: Domain, + input: "2ray.com", + output: false, + }, + { + pattern: "v2ray.com", + mType: Domain, + input: "xv2ray.com", + output: false, + }, + { + pattern: "v2ray.com", + mType: Full, + input: "v2ray.com", + output: true, + }, + { + pattern: "v2ray.com", + mType: Full, + input: "xv2ray.com", + output: false, + }, + } + for _, test := range cases1 { + var ac = NewACAutomaton() + ac.Add(test.pattern, test.mType) + ac.Build() + if m := ac.Match(test.input); m != test.output { + t.Error("unexpected output: ", m, " for test case ", test) + } + } + { + cases2Input := []struct { + pattern string + mType Type + }{ + { + pattern: "163.com", + mType: Domain, + }, + { + pattern: "m.126.com", + mType: Full, + }, + { + pattern: "3.com", + mType: Full, + }, + { + pattern: "google.com", + mType: Substr, + }, + { + pattern: "vgoogle.com", + mType: Substr, + }, + } + var ac = NewACAutomaton() + for _, test := range cases2Input { + ac.Add(test.pattern, test.mType) + } + ac.Build() + cases2Output := []struct { + pattern string + res bool + }{ + { + pattern: "126.com", + res: false, + }, + { + pattern: "m.163.com", + res: true, + }, + { + pattern: "mm163.com", + res: false, + }, + { + pattern: "m.126.com", + res: true, + }, + { + pattern: "163.com", + res: true, + }, + { + pattern: "63.com", + res: false, + }, + { + pattern: "oogle.com", + res: false, + }, + { + pattern: "vvgoogle.com", + res: true, + }, + } + for _, test := range cases2Output { + if m := ac.Match(test.pattern); m != test.res { + t.Error("unexpected output: ", m, " for test case ", test) + } + } + } + + { + cases3Input := []struct { + pattern string + mType Type + }{ + { + pattern: "video.google.com", + mType: Domain, + }, + { + pattern: "gle.com", + mType: Domain, + }, + } + var ac = NewACAutomaton() + for _, test := range cases3Input { + ac.Add(test.pattern, test.mType) + } + ac.Build() + cases3Output := []struct { + pattern string + res bool + }{ + { + pattern: "google.com", + res: false, + }, + } + for _, test := range cases3Output { + if m := ac.Match(test.pattern); m != test.res { + t.Error("unexpected output: ", m, " for test case ", test) + } + } + } +} diff --git a/common/strmatcher/strmatcher.go b/common/strmatcher/strmatcher.go index 9728047d5..5f1dcd953 100644 --- a/common/strmatcher/strmatcher.go +++ b/common/strmatcher/strmatcher.go @@ -58,6 +58,57 @@ type matcherEntry struct { id uint32 } +type ACAutomatonMatcherGroup struct { + count uint32 + ac *ACAutomaton + otherMatchers []matcherEntry +} + +func NewACAutomatonMatcherGroup() *ACAutomatonMatcherGroup { + var g = new(ACAutomatonMatcherGroup) + g.count = 1 + g.ac = NewACAutomaton() + return g +} + +func (g *ACAutomatonMatcherGroup) AddPattern(pattern string, t Type) (uint32, error) { + switch t { + case Full, Substr, Domain: + g.ac.Add(pattern, t) + case Regex: + g.count++ + r, err := regexp.Compile(pattern) + if err != nil { + return 0, err + } + g.otherMatchers = append(g.otherMatchers, matcherEntry{ + m: ®exMatcher{pattern: r}, + id: g.count, + }) + default: + panic("Unknown type") + } + return g.count, nil +} + +func (g *ACAutomatonMatcherGroup) Build() { + g.ac.Build() +} + +// Match implements IndexMatcher.Match. +func (g *ACAutomatonMatcherGroup) Match(pattern string) []uint32 { + result := []uint32{} + if g.ac.Match(pattern) { + result = append(result, 1) + } + for _, e := range g.otherMatchers { + if e.m.Match(pattern) { + result = append(result, e.id) + } + } + return result +} + // MatcherGroup is an implementation of IndexMatcher. // Empty initialization works. type MatcherGroup struct {