From ed9641dad1ab971e2977e26f2fd4545eb6524e45 Mon Sep 17 00:00:00 2001 From: Ye Zhihao Date: Fri, 5 Nov 2021 13:24:46 +0800 Subject: [PATCH] Refactor strmatcher.MphMatcherGroup (#1364) * Refactor strmatcher.MphMatcherGroup * Add test for empty mph matcher group --- common/strmatcher/matchergroup_mph.go | 334 +++++++++------------ common/strmatcher/matchergroup_mph_test.go | 104 +++++++ 2 files changed, 246 insertions(+), 192 deletions(-) diff --git a/common/strmatcher/matchergroup_mph.go b/common/strmatcher/matchergroup_mph.go index 0ec1146e8..d842e4486 100644 --- a/common/strmatcher/matchergroup_mph.go +++ b/common/strmatcher/matchergroup_mph.go @@ -10,134 +10,187 @@ import ( // PrimeRK is the prime base used in Rabin-Karp algorithm. const PrimeRK = 16777619 -// calculate the rolling murmurHash of given string -func RollingHash(s string) uint32 { - h := uint32(0) - for i := len(s) - 1; i >= 0; i-- { - h = h*PrimeRK + uint32(s[i]) +// RollingHash calculates the rolling murmurHash of given string based on a provided suffix hash. +func RollingHash(hash uint32, input string) uint32 { + for i := len(input) - 1; i >= 0; i-- { + hash = hash*PrimeRK + uint32(input[i]) } - return h + return hash +} + +// MemHash is the hash function used by go map, it utilizes available hardware instructions(behaves +// as aeshash if aes instruction is available). +// With different seed, each MemHash performs as distinct hash functions. +func MemHash(seed uint32, input string) uint32 { + return uint32(strhash(unsafe.Pointer(&input), uintptr(seed))) // nosemgrep +} + +const ( + mphMatchTypeCount = 2 // Full and Domain +) + +type mphRuleInfo struct { + rollingHash uint32 + matchers [mphMatchTypeCount][]uint32 } // MphMatcherGroup is an implementation of MatcherGroup. // It implements Rabin-Karp algorithm and minimal perfect hash table for Full and Domain matcher. type MphMatcherGroup struct { - rules []string - level0 []uint32 - level0Mask int - level1 []uint32 - level1Mask int - ruleMap *map[string]uint32 + rules []string // RuleIdx -> pattern string, index 0 reserved for failed lookup + values [][]uint32 // RuleIdx -> registered matcher values for the pattern (Full Matcher takes precedence) + level0 []uint32 // RollingHash & Mask -> seed for Memhash + level0Mask uint32 // Mask restricting RollingHash to 0 ~ len(level0) + level1 []uint32 // Memhash & Mask -> stored index for rules + level1Mask uint32 // Mask for restricting Memhash to 0 ~ len(level1) + ruleInfos *map[string]mphRuleInfo } func NewMphMatcherGroup() *MphMatcherGroup { return &MphMatcherGroup{ - rules: nil, + rules: []string{""}, + values: [][]uint32{nil}, level0: nil, level0Mask: 0, level1: nil, level1Mask: 0, - ruleMap: &map[string]uint32{}, + ruleInfos: &map[string]mphRuleInfo{}, // Only used for building, destroyed after build complete } } // AddFullMatcher implements MatcherGroupForFull. -func (g *MphMatcherGroup) AddFullMatcher(matcher FullMatcher, _ uint32) { +func (g *MphMatcherGroup) AddFullMatcher(matcher FullMatcher, value uint32) { pattern := strings.ToLower(matcher.Pattern()) - (*g.ruleMap)[pattern] = RollingHash(pattern) + g.addPattern(0, "", pattern, matcher.Type(), value) } // AddDomainMatcher implements MatcherGroupForDomain. -func (g *MphMatcherGroup) AddDomainMatcher(matcher DomainMatcher, _ uint32) { +func (g *MphMatcherGroup) AddDomainMatcher(matcher DomainMatcher, value uint32) { pattern := strings.ToLower(matcher.Pattern()) - h := RollingHash(pattern) - (*g.ruleMap)[pattern] = h - (*g.ruleMap)["."+pattern] = h*PrimeRK + uint32('.') + hash := g.addPattern(0, "", pattern, matcher.Type(), value) // For full domain match + g.addPattern(hash, pattern, ".", matcher.Type(), value) // For partial domain match +} + +func (g *MphMatcherGroup) addPattern(suffixHash uint32, suffixPattern string, pattern string, matcherType Type, value uint32) uint32 { + fullPattern := pattern + suffixPattern + info, found := (*g.ruleInfos)[fullPattern] + if !found { + info = mphRuleInfo{rollingHash: RollingHash(suffixHash, pattern)} + g.rules = append(g.rules, fullPattern) + g.values = append(g.values, nil) + } + info.matchers[matcherType] = append(info.matchers[matcherType], value) + (*g.ruleInfos)[fullPattern] = info + return info.rollingHash } // Build builds a minimal perfect hash table for insert rules. -func (g *MphMatcherGroup) Build() { - keyLen := len(*g.ruleMap) - if keyLen == 0 { - keyLen = 1 - (*g.ruleMap)["empty___"] = RollingHash("empty___") - } - g.level0 = make([]uint32, nextPow2(keyLen/4)) - g.level0Mask = len(g.level0) - 1 - g.level1 = make([]uint32, nextPow2(keyLen)) - g.level1Mask = len(g.level1) - 1 - sparseBuckets := make([][]int, len(g.level0)) - var ruleIdx int - for rule, hash := range *g.ruleMap { - n := int(hash) & g.level0Mask - g.rules = append(g.rules, rule) - sparseBuckets[n] = append(sparseBuckets[n], ruleIdx) - ruleIdx++ - } - g.ruleMap = nil - var buckets []indexBucket - for n, vals := range sparseBuckets { - if len(vals) > 0 { - buckets = append(buckets, indexBucket{n, vals}) - } - } - sort.Sort(bySize(buckets)) +// Algorithm used: Hash, displace, and compress. See http://cmph.sourceforge.net/papers/esa09.pdf +func (g *MphMatcherGroup) Build() error { + ruleCount := len(*g.ruleInfos) + g.level0 = make([]uint32, nextPow2(ruleCount/4)) + g.level0Mask = uint32(len(g.level0) - 1) + g.level1 = make([]uint32, nextPow2(ruleCount)) + g.level1Mask = uint32(len(g.level1) - 1) - occ := make([]bool, len(g.level1)) - var tmpOcc []int - for _, bucket := range buckets { + // Create buckets based on all rule's rolling hash + buckets := make([][]uint32, len(g.level0)) + for ruleIdx := 1; ruleIdx < len(g.rules); ruleIdx++ { // Traverse rules starting from index 1 (0 reserved for failed lookup) + ruleInfo := (*g.ruleInfos)[g.rules[ruleIdx]] + bucketIdx := ruleInfo.rollingHash & g.level0Mask + buckets[bucketIdx] = append(buckets[bucketIdx], uint32(ruleIdx)) + g.values[ruleIdx] = append(ruleInfo.matchers[Full], ruleInfo.matchers[Domain]...) // nolint:gocritic + } + g.ruleInfos = nil // Set ruleInfos nil to release memory + + // Sort buckets in descending order with respect to each bucket's size + bucketIdxs := make([]int, len(buckets)) + for bucketIdx := range buckets { + bucketIdxs[bucketIdx] = bucketIdx + } + sort.Slice(bucketIdxs, func(i, j int) bool { return len(buckets[bucketIdxs[i]]) > len(buckets[bucketIdxs[j]]) }) + + // Exercise Hash, Displace, and Compress algorithm to construct minimal perfect hash table + occupied := make([]bool, len(g.level1)) // Whether a second-level hash has been already used + hashedBucket := make([]uint32, 0, 4) // Second-level hashes for each rule in a specific bucket + for _, bucketIdx := range bucketIdxs { + bucket := buckets[bucketIdx] + hashedBucket = hashedBucket[:0] seed := uint32(0) - for { - findSeed := true - tmpOcc = tmpOcc[:0] - for _, i := range bucket.vals { - n := int(strhashFallback(unsafe.Pointer(&g.rules[i]), uintptr(seed))) & g.level1Mask // nosemgrep - if occ[n] { - for _, n := range tmpOcc { - occ[n] = false + for len(hashedBucket) != len(bucket) { + for _, ruleIdx := range bucket { + memHash := MemHash(seed, g.rules[ruleIdx]) & g.level1Mask + if occupied[memHash] { // Collision occurred with this seed + for _, hash := range hashedBucket { // Revert all values in this hashed bucket + occupied[hash] = false + g.level1[hash] = 0 } - seed++ - findSeed = false + hashedBucket = hashedBucket[:0] + seed++ // Try next seed break } - occ[n] = true - tmpOcc = append(tmpOcc, n) - g.level1[n] = uint32(i) - } - if findSeed { - g.level0[bucket.n] = seed - break + occupied[memHash] = true + g.level1[memHash] = ruleIdx // The final value in the hash table + hashedBucket = append(hashedBucket, memHash) } } + g.level0[bucketIdx] = seed // Displacement value for this bucket } -} - -// Lookup searches for s in t and returns its index and whether it was found. -func (g *MphMatcherGroup) Lookup(h uint32, s string) bool { - i0 := int(h) & g.level0Mask - seed := g.level0[i0] - i1 := int(strhashFallback(unsafe.Pointer(&s), uintptr(seed))) & g.level1Mask // nosemgrep - n := g.level1[i1] - return s == g.rules[int(n)] -} - -// Match implements MatcherGroup.Match. -func (*MphMatcherGroup) Match(_ string) []uint32 { return nil } -// MatchAny implements MatcherGroup.MatchAny. -func (g *MphMatcherGroup) MatchAny(pattern string) bool { +// Lookup searches for input in minimal perfect hash table and returns its index. 0 indicates not found. +func (g *MphMatcherGroup) Lookup(rollingHash uint32, input string) uint32 { + i0 := rollingHash & g.level0Mask + seed := g.level0[i0] + i1 := MemHash(seed, input) & g.level1Mask + if n := g.level1[i1]; g.rules[n] == input { + return n + } + return 0 +} + +// Match implements MatcherGroup.Match. +func (g *MphMatcherGroup) Match(input string) []uint32 { + matches := [][]uint32{} hash := uint32(0) - for i := len(pattern) - 1; i >= 0; i-- { - hash = hash*PrimeRK + uint32(pattern[i]) - if pattern[i] == '.' { - if g.Lookup(hash, pattern[i:]) { + for i := len(input) - 1; i >= 0; i-- { + hash = hash*PrimeRK + uint32(input[i]) + if input[i] == '.' { + if mphIdx := g.Lookup(hash, input[i:]); mphIdx != 0 { + matches = append(matches, g.values[mphIdx]) + } + } + } + if mphIdx := g.Lookup(hash, input); mphIdx != 0 { + matches = append(matches, g.values[mphIdx]) + } + switch len(matches) { + case 0: + return nil + case 1: + return matches[0] + default: + result := []uint32{} + for i := len(matches) - 1; i >= 0; i-- { + result = append(result, matches[i]...) + } + return result + } +} + +// MatchAny implements MatcherGroup.MatchAny. +func (g *MphMatcherGroup) MatchAny(input string) bool { + hash := uint32(0) + for i := len(input) - 1; i >= 0; i-- { + hash = hash*PrimeRK + uint32(input[i]) + if input[i] == '.' { + if g.Lookup(hash, input[i:]) != 0 { return true } } } - return g.Lookup(hash, pattern) + return g.Lookup(hash, input) != 0 } func nextPow2(v int) int { @@ -149,109 +202,6 @@ func nextPow2(v int) int { return int(n) } -type indexBucket struct { - n int - vals []int -} - -type bySize []indexBucket - -func (s bySize) Len() int { return len(s) } -func (s bySize) Less(i, j int) bool { return len(s[i].vals) > len(s[j].vals) } -func (s bySize) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - -type stringStruct struct { - str unsafe.Pointer - len int -} - -func strhashFallback(a unsafe.Pointer, h uintptr) uintptr { - x := (*stringStruct)(a) - return memhashFallback(x.str, h, uintptr(x.len)) -} - -const ( - // Constants for multiplication: four random odd 64-bit numbers. - m1 = 16877499708836156737 - m2 = 2820277070424839065 - m3 = 9497967016996688599 - m4 = 15839092249703872147 -) - -var hashkey = [4]uintptr{1, 1, 1, 1} - -func memhashFallback(p unsafe.Pointer, seed, s uintptr) uintptr { - h := uint64(seed + s*hashkey[0]) -tail: - switch { - case s == 0: - case s < 4: - h ^= uint64(*(*byte)(p)) - h ^= uint64(*(*byte)(add(p, s>>1))) << 8 - h ^= uint64(*(*byte)(add(p, s-1))) << 16 - h = rotl31(h*m1) * m2 - case s <= 8: - h ^= uint64(readUnaligned32(p)) - h ^= uint64(readUnaligned32(add(p, s-4))) << 32 - h = rotl31(h*m1) * m2 - case s <= 16: - h ^= readUnaligned64(p) - h = rotl31(h*m1) * m2 - h ^= readUnaligned64(add(p, s-8)) - h = rotl31(h*m1) * m2 - case s <= 32: - h ^= readUnaligned64(p) - h = rotl31(h*m1) * m2 - h ^= readUnaligned64(add(p, 8)) - h = rotl31(h*m1) * m2 - h ^= readUnaligned64(add(p, s-16)) - h = rotl31(h*m1) * m2 - h ^= readUnaligned64(add(p, s-8)) - h = rotl31(h*m1) * m2 - default: - v1 := h - v2 := uint64(seed * hashkey[1]) - v3 := uint64(seed * hashkey[2]) - v4 := uint64(seed * hashkey[3]) - for s >= 32 { - v1 ^= readUnaligned64(p) - v1 = rotl31(v1*m1) * m2 - p = add(p, 8) - v2 ^= readUnaligned64(p) - v2 = rotl31(v2*m2) * m3 - p = add(p, 8) - v3 ^= readUnaligned64(p) - v3 = rotl31(v3*m3) * m4 - p = add(p, 8) - v4 ^= readUnaligned64(p) - v4 = rotl31(v4*m4) * m1 - p = add(p, 8) - s -= 32 - } - h = v1 ^ v2 ^ v3 ^ v4 - goto tail - } - - h ^= h >> 29 - h *= m3 - h ^= h >> 32 - return uintptr(h) -} - -func add(p unsafe.Pointer, x uintptr) unsafe.Pointer { - return unsafe.Pointer(uintptr(p) + x) // nosemgrep -} - -func readUnaligned32(p unsafe.Pointer) uint32 { - q := (*[4]byte)(p) - return uint32(q[0]) | uint32(q[1])<<8 | uint32(q[2])<<16 | uint32(q[3])<<24 -} - -func rotl31(x uint64) uint64 { - return (x << 31) | (x >> (64 - 31)) -} - -func readUnaligned64(p unsafe.Pointer) uint64 { - q := (*[8]byte)(p) - return uint64(q[0]) | uint64(q[1])<<8 | uint64(q[2])<<16 | uint64(q[3])<<24 | uint64(q[4])<<32 | uint64(q[5])<<40 | uint64(q[6])<<48 | uint64(q[7])<<56 -} +//go:noescape +//go:linkname strhash runtime.strhash +func strhash(p unsafe.Pointer, h uintptr) uintptr diff --git a/common/strmatcher/matchergroup_mph_test.go b/common/strmatcher/matchergroup_mph_test.go index 88b569036..b876227e6 100644 --- a/common/strmatcher/matchergroup_mph_test.go +++ b/common/strmatcher/matchergroup_mph_test.go @@ -1,6 +1,7 @@ package strmatcher_test import ( + "reflect" "testing" "github.com/v2fly/v2ray-core/v4/common" @@ -172,3 +173,106 @@ func TestMphMatcherGroup(t *testing.T) { } } } + +// See https://github.com/v2fly/v2ray-core/issues/92#issuecomment-673238489 +func TestMphMatcherGroupAsIndexMatcher(t *testing.T) { + rules := []struct { + Type Type + Domain string + }{ + // Regex not supported by MphMatcherGroup + // { + // Type: Regex, + // Domain: "apis\\.us$", + // }, + // Substr not supported by MphMatcherGroup + // { + // Type: Substr, + // Domain: "apis", + // }, + { + Type: Domain, + Domain: "googleapis.com", + }, + { + Type: Domain, + Domain: "com", + }, + { + Type: Full, + Domain: "www.baidu.com", + }, + // Substr not supported by MphMatcherGroup, We add another matcher to preserve index + { + Type: Domain, // Substr, + Domain: "example.com", // "apis", + }, + { + Type: Domain, + Domain: "googleapis.com", + }, + { + Type: Full, + Domain: "fonts.googleapis.com", + }, + { + Type: Full, + Domain: "www.baidu.com", + }, + { // This matcher (index 10) is swapped with matcher (index 6) to test that full matcher takes high priority. + Type: Full, + Domain: "example.com", + }, + { + Type: Domain, + Domain: "example.com", + }, + } + cases := []struct { + Input string + Output []uint32 + }{ + { + Input: "www.baidu.com", + Output: []uint32{5, 9, 4}, + }, + { + Input: "fonts.googleapis.com", + Output: []uint32{8, 3, 7, 4 /*2, 6*/}, + }, + { + Input: "example.googleapis.com", + Output: []uint32{3, 7, 4 /*2, 6*/}, + }, + { + Input: "testapis.us", + // Output: []uint32{ /*2, 6*/ /*1,*/ }, + Output: nil, + }, + { + Input: "example.com", + Output: []uint32{10, 6, 11, 4}, + }, + } + matcherGroup := NewMphMatcherGroup() + for i, rule := range rules { + matcher, err := rule.Type.New(rule.Domain) + common.Must(err) + common.Must(AddMatcherToGroup(matcherGroup, matcher, uint32(i+3))) + } + matcherGroup.Build() + for _, test := range cases { + if m := matcherGroup.Match(test.Input); !reflect.DeepEqual(m, test.Output) { + t.Error("unexpected output: ", m, " for test case ", test) + } + } +} + +func TestEmptyMphMatcherGroup(t *testing.T) { + g := NewMphMatcherGroup() + g.Build() + r := g.Match("v2fly.org") + if len(r) != 0 { + t.Error("Expect [], but ", r) + } +}