From ac1e5cd92512003b4e0bd9ce9c5d0e5b7a6926e3 Mon Sep 17 00:00:00 2001 From: DarthVader <61409963+darsvador@users.noreply.github.com> Date: Mon, 15 Mar 2021 15:21:38 +0800 Subject: [PATCH] Add minimal perfect hash domain matcher (#743) * rename to HybridDomainMatcher & convert domain to lowercase * refactor code & add open hashing for rolling hash map * fix lint errors * update app/dns/dns.go * convert domain to lowercase in `strmatcher.go` * keep the original matcher behavior * add mph domain matcher & conver domain names to loweercase when matching * fix lint errors * fix lint errors --- app/router/condition.go | 6 +- app/router/condition_test.go | 6 +- app/router/config.go | 8 +- common/strmatcher/mph_matcher.go | 297 +++++++++++++++++++++++++++++++ common/strmatcher/strmatcher.go | 95 +--------- 5 files changed, 308 insertions(+), 104 deletions(-) create mode 100644 common/strmatcher/mph_matcher.go diff --git a/app/router/condition.go b/app/router/condition.go index 9c87a6a1c..d9aedd83d 100644 --- a/app/router/condition.go +++ b/app/router/condition.go @@ -68,8 +68,8 @@ type DomainMatcher struct { matchers strmatcher.IndexMatcher } -func NewACAutomatonDomainMatcher(domains []*Domain) (*DomainMatcher, error) { - g := strmatcher.NewACAutomatonMatcherGroup() +func NewMphMatcherGroup(domains []*Domain) (*DomainMatcher, error) { + g := strmatcher.NewMphMatcherGroup() for _, d := range domains { matcherType, f := matcherTypeMap[d.Type] if !f { @@ -102,7 +102,7 @@ func NewDomainMatcher(domains []*Domain) (*DomainMatcher, error) { } func (m *DomainMatcher) ApplyDomain(domain string) bool { - return len(m.matchers.Match(domain)) > 0 + return len(m.matchers.Match(strings.ToLower(domain))) > 0 } // Apply implements Condition. diff --git a/app/router/condition_test.go b/app/router/condition_test.go index ce3c581b5..229bf60be 100644 --- a/app/router/condition_test.go +++ b/app/router/condition_test.go @@ -358,7 +358,7 @@ func TestChinaSites(t *testing.T) { matcher, err := NewDomainMatcher(domains) common.Must(err) - acMatcher, err := NewACAutomatonDomainMatcher(domains) + acMatcher, err := NewMphMatcherGroup(domains) common.Must(err) type TestCase struct { @@ -399,11 +399,11 @@ func TestChinaSites(t *testing.T) { } } -func BenchmarkHybridDomainMatcher(b *testing.B) { +func BenchmarkMphDomainMatcher(b *testing.B) { domains, err := loadGeoSite("CN") common.Must(err) - matcher, err := NewACAutomatonDomainMatcher(domains) + matcher, err := NewMphMatcherGroup(domains) common.Must(err) type TestCase struct { diff --git a/app/router/config.go b/app/router/config.go index 52699fcfd..1612d546d 100644 --- a/app/router/config.go +++ b/app/router/config.go @@ -70,12 +70,12 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) { if len(rr.Domain) > 0 { switch rr.DomainMatcher { - case "hybrid": - matcher, err := NewACAutomatonDomainMatcher(rr.Domain) + case "mph": + matcher, err := NewMphMatcherGroup(rr.Domain) if err != nil { - return nil, newError("failed to build domain condition with ACAutomatonDomainMatcher").Base(err) + return nil, newError("failed to build domain condition with MphDomainMatcher").Base(err) } - newError("ACAutomatonDomainMatcher is enabled for ", len(rr.Domain), "domain rules(s)").AtDebug().WriteToLog() + newError("MphDomainMatcher is enabled for ", len(rr.Domain), "domain rules(s)").AtDebug().WriteToLog() conds.Add(matcher) case "linear": fallthrough diff --git a/common/strmatcher/mph_matcher.go b/common/strmatcher/mph_matcher.go new file mode 100644 index 000000000..794d6b8d4 --- /dev/null +++ b/common/strmatcher/mph_matcher.go @@ -0,0 +1,297 @@ +package strmatcher + +import ( + "math/bits" + "regexp" + "sort" + "strings" + "unsafe" +) + +// PrimeRK is the prime base used in Rabin-Karp algorithm. +const PrimeRK = 16777619 + +// calculate the rolling murmurHash of given string +func RollingHash(s string) uint32 { + h := uint32(0) + for i := len(s) - 1; i >= 0; i-- { + h = h*PrimeRK + uint32(s[i]) + } + return h +} + +// A MphMatcherGroup is divided into three parts: +// 1. `full` and `domain` patterns are matched by Rabin-Karp algorithm and minimal perfect hash table; +// 2. `substr` patterns are matched by ac automaton; +// 3. `regex` patterns are matched with the regex library. +type MphMatcherGroup struct { + ac *ACAutomaton + otherMatchers []matcherEntry + rules []string + level0 []uint32 + level0Mask int + level1 []uint32 + level1Mask int + count uint32 + ruleMap *map[string]uint32 +} + +func (g *MphMatcherGroup) AddFullOrDomainPattern(pattern string, t Type) { + h := RollingHash(pattern) + switch t { + case Domain: + (*g.ruleMap)["."+pattern] = h*PrimeRK + uint32('.') + fallthrough + case Full: + (*g.ruleMap)[pattern] = h + default: + } +} + +func NewMphMatcherGroup() *MphMatcherGroup { + return &MphMatcherGroup{ + ac: nil, + otherMatchers: nil, + rules: nil, + level0: nil, + level0Mask: 0, + level1: nil, + level1Mask: 0, + count: 1, + ruleMap: &map[string]uint32{}, + } +} + +// AddPattern adds a pattern to MphMatcherGroup +func (g *MphMatcherGroup) AddPattern(pattern string, t Type) (uint32, error) { + switch t { + case Substr: + if g.ac == nil { + g.ac = NewACAutomaton() + } + g.ac.Add(pattern, t) + case Full, Domain: + pattern = strings.ToLower(pattern) + g.AddFullOrDomainPattern(pattern, t) + case Regex: + r, err := regexp.Compile(pattern) + if err != nil { + return 0, err + } + g.otherMatchers = append(g.otherMatchers, matcherEntry{ + m: ®exMatcher{pattern: r}, + id: g.count, + }) + default: + panic("Unknown type") + } + return g.count, nil +} + +// Build builds a minimal perfect hash table and ac automaton from insert rules +func (g *MphMatcherGroup) Build() { + if g.ac != nil { + g.ac.Build() + } + keyLen := len(*g.ruleMap) + g.level0 = make([]uint32, nextPow2(keyLen/4)) + g.level0Mask = len(g.level0) - 1 + g.level1 = make([]uint32, nextPow2(keyLen)) + g.level1Mask = len(g.level1) - 1 + var sparseBuckets = make([][]int, len(g.level0)) + var ruleIdx int + for rule, hash := range *g.ruleMap { + n := int(hash) & g.level0Mask + g.rules = append(g.rules, rule) + sparseBuckets[n] = append(sparseBuckets[n], ruleIdx) + ruleIdx++ + } + g.ruleMap = nil + var buckets []indexBucket + for n, vals := range sparseBuckets { + if len(vals) > 0 { + buckets = append(buckets, indexBucket{n, vals}) + } + } + sort.Sort(bySize(buckets)) + + occ := make([]bool, len(g.level1)) + var tmpOcc []int + for _, bucket := range buckets { + var seed = uint32(0) + for { + findSeed := true + tmpOcc = tmpOcc[:0] + for _, i := range bucket.vals { + n := int(strhashFallback(unsafe.Pointer(&g.rules[i]), uintptr(seed))) & g.level1Mask + if occ[n] { + for _, n := range tmpOcc { + occ[n] = false + } + seed++ + findSeed = false + break + } + occ[n] = true + tmpOcc = append(tmpOcc, n) + g.level1[n] = uint32(i) + } + if findSeed { + g.level0[bucket.n] = seed + break + } + } + } +} + +func nextPow2(v int) int { + if v <= 1 { + return 1 + } + const MaxUInt = ^uint(0) + n := (MaxUInt >> bits.LeadingZeros(uint(v))) + 1 + return int(n) +} + +// Lookup searches for s in t and returns its index and whether it was found. +func (g *MphMatcherGroup) Lookup(h uint32, s string) bool { + i0 := int(h) & g.level0Mask + seed := g.level0[i0] + i1 := int(strhashFallback(unsafe.Pointer(&s), uintptr(seed))) & g.level1Mask + n := g.level1[i1] + return s == g.rules[int(n)] +} + +// Match implements IndexMatcher.Match. +func (g *MphMatcherGroup) Match(pattern string) []uint32 { + result := []uint32{} + hash := uint32(0) + for i := len(pattern) - 1; i >= 0; i-- { + hash = hash*PrimeRK + uint32(pattern[i]) + if pattern[i] == '.' { + if g.Lookup(hash, pattern[i:]) { + result = append(result, 1) + return result + } + } + } + if g.Lookup(hash, pattern) { + result = append(result, 1) + return result + } + if g.ac != nil && g.ac.Match(pattern) { + result = append(result, 1) + return result + } + for _, e := range g.otherMatchers { + if e.m.Match(pattern) { + result = append(result, e.id) + return result + } + } + return nil +} + +type indexBucket struct { + n int + vals []int +} + +type bySize []indexBucket + +func (s bySize) Len() int { return len(s) } +func (s bySize) Less(i, j int) bool { return len(s[i].vals) > len(s[j].vals) } +func (s bySize) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +type stringStruct struct { + str unsafe.Pointer + len int +} + +func strhashFallback(a unsafe.Pointer, h uintptr) uintptr { + x := (*stringStruct)(a) + return memhashFallback(x.str, h, uintptr(x.len)) +} + +const ( + // Constants for multiplication: four random odd 64-bit numbers. + m1 = 16877499708836156737 + m2 = 2820277070424839065 + m3 = 9497967016996688599 + m4 = 15839092249703872147 +) + +var hashkey = [4]uintptr{1, 1, 1, 1} + +func memhashFallback(p unsafe.Pointer, seed, s uintptr) uintptr { + h := uint64(seed + s*hashkey[0]) +tail: + switch { + case s == 0: + case s < 4: + h ^= uint64(*(*byte)(p)) + h ^= uint64(*(*byte)(add(p, s>>1))) << 8 + h ^= uint64(*(*byte)(add(p, s-1))) << 16 + h = rotl31(h*m1) * m2 + case s <= 8: + h ^= uint64(readUnaligned32(p)) + h ^= uint64(readUnaligned32(add(p, s-4))) << 32 + h = rotl31(h*m1) * m2 + case s <= 16: + h ^= readUnaligned64(p) + h = rotl31(h*m1) * m2 + h ^= readUnaligned64(add(p, s-8)) + h = rotl31(h*m1) * m2 + case s <= 32: + h ^= readUnaligned64(p) + h = rotl31(h*m1) * m2 + h ^= readUnaligned64(add(p, 8)) + h = rotl31(h*m1) * m2 + h ^= readUnaligned64(add(p, s-16)) + h = rotl31(h*m1) * m2 + h ^= readUnaligned64(add(p, s-8)) + h = rotl31(h*m1) * m2 + default: + v1 := h + v2 := uint64(seed * hashkey[1]) + v3 := uint64(seed * hashkey[2]) + v4 := uint64(seed * hashkey[3]) + for s >= 32 { + v1 ^= readUnaligned64(p) + v1 = rotl31(v1*m1) * m2 + p = add(p, 8) + v2 ^= readUnaligned64(p) + v2 = rotl31(v2*m2) * m3 + p = add(p, 8) + v3 ^= readUnaligned64(p) + v3 = rotl31(v3*m3) * m4 + p = add(p, 8) + v4 ^= readUnaligned64(p) + v4 = rotl31(v4*m4) * m1 + p = add(p, 8) + s -= 32 + } + h = v1 ^ v2 ^ v3 ^ v4 + goto tail + } + + h ^= h >> 29 + h *= m3 + h ^= h >> 32 + return uintptr(h) +} +func add(p unsafe.Pointer, x uintptr) unsafe.Pointer { + return unsafe.Pointer(uintptr(p) + x) +} +func readUnaligned32(p unsafe.Pointer) uint32 { + q := (*[4]byte)(p) + return uint32(q[0]) | uint32(q[1])<<8 | uint32(q[2])<<16 | uint32(q[3])<<24 +} + +func rotl31(x uint64) uint64 { + return (x << 31) | (x >> (64 - 31)) +} +func readUnaligned64(p unsafe.Pointer) uint64 { + q := (*[8]byte)(p) + return uint64(q[0]) | uint64(q[1])<<8 | uint64(q[2])<<16 | uint64(q[3])<<24 | uint64(q[4])<<32 | uint64(q[5])<<40 | uint64(q[6])<<48 | uint64(q[7])<<56 +} diff --git a/common/strmatcher/strmatcher.go b/common/strmatcher/strmatcher.go index 97d404f77..294e6e73b 100644 --- a/common/strmatcher/strmatcher.go +++ b/common/strmatcher/strmatcher.go @@ -4,9 +4,6 @@ import ( "regexp" ) -// PrimeRK is the prime base used in Rabin-Karp algorithm. -const PrimeRK = 16777619 - // Matcher is the interface to determine a string matches a pattern. type Matcher interface { // Match returns true if the given string matches a predefined pattern. @@ -30,6 +27,7 @@ const ( // New creates a new Matcher based on the given pattern. func (t Type) New(pattern string) (Matcher, error) { + // 1. regex matching is case-sensitive switch t { case Full: return fullMatcher(pattern), nil @@ -61,97 +59,6 @@ type matcherEntry struct { id uint32 } -type ACAutomatonMatcherGroup struct { - count uint32 - ac *ACAutomaton - nonSubstrMap map[uint32]string - otherMatchers []matcherEntry -} - -func NewACAutomatonMatcherGroup() *ACAutomatonMatcherGroup { - var g = new(ACAutomatonMatcherGroup) - g.count = 1 - g.nonSubstrMap = map[uint32]string{} - return g -} - -// Add `full` or `domain` pattern to hashmap -func (g *ACAutomatonMatcherGroup) AddFullOrDomainPattern(pattern string, t Type) { - h := uint32(0) - for i := len(pattern) - 1; i >= 0; i-- { - h = h*PrimeRK + uint32(pattern[i]) - } - switch t { - case Full: - g.nonSubstrMap[h] = pattern - case Domain: - g.nonSubstrMap[h] = pattern - g.nonSubstrMap[h*PrimeRK+uint32('.')] = "." + pattern - default: - } -} - -func (g *ACAutomatonMatcherGroup) AddPattern(pattern string, t Type) (uint32, error) { - switch t { - case Substr: - if g.ac == nil { - g.ac = NewACAutomaton() - } - g.ac.Add(pattern, t) - case Full, Domain: - g.AddFullOrDomainPattern(pattern, t) - case Regex: - g.count++ - r, err := regexp.Compile(pattern) - if err != nil { - return 0, err - } - g.otherMatchers = append(g.otherMatchers, matcherEntry{ - m: ®exMatcher{pattern: r}, - id: g.count, - }) - default: - panic("Unknown type") - } - return g.count, nil -} - -func (g *ACAutomatonMatcherGroup) Build() { - if g.ac != nil { - g.ac.Build() - } -} - -// Match implements IndexMatcher.Match. -func (g *ACAutomatonMatcherGroup) Match(pattern string) []uint32 { - result := []uint32{} - hash := uint32(0) - for i := len(pattern) - 1; i >= 0; i-- { - hash = hash*PrimeRK + uint32(pattern[i]) - if pattern[i] == '.' { - if v, ok := g.nonSubstrMap[hash]; ok && v == pattern[i:] { - result = append(result, 1) - return result - } - } - } - if v, ok := g.nonSubstrMap[hash]; ok && v == pattern { - result = append(result, 1) - return result - } - if g.ac != nil && g.ac.Match(pattern) { - result = append(result, 1) - return result - } - for _, e := range g.otherMatchers { - if e.m.Match(pattern) { - result = append(result, e.id) - return result - } - } - return result -} - // MatcherGroup is an implementation of IndexMatcher. // Empty initialization works. type MatcherGroup struct {