1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2025-01-04 16:37:12 -05:00

correctly propagate dns errors all the way through.

the internal dns system can correctly handle the cases where:
1) domain has no A or AAAA records
2) domain doesn't exist
fixes #1565
This commit is contained in:
Darien Raymond 2019-02-21 13:43:48 +01:00
parent c27050ad90
commit 9957c64b4a
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
7 changed files with 253 additions and 79 deletions

View File

@ -226,6 +226,9 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
if len(ips) > 0 { if len(ips) > 0 {
return ips, nil return ips, nil
} }
if err == dns.ErrEmptyResponse {
return nil, err
}
if err != nil { if err != nil {
newError("failed to lookup ip for domain ", domain, " at server ", ns.Name()).Base(err).WriteToLog() newError("failed to lookup ip for domain ", domain, " at server ", ns.Name()).Base(err).WriteToLog()
lastErr = err lastErr = err
@ -238,6 +241,9 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
if len(ips) > 0 { if len(ips) > 0 {
return ips, nil return ips, nil
} }
if err == dns.ErrEmptyResponse {
return nil, err
}
if err != nil { if err != nil {
newError("failed to lookup ip for domain ", domain, " at server ", client.Name()).Base(err).WriteToLog() newError("failed to lookup ip for domain ", domain, " at server ", client.Name()).Base(err).WriteToLog()
lastErr = err lastErr = err

View File

@ -60,6 +60,8 @@ func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
rr, err := dns.NewRR("ipv6.google.com. IN AAAA 2001:4860:4860::8888") rr, err := dns.NewRR("ipv6.google.com. IN AAAA 2001:4860:4860::8888")
common.Must(err) common.Must(err)
ans.Answer = append(ans.Answer, rr) ans.Answer = append(ans.Answer, rr)
} else if q.Name == "notexist.google.com." && q.Qtype == dns.TypeAAAA {
ans.MsgHdr.Rcode = dns.RcodeNameError
} }
} }
w.WriteMsg(ans) w.WriteMsg(ans)
@ -186,6 +188,27 @@ func TestUDPServer(t *testing.T) {
} }
} }
{
_, err := client.LookupIP("notexist.google.com")
if err == nil {
t.Fatal("nil error")
}
if r := feature_dns.RCodeFromError(err); r != uint16(dns.RcodeNameError) {
t.Fatal("expected NameError, but got ", r)
}
}
{
clientv6 := client.(feature_dns.IPv6Lookup)
ips, err := clientv6.LookupIPv6("ipv4only.google.com")
if err != feature_dns.ErrEmptyResponse {
t.Fatal("error: ", err)
}
if len(ips) != 0 {
t.Fatal("ips: ", ips)
}
}
dnsServer.Shutdown() dnsServer.Shutdown()
{ {

View File

@ -5,36 +5,60 @@ package dns
import ( import (
"context" "context"
"encoding/binary" "encoding/binary"
fmt "fmt"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"golang.org/x/net/dns/dnsmessage" "golang.org/x/net/dns/dnsmessage"
"v2ray.com/core/common" "v2ray.com/core/common"
"v2ray.com/core/common/errors"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/protocol/dns" "v2ray.com/core/common/protocol/dns"
udp_proto "v2ray.com/core/common/protocol/udp" udp_proto "v2ray.com/core/common/protocol/udp"
"v2ray.com/core/common/session" "v2ray.com/core/common/session"
"v2ray.com/core/common/signal/pubsub" "v2ray.com/core/common/signal/pubsub"
"v2ray.com/core/common/task" "v2ray.com/core/common/task"
dns_feature "v2ray.com/core/features/dns"
"v2ray.com/core/features/routing" "v2ray.com/core/features/routing"
"v2ray.com/core/transport/internet/udp" "v2ray.com/core/transport/internet/udp"
) )
type record struct {
A *IPRecord
AAAA *IPRecord
}
type IPRecord struct { type IPRecord struct {
IP net.Address IP []net.Address
Expire time.Time Expire time.Time
RCode dnsmessage.RCode
}
func (r *IPRecord) getIPs() ([]net.Address, error) {
if r == nil || r.Expire.Before(time.Now()) {
return nil, errRecordNotFound
}
if r.RCode != dnsmessage.RCodeSuccess {
return nil, dns_feature.RCodeError(r.RCode)
}
return r.IP, nil
} }
type pendingRequest struct { type pendingRequest struct {
domain string domain string
expire time.Time expire time.Time
recType dnsmessage.Type
} }
var (
errRecordNotFound = errors.New("record not found")
)
type ClassicNameServer struct { type ClassicNameServer struct {
sync.RWMutex sync.RWMutex
address net.Destination address net.Destination
ips map[string][]IPRecord ips map[string]record
requests map[uint16]pendingRequest requests map[uint16]pendingRequest
pub *pubsub.Service pub *pubsub.Service
udpServer *udp.Dispatcher udpServer *udp.Dispatcher
@ -46,7 +70,7 @@ type ClassicNameServer struct {
func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, clientIP net.IP) *ClassicNameServer { func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, clientIP net.IP) *ClassicNameServer {
s := &ClassicNameServer{ s := &ClassicNameServer{
address: address, address: address,
ips: make(map[string][]IPRecord), ips: make(map[string]record),
requests: make(map[uint16]pendingRequest), requests: make(map[uint16]pendingRequest),
clientIP: clientIP, clientIP: clientIP,
pub: pubsub.NewService(), pub: pubsub.NewService(),
@ -72,22 +96,23 @@ func (s *ClassicNameServer) Cleanup() error {
return newError("nothing to do. stopping...") return newError("nothing to do. stopping...")
} }
for domain, ips := range s.ips { for domain, record := range s.ips {
newIPs := make([]IPRecord, 0, len(ips)) if record.A != nil && record.A.Expire.Before(now) {
for _, ip := range ips { record.A = nil
if ip.Expire.After(now) {
newIPs = append(newIPs, ip)
} }
if record.AAAA != nil && record.AAAA.Expire.Before(now) {
record.AAAA = nil
} }
if len(newIPs) == 0 {
if record.A == nil && record.AAAA == nil {
delete(s.ips, domain) delete(s.ips, domain)
} else if len(newIPs) < len(ips) { } else {
s.ips[domain] = newIPs s.ips[domain] = record
} }
} }
if len(s.ips) == 0 { if len(s.ips) == 0 {
s.ips = make(map[string][]IPRecord) s.ips = make(map[string]record)
} }
for id, req := range s.requests { for id, req := range s.requests {
@ -130,9 +155,14 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
} }
domain := req.domain domain := req.domain
ips := make([]IPRecord, 0, 16) recType := req.recType
now := time.Now() now := time.Now()
ipRecord := &IPRecord{
RCode: header.RCode,
Expire: now.Add(time.Second * 600),
}
for { for {
header, err := parser.AnswerHeader() header, err := parser.AnswerHeader()
if err != nil { if err != nil {
@ -145,6 +175,15 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
if ttl == 0 { if ttl == 0 {
ttl = 600 ttl = 600
} }
expire := now.Add(time.Duration(ttl) * time.Second)
if ipRecord.Expire.After(expire) {
ipRecord.Expire = expire
}
if header.Type != recType {
continue
}
switch header.Type { switch header.Type {
case dnsmessage.TypeA: case dnsmessage.TypeA:
ans, err := parser.AResource() ans, err := parser.AResource()
@ -152,20 +191,14 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog() newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
break break
} }
ips = append(ips, IPRecord{ ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
IP: net.IPAddress(ans.A[:]),
Expire: now.Add(time.Duration(ttl) * time.Second),
})
case dnsmessage.TypeAAAA: case dnsmessage.TypeAAAA:
ans, err := parser.AAAAResource() ans, err := parser.AAAAResource()
if err != nil { if err != nil {
newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog() newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
break break
} }
ips = append(ips, IPRecord{ ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
IP: net.IPAddress(ans.AAAA[:]),
Expire: now.Add(time.Duration(ttl) * time.Second),
})
default: default:
if err := parser.SkipAnswer(); err != nil { if err := parser.SkipAnswer(); err != nil {
newError("failed to skip answer").Base(err).WriteToLog() newError("failed to skip answer").Base(err).WriteToLog()
@ -173,24 +206,49 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
} }
} }
if len(domain) > 0 && len(ips) > 0 { var rec record
s.updateIP(domain, ips) switch recType {
case dnsmessage.TypeA:
rec.A = ipRecord
case dnsmessage.TypeAAAA:
rec.AAAA = ipRecord
}
if len(domain) > 0 && (rec.A != nil || rec.AAAA != nil) {
s.updateIP(domain, rec)
} }
} }
func (s *ClassicNameServer) updateIP(domain string, ips []IPRecord) { func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
if newRec == nil {
return false
}
if baseRec == nil {
return true
}
return baseRec.Expire.Before(newRec.Expire)
}
func (s *ClassicNameServer) updateIP(domain string, newRec record) {
s.Lock() s.Lock()
newError("updating IP records for domain:", domain).AtDebug().WriteToLog() newError("updating IP records for domain:", domain).AtDebug().WriteToLog()
now := time.Now() rec := s.ips[domain]
eips := s.ips[domain]
for _, ip := range eips { updated := false
if ip.Expire.After(now) { if isNewer(rec.A, newRec.A) {
ips = append(ips, ip) rec.A = newRec.A
updated = true
} }
if isNewer(rec.AAAA, newRec.AAAA) {
rec.AAAA = newRec.AAAA
updated = true
} }
s.ips[domain] = ips
if updated {
s.ips[domain] = rec
s.pub.Publish(domain, nil) s.pub.Publish(domain, nil)
}
s.Unlock() s.Unlock()
common.Must(s.cleanup.Start()) common.Must(s.cleanup.Start())
@ -244,7 +302,7 @@ func (s *ClassicNameServer) getMsgOptions() *dnsmessage.Resource {
return opt return opt
} }
func (s *ClassicNameServer) addPendingRequest(domain string) uint16 { func (s *ClassicNameServer) addPendingRequest(domain string, recType dnsmessage.Type) uint16 {
id := uint16(atomic.AddUint32(&s.reqID, 1)) id := uint16(atomic.AddUint32(&s.reqID, 1))
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
@ -252,6 +310,7 @@ func (s *ClassicNameServer) addPendingRequest(domain string) uint16 {
s.requests[id] = pendingRequest{ s.requests[id] = pendingRequest{
domain: domain, domain: domain,
expire: time.Now().Add(time.Second * 8), expire: time.Now().Add(time.Second * 8),
recType: recType,
} }
return id return id
@ -274,7 +333,7 @@ func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmess
if option.IPv4Enable { if option.IPv4Enable {
msg := new(dnsmessage.Message) msg := new(dnsmessage.Message)
msg.Header.ID = s.addPendingRequest(domain) msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeA)
msg.Header.RecursionDesired = true msg.Header.RecursionDesired = true
msg.Questions = []dnsmessage.Question{qA} msg.Questions = []dnsmessage.Question{qA}
if opt := s.getMsgOptions(); opt != nil { if opt := s.getMsgOptions(); opt != nil {
@ -285,7 +344,7 @@ func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmess
if option.IPv6Enable { if option.IPv6Enable {
msg := new(dnsmessage.Message) msg := new(dnsmessage.Message)
msg.Header.ID = s.addPendingRequest(domain) msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeAAAA)
msg.Header.RecursionDesired = true msg.Header.RecursionDesired = true
msg.Questions = []dnsmessage.Question{qAAAA} msg.Questions = []dnsmessage.Question{qAAAA}
if opt := s.getMsgOptions(); opt != nil { if opt := s.getMsgOptions(); opt != nil {
@ -313,22 +372,44 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, option
} }
} }
func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) []net.IP { func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) ([]net.IP, error) {
s.RLock() s.RLock()
records, found := s.ips[domain] record, found := s.ips[domain]
s.RUnlock() s.RUnlock()
if found && len(records) > 0 { if !found {
return nil, errRecordNotFound
}
var ips []net.Address var ips []net.Address
now := time.Now() var lastErr error
for _, rec := range records { if option.IPv4Enable {
if rec.Expire.After(now) { a, err := record.A.getIPs()
ips = append(ips, rec.IP) if err != nil {
lastErr = err
} }
ips = append(ips, a...)
} }
return toNetIP(filterIP(ips, option))
if option.IPv6Enable {
aaaa, err := record.AAAA.getIPs()
if err != nil {
lastErr = err
} }
return nil ips = append(ips, aaaa...)
}
fmt.Println("IPs for ", domain, ": ", ips)
if len(ips) > 0 {
return toNetIP(ips), nil
}
if lastErr != nil {
return nil, lastErr
}
return nil, dns_feature.ErrEmptyResponse
} }
func Fqdn(domain string) string { func Fqdn(domain string) string {
@ -341,9 +422,9 @@ func Fqdn(domain string) string {
func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) { func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) {
fqdn := Fqdn(domain) fqdn := Fqdn(domain)
ips := s.findIPsForDomain(fqdn, option) ips, err := s.findIPsForDomain(fqdn, option)
if len(ips) > 0 { if err != errRecordNotFound {
return ips, nil return ips, err
} }
sub := s.pub.Subscribe(fqdn) sub := s.pub.Subscribe(fqdn)
@ -352,9 +433,9 @@ func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option I
s.sendQuery(ctx, fqdn, option) s.sendQuery(ctx, fqdn, option)
for { for {
ips := s.findIPsForDomain(fqdn, option) ips, err := s.findIPsForDomain(fqdn, option)
if len(ips) > 0 { if err != errRecordNotFound {
return ips, nil return ips, err
} }
select { select {

View File

@ -1,7 +1,9 @@
package dns package dns
import ( import (
"v2ray.com/core/common/errors"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/serial"
"v2ray.com/core/features" "v2ray.com/core/features"
) )
@ -35,3 +37,23 @@ type IPv6Lookup interface {
func ClientType() interface{} { func ClientType() interface{} {
return (*Client)(nil) return (*Client)(nil)
} }
// ErrEmptyResponse indicates that DNS query succeeded but no answer was returned.
var ErrEmptyResponse = errors.New("empty response")
type RCodeError uint16
func (e RCodeError) Error() string {
return serial.Concat("rcode: ", uint16(e))
}
func RCodeFromError(err error) uint16 {
if err == nil {
return 0
}
cause := errors.Cause(err)
if r, ok := cause.(RCodeError); ok {
return uint16(r)
}
return 0
}

View File

@ -32,6 +32,9 @@ func (*Client) LookupIP(host string) ([]net.IP, error) {
parsedIPs = append(parsedIPs, parsed.IP()) parsedIPs = append(parsedIPs, parsed.IP())
} }
} }
if len(parsedIPs) == 0 {
return nil, dns.ErrEmptyResponse
}
return parsedIPs, nil return parsedIPs, nil
} }
@ -47,6 +50,9 @@ func (c *Client) LookupIPv4(host string) ([]net.IP, error) {
ipv4 = append(ipv4, ip) ipv4 = append(ipv4, ip)
} }
} }
if len(ipv4) == 0 {
return nil, dns.ErrEmptyResponse
}
return ipv4, nil return ipv4, nil
} }
@ -62,6 +68,9 @@ func (c *Client) LookupIPv6(host string) ([]net.IP, error) {
ipv6 = append(ipv6, ip) ipv6 = append(ipv6, ip)
} }
} }
if len(ipv6) == 0 {
return nil, dns.ErrEmptyResponse
}
return ipv6, nil return ipv6, nil
} }

View File

@ -218,20 +218,17 @@ func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string,
ips, err = h.ipv6Lookup.LookupIPv6(domain) ips, err = h.ipv6Lookup.LookupIPv6(domain)
} }
if err != nil { rcode := dns.RCodeFromError(err)
if rcode == 0 && len(ips) == 0 && err != dns.ErrEmptyResponse {
newError("ip query").Base(err).WriteToLog() newError("ip query").Base(err).WriteToLog()
return return
} }
if len(ips) == 0 {
return
}
b := buf.New() b := buf.New()
rawBytes := b.Extend(buf.Size) rawBytes := b.Extend(buf.Size)
builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{ builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{
ID: id, ID: id,
RCode: dnsmessage.RCodeSuccess, RCode: dnsmessage.RCode(rcode),
RecursionAvailable: true, RecursionAvailable: true,
RecursionDesired: true, RecursionDesired: true,
Response: true, Response: true,

View File

@ -63,6 +63,8 @@ func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
rr, err := dns.NewRR("ipv6.google.com. IN AAAA 2001:4860:4860::8888") rr, err := dns.NewRR("ipv6.google.com. IN AAAA 2001:4860:4860::8888")
common.Must(err) common.Must(err)
ans.Answer = append(ans.Answer, rr) ans.Answer = append(ans.Answer, rr)
} else if q.Name == "notexist.google.com." && q.Qtype == dns.TypeAAAA {
ans.MsgHdr.Rcode = dns.RcodeNameError
} }
} }
w.WriteMsg(ans) w.WriteMsg(ans)
@ -128,6 +130,7 @@ func TestUDPDNSTunnel(t *testing.T) {
common.Must(v.Start()) common.Must(v.Start())
defer v.Close() defer v.Close()
{
m1 := new(dns.Msg) m1 := new(dns.Msg)
m1.Id = dns.Id() m1.Id = dns.Id()
m1.RecursionDesired = true m1.RecursionDesired = true
@ -151,6 +154,39 @@ func TestUDPDNSTunnel(t *testing.T) {
} }
} }
{
m1 := new(dns.Msg)
m1.Id = dns.Id()
m1.RecursionDesired = true
m1.Question = make([]dns.Question, 1)
m1.Question[0] = dns.Question{"ipv4only.google.com.", dns.TypeAAAA, dns.ClassINET}
c := new(dns.Client)
in, _, err := c.Exchange(m1, "127.0.0.1:"+strconv.Itoa(int(serverPort)))
common.Must(err)
if len(in.Answer) != 0 {
t.Fatal("len(answer): ", len(in.Answer))
}
}
{
m1 := new(dns.Msg)
m1.Id = dns.Id()
m1.RecursionDesired = true
m1.Question = make([]dns.Question, 1)
m1.Question[0] = dns.Question{"notexist.google.com.", dns.TypeAAAA, dns.ClassINET}
c := new(dns.Client)
in, _, err := c.Exchange(m1, "127.0.0.1:"+strconv.Itoa(int(serverPort)))
common.Must(err)
if in.Rcode != dns.RcodeNameError {
t.Error("expected NameError, but got ", in.Rcode)
}
}
}
func TestTCPDNSTunnel(t *testing.T) { func TestTCPDNSTunnel(t *testing.T) {
port := udp.PickPort() port := udp.PickPort()