diff --git a/app/router/condition.go b/app/router/condition.go index 26d50a325..5735bd4ca 100644 --- a/app/router/condition.go +++ b/app/router/condition.go @@ -3,8 +3,6 @@ package router import ( "context" "strings" - "sync" - "time" "v2ray.com/core/app/dispatcher" "v2ray.com/core/common/net" @@ -67,116 +65,56 @@ func (v *AnyCondition) Len() int { return len(*v) } -type timedResult struct { - timestamp time.Time - result bool -} - -type CachableDomainMatcher struct { - sync.Mutex - matchers *strmatcher.MatcherGroup - cache map[string]timedResult - lastScan time.Time -} - -func NewCachableDomainMatcher() *CachableDomainMatcher { - return &CachableDomainMatcher{ - matchers: strmatcher.NewMatcherGroup(), - cache: make(map[string]timedResult, 512), - } -} - var matcherTypeMap = map[Domain_Type]strmatcher.Type{ Domain_Plain: strmatcher.Substr, Domain_Regex: strmatcher.Regex, Domain_Domain: strmatcher.Domain, } -func (m *CachableDomainMatcher) Add(domain *Domain) error { +func domainToMatcher(domain *Domain) (strmatcher.Matcher, error) { matcherType, f := matcherTypeMap[domain.Type] if !f { - return newError("unsupported domain type", domain.Type) + return nil, newError("unsupported domain type", domain.Type) } matcher, err := matcherType.New(domain.Value) if err != nil { - return newError("failed to create domain matcher").Base(err) + return nil, newError("failed to create domain matcher").Base(err) } - m.matchers.Add(matcher) - return nil + return matcher, nil } -func (m *CachableDomainMatcher) applyInternal(domain string) bool { +type DomainMatcher struct { + matchers strmatcher.IndexMatcher +} + +func NewDomainMatcher(domains []*Domain) (*DomainMatcher, error) { + g := strmatcher.NewMatcherGroup() + for _, d := range domains { + m, err := domainToMatcher(d) + if err != nil { + return nil, err + } + g.Add(m) + } + + if len(domains) < 64 { + return &DomainMatcher{ + matchers: g, + }, nil + } + + return &DomainMatcher{ + matchers: strmatcher.NewCachedMatcherGroup(g), + }, nil +} + +func (m *DomainMatcher) ApplyDomain(domain string) bool { return m.matchers.Match(domain) > 0 } -type cacheResult int - -const ( - cacheMiss cacheResult = iota - cacheHitTrue - cacheHitFalse -) - -func (m *CachableDomainMatcher) findInCache(domain string) cacheResult { - m.Lock() - defer m.Unlock() - - r, f := m.cache[domain] - if !f { - return cacheMiss - } - r.timestamp = time.Now() - m.cache[domain] = r - - if r.result { - return cacheHitTrue - } - return cacheHitFalse -} - -func (m *CachableDomainMatcher) ApplyDomain(domain string) bool { - if m.matchers.Size() < 64 { - return m.applyInternal(domain) - } - - cr := m.findInCache(domain) - - if cr == cacheHitTrue { - return true - } - - if cr == cacheHitFalse { - return false - } - - r := m.applyInternal(domain) - m.Lock() - defer m.Unlock() - - m.cache[domain] = timedResult{ - result: r, - timestamp: time.Now(), - } - - now := time.Now() - if len(m.cache) > 256 && now.Sub(m.lastScan)/time.Second > 5 { - now := time.Now() - - for k, v := range m.cache { - if now.Sub(v.timestamp)/time.Second > 60 { - delete(m.cache, k) - } - } - - m.lastScan = now - } - - return r -} - -func (m *CachableDomainMatcher) Apply(ctx context.Context) bool { +func (m *DomainMatcher) Apply(ctx context.Context) bool { dest, ok := proxy.TargetFromContext(ctx) if !ok { return false diff --git a/app/router/condition_test.go b/app/router/condition_test.go index f5a1a0747..6b3315455 100644 --- a/app/router/condition_test.go +++ b/app/router/condition_test.go @@ -189,10 +189,8 @@ func TestChinaSites(t *testing.T) { domains, err := loadGeoSite("CN") assert(err, IsNil) - matcher := NewCachableDomainMatcher() - for _, d := range domains { - assert(matcher.Add(d), IsNil) - } + matcher, err := NewCachableDomainMatcher(domains) + common.Must(err) assert(matcher.ApplyDomain("163.com"), IsTrue) assert(matcher.ApplyDomain("163.com"), IsTrue) diff --git a/app/router/config.go b/app/router/config.go index de8be63a8..e141062d1 100644 --- a/app/router/config.go +++ b/app/router/config.go @@ -52,11 +52,9 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) { conds := NewConditionChan() if len(rr.Domain) > 0 { - matcher := NewCachableDomainMatcher() - for _, domain := range rr.Domain { - if err := matcher.Add(domain); err != nil { - return nil, newError("failed to build domain condition").Base(err) - } + matcher, err := NewCachableDomainMatcher(rr.Domain) + if err != nil { + return nil, newError("failed to build domain condition").Base(err) } conds.Add(matcher) } diff --git a/common/strmatcher/benchmark_test.go b/common/strmatcher/benchmark_test.go new file mode 100644 index 000000000..6eb7f0eba --- /dev/null +++ b/common/strmatcher/benchmark_test.go @@ -0,0 +1,36 @@ +package strmatcher_test + +import ( + "strconv" + "testing" + + "v2ray.com/core/common" + . "v2ray.com/core/common/strmatcher" +) + +func BenchmarkDomainMatcherGroup(b *testing.B) { + g := new(DomainMatcherGroup) + + for i := 1; i <= 1024; i++ { + g.Add(strconv.Itoa(i)+".v2ray.com", uint32(i)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = g.Match("0.v2ray.com") + } +} + +func BenchmarkMarchGroup(b *testing.B) { + g := NewMatcherGroup() + for i := 1; i <= 1024; i++ { + m, err := Domain.New(strconv.Itoa(i) + ".v2ray.com") + common.Must(err) + g.Add(m) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = g.Match("0.v2ray.com") + } +} diff --git a/common/strmatcher/domain_matcher.go b/common/strmatcher/domain_matcher.go new file mode 100644 index 000000000..d31dbf724 --- /dev/null +++ b/common/strmatcher/domain_matcher.go @@ -0,0 +1,52 @@ +package strmatcher + +import "strings" + +func breakDomain(domain string) []string { + return strings.Split(domain, ".") +} + +type node struct { + value uint32 + sub map[string]*node +} + +type DomainMatcherGroup struct { + root *node +} + +func (g *DomainMatcherGroup) Add(domain string, value uint32) { + if g.root == nil { + g.root = &node{ + sub: make(map[string]*node), + } + } + + current := g.root + parts := breakDomain(domain) + for i := len(parts) - 1; i >= 0; i-- { + part := parts[i] + next := current.sub[part] + if next == nil { + next = &node{sub: make(map[string]*node)} + current.sub[part] = next + } + current = next + } + + current.value = value +} + +func (g *DomainMatcherGroup) Match(domain string) uint32 { + current := g.root + parts := breakDomain(domain) + for i := len(parts) - 1; i >= 0; i-- { + part := parts[i] + next := current.sub[part] + if next == nil { + break + } + current = next + } + return current.value +} diff --git a/common/strmatcher/domain_matcher_test.go b/common/strmatcher/domain_matcher_test.go new file mode 100644 index 000000000..54a81d9b8 --- /dev/null +++ b/common/strmatcher/domain_matcher_test.go @@ -0,0 +1,35 @@ +package strmatcher_test + +import ( + "testing" + + . "v2ray.com/core/common/strmatcher" +) + +func TestDomainMatcherGroup(t *testing.T) { + g := new(DomainMatcherGroup) + g.Add("v2ray.com", 1) + g.Add("google.com", 2) + g.Add("x.a.com", 3) + + testCases := []struct { + Domain string + Result uint32 + }{ + { + Domain: "x.v2ray.com", + Result: 1, + }, + { + Domain: "y.com", + Result: 0, + }, + } + + for _, testCase := range testCases { + r := g.Match(testCase.Domain) + if r != testCase.Result { + t.Error("Failed to match domain: ", testCase.Domain, ", expect ", testCase.Result, ", but got ", r) + } + } +} diff --git a/common/strmatcher/strmatcher.go b/common/strmatcher/strmatcher.go index 80e7d34ea..811110752 100644 --- a/common/strmatcher/strmatcher.go +++ b/common/strmatcher/strmatcher.go @@ -1,6 +1,12 @@ package strmatcher -import "regexp" +import ( + "regexp" + "sync" + "time" + + "v2ray.com/core/common/task" +) type Matcher interface { Match(string) bool @@ -36,6 +42,10 @@ func (t Type) New(pattern string) (Matcher, error) { } } +type IndexMatcher interface { + Match(pattern string) uint32 +} + type matcherEntry struct { m Matcher id uint32 @@ -44,6 +54,7 @@ type matcherEntry struct { type MatcherGroup struct { count uint32 fullMatchers map[string]uint32 + domainMatcher DomainMatcherGroup otherMatchers []matcherEntry } @@ -58,9 +69,12 @@ func (g *MatcherGroup) Add(m Matcher) uint32 { c := g.count g.count++ - if fm, ok := m.(fullMatcher); ok { - g.fullMatchers[string(fm)] = c - } else { + switch tm := m.(type) { + case fullMatcher: + g.fullMatchers[string(tm)] = c + case domainMatcher: + g.domainMatcher.Add(string(tm), c) + default: g.otherMatchers = append(g.otherMatchers, matcherEntry{ m: m, id: c, @@ -87,3 +101,60 @@ func (g *MatcherGroup) Match(pattern string) uint32 { func (g *MatcherGroup) Size() uint32 { return g.count } + +type cacheEntry struct { + timestamp time.Time + result uint32 +} + +type CachedMatcherGroup struct { + sync.Mutex + group *MatcherGroup + cache map[string]cacheEntry + cleanup *task.Periodic +} + +func NewCachedMatcherGroup(g *MatcherGroup) *CachedMatcherGroup { + r := &CachedMatcherGroup{ + group: g, + cache: make(map[string]cacheEntry), + } + r.cleanup = &task.Periodic{ + Interval: time.Second * 30, + Execute: func() error { + r.Lock() + defer r.Unlock() + + expire := time.Now().Add(-1 * time.Second * 60) + for p, e := range r.cache { + if e.timestamp.Before(expire) { + delete(r.cache, p) + } + } + + return nil + }, + } + return r +} + +func (g *CachedMatcherGroup) Match(pattern string) uint32 { + g.Lock() + defer g.Unlock() + + r, f := g.cache[pattern] + if f { + r.timestamp = time.Now() + g.cache[pattern] = r + return r.result + } + + mr := g.group.Match(pattern) + + g.cache[pattern] = cacheEntry{ + result: mr, + timestamp: time.Now(), + } + + return mr +}