correctly propagate dns errors all the way through.

the internal dns system can correctly handle the cases where:
1) domain has no A or AAAA records
2) domain doesn't exist
fixes #1565
This commit is contained in:
Darien Raymond 2019-02-21 13:43:48 +01:00
parent c27050ad90
commit 9957c64b4a
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
7 changed files with 253 additions and 79 deletions

View File

@ -226,6 +226,9 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
if len(ips) > 0 {
return ips, nil
}
if err == dns.ErrEmptyResponse {
return nil, err
}
if err != nil {
newError("failed to lookup ip for domain ", domain, " at server ", ns.Name()).Base(err).WriteToLog()
lastErr = err
@ -238,6 +241,9 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
if len(ips) > 0 {
return ips, nil
}
if err == dns.ErrEmptyResponse {
return nil, err
}
if err != nil {
newError("failed to lookup ip for domain ", domain, " at server ", client.Name()).Base(err).WriteToLog()
lastErr = err

View File

@ -60,6 +60,8 @@ func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
rr, err := dns.NewRR("ipv6.google.com. IN AAAA 2001:4860:4860::8888")
common.Must(err)
ans.Answer = append(ans.Answer, rr)
} else if q.Name == "notexist.google.com." && q.Qtype == dns.TypeAAAA {
ans.MsgHdr.Rcode = dns.RcodeNameError
}
}
w.WriteMsg(ans)
@ -186,6 +188,27 @@ func TestUDPServer(t *testing.T) {
}
}
{
_, err := client.LookupIP("notexist.google.com")
if err == nil {
t.Fatal("nil error")
}
if r := feature_dns.RCodeFromError(err); r != uint16(dns.RcodeNameError) {
t.Fatal("expected NameError, but got ", r)
}
}
{
clientv6 := client.(feature_dns.IPv6Lookup)
ips, err := clientv6.LookupIPv6("ipv4only.google.com")
if err != feature_dns.ErrEmptyResponse {
t.Fatal("error: ", err)
}
if len(ips) != 0 {
t.Fatal("ips: ", ips)
}
}
dnsServer.Shutdown()
{

View File

@ -5,36 +5,60 @@ package dns
import (
"context"
"encoding/binary"
fmt "fmt"
"sync"
"sync/atomic"
"time"
"golang.org/x/net/dns/dnsmessage"
"v2ray.com/core/common"
"v2ray.com/core/common/errors"
"v2ray.com/core/common/net"
"v2ray.com/core/common/protocol/dns"
udp_proto "v2ray.com/core/common/protocol/udp"
"v2ray.com/core/common/session"
"v2ray.com/core/common/signal/pubsub"
"v2ray.com/core/common/task"
dns_feature "v2ray.com/core/features/dns"
"v2ray.com/core/features/routing"
"v2ray.com/core/transport/internet/udp"
)
type record struct {
A *IPRecord
AAAA *IPRecord
}
type IPRecord struct {
IP net.Address
IP []net.Address
Expire time.Time
RCode dnsmessage.RCode
}
func (r *IPRecord) getIPs() ([]net.Address, error) {
if r == nil || r.Expire.Before(time.Now()) {
return nil, errRecordNotFound
}
if r.RCode != dnsmessage.RCodeSuccess {
return nil, dns_feature.RCodeError(r.RCode)
}
return r.IP, nil
}
type pendingRequest struct {
domain string
expire time.Time
domain string
expire time.Time
recType dnsmessage.Type
}
var (
errRecordNotFound = errors.New("record not found")
)
type ClassicNameServer struct {
sync.RWMutex
address net.Destination
ips map[string][]IPRecord
ips map[string]record
requests map[uint16]pendingRequest
pub *pubsub.Service
udpServer *udp.Dispatcher
@ -46,7 +70,7 @@ type ClassicNameServer struct {
func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, clientIP net.IP) *ClassicNameServer {
s := &ClassicNameServer{
address: address,
ips: make(map[string][]IPRecord),
ips: make(map[string]record),
requests: make(map[uint16]pendingRequest),
clientIP: clientIP,
pub: pubsub.NewService(),
@ -72,22 +96,23 @@ func (s *ClassicNameServer) Cleanup() error {
return newError("nothing to do. stopping...")
}
for domain, ips := range s.ips {
newIPs := make([]IPRecord, 0, len(ips))
for _, ip := range ips {
if ip.Expire.After(now) {
newIPs = append(newIPs, ip)
}
for domain, record := range s.ips {
if record.A != nil && record.A.Expire.Before(now) {
record.A = nil
}
if len(newIPs) == 0 {
if record.AAAA != nil && record.AAAA.Expire.Before(now) {
record.AAAA = nil
}
if record.A == nil && record.AAAA == nil {
delete(s.ips, domain)
} else if len(newIPs) < len(ips) {
s.ips[domain] = newIPs
} else {
s.ips[domain] = record
}
}
if len(s.ips) == 0 {
s.ips = make(map[string][]IPRecord)
s.ips = make(map[string]record)
}
for id, req := range s.requests {
@ -130,9 +155,14 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
}
domain := req.domain
ips := make([]IPRecord, 0, 16)
recType := req.recType
now := time.Now()
ipRecord := &IPRecord{
RCode: header.RCode,
Expire: now.Add(time.Second * 600),
}
for {
header, err := parser.AnswerHeader()
if err != nil {
@ -145,6 +175,15 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
if ttl == 0 {
ttl = 600
}
expire := now.Add(time.Duration(ttl) * time.Second)
if ipRecord.Expire.After(expire) {
ipRecord.Expire = expire
}
if header.Type != recType {
continue
}
switch header.Type {
case dnsmessage.TypeA:
ans, err := parser.AResource()
@ -152,20 +191,14 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
break
}
ips = append(ips, IPRecord{
IP: net.IPAddress(ans.A[:]),
Expire: now.Add(time.Duration(ttl) * time.Second),
})
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
case dnsmessage.TypeAAAA:
ans, err := parser.AAAAResource()
if err != nil {
newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
break
}
ips = append(ips, IPRecord{
IP: net.IPAddress(ans.AAAA[:]),
Expire: now.Add(time.Duration(ttl) * time.Second),
})
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
default:
if err := parser.SkipAnswer(); err != nil {
newError("failed to skip answer").Base(err).WriteToLog()
@ -173,24 +206,49 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
}
}
if len(domain) > 0 && len(ips) > 0 {
s.updateIP(domain, ips)
var rec record
switch recType {
case dnsmessage.TypeA:
rec.A = ipRecord
case dnsmessage.TypeAAAA:
rec.AAAA = ipRecord
}
if len(domain) > 0 && (rec.A != nil || rec.AAAA != nil) {
s.updateIP(domain, rec)
}
}
func (s *ClassicNameServer) updateIP(domain string, ips []IPRecord) {
func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
if newRec == nil {
return false
}
if baseRec == nil {
return true
}
return baseRec.Expire.Before(newRec.Expire)
}
func (s *ClassicNameServer) updateIP(domain string, newRec record) {
s.Lock()
newError("updating IP records for domain:", domain).AtDebug().WriteToLog()
now := time.Now()
eips := s.ips[domain]
for _, ip := range eips {
if ip.Expire.After(now) {
ips = append(ips, ip)
}
rec := s.ips[domain]
updated := false
if isNewer(rec.A, newRec.A) {
rec.A = newRec.A
updated = true
}
if isNewer(rec.AAAA, newRec.AAAA) {
rec.AAAA = newRec.AAAA
updated = true
}
if updated {
s.ips[domain] = rec
s.pub.Publish(domain, nil)
}
s.ips[domain] = ips
s.pub.Publish(domain, nil)
s.Unlock()
common.Must(s.cleanup.Start())
@ -244,14 +302,15 @@ func (s *ClassicNameServer) getMsgOptions() *dnsmessage.Resource {
return opt
}
func (s *ClassicNameServer) addPendingRequest(domain string) uint16 {
func (s *ClassicNameServer) addPendingRequest(domain string, recType dnsmessage.Type) uint16 {
id := uint16(atomic.AddUint32(&s.reqID, 1))
s.Lock()
defer s.Unlock()
s.requests[id] = pendingRequest{
domain: domain,
expire: time.Now().Add(time.Second * 8),
domain: domain,
expire: time.Now().Add(time.Second * 8),
recType: recType,
}
return id
@ -274,7 +333,7 @@ func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmess
if option.IPv4Enable {
msg := new(dnsmessage.Message)
msg.Header.ID = s.addPendingRequest(domain)
msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeA)
msg.Header.RecursionDesired = true
msg.Questions = []dnsmessage.Question{qA}
if opt := s.getMsgOptions(); opt != nil {
@ -285,7 +344,7 @@ func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmess
if option.IPv6Enable {
msg := new(dnsmessage.Message)
msg.Header.ID = s.addPendingRequest(domain)
msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeAAAA)
msg.Header.RecursionDesired = true
msg.Questions = []dnsmessage.Question{qAAAA}
if opt := s.getMsgOptions(); opt != nil {
@ -313,22 +372,44 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, option
}
}
func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) []net.IP {
func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) ([]net.IP, error) {
s.RLock()
records, found := s.ips[domain]
record, found := s.ips[domain]
s.RUnlock()
if found && len(records) > 0 {
var ips []net.Address
now := time.Now()
for _, rec := range records {
if rec.Expire.After(now) {
ips = append(ips, rec.IP)
}
}
return toNetIP(filterIP(ips, option))
if !found {
return nil, errRecordNotFound
}
return nil
var ips []net.Address
var lastErr error
if option.IPv4Enable {
a, err := record.A.getIPs()
if err != nil {
lastErr = err
}
ips = append(ips, a...)
}
if option.IPv6Enable {
aaaa, err := record.AAAA.getIPs()
if err != nil {
lastErr = err
}
ips = append(ips, aaaa...)
}
fmt.Println("IPs for ", domain, ": ", ips)
if len(ips) > 0 {
return toNetIP(ips), nil
}
if lastErr != nil {
return nil, lastErr
}
return nil, dns_feature.ErrEmptyResponse
}
func Fqdn(domain string) string {
@ -341,9 +422,9 @@ func Fqdn(domain string) string {
func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) {
fqdn := Fqdn(domain)
ips := s.findIPsForDomain(fqdn, option)
if len(ips) > 0 {
return ips, nil
ips, err := s.findIPsForDomain(fqdn, option)
if err != errRecordNotFound {
return ips, err
}
sub := s.pub.Subscribe(fqdn)
@ -352,9 +433,9 @@ func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option I
s.sendQuery(ctx, fqdn, option)
for {
ips := s.findIPsForDomain(fqdn, option)
if len(ips) > 0 {
return ips, nil
ips, err := s.findIPsForDomain(fqdn, option)
if err != errRecordNotFound {
return ips, err
}
select {

View File

@ -1,7 +1,9 @@
package dns
import (
"v2ray.com/core/common/errors"
"v2ray.com/core/common/net"
"v2ray.com/core/common/serial"
"v2ray.com/core/features"
)
@ -35,3 +37,23 @@ type IPv6Lookup interface {
func ClientType() interface{} {
return (*Client)(nil)
}
// ErrEmptyResponse indicates that DNS query succeeded but no answer was returned.
var ErrEmptyResponse = errors.New("empty response")
type RCodeError uint16
func (e RCodeError) Error() string {
return serial.Concat("rcode: ", uint16(e))
}
func RCodeFromError(err error) uint16 {
if err == nil {
return 0
}
cause := errors.Cause(err)
if r, ok := cause.(RCodeError); ok {
return uint16(r)
}
return 0
}

View File

@ -32,6 +32,9 @@ func (*Client) LookupIP(host string) ([]net.IP, error) {
parsedIPs = append(parsedIPs, parsed.IP())
}
}
if len(parsedIPs) == 0 {
return nil, dns.ErrEmptyResponse
}
return parsedIPs, nil
}
@ -47,6 +50,9 @@ func (c *Client) LookupIPv4(host string) ([]net.IP, error) {
ipv4 = append(ipv4, ip)
}
}
if len(ipv4) == 0 {
return nil, dns.ErrEmptyResponse
}
return ipv4, nil
}
@ -62,6 +68,9 @@ func (c *Client) LookupIPv6(host string) ([]net.IP, error) {
ipv6 = append(ipv6, ip)
}
}
if len(ipv6) == 0 {
return nil, dns.ErrEmptyResponse
}
return ipv6, nil
}

View File

@ -218,20 +218,17 @@ func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string,
ips, err = h.ipv6Lookup.LookupIPv6(domain)
}
if err != nil {
rcode := dns.RCodeFromError(err)
if rcode == 0 && len(ips) == 0 && err != dns.ErrEmptyResponse {
newError("ip query").Base(err).WriteToLog()
return
}
if len(ips) == 0 {
return
}
b := buf.New()
rawBytes := b.Extend(buf.Size)
builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{
ID: id,
RCode: dnsmessage.RCodeSuccess,
RCode: dnsmessage.RCode(rcode),
RecursionAvailable: true,
RecursionDesired: true,
Response: true,

View File

@ -63,6 +63,8 @@ func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
rr, err := dns.NewRR("ipv6.google.com. IN AAAA 2001:4860:4860::8888")
common.Must(err)
ans.Answer = append(ans.Answer, rr)
} else if q.Name == "notexist.google.com." && q.Qtype == dns.TypeAAAA {
ans.MsgHdr.Rcode = dns.RcodeNameError
}
}
w.WriteMsg(ans)
@ -128,26 +130,60 @@ func TestUDPDNSTunnel(t *testing.T) {
common.Must(v.Start())
defer v.Close()
m1 := new(dns.Msg)
m1.Id = dns.Id()
m1.RecursionDesired = true
m1.Question = make([]dns.Question, 1)
m1.Question[0] = dns.Question{"google.com.", dns.TypeA, dns.ClassINET}
{
m1 := new(dns.Msg)
m1.Id = dns.Id()
m1.RecursionDesired = true
m1.Question = make([]dns.Question, 1)
m1.Question[0] = dns.Question{"google.com.", dns.TypeA, dns.ClassINET}
c := new(dns.Client)
in, _, err := c.Exchange(m1, "127.0.0.1:"+strconv.Itoa(int(serverPort)))
common.Must(err)
c := new(dns.Client)
in, _, err := c.Exchange(m1, "127.0.0.1:"+strconv.Itoa(int(serverPort)))
common.Must(err)
if len(in.Answer) != 1 {
t.Fatal("len(answer): ", len(in.Answer))
if len(in.Answer) != 1 {
t.Fatal("len(answer): ", len(in.Answer))
}
rr, ok := in.Answer[0].(*dns.A)
if !ok {
t.Fatal("not A record")
}
if r := cmp.Diff(rr.A[:], net.IP{8, 8, 8, 8}); r != "" {
t.Error(r)
}
}
rr, ok := in.Answer[0].(*dns.A)
if !ok {
t.Fatal("not A record")
{
m1 := new(dns.Msg)
m1.Id = dns.Id()
m1.RecursionDesired = true
m1.Question = make([]dns.Question, 1)
m1.Question[0] = dns.Question{"ipv4only.google.com.", dns.TypeAAAA, dns.ClassINET}
c := new(dns.Client)
in, _, err := c.Exchange(m1, "127.0.0.1:"+strconv.Itoa(int(serverPort)))
common.Must(err)
if len(in.Answer) != 0 {
t.Fatal("len(answer): ", len(in.Answer))
}
}
if r := cmp.Diff(rr.A[:], net.IP{8, 8, 8, 8}); r != "" {
t.Error(r)
{
m1 := new(dns.Msg)
m1.Id = dns.Id()
m1.RecursionDesired = true
m1.Question = make([]dns.Question, 1)
m1.Question[0] = dns.Question{"notexist.google.com.", dns.TypeAAAA, dns.ClassINET}
c := new(dns.Client)
in, _, err := c.Exchange(m1, "127.0.0.1:"+strconv.Itoa(int(serverPort)))
common.Must(err)
if in.Rcode != dns.RcodeNameError {
t.Error("expected NameError, but got ", in.Rcode)
}
}
}