From d013e8069d5beeb9a91c1c72b1c50b11215fe935 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Mon, 19 Nov 2018 13:13:02 +0100 Subject: [PATCH] switch to stdlib for dns queries --- app/dns/udpns.go | 167 +++++++++++++++++++++++++++-------------------- 1 file changed, 97 insertions(+), 70 deletions(-) diff --git a/app/dns/udpns.go b/app/dns/udpns.go index 20d99ffa2..1cc1b3e67 100644 --- a/app/dns/udpns.go +++ b/app/dns/udpns.go @@ -2,11 +2,12 @@ package dns import ( "context" + "encoding/binary" "sync" "sync/atomic" "time" - "github.com/miekg/dns" + "golang.org/x/net/dns/dnsmessage" "v2ray.com/core/common" "v2ray.com/core/common/buf" @@ -18,14 +19,6 @@ import ( "v2ray.com/core/transport/internet/udp" ) -var ( - multiQuestionDNS = map[net.Address]bool{ - net.IPAddress([]byte{8, 8, 8, 8}): true, - net.IPAddress([]byte{8, 8, 4, 4}): true, - net.IPAddress([]byte{9, 9, 9, 9}): true, - } -) - type IPRecord struct { IP net.IP Expire time.Time @@ -105,16 +98,15 @@ func (s *ClassicNameServer) Cleanup() error { } func (s *ClassicNameServer) HandleResponse(ctx context.Context, payload *buf.Buffer) { - msg := new(dns.Msg) - err := msg.Unpack(payload.Bytes()) - if err == dns.ErrTruncated { - newError("truncated message received. DNS server should still work. If you see anything abnormal, please submit an issue to v2ray-core.").AtWarning().WriteToLog() - } else if err != nil { + var parser dnsmessage.Parser + header, err := parser.Start(payload.Bytes()) + if err != nil { newError("failed to parse DNS response").Base(err).AtWarning().WriteToLog() return } + parser.SkipAllQuestions() - id := msg.Id + id := header.ID s.Lock() req, f := s.requests[id] if f { @@ -130,23 +122,35 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, payload *buf.Buf ips := make([]IPRecord, 0, 16) now := time.Now() - for _, rr := range msg.Answer { - var ip net.IP - ttl := rr.Header().Ttl - switch rr := rr.(type) { - case *dns.A: - ip = rr.A - case *dns.AAAA: - ip = rr.AAAA + for { + header, err := parser.AnswerHeader() + if err != nil { + break } + ttl := header.TTL if ttl == 0 { ttl = 600 } - if len(ip) > 0 { + switch header.Type { + case dnsmessage.TypeA: + ans, err := parser.AResource() + if err != nil { + break + } ips = append(ips, IPRecord{ - IP: ip, - Expire: now.Add(time.Second * time.Duration(ttl)), + IP: net.IP(ans.A[:]), + Expire: now.Add(time.Duration(ttl) * time.Second), }) + case dnsmessage.TypeAAAA: + ans, err := parser.AAAAResource() + if err != nil { + break + } + ips = append(ips, IPRecord{ + IP: net.IP(ans.AAAA[:]), + Expire: now.Add(time.Duration(ttl) * time.Second), + }) + default: } } @@ -173,31 +177,52 @@ func (s *ClassicNameServer) updateIP(domain string, ips []IPRecord) { common.Must(s.cleanup.Start()) } -func (s *ClassicNameServer) getMsgOptions() *dns.OPT { +func (s *ClassicNameServer) getMsgOptions() *dnsmessage.Resource { if len(s.clientIP) == 0 { return nil } - o := new(dns.OPT) - o.Hdr.Name = "." - o.Hdr.Rrtype = dns.TypeOPT - o.SetUDPSize(1350) + var netmask int + var family uint16 - e := new(dns.EDNS0_SUBNET) - e.Code = dns.EDNS0SUBNET if len(s.clientIP) == 4 { - e.Family = 1 // 1 for IPv4 source address, 2 for IPv6 - e.SourceNetmask = 24 // 32 for IPV4, 128 for IPv6 + family = 1 + netmask = 24 // 24 for IPV4, 96 for IPv6 } else { - e.Family = 2 - e.SourceNetmask = 96 + family = 2 + netmask = 96 } - e.SourceScope = 0 - e.Address = s.clientIP - o.Option = append(o.Option, e) + b := make([]byte, 4) + binary.BigEndian.PutUint16(b[0:], family) + b[2] = byte(netmask) + b[3] = 0 + switch family { + case 1: + ip := s.clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8)) + needLength := (netmask + 8 - 1) / 8 // division rounding up + b = append(b, ip[:needLength]...) + case 2: + ip := s.clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8)) + needLength := (netmask + 8 - 1) / 8 // division rounding up + b = append(b, ip[:needLength]...) + } - return o + const EDNS0SUBNET = 0x08 + + opt := new(dnsmessage.Resource) + common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true)) + + opt.Body = &dnsmessage.OPTResource{ + Options: []dnsmessage.Option{ + { + Code: EDNS0SUBNET, + Data: b, + }, + }, + } + + return opt } func (s *ClassicNameServer) addPendingRequest(domain string) uint16 { @@ -213,44 +238,39 @@ func (s *ClassicNameServer) addPendingRequest(domain string) uint16 { return id } -func (s *ClassicNameServer) buildMsgs(domain string) []*dns.Msg { - allowMulti := multiQuestionDNS[s.address.Address] - - qA := dns.Question{ - Name: domain, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, +func (s *ClassicNameServer) buildMsgs(domain string) []*dnsmessage.Message { + qA := dnsmessage.Question{ + Name: dnsmessage.MustNewName(domain), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, } - qAAAA := dns.Question{ - Name: domain, - Qtype: dns.TypeAAAA, - Qclass: dns.ClassINET, + qAAAA := dnsmessage.Question{ + Name: dnsmessage.MustNewName(domain), + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, } - var msgs []*dns.Msg + var msgs []*dnsmessage.Message { - msg := new(dns.Msg) - msg.Id = s.addPendingRequest(domain) - msg.RecursionDesired = true - msg.Question = []dns.Question{qA} - if allowMulti { - msg.Question = append(msg.Question, qAAAA) - } + msg := new(dnsmessage.Message) + msg.Header.ID = s.addPendingRequest(domain) + msg.Header.RecursionDesired = true + msg.Questions = []dnsmessage.Question{qA} if opt := s.getMsgOptions(); opt != nil { - msg.Extra = append(msg.Extra, opt) + msg.Additionals = append(msg.Additionals, *opt) } msgs = append(msgs, msg) } - if !allowMulti { - msg := new(dns.Msg) - msg.Id = s.addPendingRequest(domain) - msg.RecursionDesired = true - msg.Question = []dns.Question{qAAAA} + { + msg := new(dnsmessage.Message) + msg.Header.ID = s.addPendingRequest(domain) + msg.Header.RecursionDesired = true + msg.Questions = []dnsmessage.Question{qAAAA} if opt := s.getMsgOptions(); opt != nil { - msg.Extra = append(msg.Extra, opt) + msg.Additionals = append(msg.Additionals, *opt) } msgs = append(msgs, msg) } @@ -258,10 +278,10 @@ func (s *ClassicNameServer) buildMsgs(domain string) []*dns.Msg { return msgs } -func msgToBuffer(msg *dns.Msg) (*buf.Buffer, error) { +func msgToBuffer(msg *dnsmessage.Message) (*buf.Buffer, error) { buffer := buf.New() rawBytes := buffer.Extend(buf.Size) - packed, err := msg.PackBuffer(rawBytes) + packed, err := msg.AppendPack(rawBytes[:0]) if err != nil { buffer.Release() return nil, err @@ -300,8 +320,15 @@ func (s *ClassicNameServer) findIPsForDomain(domain string) []net.IP { return nil } +func Fqdn(domain string) string { + if len(domain) > 0 && domain[len(domain)-1] == '.' { + return domain + } + return domain + "." +} + func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string) ([]net.IP, error) { - fqdn := dns.Fqdn(domain) + fqdn := Fqdn(domain) ips := s.findIPsForDomain(fqdn) if len(ips) > 0 {