1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-06-16 04:35:24 +00:00

Refactor strmatcher.MphMatcherGroup (#1364)

* Refactor strmatcher.MphMatcherGroup

* Add test for empty mph matcher group
This commit is contained in:
Ye Zhihao 2021-11-05 13:24:46 +08:00 committed by GitHub
parent 80d92381af
commit ed9641dad1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 246 additions and 192 deletions

View File

@ -10,134 +10,187 @@ import (
// PrimeRK is the prime base used in Rabin-Karp algorithm.
const PrimeRK = 16777619
// calculate the rolling murmurHash of given string
func RollingHash(s string) uint32 {
h := uint32(0)
for i := len(s) - 1; i >= 0; i-- {
h = h*PrimeRK + uint32(s[i])
// RollingHash calculates the rolling murmurHash of given string based on a provided suffix hash.
func RollingHash(hash uint32, input string) uint32 {
for i := len(input) - 1; i >= 0; i-- {
hash = hash*PrimeRK + uint32(input[i])
}
return h
return hash
}
// MemHash is the hash function used by go map, it utilizes available hardware instructions(behaves
// as aeshash if aes instruction is available).
// With different seed, each MemHash<seed> performs as distinct hash functions.
func MemHash(seed uint32, input string) uint32 {
return uint32(strhash(unsafe.Pointer(&input), uintptr(seed))) // nosemgrep
}
const (
mphMatchTypeCount = 2 // Full and Domain
)
type mphRuleInfo struct {
rollingHash uint32
matchers [mphMatchTypeCount][]uint32
}
// MphMatcherGroup is an implementation of MatcherGroup.
// It implements Rabin-Karp algorithm and minimal perfect hash table for Full and Domain matcher.
type MphMatcherGroup struct {
rules []string
level0 []uint32
level0Mask int
level1 []uint32
level1Mask int
ruleMap *map[string]uint32
rules []string // RuleIdx -> pattern string, index 0 reserved for failed lookup
values [][]uint32 // RuleIdx -> registered matcher values for the pattern (Full Matcher takes precedence)
level0 []uint32 // RollingHash & Mask -> seed for Memhash
level0Mask uint32 // Mask restricting RollingHash to 0 ~ len(level0)
level1 []uint32 // Memhash<seed> & Mask -> stored index for rules
level1Mask uint32 // Mask for restricting Memhash<seed> to 0 ~ len(level1)
ruleInfos *map[string]mphRuleInfo
}
func NewMphMatcherGroup() *MphMatcherGroup {
return &MphMatcherGroup{
rules: nil,
rules: []string{""},
values: [][]uint32{nil},
level0: nil,
level0Mask: 0,
level1: nil,
level1Mask: 0,
ruleMap: &map[string]uint32{},
ruleInfos: &map[string]mphRuleInfo{}, // Only used for building, destroyed after build complete
}
}
// AddFullMatcher implements MatcherGroupForFull.
func (g *MphMatcherGroup) AddFullMatcher(matcher FullMatcher, _ uint32) {
func (g *MphMatcherGroup) AddFullMatcher(matcher FullMatcher, value uint32) {
pattern := strings.ToLower(matcher.Pattern())
(*g.ruleMap)[pattern] = RollingHash(pattern)
g.addPattern(0, "", pattern, matcher.Type(), value)
}
// AddDomainMatcher implements MatcherGroupForDomain.
func (g *MphMatcherGroup) AddDomainMatcher(matcher DomainMatcher, _ uint32) {
func (g *MphMatcherGroup) AddDomainMatcher(matcher DomainMatcher, value uint32) {
pattern := strings.ToLower(matcher.Pattern())
h := RollingHash(pattern)
(*g.ruleMap)[pattern] = h
(*g.ruleMap)["."+pattern] = h*PrimeRK + uint32('.')
hash := g.addPattern(0, "", pattern, matcher.Type(), value) // For full domain match
g.addPattern(hash, pattern, ".", matcher.Type(), value) // For partial domain match
}
func (g *MphMatcherGroup) addPattern(suffixHash uint32, suffixPattern string, pattern string, matcherType Type, value uint32) uint32 {
fullPattern := pattern + suffixPattern
info, found := (*g.ruleInfos)[fullPattern]
if !found {
info = mphRuleInfo{rollingHash: RollingHash(suffixHash, pattern)}
g.rules = append(g.rules, fullPattern)
g.values = append(g.values, nil)
}
info.matchers[matcherType] = append(info.matchers[matcherType], value)
(*g.ruleInfos)[fullPattern] = info
return info.rollingHash
}
// Build builds a minimal perfect hash table for insert rules.
func (g *MphMatcherGroup) Build() {
keyLen := len(*g.ruleMap)
if keyLen == 0 {
keyLen = 1
(*g.ruleMap)["empty___"] = RollingHash("empty___")
}
g.level0 = make([]uint32, nextPow2(keyLen/4))
g.level0Mask = len(g.level0) - 1
g.level1 = make([]uint32, nextPow2(keyLen))
g.level1Mask = len(g.level1) - 1
sparseBuckets := make([][]int, len(g.level0))
var ruleIdx int
for rule, hash := range *g.ruleMap {
n := int(hash) & g.level0Mask
g.rules = append(g.rules, rule)
sparseBuckets[n] = append(sparseBuckets[n], ruleIdx)
ruleIdx++
}
g.ruleMap = nil
var buckets []indexBucket
for n, vals := range sparseBuckets {
if len(vals) > 0 {
buckets = append(buckets, indexBucket{n, vals})
}
}
sort.Sort(bySize(buckets))
// Algorithm used: Hash, displace, and compress. See http://cmph.sourceforge.net/papers/esa09.pdf
func (g *MphMatcherGroup) Build() error {
ruleCount := len(*g.ruleInfos)
g.level0 = make([]uint32, nextPow2(ruleCount/4))
g.level0Mask = uint32(len(g.level0) - 1)
g.level1 = make([]uint32, nextPow2(ruleCount))
g.level1Mask = uint32(len(g.level1) - 1)
occ := make([]bool, len(g.level1))
var tmpOcc []int
for _, bucket := range buckets {
// Create buckets based on all rule's rolling hash
buckets := make([][]uint32, len(g.level0))
for ruleIdx := 1; ruleIdx < len(g.rules); ruleIdx++ { // Traverse rules starting from index 1 (0 reserved for failed lookup)
ruleInfo := (*g.ruleInfos)[g.rules[ruleIdx]]
bucketIdx := ruleInfo.rollingHash & g.level0Mask
buckets[bucketIdx] = append(buckets[bucketIdx], uint32(ruleIdx))
g.values[ruleIdx] = append(ruleInfo.matchers[Full], ruleInfo.matchers[Domain]...) // nolint:gocritic
}
g.ruleInfos = nil // Set ruleInfos nil to release memory
// Sort buckets in descending order with respect to each bucket's size
bucketIdxs := make([]int, len(buckets))
for bucketIdx := range buckets {
bucketIdxs[bucketIdx] = bucketIdx
}
sort.Slice(bucketIdxs, func(i, j int) bool { return len(buckets[bucketIdxs[i]]) > len(buckets[bucketIdxs[j]]) })
// Exercise Hash, Displace, and Compress algorithm to construct minimal perfect hash table
occupied := make([]bool, len(g.level1)) // Whether a second-level hash has been already used
hashedBucket := make([]uint32, 0, 4) // Second-level hashes for each rule in a specific bucket
for _, bucketIdx := range bucketIdxs {
bucket := buckets[bucketIdx]
hashedBucket = hashedBucket[:0]
seed := uint32(0)
for {
findSeed := true
tmpOcc = tmpOcc[:0]
for _, i := range bucket.vals {
n := int(strhashFallback(unsafe.Pointer(&g.rules[i]), uintptr(seed))) & g.level1Mask // nosemgrep
if occ[n] {
for _, n := range tmpOcc {
occ[n] = false
for len(hashedBucket) != len(bucket) {
for _, ruleIdx := range bucket {
memHash := MemHash(seed, g.rules[ruleIdx]) & g.level1Mask
if occupied[memHash] { // Collision occurred with this seed
for _, hash := range hashedBucket { // Revert all values in this hashed bucket
occupied[hash] = false
g.level1[hash] = 0
}
seed++
findSeed = false
hashedBucket = hashedBucket[:0]
seed++ // Try next seed
break
}
occ[n] = true
tmpOcc = append(tmpOcc, n)
g.level1[n] = uint32(i)
}
if findSeed {
g.level0[bucket.n] = seed
break
occupied[memHash] = true
g.level1[memHash] = ruleIdx // The final value in the hash table
hashedBucket = append(hashedBucket, memHash)
}
}
g.level0[bucketIdx] = seed // Displacement value for this bucket
}
}
// Lookup searches for s in t and returns its index and whether it was found.
func (g *MphMatcherGroup) Lookup(h uint32, s string) bool {
i0 := int(h) & g.level0Mask
seed := g.level0[i0]
i1 := int(strhashFallback(unsafe.Pointer(&s), uintptr(seed))) & g.level1Mask // nosemgrep
n := g.level1[i1]
return s == g.rules[int(n)]
}
// Match implements MatcherGroup.Match.
func (*MphMatcherGroup) Match(_ string) []uint32 {
return nil
}
// MatchAny implements MatcherGroup.MatchAny.
func (g *MphMatcherGroup) MatchAny(pattern string) bool {
// Lookup searches for input in minimal perfect hash table and returns its index. 0 indicates not found.
func (g *MphMatcherGroup) Lookup(rollingHash uint32, input string) uint32 {
i0 := rollingHash & g.level0Mask
seed := g.level0[i0]
i1 := MemHash(seed, input) & g.level1Mask
if n := g.level1[i1]; g.rules[n] == input {
return n
}
return 0
}
// Match implements MatcherGroup.Match.
func (g *MphMatcherGroup) Match(input string) []uint32 {
matches := [][]uint32{}
hash := uint32(0)
for i := len(pattern) - 1; i >= 0; i-- {
hash = hash*PrimeRK + uint32(pattern[i])
if pattern[i] == '.' {
if g.Lookup(hash, pattern[i:]) {
for i := len(input) - 1; i >= 0; i-- {
hash = hash*PrimeRK + uint32(input[i])
if input[i] == '.' {
if mphIdx := g.Lookup(hash, input[i:]); mphIdx != 0 {
matches = append(matches, g.values[mphIdx])
}
}
}
if mphIdx := g.Lookup(hash, input); mphIdx != 0 {
matches = append(matches, g.values[mphIdx])
}
switch len(matches) {
case 0:
return nil
case 1:
return matches[0]
default:
result := []uint32{}
for i := len(matches) - 1; i >= 0; i-- {
result = append(result, matches[i]...)
}
return result
}
}
// MatchAny implements MatcherGroup.MatchAny.
func (g *MphMatcherGroup) MatchAny(input string) bool {
hash := uint32(0)
for i := len(input) - 1; i >= 0; i-- {
hash = hash*PrimeRK + uint32(input[i])
if input[i] == '.' {
if g.Lookup(hash, input[i:]) != 0 {
return true
}
}
}
return g.Lookup(hash, pattern)
return g.Lookup(hash, input) != 0
}
func nextPow2(v int) int {
@ -149,109 +202,6 @@ func nextPow2(v int) int {
return int(n)
}
type indexBucket struct {
n int
vals []int
}
type bySize []indexBucket
func (s bySize) Len() int { return len(s) }
func (s bySize) Less(i, j int) bool { return len(s[i].vals) > len(s[j].vals) }
func (s bySize) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
type stringStruct struct {
str unsafe.Pointer
len int
}
func strhashFallback(a unsafe.Pointer, h uintptr) uintptr {
x := (*stringStruct)(a)
return memhashFallback(x.str, h, uintptr(x.len))
}
const (
// Constants for multiplication: four random odd 64-bit numbers.
m1 = 16877499708836156737
m2 = 2820277070424839065
m3 = 9497967016996688599
m4 = 15839092249703872147
)
var hashkey = [4]uintptr{1, 1, 1, 1}
func memhashFallback(p unsafe.Pointer, seed, s uintptr) uintptr {
h := uint64(seed + s*hashkey[0])
tail:
switch {
case s == 0:
case s < 4:
h ^= uint64(*(*byte)(p))
h ^= uint64(*(*byte)(add(p, s>>1))) << 8
h ^= uint64(*(*byte)(add(p, s-1))) << 16
h = rotl31(h*m1) * m2
case s <= 8:
h ^= uint64(readUnaligned32(p))
h ^= uint64(readUnaligned32(add(p, s-4))) << 32
h = rotl31(h*m1) * m2
case s <= 16:
h ^= readUnaligned64(p)
h = rotl31(h*m1) * m2
h ^= readUnaligned64(add(p, s-8))
h = rotl31(h*m1) * m2
case s <= 32:
h ^= readUnaligned64(p)
h = rotl31(h*m1) * m2
h ^= readUnaligned64(add(p, 8))
h = rotl31(h*m1) * m2
h ^= readUnaligned64(add(p, s-16))
h = rotl31(h*m1) * m2
h ^= readUnaligned64(add(p, s-8))
h = rotl31(h*m1) * m2
default:
v1 := h
v2 := uint64(seed * hashkey[1])
v3 := uint64(seed * hashkey[2])
v4 := uint64(seed * hashkey[3])
for s >= 32 {
v1 ^= readUnaligned64(p)
v1 = rotl31(v1*m1) * m2
p = add(p, 8)
v2 ^= readUnaligned64(p)
v2 = rotl31(v2*m2) * m3
p = add(p, 8)
v3 ^= readUnaligned64(p)
v3 = rotl31(v3*m3) * m4
p = add(p, 8)
v4 ^= readUnaligned64(p)
v4 = rotl31(v4*m4) * m1
p = add(p, 8)
s -= 32
}
h = v1 ^ v2 ^ v3 ^ v4
goto tail
}
h ^= h >> 29
h *= m3
h ^= h >> 32
return uintptr(h)
}
func add(p unsafe.Pointer, x uintptr) unsafe.Pointer {
return unsafe.Pointer(uintptr(p) + x) // nosemgrep
}
func readUnaligned32(p unsafe.Pointer) uint32 {
q := (*[4]byte)(p)
return uint32(q[0]) | uint32(q[1])<<8 | uint32(q[2])<<16 | uint32(q[3])<<24
}
func rotl31(x uint64) uint64 {
return (x << 31) | (x >> (64 - 31))
}
func readUnaligned64(p unsafe.Pointer) uint64 {
q := (*[8]byte)(p)
return uint64(q[0]) | uint64(q[1])<<8 | uint64(q[2])<<16 | uint64(q[3])<<24 | uint64(q[4])<<32 | uint64(q[5])<<40 | uint64(q[6])<<48 | uint64(q[7])<<56
}
//go:noescape
//go:linkname strhash runtime.strhash
func strhash(p unsafe.Pointer, h uintptr) uintptr

View File

@ -1,6 +1,7 @@
package strmatcher_test
import (
"reflect"
"testing"
"github.com/v2fly/v2ray-core/v4/common"
@ -172,3 +173,106 @@ func TestMphMatcherGroup(t *testing.T) {
}
}
}
// See https://github.com/v2fly/v2ray-core/issues/92#issuecomment-673238489
func TestMphMatcherGroupAsIndexMatcher(t *testing.T) {
rules := []struct {
Type Type
Domain string
}{
// Regex not supported by MphMatcherGroup
// {
// Type: Regex,
// Domain: "apis\\.us$",
// },
// Substr not supported by MphMatcherGroup
// {
// Type: Substr,
// Domain: "apis",
// },
{
Type: Domain,
Domain: "googleapis.com",
},
{
Type: Domain,
Domain: "com",
},
{
Type: Full,
Domain: "www.baidu.com",
},
// Substr not supported by MphMatcherGroup, We add another matcher to preserve index
{
Type: Domain, // Substr,
Domain: "example.com", // "apis",
},
{
Type: Domain,
Domain: "googleapis.com",
},
{
Type: Full,
Domain: "fonts.googleapis.com",
},
{
Type: Full,
Domain: "www.baidu.com",
},
{ // This matcher (index 10) is swapped with matcher (index 6) to test that full matcher takes high priority.
Type: Full,
Domain: "example.com",
},
{
Type: Domain,
Domain: "example.com",
},
}
cases := []struct {
Input string
Output []uint32
}{
{
Input: "www.baidu.com",
Output: []uint32{5, 9, 4},
},
{
Input: "fonts.googleapis.com",
Output: []uint32{8, 3, 7, 4 /*2, 6*/},
},
{
Input: "example.googleapis.com",
Output: []uint32{3, 7, 4 /*2, 6*/},
},
{
Input: "testapis.us",
// Output: []uint32{ /*2, 6*/ /*1,*/ },
Output: nil,
},
{
Input: "example.com",
Output: []uint32{10, 6, 11, 4},
},
}
matcherGroup := NewMphMatcherGroup()
for i, rule := range rules {
matcher, err := rule.Type.New(rule.Domain)
common.Must(err)
common.Must(AddMatcherToGroup(matcherGroup, matcher, uint32(i+3)))
}
matcherGroup.Build()
for _, test := range cases {
if m := matcherGroup.Match(test.Input); !reflect.DeepEqual(m, test.Output) {
t.Error("unexpected output: ", m, " for test case ", test)
}
}
}
func TestEmptyMphMatcherGroup(t *testing.T) {
g := NewMphMatcherGroup()
g.Build()
r := g.Match("v2fly.org")
if len(r) != 0 {
t.Error("Expect [], but ", r)
}
}