diff --git a/app/router/condition.go b/app/router/condition.go index 0da9ebf7e..749879289 100644 --- a/app/router/condition.go +++ b/app/router/condition.go @@ -64,13 +64,35 @@ func (v *AnyCondition) Len() int { return len(*v) } -type PlainDomainMatcher string - -func NewPlainDomainMatcher(pattern string) Condition { - return PlainDomainMatcher(pattern) +type CachableDomainMatcher struct { + matchers []domainMatcher } -func (v PlainDomainMatcher) Apply(ctx context.Context) bool { +func NewCachableDomainMatcher() *CachableDomainMatcher { + return &CachableDomainMatcher{ + matchers: make([]domainMatcher, 0, 64), + } +} + +func (m *CachableDomainMatcher) Add(domain *Domain) error { + switch domain.Type { + case Domain_Plain: + m.matchers = append(m.matchers, NewPlainDomainMatcher(domain.Value)) + case Domain_Regex: + rm, err := NewRegexpDomainMatcher(domain.Value) + if err != nil { + return err + } + m.matchers = append(m.matchers, rm) + case Domain_Domain: + m.matchers = append(m.matchers, NewSubDomainMatcher(domain.Value)) + default: + return newError("unknown domain type: ", domain.Type).AtError() + } + return nil +} + +func (m *CachableDomainMatcher) Apply(ctx context.Context) bool { dest, ok := proxy.TargetFromContext(ctx) if !ok { return false @@ -80,6 +102,27 @@ func (v PlainDomainMatcher) Apply(ctx context.Context) bool { return false } domain := dest.Address.Domain() + + for _, matcher := range m.matchers { + if matcher.Apply(domain) { + return true + } + } + + return false +} + +type domainMatcher interface { + Apply(domain string) bool +} + +type PlainDomainMatcher string + +func NewPlainDomainMatcher(pattern string) PlainDomainMatcher { + return PlainDomainMatcher(pattern) +} + +func (v PlainDomainMatcher) Apply(domain string) bool { return strings.Contains(domain, string(v)) } @@ -97,33 +140,17 @@ func NewRegexpDomainMatcher(pattern string) (*RegexpDomainMatcher, error) { }, nil } -func (v *RegexpDomainMatcher) Apply(ctx context.Context) bool { - dest, ok := proxy.TargetFromContext(ctx) - if !ok { - return false - } - if !dest.Address.Family().IsDomain() { - return false - } - domain := dest.Address.Domain() +func (v *RegexpDomainMatcher) Apply(domain string) bool { return v.pattern.MatchString(strings.ToLower(domain)) } type SubDomainMatcher string -func NewSubDomainMatcher(p string) Condition { +func NewSubDomainMatcher(p string) SubDomainMatcher { return SubDomainMatcher(p) } -func (m SubDomainMatcher) Apply(ctx context.Context) bool { - dest, ok := proxy.TargetFromContext(ctx) - if !ok { - return false - } - if !dest.Address.Family().IsDomain() { - return false - } - domain := dest.Address.Domain() +func (m SubDomainMatcher) Apply(domain string) bool { pattern := string(m) if !strings.HasSuffix(domain, pattern) { return false diff --git a/app/router/condition_test.go b/app/router/condition_test.go index f0283dc3d..2e35b7d7d 100644 --- a/app/router/condition_test.go +++ b/app/router/condition_test.go @@ -16,32 +16,32 @@ func TestSubDomainMatcher(t *testing.T) { cases := []struct { pattern string - input context.Context + input string output bool }{ { pattern: "v2ray.com", - input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("www.v2ray.com"), 80)), + input: "www.v2ray.com", output: true, }, { pattern: "v2ray.com", - input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("v2ray.com"), 80)), + input: "v2ray.com", output: true, }, { pattern: "v2ray.com", - input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("www.v3ray.com"), 80)), + input: "www.v3ray.com", output: false, }, { pattern: "v2ray.com", - input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("2ray.com"), 80)), + input: "2ray.com", output: false, }, { pattern: "v2ray.com", - input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("xv2ray.com"), 80)), + input: "xv2ray.com", output: false, }, } diff --git a/app/router/config.go b/app/router/config.go index 048f602f1..05ff36614 100644 --- a/app/router/config.go +++ b/app/router/config.go @@ -52,24 +52,11 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) { conds := NewConditionChan() if len(rr.Domain) > 0 { - anyCond := NewAnyCondition() + matcher := NewCachableDomainMatcher() for _, domain := range rr.Domain { - switch domain.Type { - case Domain_Plain: - anyCond.Add(NewPlainDomainMatcher(domain.Value)) - case Domain_Regex: - matcher, err := NewRegexpDomainMatcher(domain.Value) - if err != nil { - return nil, err - } - anyCond.Add(matcher) - case Domain_Domain: - anyCond.Add(NewSubDomainMatcher(domain.Value)) - default: - panic("Unknown domain type.") - } + matcher.Add(domain) } - conds.Add(anyCond) + conds.Add(matcher) } if len(rr.Cidr) > 0 {