diff --git a/app/dns/hosts.go b/app/dns/hosts.go index b4d18ff1e..f47a32fef 100644 --- a/app/dns/hosts.go +++ b/app/dns/hosts.go @@ -106,11 +106,14 @@ func filterIP(ips []net.Address, option IPOption) []net.Address { // LookupIP returns IP address for the given domain, if exists in this StaticHosts. func (h *StaticHosts) LookupIP(domain string, option IPOption) []net.Address { - id := h.matchers.Match(domain) - if id == 0 { + indices := h.matchers.Match(domain) + if len(indices) == 0 { return nil } - ips := h.ips[id] + ips := []net.Address{} + for _, id := range indices { + ips = append(ips, h.ips[id]...) + } if len(ips) == 1 && ips[0].Family().IsDomain() { return ips } diff --git a/app/dns/server.go b/app/dns/server.go index 33394f693..551af9ceb 100644 --- a/app/dns/server.go +++ b/app/dns/server.go @@ -330,8 +330,8 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err var lastErr error var matchedClient Client if s.domainMatcher != nil { - idx := s.domainMatcher.Match(domain) - if idx > 0 { + indices := s.domainMatcher.Match(domain) + for _, idx := range indices { matchedClient = s.clients[s.domainIndexMap[idx]] ips, err := s.queryIPTimeout(s.domainIndexMap[idx], matchedClient, domain, option) if len(ips) > 0 { diff --git a/app/dns/server_test.go b/app/dns/server_test.go index c340b41b1..cbd3c8e1b 100644 --- a/app/dns/server_test.go +++ b/app/dns/server_test.go @@ -50,6 +50,9 @@ func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { rr, _ := dns.NewRR("google.com. IN A 8.8.4.4") ans.Answer = append(ans.Answer, rr) } + } else if q.Name == "api.google.com." && q.Qtype == dns.TypeA { + rr, _ := dns.NewRR("api.google.com. IN A 8.8.7.7") + ans.Answer = append(ans.Answer, rr) } else if q.Name == "facebook.com." && q.Qtype == dns.TypeA { rr, _ := dns.NewRR("facebook.com. IN A 9.9.9.9") ans.Answer = append(ans.Answer, rr) @@ -754,3 +757,164 @@ func TestLocalDomain(t *testing.T) { t.Error("DNS query doesn't finish in 2 seconds.") } } + +func TestMultiMatchPrioritizedDomain(t *testing.T) { + port := udp.PickPort() + + dnsServer := dns.Server{ + Addr: "127.0.0.1:" + port.String(), + Net: "udp", + Handler: &staticHandler{}, + UDPSize: 1200, + } + + go dnsServer.ListenAndServe() + time.Sleep(time.Second) + + config := &core.Config{ + App: []*serial.TypedMessage{ + serial.ToTypedMessage(&Config{ + NameServers: []*net.Endpoint{ + { + Network: net.Network_UDP, + Address: &net.IPOrDomain{ + Address: &net.IPOrDomain_Ip{ + Ip: []byte{127, 0, 0, 1}, + }, + }, + Port: 9999, /* unreachable */ + }, + }, + NameServer: []*NameServer{ + { + Address: &net.Endpoint{ + Network: net.Network_UDP, + Address: &net.IPOrDomain{ + Address: &net.IPOrDomain_Ip{ + Ip: []byte{127, 0, 0, 1}, + }, + }, + Port: uint32(port), + }, + PrioritizedDomain: []*NameServer_PriorityDomain{ + { + Type: DomainMatchingType_Subdomain, + Domain: "google.com", + }, + }, + Geoip: []*router.GeoIP{ + { // Will only match 8.8.8.8 and 8.8.4.4 + Cidr: []*router.CIDR{ + {Ip: []byte{8, 8, 8, 8}, Prefix: 32}, + {Ip: []byte{8, 8, 4, 4}, Prefix: 32}, + }, + }, + }, + }, + { + Address: &net.Endpoint{ + Network: net.Network_UDP, + Address: &net.IPOrDomain{ + Address: &net.IPOrDomain_Ip{ + Ip: []byte{127, 0, 0, 1}, + }, + }, + Port: uint32(port), + }, + PrioritizedDomain: []*NameServer_PriorityDomain{ + { + Type: DomainMatchingType_Subdomain, + Domain: "google.com", + }, + }, + Geoip: []*router.GeoIP{ + { // Will match 8.8.8.8 and 8.8.8.7, etc + Cidr: []*router.CIDR{ + {Ip: []byte{8, 8, 8, 7}, Prefix: 24}, + }, + }, + }, + }, + { + Address: &net.Endpoint{ + Network: net.Network_UDP, + Address: &net.IPOrDomain{ + Address: &net.IPOrDomain_Ip{ + Ip: []byte{127, 0, 0, 1}, + }, + }, + Port: uint32(port), + }, + PrioritizedDomain: []*NameServer_PriorityDomain{ + { + Type: DomainMatchingType_Full, + Domain: "api.google.com", + }, + }, + Geoip: []*router.GeoIP{ + { // Will only match 8.8.7.7 (api.google.com) + Cidr: []*router.CIDR{ + {Ip: []byte{8, 8, 7, 7}, Prefix: 0}, + }, + }, + }, + }, + }, + }), + serial.ToTypedMessage(&dispatcher.Config{}), + serial.ToTypedMessage(&proxyman.OutboundConfig{}), + serial.ToTypedMessage(&policy.Config{}), + }, + Outbound: []*core.OutboundHandlerConfig{ + { + ProxySettings: serial.ToTypedMessage(&freedom.Config{}), + }, + }, + } + + v, err := core.New(config) + common.Must(err) + + client := v.GetFeature(feature_dns.ClientType()).(feature_dns.Client) + + startTime := time.Now() + + { // Will match server 1,2 and server 1 returns expected ip + ips, err := client.LookupIP("google.com") + if err != nil { + t.Fatal("unexpected error: ", err) + } + + if r := cmp.Diff(ips, []net.IP{{8, 8, 8, 8}}); r != "" { + t.Fatal(r) + } + } + + { // Will match server 1,2 and server 1 returns unexpected ip, then server 2 returns expected one + clientv4 := client.(feature_dns.IPv4Lookup) + ips, err := clientv4.LookupIPv4("ipv6.google.com") + if err != nil { + t.Fatal("unexpected error: ", err) + } + + if r := cmp.Diff(ips, []net.IP{{8, 8, 8, 7}}); r != "" { + t.Fatal(r) + } + } + + { // Will match server 1,2,3 and server 1,2 returns unexpected ip, then server 3 returns expected one + ips, err := client.LookupIP("api.google.com") + if err != nil { + t.Fatal("unexpected error: ", err) + } + + if r := cmp.Diff(ips, []net.IP{{8, 8, 7, 7}}); r != "" { + t.Fatal(r) + } + } + + endTime := time.Now() + if startTime.After(endTime.Add(time.Second * 2)) { + t.Error("DNS query doesn't finish in 2 seconds.") + } +} diff --git a/app/router/condition.go b/app/router/condition.go index d189f32e1..ffafdb3c3 100644 --- a/app/router/condition.go +++ b/app/router/condition.go @@ -82,7 +82,7 @@ func NewDomainMatcher(domains []*Domain) (*DomainMatcher, error) { } func (m *DomainMatcher) ApplyDomain(domain string) bool { - return m.matchers.Match(domain) > 0 + return len(m.matchers.Match(domain)) > 0 } func (m *DomainMatcher) Apply(ctx *Context) bool { diff --git a/common/strmatcher/domain_matcher.go b/common/strmatcher/domain_matcher.go index aabbf43eb..3b109aba5 100644 --- a/common/strmatcher/domain_matcher.go +++ b/common/strmatcher/domain_matcher.go @@ -7,8 +7,8 @@ func breakDomain(domain string) []string { } type node struct { - value uint32 - sub map[string]*node + values []uint32 + sub map[string]*node } // DomainMatcherGroup is a IndexMatcher for a large set of Domain matchers. @@ -25,7 +25,7 @@ func (g *DomainMatcherGroup) Add(domain string, value uint32) { current := g.root parts := breakDomain(domain) for i := len(parts) - 1; i >= 0; i-- { - if current.value > 0 { + if len(current.values) > 0 { // if current node is already a match, it is not necessary to match further. return } @@ -42,7 +42,7 @@ func (g *DomainMatcherGroup) Add(domain string, value uint32) { current = next } - current.value = value + current.values = append(current.values, value) current.sub = nil // shortcut sub nodes as current node is a match. } @@ -50,14 +50,14 @@ func (g *DomainMatcherGroup) addMatcher(m domainMatcher, value uint32) { g.Add(string(m), value) } -func (g *DomainMatcherGroup) Match(domain string) uint32 { +func (g *DomainMatcherGroup) Match(domain string) []uint32 { if domain == "" { - return 0 + return nil } current := g.root if current == nil { - return 0 + return nil } nextPart := func(idx int) int { @@ -84,5 +84,5 @@ func (g *DomainMatcherGroup) Match(domain string) uint32 { current = next idx = nidx } - return current.value + return current.values } diff --git a/common/strmatcher/domain_matcher_test.go b/common/strmatcher/domain_matcher_test.go index a6319740b..10de133b9 100644 --- a/common/strmatcher/domain_matcher_test.go +++ b/common/strmatcher/domain_matcher_test.go @@ -1,6 +1,7 @@ package strmatcher_test import ( + "reflect" "testing" . "v2ray.com/core/common/strmatcher" @@ -13,48 +14,54 @@ func TestDomainMatcherGroup(t *testing.T) { g.Add("x.a.com", 3) g.Add("a.b.com", 4) g.Add("c.a.b.com", 5) + g.Add("x.y.com", 4) + g.Add("x.y.com", 6) testCases := []struct { Domain string - Result uint32 + Result []uint32 }{ { Domain: "x.v2ray.com", - Result: 1, + Result: []uint32{1}, }, { Domain: "y.com", - Result: 0, + Result: nil, }, { Domain: "a.b.com", - Result: 4, + Result: []uint32{4}, }, { Domain: "c.a.b.com", - Result: 4, + Result: []uint32{4}, }, { Domain: "c.a..b.com", - Result: 0, + Result: nil, }, { Domain: ".com", - Result: 0, + Result: nil, }, { Domain: "com", - Result: 0, + Result: nil, }, { Domain: "", - Result: 0, + Result: nil, + }, + { + Domain: "x.y.com", + Result: []uint32{4, 6}, }, } for _, testCase := range testCases { r := g.Match(testCase.Domain) - if r != testCase.Result { + if !reflect.DeepEqual(r, testCase.Result) { t.Error("Failed to match domain: ", testCase.Domain, ", expect ", testCase.Result, ", but got ", r) } } @@ -63,7 +70,7 @@ func TestDomainMatcherGroup(t *testing.T) { func TestEmptyDomainMatcherGroup(t *testing.T) { g := new(DomainMatcherGroup) r := g.Match("v2ray.com") - if r != 0 { - t.Error("Expect 0, but ", r) + if len(r) != 0 { + t.Error("Expect [], but ", r) } } diff --git a/common/strmatcher/full_matcher.go b/common/strmatcher/full_matcher.go index fc7e0c335..e00d02aa9 100644 --- a/common/strmatcher/full_matcher.go +++ b/common/strmatcher/full_matcher.go @@ -1,24 +1,24 @@ package strmatcher type FullMatcherGroup struct { - matchers map[string]uint32 + matchers map[string][]uint32 } func (g *FullMatcherGroup) Add(domain string, value uint32) { if g.matchers == nil { - g.matchers = make(map[string]uint32) + g.matchers = make(map[string][]uint32) } - g.matchers[domain] = value + g.matchers[domain] = append(g.matchers[domain], value) } func (g *FullMatcherGroup) addMatcher(m fullMatcher, value uint32) { g.Add(string(m), value) } -func (g *FullMatcherGroup) Match(str string) uint32 { +func (g *FullMatcherGroup) Match(str string) []uint32 { if g.matchers == nil { - return 0 + return nil } return g.matchers[str] diff --git a/common/strmatcher/full_matcher_test.go b/common/strmatcher/full_matcher_test.go index a19a60c77..2fe6ee34e 100644 --- a/common/strmatcher/full_matcher_test.go +++ b/common/strmatcher/full_matcher_test.go @@ -1,6 +1,7 @@ package strmatcher_test import ( + "reflect" "testing" . "v2ray.com/core/common/strmatcher" @@ -11,24 +12,30 @@ func TestFullMatcherGroup(t *testing.T) { g.Add("v2ray.com", 1) g.Add("google.com", 2) g.Add("x.a.com", 3) + g.Add("x.y.com", 4) + g.Add("x.y.com", 6) testCases := []struct { Domain string - Result uint32 + Result []uint32 }{ { Domain: "v2ray.com", - Result: 1, + Result: []uint32{1}, }, { Domain: "y.com", - Result: 0, + Result: nil, + }, + { + Domain: "x.y.com", + Result: []uint32{4, 6}, }, } for _, testCase := range testCases { r := g.Match(testCase.Domain) - if r != testCase.Result { + if !reflect.DeepEqual(r, testCase.Result) { t.Error("Failed to match domain: ", testCase.Domain, ", expect ", testCase.Result, ", but got ", r) } } @@ -37,7 +44,7 @@ func TestFullMatcherGroup(t *testing.T) { func TestEmptyFullMatcherGroup(t *testing.T) { g := new(FullMatcherGroup) r := g.Match("v2ray.com") - if r != 0 { - t.Error("Expect 0, but ", r) + if len(r) != 0 { + t.Error("Expect [], but ", r) } } diff --git a/common/strmatcher/strmatcher.go b/common/strmatcher/strmatcher.go index fb63eda56..6486d8369 100644 --- a/common/strmatcher/strmatcher.go +++ b/common/strmatcher/strmatcher.go @@ -49,7 +49,7 @@ func (t Type) New(pattern string) (Matcher, error) { // IndexMatcher is the interface for matching with a group of matchers. type IndexMatcher interface { // Match returns the the index of a matcher that matches the input. It returns 0 if no such matcher exists. - Match(input string) uint32 + Match(input string) []uint32 } type matcherEntry struct { @@ -87,22 +87,16 @@ func (g *MatcherGroup) Add(m Matcher) uint32 { } // Match implements IndexMatcher.Match. -func (g *MatcherGroup) Match(pattern string) uint32 { - if c := g.fullMatcher.Match(pattern); c > 0 { - return c - } - - if c := g.domainMatcher.Match(pattern); c > 0 { - return c - } - +func (g *MatcherGroup) Match(pattern string) []uint32 { + result := []uint32{} + result = append(result, g.fullMatcher.Match(pattern)...) + result = append(result, g.domainMatcher.Match(pattern)...) for _, e := range g.otherMatchers { if e.m.Match(pattern) { - return e.id + result = append(result, e.id) } } - - return 0 + return result } // Size returns the number of matchers in the MatcherGroup.