1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-12-22 01:57:12 -05:00

Amending domain matcher with returning array of all matches

This commit is contained in:
Vigilans 2020-08-11 13:31:04 +08:00
parent 65c16cd44c
commit c74a33f827
9 changed files with 225 additions and 50 deletions

View File

@ -106,11 +106,14 @@ func filterIP(ips []net.Address, option IPOption) []net.Address {
// LookupIP returns IP address for the given domain, if exists in this StaticHosts.
func (h *StaticHosts) LookupIP(domain string, option IPOption) []net.Address {
id := h.matchers.Match(domain)
if id == 0 {
indices := h.matchers.Match(domain)
if len(indices) == 0 {
return nil
}
ips := h.ips[id]
ips := []net.Address{}
for _, id := range indices {
ips = append(ips, h.ips[id]...)
}
if len(ips) == 1 && ips[0].Family().IsDomain() {
return ips
}

View File

@ -330,8 +330,8 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
var lastErr error
var matchedClient Client
if s.domainMatcher != nil {
idx := s.domainMatcher.Match(domain)
if idx > 0 {
indices := s.domainMatcher.Match(domain)
for _, idx := range indices {
matchedClient = s.clients[s.domainIndexMap[idx]]
ips, err := s.queryIPTimeout(s.domainIndexMap[idx], matchedClient, domain, option)
if len(ips) > 0 {

View File

@ -50,6 +50,9 @@ func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
rr, _ := dns.NewRR("google.com. IN A 8.8.4.4")
ans.Answer = append(ans.Answer, rr)
}
} else if q.Name == "api.google.com." && q.Qtype == dns.TypeA {
rr, _ := dns.NewRR("api.google.com. IN A 8.8.7.7")
ans.Answer = append(ans.Answer, rr)
} else if q.Name == "facebook.com." && q.Qtype == dns.TypeA {
rr, _ := dns.NewRR("facebook.com. IN A 9.9.9.9")
ans.Answer = append(ans.Answer, rr)
@ -754,3 +757,164 @@ func TestLocalDomain(t *testing.T) {
t.Error("DNS query doesn't finish in 2 seconds.")
}
}
func TestMultiMatchPrioritizedDomain(t *testing.T) {
port := udp.PickPort()
dnsServer := dns.Server{
Addr: "127.0.0.1:" + port.String(),
Net: "udp",
Handler: &staticHandler{},
UDPSize: 1200,
}
go dnsServer.ListenAndServe()
time.Sleep(time.Second)
config := &core.Config{
App: []*serial.TypedMessage{
serial.ToTypedMessage(&Config{
NameServers: []*net.Endpoint{
{
Network: net.Network_UDP,
Address: &net.IPOrDomain{
Address: &net.IPOrDomain_Ip{
Ip: []byte{127, 0, 0, 1},
},
},
Port: 9999, /* unreachable */
},
},
NameServer: []*NameServer{
{
Address: &net.Endpoint{
Network: net.Network_UDP,
Address: &net.IPOrDomain{
Address: &net.IPOrDomain_Ip{
Ip: []byte{127, 0, 0, 1},
},
},
Port: uint32(port),
},
PrioritizedDomain: []*NameServer_PriorityDomain{
{
Type: DomainMatchingType_Subdomain,
Domain: "google.com",
},
},
Geoip: []*router.GeoIP{
{ // Will only match 8.8.8.8 and 8.8.4.4
Cidr: []*router.CIDR{
{Ip: []byte{8, 8, 8, 8}, Prefix: 32},
{Ip: []byte{8, 8, 4, 4}, Prefix: 32},
},
},
},
},
{
Address: &net.Endpoint{
Network: net.Network_UDP,
Address: &net.IPOrDomain{
Address: &net.IPOrDomain_Ip{
Ip: []byte{127, 0, 0, 1},
},
},
Port: uint32(port),
},
PrioritizedDomain: []*NameServer_PriorityDomain{
{
Type: DomainMatchingType_Subdomain,
Domain: "google.com",
},
},
Geoip: []*router.GeoIP{
{ // Will match 8.8.8.8 and 8.8.8.7, etc
Cidr: []*router.CIDR{
{Ip: []byte{8, 8, 8, 7}, Prefix: 24},
},
},
},
},
{
Address: &net.Endpoint{
Network: net.Network_UDP,
Address: &net.IPOrDomain{
Address: &net.IPOrDomain_Ip{
Ip: []byte{127, 0, 0, 1},
},
},
Port: uint32(port),
},
PrioritizedDomain: []*NameServer_PriorityDomain{
{
Type: DomainMatchingType_Full,
Domain: "api.google.com",
},
},
Geoip: []*router.GeoIP{
{ // Will only match 8.8.7.7 (api.google.com)
Cidr: []*router.CIDR{
{Ip: []byte{8, 8, 7, 7}, Prefix: 0},
},
},
},
},
},
}),
serial.ToTypedMessage(&dispatcher.Config{}),
serial.ToTypedMessage(&proxyman.OutboundConfig{}),
serial.ToTypedMessage(&policy.Config{}),
},
Outbound: []*core.OutboundHandlerConfig{
{
ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
},
},
}
v, err := core.New(config)
common.Must(err)
client := v.GetFeature(feature_dns.ClientType()).(feature_dns.Client)
startTime := time.Now()
{ // Will match server 1,2 and server 1 returns expected ip
ips, err := client.LookupIP("google.com")
if err != nil {
t.Fatal("unexpected error: ", err)
}
if r := cmp.Diff(ips, []net.IP{{8, 8, 8, 8}}); r != "" {
t.Fatal(r)
}
}
{ // Will match server 1,2 and server 1 returns unexpected ip, then server 2 returns expected one
clientv4 := client.(feature_dns.IPv4Lookup)
ips, err := clientv4.LookupIPv4("ipv6.google.com")
if err != nil {
t.Fatal("unexpected error: ", err)
}
if r := cmp.Diff(ips, []net.IP{{8, 8, 8, 7}}); r != "" {
t.Fatal(r)
}
}
{ // Will match server 1,2,3 and server 1,2 returns unexpected ip, then server 3 returns expected one
ips, err := client.LookupIP("api.google.com")
if err != nil {
t.Fatal("unexpected error: ", err)
}
if r := cmp.Diff(ips, []net.IP{{8, 8, 7, 7}}); r != "" {
t.Fatal(r)
}
}
endTime := time.Now()
if startTime.After(endTime.Add(time.Second * 2)) {
t.Error("DNS query doesn't finish in 2 seconds.")
}
}

View File

@ -82,7 +82,7 @@ func NewDomainMatcher(domains []*Domain) (*DomainMatcher, error) {
}
func (m *DomainMatcher) ApplyDomain(domain string) bool {
return m.matchers.Match(domain) > 0
return len(m.matchers.Match(domain)) > 0
}
func (m *DomainMatcher) Apply(ctx *Context) bool {

View File

@ -7,8 +7,8 @@ func breakDomain(domain string) []string {
}
type node struct {
value uint32
sub map[string]*node
values []uint32
sub map[string]*node
}
// DomainMatcherGroup is a IndexMatcher for a large set of Domain matchers.
@ -25,7 +25,7 @@ func (g *DomainMatcherGroup) Add(domain string, value uint32) {
current := g.root
parts := breakDomain(domain)
for i := len(parts) - 1; i >= 0; i-- {
if current.value > 0 {
if len(current.values) > 0 {
// if current node is already a match, it is not necessary to match further.
return
}
@ -42,7 +42,7 @@ func (g *DomainMatcherGroup) Add(domain string, value uint32) {
current = next
}
current.value = value
current.values = append(current.values, value)
current.sub = nil // shortcut sub nodes as current node is a match.
}
@ -50,14 +50,14 @@ func (g *DomainMatcherGroup) addMatcher(m domainMatcher, value uint32) {
g.Add(string(m), value)
}
func (g *DomainMatcherGroup) Match(domain string) uint32 {
func (g *DomainMatcherGroup) Match(domain string) []uint32 {
if domain == "" {
return 0
return nil
}
current := g.root
if current == nil {
return 0
return nil
}
nextPart := func(idx int) int {
@ -84,5 +84,5 @@ func (g *DomainMatcherGroup) Match(domain string) uint32 {
current = next
idx = nidx
}
return current.value
return current.values
}

View File

@ -1,6 +1,7 @@
package strmatcher_test
import (
"reflect"
"testing"
. "v2ray.com/core/common/strmatcher"
@ -13,48 +14,54 @@ func TestDomainMatcherGroup(t *testing.T) {
g.Add("x.a.com", 3)
g.Add("a.b.com", 4)
g.Add("c.a.b.com", 5)
g.Add("x.y.com", 4)
g.Add("x.y.com", 6)
testCases := []struct {
Domain string
Result uint32
Result []uint32
}{
{
Domain: "x.v2ray.com",
Result: 1,
Result: []uint32{1},
},
{
Domain: "y.com",
Result: 0,
Result: nil,
},
{
Domain: "a.b.com",
Result: 4,
Result: []uint32{4},
},
{
Domain: "c.a.b.com",
Result: 4,
Result: []uint32{4},
},
{
Domain: "c.a..b.com",
Result: 0,
Result: nil,
},
{
Domain: ".com",
Result: 0,
Result: nil,
},
{
Domain: "com",
Result: 0,
Result: nil,
},
{
Domain: "",
Result: 0,
Result: nil,
},
{
Domain: "x.y.com",
Result: []uint32{4, 6},
},
}
for _, testCase := range testCases {
r := g.Match(testCase.Domain)
if r != testCase.Result {
if !reflect.DeepEqual(r, testCase.Result) {
t.Error("Failed to match domain: ", testCase.Domain, ", expect ", testCase.Result, ", but got ", r)
}
}
@ -63,7 +70,7 @@ func TestDomainMatcherGroup(t *testing.T) {
func TestEmptyDomainMatcherGroup(t *testing.T) {
g := new(DomainMatcherGroup)
r := g.Match("v2ray.com")
if r != 0 {
t.Error("Expect 0, but ", r)
if len(r) != 0 {
t.Error("Expect [], but ", r)
}
}

View File

@ -1,24 +1,24 @@
package strmatcher
type FullMatcherGroup struct {
matchers map[string]uint32
matchers map[string][]uint32
}
func (g *FullMatcherGroup) Add(domain string, value uint32) {
if g.matchers == nil {
g.matchers = make(map[string]uint32)
g.matchers = make(map[string][]uint32)
}
g.matchers[domain] = value
g.matchers[domain] = append(g.matchers[domain], value)
}
func (g *FullMatcherGroup) addMatcher(m fullMatcher, value uint32) {
g.Add(string(m), value)
}
func (g *FullMatcherGroup) Match(str string) uint32 {
func (g *FullMatcherGroup) Match(str string) []uint32 {
if g.matchers == nil {
return 0
return nil
}
return g.matchers[str]

View File

@ -1,6 +1,7 @@
package strmatcher_test
import (
"reflect"
"testing"
. "v2ray.com/core/common/strmatcher"
@ -11,24 +12,30 @@ func TestFullMatcherGroup(t *testing.T) {
g.Add("v2ray.com", 1)
g.Add("google.com", 2)
g.Add("x.a.com", 3)
g.Add("x.y.com", 4)
g.Add("x.y.com", 6)
testCases := []struct {
Domain string
Result uint32
Result []uint32
}{
{
Domain: "v2ray.com",
Result: 1,
Result: []uint32{1},
},
{
Domain: "y.com",
Result: 0,
Result: nil,
},
{
Domain: "x.y.com",
Result: []uint32{4, 6},
},
}
for _, testCase := range testCases {
r := g.Match(testCase.Domain)
if r != testCase.Result {
if !reflect.DeepEqual(r, testCase.Result) {
t.Error("Failed to match domain: ", testCase.Domain, ", expect ", testCase.Result, ", but got ", r)
}
}
@ -37,7 +44,7 @@ func TestFullMatcherGroup(t *testing.T) {
func TestEmptyFullMatcherGroup(t *testing.T) {
g := new(FullMatcherGroup)
r := g.Match("v2ray.com")
if r != 0 {
t.Error("Expect 0, but ", r)
if len(r) != 0 {
t.Error("Expect [], but ", r)
}
}

View File

@ -49,7 +49,7 @@ func (t Type) New(pattern string) (Matcher, error) {
// IndexMatcher is the interface for matching with a group of matchers.
type IndexMatcher interface {
// Match returns the the index of a matcher that matches the input. It returns 0 if no such matcher exists.
Match(input string) uint32
Match(input string) []uint32
}
type matcherEntry struct {
@ -87,22 +87,16 @@ func (g *MatcherGroup) Add(m Matcher) uint32 {
}
// Match implements IndexMatcher.Match.
func (g *MatcherGroup) Match(pattern string) uint32 {
if c := g.fullMatcher.Match(pattern); c > 0 {
return c
}
if c := g.domainMatcher.Match(pattern); c > 0 {
return c
}
func (g *MatcherGroup) Match(pattern string) []uint32 {
result := []uint32{}
result = append(result, g.fullMatcher.Match(pattern)...)
result = append(result, g.domainMatcher.Match(pattern)...)
for _, e := range g.otherMatchers {
if e.m.Match(pattern) {
return e.id
result = append(result, e.id)
}
}
return 0
return result
}
// Size returns the number of matchers in the MatcherGroup.