mirror of
https://github.com/v2fly/v2ray-core.git
synced 2024-09-07 20:45:19 -04:00
Refactor strmatcher.MphMatcherGroup (#1364)
* Refactor strmatcher.MphMatcherGroup * Add test for empty mph matcher group
This commit is contained in:
parent
80d92381af
commit
ed9641dad1
@ -10,134 +10,187 @@ import (
|
||||
// PrimeRK is the prime base used in Rabin-Karp algorithm.
|
||||
const PrimeRK = 16777619
|
||||
|
||||
// calculate the rolling murmurHash of given string
|
||||
func RollingHash(s string) uint32 {
|
||||
h := uint32(0)
|
||||
for i := len(s) - 1; i >= 0; i-- {
|
||||
h = h*PrimeRK + uint32(s[i])
|
||||
// RollingHash calculates the rolling murmurHash of given string based on a provided suffix hash.
|
||||
func RollingHash(hash uint32, input string) uint32 {
|
||||
for i := len(input) - 1; i >= 0; i-- {
|
||||
hash = hash*PrimeRK + uint32(input[i])
|
||||
}
|
||||
return h
|
||||
return hash
|
||||
}
|
||||
|
||||
// MemHash is the hash function used by go map, it utilizes available hardware instructions(behaves
|
||||
// as aeshash if aes instruction is available).
|
||||
// With different seed, each MemHash<seed> performs as distinct hash functions.
|
||||
func MemHash(seed uint32, input string) uint32 {
|
||||
return uint32(strhash(unsafe.Pointer(&input), uintptr(seed))) // nosemgrep
|
||||
}
|
||||
|
||||
const (
|
||||
mphMatchTypeCount = 2 // Full and Domain
|
||||
)
|
||||
|
||||
type mphRuleInfo struct {
|
||||
rollingHash uint32
|
||||
matchers [mphMatchTypeCount][]uint32
|
||||
}
|
||||
|
||||
// MphMatcherGroup is an implementation of MatcherGroup.
|
||||
// It implements Rabin-Karp algorithm and minimal perfect hash table for Full and Domain matcher.
|
||||
type MphMatcherGroup struct {
|
||||
rules []string
|
||||
level0 []uint32
|
||||
level0Mask int
|
||||
level1 []uint32
|
||||
level1Mask int
|
||||
ruleMap *map[string]uint32
|
||||
rules []string // RuleIdx -> pattern string, index 0 reserved for failed lookup
|
||||
values [][]uint32 // RuleIdx -> registered matcher values for the pattern (Full Matcher takes precedence)
|
||||
level0 []uint32 // RollingHash & Mask -> seed for Memhash
|
||||
level0Mask uint32 // Mask restricting RollingHash to 0 ~ len(level0)
|
||||
level1 []uint32 // Memhash<seed> & Mask -> stored index for rules
|
||||
level1Mask uint32 // Mask for restricting Memhash<seed> to 0 ~ len(level1)
|
||||
ruleInfos *map[string]mphRuleInfo
|
||||
}
|
||||
|
||||
func NewMphMatcherGroup() *MphMatcherGroup {
|
||||
return &MphMatcherGroup{
|
||||
rules: nil,
|
||||
rules: []string{""},
|
||||
values: [][]uint32{nil},
|
||||
level0: nil,
|
||||
level0Mask: 0,
|
||||
level1: nil,
|
||||
level1Mask: 0,
|
||||
ruleMap: &map[string]uint32{},
|
||||
ruleInfos: &map[string]mphRuleInfo{}, // Only used for building, destroyed after build complete
|
||||
}
|
||||
}
|
||||
|
||||
// AddFullMatcher implements MatcherGroupForFull.
|
||||
func (g *MphMatcherGroup) AddFullMatcher(matcher FullMatcher, _ uint32) {
|
||||
func (g *MphMatcherGroup) AddFullMatcher(matcher FullMatcher, value uint32) {
|
||||
pattern := strings.ToLower(matcher.Pattern())
|
||||
(*g.ruleMap)[pattern] = RollingHash(pattern)
|
||||
g.addPattern(0, "", pattern, matcher.Type(), value)
|
||||
}
|
||||
|
||||
// AddDomainMatcher implements MatcherGroupForDomain.
|
||||
func (g *MphMatcherGroup) AddDomainMatcher(matcher DomainMatcher, _ uint32) {
|
||||
func (g *MphMatcherGroup) AddDomainMatcher(matcher DomainMatcher, value uint32) {
|
||||
pattern := strings.ToLower(matcher.Pattern())
|
||||
h := RollingHash(pattern)
|
||||
(*g.ruleMap)[pattern] = h
|
||||
(*g.ruleMap)["."+pattern] = h*PrimeRK + uint32('.')
|
||||
hash := g.addPattern(0, "", pattern, matcher.Type(), value) // For full domain match
|
||||
g.addPattern(hash, pattern, ".", matcher.Type(), value) // For partial domain match
|
||||
}
|
||||
|
||||
func (g *MphMatcherGroup) addPattern(suffixHash uint32, suffixPattern string, pattern string, matcherType Type, value uint32) uint32 {
|
||||
fullPattern := pattern + suffixPattern
|
||||
info, found := (*g.ruleInfos)[fullPattern]
|
||||
if !found {
|
||||
info = mphRuleInfo{rollingHash: RollingHash(suffixHash, pattern)}
|
||||
g.rules = append(g.rules, fullPattern)
|
||||
g.values = append(g.values, nil)
|
||||
}
|
||||
info.matchers[matcherType] = append(info.matchers[matcherType], value)
|
||||
(*g.ruleInfos)[fullPattern] = info
|
||||
return info.rollingHash
|
||||
}
|
||||
|
||||
// Build builds a minimal perfect hash table for insert rules.
|
||||
func (g *MphMatcherGroup) Build() {
|
||||
keyLen := len(*g.ruleMap)
|
||||
if keyLen == 0 {
|
||||
keyLen = 1
|
||||
(*g.ruleMap)["empty___"] = RollingHash("empty___")
|
||||
}
|
||||
g.level0 = make([]uint32, nextPow2(keyLen/4))
|
||||
g.level0Mask = len(g.level0) - 1
|
||||
g.level1 = make([]uint32, nextPow2(keyLen))
|
||||
g.level1Mask = len(g.level1) - 1
|
||||
sparseBuckets := make([][]int, len(g.level0))
|
||||
var ruleIdx int
|
||||
for rule, hash := range *g.ruleMap {
|
||||
n := int(hash) & g.level0Mask
|
||||
g.rules = append(g.rules, rule)
|
||||
sparseBuckets[n] = append(sparseBuckets[n], ruleIdx)
|
||||
ruleIdx++
|
||||
}
|
||||
g.ruleMap = nil
|
||||
var buckets []indexBucket
|
||||
for n, vals := range sparseBuckets {
|
||||
if len(vals) > 0 {
|
||||
buckets = append(buckets, indexBucket{n, vals})
|
||||
}
|
||||
}
|
||||
sort.Sort(bySize(buckets))
|
||||
// Algorithm used: Hash, displace, and compress. See http://cmph.sourceforge.net/papers/esa09.pdf
|
||||
func (g *MphMatcherGroup) Build() error {
|
||||
ruleCount := len(*g.ruleInfos)
|
||||
g.level0 = make([]uint32, nextPow2(ruleCount/4))
|
||||
g.level0Mask = uint32(len(g.level0) - 1)
|
||||
g.level1 = make([]uint32, nextPow2(ruleCount))
|
||||
g.level1Mask = uint32(len(g.level1) - 1)
|
||||
|
||||
occ := make([]bool, len(g.level1))
|
||||
var tmpOcc []int
|
||||
for _, bucket := range buckets {
|
||||
// Create buckets based on all rule's rolling hash
|
||||
buckets := make([][]uint32, len(g.level0))
|
||||
for ruleIdx := 1; ruleIdx < len(g.rules); ruleIdx++ { // Traverse rules starting from index 1 (0 reserved for failed lookup)
|
||||
ruleInfo := (*g.ruleInfos)[g.rules[ruleIdx]]
|
||||
bucketIdx := ruleInfo.rollingHash & g.level0Mask
|
||||
buckets[bucketIdx] = append(buckets[bucketIdx], uint32(ruleIdx))
|
||||
g.values[ruleIdx] = append(ruleInfo.matchers[Full], ruleInfo.matchers[Domain]...) // nolint:gocritic
|
||||
}
|
||||
g.ruleInfos = nil // Set ruleInfos nil to release memory
|
||||
|
||||
// Sort buckets in descending order with respect to each bucket's size
|
||||
bucketIdxs := make([]int, len(buckets))
|
||||
for bucketIdx := range buckets {
|
||||
bucketIdxs[bucketIdx] = bucketIdx
|
||||
}
|
||||
sort.Slice(bucketIdxs, func(i, j int) bool { return len(buckets[bucketIdxs[i]]) > len(buckets[bucketIdxs[j]]) })
|
||||
|
||||
// Exercise Hash, Displace, and Compress algorithm to construct minimal perfect hash table
|
||||
occupied := make([]bool, len(g.level1)) // Whether a second-level hash has been already used
|
||||
hashedBucket := make([]uint32, 0, 4) // Second-level hashes for each rule in a specific bucket
|
||||
for _, bucketIdx := range bucketIdxs {
|
||||
bucket := buckets[bucketIdx]
|
||||
hashedBucket = hashedBucket[:0]
|
||||
seed := uint32(0)
|
||||
for {
|
||||
findSeed := true
|
||||
tmpOcc = tmpOcc[:0]
|
||||
for _, i := range bucket.vals {
|
||||
n := int(strhashFallback(unsafe.Pointer(&g.rules[i]), uintptr(seed))) & g.level1Mask // nosemgrep
|
||||
if occ[n] {
|
||||
for _, n := range tmpOcc {
|
||||
occ[n] = false
|
||||
for len(hashedBucket) != len(bucket) {
|
||||
for _, ruleIdx := range bucket {
|
||||
memHash := MemHash(seed, g.rules[ruleIdx]) & g.level1Mask
|
||||
if occupied[memHash] { // Collision occurred with this seed
|
||||
for _, hash := range hashedBucket { // Revert all values in this hashed bucket
|
||||
occupied[hash] = false
|
||||
g.level1[hash] = 0
|
||||
}
|
||||
seed++
|
||||
findSeed = false
|
||||
hashedBucket = hashedBucket[:0]
|
||||
seed++ // Try next seed
|
||||
break
|
||||
}
|
||||
occ[n] = true
|
||||
tmpOcc = append(tmpOcc, n)
|
||||
g.level1[n] = uint32(i)
|
||||
}
|
||||
if findSeed {
|
||||
g.level0[bucket.n] = seed
|
||||
break
|
||||
occupied[memHash] = true
|
||||
g.level1[memHash] = ruleIdx // The final value in the hash table
|
||||
hashedBucket = append(hashedBucket, memHash)
|
||||
}
|
||||
}
|
||||
g.level0[bucketIdx] = seed // Displacement value for this bucket
|
||||
}
|
||||
}
|
||||
|
||||
// Lookup searches for s in t and returns its index and whether it was found.
|
||||
func (g *MphMatcherGroup) Lookup(h uint32, s string) bool {
|
||||
i0 := int(h) & g.level0Mask
|
||||
seed := g.level0[i0]
|
||||
i1 := int(strhashFallback(unsafe.Pointer(&s), uintptr(seed))) & g.level1Mask // nosemgrep
|
||||
n := g.level1[i1]
|
||||
return s == g.rules[int(n)]
|
||||
}
|
||||
|
||||
// Match implements MatcherGroup.Match.
|
||||
func (*MphMatcherGroup) Match(_ string) []uint32 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MatchAny implements MatcherGroup.MatchAny.
|
||||
func (g *MphMatcherGroup) MatchAny(pattern string) bool {
|
||||
// Lookup searches for input in minimal perfect hash table and returns its index. 0 indicates not found.
|
||||
func (g *MphMatcherGroup) Lookup(rollingHash uint32, input string) uint32 {
|
||||
i0 := rollingHash & g.level0Mask
|
||||
seed := g.level0[i0]
|
||||
i1 := MemHash(seed, input) & g.level1Mask
|
||||
if n := g.level1[i1]; g.rules[n] == input {
|
||||
return n
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Match implements MatcherGroup.Match.
|
||||
func (g *MphMatcherGroup) Match(input string) []uint32 {
|
||||
matches := [][]uint32{}
|
||||
hash := uint32(0)
|
||||
for i := len(pattern) - 1; i >= 0; i-- {
|
||||
hash = hash*PrimeRK + uint32(pattern[i])
|
||||
if pattern[i] == '.' {
|
||||
if g.Lookup(hash, pattern[i:]) {
|
||||
for i := len(input) - 1; i >= 0; i-- {
|
||||
hash = hash*PrimeRK + uint32(input[i])
|
||||
if input[i] == '.' {
|
||||
if mphIdx := g.Lookup(hash, input[i:]); mphIdx != 0 {
|
||||
matches = append(matches, g.values[mphIdx])
|
||||
}
|
||||
}
|
||||
}
|
||||
if mphIdx := g.Lookup(hash, input); mphIdx != 0 {
|
||||
matches = append(matches, g.values[mphIdx])
|
||||
}
|
||||
switch len(matches) {
|
||||
case 0:
|
||||
return nil
|
||||
case 1:
|
||||
return matches[0]
|
||||
default:
|
||||
result := []uint32{}
|
||||
for i := len(matches) - 1; i >= 0; i-- {
|
||||
result = append(result, matches[i]...)
|
||||
}
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
// MatchAny implements MatcherGroup.MatchAny.
|
||||
func (g *MphMatcherGroup) MatchAny(input string) bool {
|
||||
hash := uint32(0)
|
||||
for i := len(input) - 1; i >= 0; i-- {
|
||||
hash = hash*PrimeRK + uint32(input[i])
|
||||
if input[i] == '.' {
|
||||
if g.Lookup(hash, input[i:]) != 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return g.Lookup(hash, pattern)
|
||||
return g.Lookup(hash, input) != 0
|
||||
}
|
||||
|
||||
func nextPow2(v int) int {
|
||||
@ -149,109 +202,6 @@ func nextPow2(v int) int {
|
||||
return int(n)
|
||||
}
|
||||
|
||||
type indexBucket struct {
|
||||
n int
|
||||
vals []int
|
||||
}
|
||||
|
||||
type bySize []indexBucket
|
||||
|
||||
func (s bySize) Len() int { return len(s) }
|
||||
func (s bySize) Less(i, j int) bool { return len(s[i].vals) > len(s[j].vals) }
|
||||
func (s bySize) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
|
||||
type stringStruct struct {
|
||||
str unsafe.Pointer
|
||||
len int
|
||||
}
|
||||
|
||||
func strhashFallback(a unsafe.Pointer, h uintptr) uintptr {
|
||||
x := (*stringStruct)(a)
|
||||
return memhashFallback(x.str, h, uintptr(x.len))
|
||||
}
|
||||
|
||||
const (
|
||||
// Constants for multiplication: four random odd 64-bit numbers.
|
||||
m1 = 16877499708836156737
|
||||
m2 = 2820277070424839065
|
||||
m3 = 9497967016996688599
|
||||
m4 = 15839092249703872147
|
||||
)
|
||||
|
||||
var hashkey = [4]uintptr{1, 1, 1, 1}
|
||||
|
||||
func memhashFallback(p unsafe.Pointer, seed, s uintptr) uintptr {
|
||||
h := uint64(seed + s*hashkey[0])
|
||||
tail:
|
||||
switch {
|
||||
case s == 0:
|
||||
case s < 4:
|
||||
h ^= uint64(*(*byte)(p))
|
||||
h ^= uint64(*(*byte)(add(p, s>>1))) << 8
|
||||
h ^= uint64(*(*byte)(add(p, s-1))) << 16
|
||||
h = rotl31(h*m1) * m2
|
||||
case s <= 8:
|
||||
h ^= uint64(readUnaligned32(p))
|
||||
h ^= uint64(readUnaligned32(add(p, s-4))) << 32
|
||||
h = rotl31(h*m1) * m2
|
||||
case s <= 16:
|
||||
h ^= readUnaligned64(p)
|
||||
h = rotl31(h*m1) * m2
|
||||
h ^= readUnaligned64(add(p, s-8))
|
||||
h = rotl31(h*m1) * m2
|
||||
case s <= 32:
|
||||
h ^= readUnaligned64(p)
|
||||
h = rotl31(h*m1) * m2
|
||||
h ^= readUnaligned64(add(p, 8))
|
||||
h = rotl31(h*m1) * m2
|
||||
h ^= readUnaligned64(add(p, s-16))
|
||||
h = rotl31(h*m1) * m2
|
||||
h ^= readUnaligned64(add(p, s-8))
|
||||
h = rotl31(h*m1) * m2
|
||||
default:
|
||||
v1 := h
|
||||
v2 := uint64(seed * hashkey[1])
|
||||
v3 := uint64(seed * hashkey[2])
|
||||
v4 := uint64(seed * hashkey[3])
|
||||
for s >= 32 {
|
||||
v1 ^= readUnaligned64(p)
|
||||
v1 = rotl31(v1*m1) * m2
|
||||
p = add(p, 8)
|
||||
v2 ^= readUnaligned64(p)
|
||||
v2 = rotl31(v2*m2) * m3
|
||||
p = add(p, 8)
|
||||
v3 ^= readUnaligned64(p)
|
||||
v3 = rotl31(v3*m3) * m4
|
||||
p = add(p, 8)
|
||||
v4 ^= readUnaligned64(p)
|
||||
v4 = rotl31(v4*m4) * m1
|
||||
p = add(p, 8)
|
||||
s -= 32
|
||||
}
|
||||
h = v1 ^ v2 ^ v3 ^ v4
|
||||
goto tail
|
||||
}
|
||||
|
||||
h ^= h >> 29
|
||||
h *= m3
|
||||
h ^= h >> 32
|
||||
return uintptr(h)
|
||||
}
|
||||
|
||||
func add(p unsafe.Pointer, x uintptr) unsafe.Pointer {
|
||||
return unsafe.Pointer(uintptr(p) + x) // nosemgrep
|
||||
}
|
||||
|
||||
func readUnaligned32(p unsafe.Pointer) uint32 {
|
||||
q := (*[4]byte)(p)
|
||||
return uint32(q[0]) | uint32(q[1])<<8 | uint32(q[2])<<16 | uint32(q[3])<<24
|
||||
}
|
||||
|
||||
func rotl31(x uint64) uint64 {
|
||||
return (x << 31) | (x >> (64 - 31))
|
||||
}
|
||||
|
||||
func readUnaligned64(p unsafe.Pointer) uint64 {
|
||||
q := (*[8]byte)(p)
|
||||
return uint64(q[0]) | uint64(q[1])<<8 | uint64(q[2])<<16 | uint64(q[3])<<24 | uint64(q[4])<<32 | uint64(q[5])<<40 | uint64(q[6])<<48 | uint64(q[7])<<56
|
||||
}
|
||||
//go:noescape
|
||||
//go:linkname strhash runtime.strhash
|
||||
func strhash(p unsafe.Pointer, h uintptr) uintptr
|
||||
|
@ -1,6 +1,7 @@
|
||||
package strmatcher_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/v2fly/v2ray-core/v4/common"
|
||||
@ -172,3 +173,106 @@ func TestMphMatcherGroup(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// See https://github.com/v2fly/v2ray-core/issues/92#issuecomment-673238489
|
||||
func TestMphMatcherGroupAsIndexMatcher(t *testing.T) {
|
||||
rules := []struct {
|
||||
Type Type
|
||||
Domain string
|
||||
}{
|
||||
// Regex not supported by MphMatcherGroup
|
||||
// {
|
||||
// Type: Regex,
|
||||
// Domain: "apis\\.us$",
|
||||
// },
|
||||
// Substr not supported by MphMatcherGroup
|
||||
// {
|
||||
// Type: Substr,
|
||||
// Domain: "apis",
|
||||
// },
|
||||
{
|
||||
Type: Domain,
|
||||
Domain: "googleapis.com",
|
||||
},
|
||||
{
|
||||
Type: Domain,
|
||||
Domain: "com",
|
||||
},
|
||||
{
|
||||
Type: Full,
|
||||
Domain: "www.baidu.com",
|
||||
},
|
||||
// Substr not supported by MphMatcherGroup, We add another matcher to preserve index
|
||||
{
|
||||
Type: Domain, // Substr,
|
||||
Domain: "example.com", // "apis",
|
||||
},
|
||||
{
|
||||
Type: Domain,
|
||||
Domain: "googleapis.com",
|
||||
},
|
||||
{
|
||||
Type: Full,
|
||||
Domain: "fonts.googleapis.com",
|
||||
},
|
||||
{
|
||||
Type: Full,
|
||||
Domain: "www.baidu.com",
|
||||
},
|
||||
{ // This matcher (index 10) is swapped with matcher (index 6) to test that full matcher takes high priority.
|
||||
Type: Full,
|
||||
Domain: "example.com",
|
||||
},
|
||||
{
|
||||
Type: Domain,
|
||||
Domain: "example.com",
|
||||
},
|
||||
}
|
||||
cases := []struct {
|
||||
Input string
|
||||
Output []uint32
|
||||
}{
|
||||
{
|
||||
Input: "www.baidu.com",
|
||||
Output: []uint32{5, 9, 4},
|
||||
},
|
||||
{
|
||||
Input: "fonts.googleapis.com",
|
||||
Output: []uint32{8, 3, 7, 4 /*2, 6*/},
|
||||
},
|
||||
{
|
||||
Input: "example.googleapis.com",
|
||||
Output: []uint32{3, 7, 4 /*2, 6*/},
|
||||
},
|
||||
{
|
||||
Input: "testapis.us",
|
||||
// Output: []uint32{ /*2, 6*/ /*1,*/ },
|
||||
Output: nil,
|
||||
},
|
||||
{
|
||||
Input: "example.com",
|
||||
Output: []uint32{10, 6, 11, 4},
|
||||
},
|
||||
}
|
||||
matcherGroup := NewMphMatcherGroup()
|
||||
for i, rule := range rules {
|
||||
matcher, err := rule.Type.New(rule.Domain)
|
||||
common.Must(err)
|
||||
common.Must(AddMatcherToGroup(matcherGroup, matcher, uint32(i+3)))
|
||||
}
|
||||
matcherGroup.Build()
|
||||
for _, test := range cases {
|
||||
if m := matcherGroup.Match(test.Input); !reflect.DeepEqual(m, test.Output) {
|
||||
t.Error("unexpected output: ", m, " for test case ", test)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmptyMphMatcherGroup(t *testing.T) {
|
||||
g := NewMphMatcherGroup()
|
||||
g.Build()
|
||||
r := g.Match("v2fly.org")
|
||||
if len(r) != 0 {
|
||||
t.Error("Expect [], but ", r)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user