diff --git a/app/dns/server_test.go b/app/dns/server_test.go index cbd3c8e1b..de3f8caf7 100644 --- a/app/dns/server_test.go +++ b/app/dns/server_test.go @@ -53,6 +53,9 @@ func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } else if q.Name == "api.google.com." && q.Qtype == dns.TypeA { rr, _ := dns.NewRR("api.google.com. IN A 8.8.7.7") ans.Answer = append(ans.Answer, rr) + } else if q.Name == "v2.api.google.com." && q.Qtype == dns.TypeA { + rr, _ := dns.NewRR("v2.api.google.com. IN A 8.8.7.8") + ans.Answer = append(ans.Answer, rr) } else if q.Name == "facebook.com." && q.Qtype == dns.TypeA { rr, _ := dns.NewRR("facebook.com. IN A 9.9.9.9") ans.Answer = append(ans.Answer, rr) @@ -847,14 +850,38 @@ func TestMultiMatchPrioritizedDomain(t *testing.T) { }, PrioritizedDomain: []*NameServer_PriorityDomain{ { - Type: DomainMatchingType_Full, + Type: DomainMatchingType_Subdomain, Domain: "api.google.com", }, }, Geoip: []*router.GeoIP{ { // Will only match 8.8.7.7 (api.google.com) Cidr: []*router.CIDR{ - {Ip: []byte{8, 8, 7, 7}, Prefix: 0}, + {Ip: []byte{8, 8, 7, 7}, Prefix: 32}, + }, + }, + }, + }, + { + Address: &net.Endpoint{ + Network: net.Network_UDP, + Address: &net.IPOrDomain{ + Address: &net.IPOrDomain_Ip{ + Ip: []byte{127, 0, 0, 1}, + }, + }, + Port: uint32(port), + }, + PrioritizedDomain: []*NameServer_PriorityDomain{ + { + Type: DomainMatchingType_Full, + Domain: "v2.api.google.com", + }, + }, + Geoip: []*router.GeoIP{ + { // Will only match 8.8.7.8 (v2.api.google.com) + Cidr: []*router.CIDR{ + {Ip: []byte{8, 8, 7, 8}, Prefix: 32}, }, }, }, @@ -902,7 +929,7 @@ func TestMultiMatchPrioritizedDomain(t *testing.T) { } } - { // Will match server 1,2,3 and server 1,2 returns unexpected ip, then server 3 returns expected one + { // Will match server 3,1,2 and server 3 returns expected one ips, err := client.LookupIP("api.google.com") if err != nil { t.Fatal("unexpected error: ", err) @@ -913,6 +940,17 @@ func TestMultiMatchPrioritizedDomain(t *testing.T) { } } + { // Will match server 4,3,1,2 and server 4 returns expected one + ips, err := client.LookupIP("v2.api.google.com") + if err != nil { + t.Fatal("unexpected error: ", err) + } + + if r := cmp.Diff(ips, []net.IP{{8, 8, 7, 8}}); r != "" { + t.Fatal(r) + } + } + endTime := time.Now() if startTime.After(endTime.Add(time.Second * 2)) { t.Error("DNS query doesn't finish in 2 seconds.") diff --git a/common/strmatcher/domain_matcher.go b/common/strmatcher/domain_matcher.go index 3b109aba5..ae8e65bc2 100644 --- a/common/strmatcher/domain_matcher.go +++ b/common/strmatcher/domain_matcher.go @@ -25,11 +25,6 @@ func (g *DomainMatcherGroup) Add(domain string, value uint32) { current := g.root parts := breakDomain(domain) for i := len(parts) - 1; i >= 0; i-- { - if len(current.values) > 0 { - // if current node is already a match, it is not necessary to match further. - return - } - part := parts[i] if current.sub == nil { current.sub = make(map[string]*node) @@ -43,7 +38,6 @@ func (g *DomainMatcherGroup) Add(domain string, value uint32) { } current.values = append(current.values, value) - current.sub = nil // shortcut sub nodes as current node is a match. } func (g *DomainMatcherGroup) addMatcher(m domainMatcher, value uint32) { @@ -69,6 +63,7 @@ func (g *DomainMatcherGroup) Match(domain string) []uint32 { return -1 } + matches := [][]uint32{} idx := len(domain) for { if idx == -1 || current.sub == nil { @@ -83,6 +78,21 @@ func (g *DomainMatcherGroup) Match(domain string) []uint32 { } current = next idx = nidx + if len(current.values) > 0 { + matches = append(matches, current.values) + } + } + switch len(matches) { + case 0: + return nil + case 1: + return matches[0] + default: + result := []uint32{} + for idx := range matches { + // Insert reversely, the subdomain that matches further ranks higher + result = append(result, matches[len(matches)-1-idx]...) + } + return result } - return current.values } diff --git a/common/strmatcher/domain_matcher_test.go b/common/strmatcher/domain_matcher_test.go index 10de133b9..660594a9f 100644 --- a/common/strmatcher/domain_matcher_test.go +++ b/common/strmatcher/domain_matcher_test.go @@ -33,9 +33,9 @@ func TestDomainMatcherGroup(t *testing.T) { Domain: "a.b.com", Result: []uint32{4}, }, - { + { // Matches [c.a.b.com, a.b.com] Domain: "c.a.b.com", - Result: []uint32{4}, + Result: []uint32{5, 4}, }, { Domain: "c.a..b.com",