From bb1efdebd1e1190919f60bcc9b7778fdf6d69dfd Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Mon, 19 Nov 2018 20:42:02 +0100 Subject: [PATCH] support querying either IPv4 or IPv6 dns --- app/dns/hosts.go | 17 +++++++- app/dns/hosts_test.go | 10 ++++- app/dns/nameserver.go | 34 +++++++++------ app/dns/nameserver_test.go | 5 ++- app/dns/server.go | 33 ++++++++++++--- app/dns/server_test.go | 65 +++++++++++++++++++++++++++++ app/dns/udpns.go | 26 ++++++------ features/dns/client.go | 33 ++++++--------- features/dns/localdns/client.go | 73 +++++++++++++++++++++++++++++++++ v2ray.go | 3 +- 10 files changed, 242 insertions(+), 57 deletions(-) create mode 100644 features/dns/localdns/client.go diff --git a/app/dns/hosts.go b/app/dns/hosts.go index c4dcdbc70..2871c1e7f 100644 --- a/app/dns/hosts.go +++ b/app/dns/hosts.go @@ -73,11 +73,24 @@ func NewStaticHosts(hosts []*Config_HostMapping, legacy map[string]*net.IPOrDoma return sh, nil } +func filterIP(ips []net.IP, option IPOption) []net.IP { + filtered := make([]net.IP, 0, len(ips)) + for _, ip := range ips { + if (len(ip) == net.IPv4len && option.IPv4Enable) || (len(ip) == net.IPv6len && option.IPv6Enable) { + filtered = append(filtered, ip) + } + } + if len(filtered) == 0 { + return nil + } + return filtered +} + // LookupIP returns IP address for the given domain, if exists in this StaticHosts. -func (h *StaticHosts) LookupIP(domain string) []net.IP { +func (h *StaticHosts) LookupIP(domain string, option IPOption) []net.IP { id := h.matchers.Match(domain) if id == 0 { return nil } - return h.ips[id] + return filterIP(h.ips[id], option) } diff --git a/app/dns/hosts_test.go b/app/dns/hosts_test.go index 87708c474..c4a9b15f9 100644 --- a/app/dns/hosts_test.go +++ b/app/dns/hosts_test.go @@ -31,7 +31,10 @@ func TestStaticHosts(t *testing.T) { common.Must(err) { - ips := hosts.LookupIP("v2ray.com") + ips := hosts.LookupIP("v2ray.com", IPOption{ + IPv4Enable: true, + IPv6Enable: true, + }) if len(ips) != 1 { t.Error("expect 1 IP, but got ", len(ips)) } @@ -41,7 +44,10 @@ func TestStaticHosts(t *testing.T) { } { - ips := hosts.LookupIP("www.v2ray.cn") + ips := hosts.LookupIP("www.v2ray.cn", IPOption{ + IPv4Enable: true, + IPv6Enable: true, + }) if len(ips) != 1 { t.Error("expect 1 IP, but got ", len(ips)) } diff --git a/app/dns/nameserver.go b/app/dns/nameserver.go index 8517be043..47412191c 100644 --- a/app/dns/nameserver.go +++ b/app/dns/nameserver.go @@ -4,32 +4,40 @@ import ( "context" "v2ray.com/core/common/net" + "v2ray.com/core/features/dns/localdns" ) +type IPOption struct { + IPv4Enable bool + IPv6Enable bool +} + type NameServerInterface interface { - QueryIP(ctx context.Context, domain string) ([]net.IP, error) + QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) } type localNameServer struct { - resolver net.Resolver + client *localdns.Client } -func (s *localNameServer) QueryIP(ctx context.Context, domain string) ([]net.IP, error) { - ipAddr, err := s.resolver.LookupIPAddr(ctx, domain) - if err != nil { - return nil, err +func (s *localNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) { + if option.IPv4Enable && option.IPv6Enable { + return s.client.LookupIP(domain) } - var ips []net.IP - for _, addr := range ipAddr { - ips = append(ips, addr.IP) + + if option.IPv4Enable { + return s.client.LookupIPv4(domain) } - return ips, nil + + if option.IPv6Enable { + return s.client.LookupIPv6(domain) + } + + return nil, newError("neither IPv4 nor IPv6 is enabled") } func NewLocalNameServer() *localNameServer { return &localNameServer{ - resolver: net.Resolver{ - PreferGo: true, - }, + client: localdns.New(), } } diff --git a/app/dns/nameserver_test.go b/app/dns/nameserver_test.go index 46ae837e4..221d1451c 100644 --- a/app/dns/nameserver_test.go +++ b/app/dns/nameserver_test.go @@ -12,7 +12,10 @@ import ( func TestLocalNameServer(t *testing.T) { s := NewLocalNameServer() ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) - ips, err := s.QueryIP(ctx, "google.com") + ips, err := s.QueryIP(ctx, "google.com", IPOption{ + IPv4Enable: true, + IPv6Enable: true, + }) cancel() common.Must(err) if len(ips) == 0 { diff --git a/app/dns/server.go b/app/dns/server.go index c4e1bd274..96c5b4314 100644 --- a/app/dns/server.go +++ b/app/dns/server.go @@ -116,16 +116,39 @@ func (s *Server) Close() error { return nil } -func (s *Server) queryIPTimeout(server NameServerInterface, domain string) ([]net.IP, error) { +func (s *Server) queryIPTimeout(server NameServerInterface, domain string, option IPOption) ([]net.IP, error) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*4) - ips, err := server.QueryIP(ctx, domain) + ips, err := server.QueryIP(ctx, domain, option) cancel() return ips, err } // LookupIP implements dns.Client. func (s *Server) LookupIP(domain string) ([]net.IP, error) { - if ip := s.hosts.LookupIP(domain); len(ip) > 0 { + return s.lookupIPInternal(domain, IPOption{ + IPv4Enable: true, + IPv6Enable: true, + }) +} + +// LookupIPv4 implements dns.IPv4Lookup. +func (s *Server) LookupIPv4(domain string) ([]net.IP, error) { + return s.lookupIPInternal(domain, IPOption{ + IPv4Enable: true, + IPv6Enable: false, + }) +} + +// LookupIPv6 implements dns.IPv6Lookup. +func (s *Server) LookupIPv6(domain string) ([]net.IP, error) { + return s.lookupIPInternal(domain, IPOption{ + IPv4Enable: false, + IPv6Enable: true, + }) +} + +func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, error) { + if ip := s.hosts.LookupIP(domain, option); len(ip) > 0 { return ip, nil } @@ -134,7 +157,7 @@ func (s *Server) LookupIP(domain string) ([]net.IP, error) { idx := s.domainMatcher.Match(domain) if idx > 0 { ns := s.servers[s.domainIndexMap[idx]] - ips, err := s.queryIPTimeout(ns, domain) + ips, err := s.queryIPTimeout(ns, domain, option) if len(ips) > 0 { return ips, nil } @@ -145,7 +168,7 @@ func (s *Server) LookupIP(domain string) ([]net.IP, error) { } for _, server := range s.servers { - ips, err := s.queryIPTimeout(server, domain) + ips, err := s.queryIPTimeout(server, domain, option) if len(ips) > 0 { return ips, nil } diff --git a/app/dns/server_test.go b/app/dns/server_test.go index a9bdb771a..1654dfed0 100644 --- a/app/dns/server_test.go +++ b/app/dns/server_test.go @@ -11,6 +11,7 @@ import ( "v2ray.com/core/app/policy" "v2ray.com/core/app/proxyman" _ "v2ray.com/core/app/proxyman/outbound" + "v2ray.com/core/common" "v2ray.com/core/common/net" "v2ray.com/core/common/serial" feature_dns "v2ray.com/core/features/dns" @@ -52,6 +53,14 @@ func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } else if q.Name == "facebook.com." && q.Qtype == dns.TypeA { rr, _ := dns.NewRR("facebook.com. IN A 9.9.9.9") ans.Answer = append(ans.Answer, rr) + } else if q.Name == "ipv6.google.com." && q.Qtype == dns.TypeA { + rr, err := dns.NewRR("ipv6.google.com. IN A 8.8.8.7") + common.Must(err) + ans.Answer = append(ans.Answer, rr) + } else if q.Name == "ipv6.google.com." && q.Qtype == dns.TypeAAAA { + rr, err := dns.NewRR("ipv6.google.com. IN AAAA 2001:4860:4860::8888") + common.Must(err) + ans.Answer = append(ans.Answer, rr) } } w.WriteMsg(ans) @@ -259,3 +268,59 @@ func TestPrioritizedDomain(t *testing.T) { t.Error("DNS query doesn't finish in 2 seconds.") } } + +func TestUDPServerIPv6(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("doesn't work on Windows due to miekg/dns changes.") + } + assert := With(t) + + port := udp.PickPort() + + dnsServer := dns.Server{ + Addr: "127.0.0.1:" + port.String(), + Net: "udp", + Handler: &staticHandler{}, + UDPSize: 1200, + } + + go dnsServer.ListenAndServe() + time.Sleep(time.Second) + + config := &core.Config{ + App: []*serial.TypedMessage{ + serial.ToTypedMessage(&Config{ + NameServers: []*net.Endpoint{ + { + Network: net.Network_UDP, + Address: &net.IPOrDomain{ + Address: &net.IPOrDomain_Ip{ + Ip: []byte{127, 0, 0, 1}, + }, + }, + Port: uint32(port), + }, + }, + }), + serial.ToTypedMessage(&dispatcher.Config{}), + serial.ToTypedMessage(&proxyman.OutboundConfig{}), + serial.ToTypedMessage(&policy.Config{}), + }, + Outbound: []*core.OutboundHandlerConfig{ + { + ProxySettings: serial.ToTypedMessage(&freedom.Config{}), + }, + }, + } + + v, err := core.New(config) + assert(err, IsNil) + + client := v.GetFeature(feature_dns.ClientType()).(feature_dns.Client) + client6 := client.(feature_dns.IPv6Lookup) + + ips, err := client6.LookupIPv6("ipv6.google.com") + assert(err, IsNil) + assert(len(ips), Equals, 1) + assert([]byte(ips[0]), Equals, []byte{32, 1, 72, 96, 72, 96, 0, 0, 0, 0, 0, 0, 0, 0, 136, 136}) +} diff --git a/app/dns/udpns.go b/app/dns/udpns.go index f9bd8353c..b940d1c21 100644 --- a/app/dns/udpns.go +++ b/app/dns/udpns.go @@ -241,7 +241,7 @@ func (s *ClassicNameServer) addPendingRequest(domain string) uint16 { return id } -func (s *ClassicNameServer) buildMsgs(domain string) []*dnsmessage.Message { +func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmessage.Message { qA := dnsmessage.Question{ Name: dnsmessage.MustNewName(domain), Type: dnsmessage.TypeA, @@ -256,7 +256,7 @@ func (s *ClassicNameServer) buildMsgs(domain string) []*dnsmessage.Message { var msgs []*dnsmessage.Message - { + if option.IPv4Enable { msg := new(dnsmessage.Message) msg.Header.ID = s.addPendingRequest(domain) msg.Header.RecursionDesired = true @@ -267,7 +267,7 @@ func (s *ClassicNameServer) buildMsgs(domain string) []*dnsmessage.Message { msgs = append(msgs, msg) } - { + if option.IPv6Enable { msg := new(dnsmessage.Message) msg.Header.ID = s.addPendingRequest(domain) msg.Header.RecursionDesired = true @@ -281,7 +281,7 @@ func (s *ClassicNameServer) buildMsgs(domain string) []*dnsmessage.Message { return msgs } -func msgToBuffer(msg *dnsmessage.Message) (*buf.Buffer, error) { +func msgToBuffer2(msg *dnsmessage.Message) (*buf.Buffer, error) { buffer := buf.New() rawBytes := buffer.Extend(buf.Size) packed, err := msg.AppendPack(rawBytes[:0]) @@ -293,19 +293,19 @@ func msgToBuffer(msg *dnsmessage.Message) (*buf.Buffer, error) { return buffer, nil } -func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string) { +func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, option IPOption) { newError("querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx)) - msgs := s.buildMsgs(domain) + msgs := s.buildMsgs(domain, option) for _, msg := range msgs { - b, err := msgToBuffer(msg) + b, err := msgToBuffer2(msg) common.Must(err) s.udpServer.Dispatch(context.Background(), s.address, b) } } -func (s *ClassicNameServer) findIPsForDomain(domain string) []net.IP { +func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) []net.IP { s.RLock() records, found := s.ips[domain] s.RUnlock() @@ -318,7 +318,7 @@ func (s *ClassicNameServer) findIPsForDomain(domain string) []net.IP { ips = append(ips, rec.IP) } } - return ips + return filterIP(ips, option) } return nil } @@ -330,10 +330,10 @@ func Fqdn(domain string) string { return domain + "." } -func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string) ([]net.IP, error) { +func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) { fqdn := Fqdn(domain) - ips := s.findIPsForDomain(fqdn) + ips := s.findIPsForDomain(fqdn, option) if len(ips) > 0 { return ips, nil } @@ -341,10 +341,10 @@ func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string) ([]net.I sub := s.pub.Subscribe(fqdn) defer sub.Close() - s.sendQuery(ctx, fqdn) + s.sendQuery(ctx, fqdn, option) for { - ips := s.findIPsForDomain(fqdn) + ips := s.findIPsForDomain(fqdn, option) if len(ips) > 0 { return ips, nil } diff --git a/features/dns/client.go b/features/dns/client.go index 7f3c274a6..7b2ec84e2 100644 --- a/features/dns/client.go +++ b/features/dns/client.go @@ -8,29 +8,22 @@ import ( // Client is a V2Ray feature for querying DNS information. type Client interface { features.Feature - LookupIP(host string) ([]net.IP, error) + + // LookupIP returns IP address for the given domain. IPs may contain IPv4 and/or IPv6 addresses. + LookupIP(domain string) ([]net.IP, error) +} + +// IPv4Lookup is an optional feature for querying IPv4 addresses only. +type IPv4Lookup interface { + LookupIPv4(domain string) ([]net.IP, error) +} + +// IPv6Lookup is an optional feature for querying IPv6 addresses only. +type IPv6Lookup interface { + LookupIPv6(domain string) ([]net.IP, error) } // ClientType returns the type of Client interface. Can be used for implementing common.HasType. func ClientType() interface{} { return (*Client)(nil) } - -// LocalClient is an implementation of Client, which queries localhost for DNS. -type LocalClient struct{} - -// Type implements common.HasType. -func (LocalClient) Type() interface{} { - return ClientType() -} - -// Start implements common.Runnable. -func (LocalClient) Start() error { return nil } - -// Close implements common.Closable. -func (LocalClient) Close() error { return nil } - -// LookupIP implements Client. -func (LocalClient) LookupIP(host string) ([]net.IP, error) { - return net.LookupIP(host) -} diff --git a/features/dns/localdns/client.go b/features/dns/localdns/client.go new file mode 100644 index 000000000..de1c91067 --- /dev/null +++ b/features/dns/localdns/client.go @@ -0,0 +1,73 @@ +package localdns + +import ( + "context" + "net" + + "v2ray.com/core/features/dns" +) + +// Client is an implementation of dns.Client, which queries localhost for DNS. +type Client struct { + resolver net.Resolver +} + +// Type implements common.HasType. +func (*Client) Type() interface{} { + return dns.ClientType() +} + +// Start implements common.Runnable. +func (*Client) Start() error { return nil } + +// Close implements common.Closable. +func (*Client) Close() error { return nil } + +// LookupIP implements Client. +func (c *Client) LookupIP(host string) ([]net.IP, error) { + ipAddr, err := c.resolver.LookupIPAddr(context.Background(), host) + if err != nil { + return nil, err + } + ips := make([]net.IP, 0, len(ipAddr)) + for _, addr := range ipAddr { + ips = append(ips, addr.IP) + } + return ips, nil +} + +func (c *Client) LookupIPv4(host string) ([]net.IP, error) { + ips, err := c.LookupIP(host) + if err != nil { + return nil, err + } + var ipv4 []net.IP + for _, ip := range ips { + if len(ip) == net.IPv4len { + ipv4 = append(ipv4, ip) + } + } + return ipv4, nil +} + +func (c *Client) LookupIPv6(host string) ([]net.IP, error) { + ips, err := c.LookupIP(host) + if err != nil { + return nil, err + } + var ipv6 []net.IP + for _, ip := range ips { + if len(ip) == net.IPv6len { + ipv6 = append(ipv6, ip) + } + } + return ipv6, nil +} + +func New() *Client { + return &Client{ + resolver: net.Resolver{ + PreferGo: true, + }, + } +} diff --git a/v2ray.go b/v2ray.go index ac777240f..c8fbc9bff 100755 --- a/v2ray.go +++ b/v2ray.go @@ -9,6 +9,7 @@ import ( "v2ray.com/core/common/serial" "v2ray.com/core/features" "v2ray.com/core/features/dns" + "v2ray.com/core/features/dns/localdns" "v2ray.com/core/features/inbound" "v2ray.com/core/features/outbound" "v2ray.com/core/features/policy" @@ -183,7 +184,7 @@ func New(config *Config) (*Instance, error) { Type interface{} Instance features.Feature }{ - {dns.ClientType(), dns.LocalClient{}}, + {dns.ClientType(), localdns.New()}, {policy.ManagerType(), policy.DefaultManager{}}, {routing.RouterType(), routing.DefaultRouter{}}, {stats.ManagerType(), stats.NoopManager{}},