mirror of
https://github.com/v2fly/v2ray-core.git
synced 2025-01-17 14:57:44 -05:00
correctly propagate dns errors all the way through.
the internal dns system can correctly handle the cases where: 1) domain has no A or AAAA records 2) domain doesn't exist fixes #1565
This commit is contained in:
parent
c27050ad90
commit
9957c64b4a
@ -226,6 +226,9 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
|
||||
if len(ips) > 0 {
|
||||
return ips, nil
|
||||
}
|
||||
if err == dns.ErrEmptyResponse {
|
||||
return nil, err
|
||||
}
|
||||
if err != nil {
|
||||
newError("failed to lookup ip for domain ", domain, " at server ", ns.Name()).Base(err).WriteToLog()
|
||||
lastErr = err
|
||||
@ -238,6 +241,9 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
|
||||
if len(ips) > 0 {
|
||||
return ips, nil
|
||||
}
|
||||
if err == dns.ErrEmptyResponse {
|
||||
return nil, err
|
||||
}
|
||||
if err != nil {
|
||||
newError("failed to lookup ip for domain ", domain, " at server ", client.Name()).Base(err).WriteToLog()
|
||||
lastErr = err
|
||||
|
@ -60,6 +60,8 @@ func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
rr, err := dns.NewRR("ipv6.google.com. IN AAAA 2001:4860:4860::8888")
|
||||
common.Must(err)
|
||||
ans.Answer = append(ans.Answer, rr)
|
||||
} else if q.Name == "notexist.google.com." && q.Qtype == dns.TypeAAAA {
|
||||
ans.MsgHdr.Rcode = dns.RcodeNameError
|
||||
}
|
||||
}
|
||||
w.WriteMsg(ans)
|
||||
@ -186,6 +188,27 @@ func TestUDPServer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
_, err := client.LookupIP("notexist.google.com")
|
||||
if err == nil {
|
||||
t.Fatal("nil error")
|
||||
}
|
||||
if r := feature_dns.RCodeFromError(err); r != uint16(dns.RcodeNameError) {
|
||||
t.Fatal("expected NameError, but got ", r)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
clientv6 := client.(feature_dns.IPv6Lookup)
|
||||
ips, err := clientv6.LookupIPv6("ipv4only.google.com")
|
||||
if err != feature_dns.ErrEmptyResponse {
|
||||
t.Fatal("error: ", err)
|
||||
}
|
||||
if len(ips) != 0 {
|
||||
t.Fatal("ips: ", ips)
|
||||
}
|
||||
}
|
||||
|
||||
dnsServer.Shutdown()
|
||||
|
||||
{
|
||||
|
177
app/dns/udpns.go
177
app/dns/udpns.go
@ -5,36 +5,60 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
fmt "fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"v2ray.com/core/common"
|
||||
"v2ray.com/core/common/errors"
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/common/protocol/dns"
|
||||
udp_proto "v2ray.com/core/common/protocol/udp"
|
||||
"v2ray.com/core/common/session"
|
||||
"v2ray.com/core/common/signal/pubsub"
|
||||
"v2ray.com/core/common/task"
|
||||
dns_feature "v2ray.com/core/features/dns"
|
||||
"v2ray.com/core/features/routing"
|
||||
"v2ray.com/core/transport/internet/udp"
|
||||
)
|
||||
|
||||
type record struct {
|
||||
A *IPRecord
|
||||
AAAA *IPRecord
|
||||
}
|
||||
|
||||
type IPRecord struct {
|
||||
IP net.Address
|
||||
IP []net.Address
|
||||
Expire time.Time
|
||||
RCode dnsmessage.RCode
|
||||
}
|
||||
|
||||
func (r *IPRecord) getIPs() ([]net.Address, error) {
|
||||
if r == nil || r.Expire.Before(time.Now()) {
|
||||
return nil, errRecordNotFound
|
||||
}
|
||||
if r.RCode != dnsmessage.RCodeSuccess {
|
||||
return nil, dns_feature.RCodeError(r.RCode)
|
||||
}
|
||||
return r.IP, nil
|
||||
}
|
||||
|
||||
type pendingRequest struct {
|
||||
domain string
|
||||
expire time.Time
|
||||
recType dnsmessage.Type
|
||||
}
|
||||
|
||||
var (
|
||||
errRecordNotFound = errors.New("record not found")
|
||||
)
|
||||
|
||||
type ClassicNameServer struct {
|
||||
sync.RWMutex
|
||||
address net.Destination
|
||||
ips map[string][]IPRecord
|
||||
ips map[string]record
|
||||
requests map[uint16]pendingRequest
|
||||
pub *pubsub.Service
|
||||
udpServer *udp.Dispatcher
|
||||
@ -46,7 +70,7 @@ type ClassicNameServer struct {
|
||||
func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, clientIP net.IP) *ClassicNameServer {
|
||||
s := &ClassicNameServer{
|
||||
address: address,
|
||||
ips: make(map[string][]IPRecord),
|
||||
ips: make(map[string]record),
|
||||
requests: make(map[uint16]pendingRequest),
|
||||
clientIP: clientIP,
|
||||
pub: pubsub.NewService(),
|
||||
@ -72,22 +96,23 @@ func (s *ClassicNameServer) Cleanup() error {
|
||||
return newError("nothing to do. stopping...")
|
||||
}
|
||||
|
||||
for domain, ips := range s.ips {
|
||||
newIPs := make([]IPRecord, 0, len(ips))
|
||||
for _, ip := range ips {
|
||||
if ip.Expire.After(now) {
|
||||
newIPs = append(newIPs, ip)
|
||||
for domain, record := range s.ips {
|
||||
if record.A != nil && record.A.Expire.Before(now) {
|
||||
record.A = nil
|
||||
}
|
||||
if record.AAAA != nil && record.AAAA.Expire.Before(now) {
|
||||
record.AAAA = nil
|
||||
}
|
||||
if len(newIPs) == 0 {
|
||||
|
||||
if record.A == nil && record.AAAA == nil {
|
||||
delete(s.ips, domain)
|
||||
} else if len(newIPs) < len(ips) {
|
||||
s.ips[domain] = newIPs
|
||||
} else {
|
||||
s.ips[domain] = record
|
||||
}
|
||||
}
|
||||
|
||||
if len(s.ips) == 0 {
|
||||
s.ips = make(map[string][]IPRecord)
|
||||
s.ips = make(map[string]record)
|
||||
}
|
||||
|
||||
for id, req := range s.requests {
|
||||
@ -130,9 +155,14 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
|
||||
}
|
||||
|
||||
domain := req.domain
|
||||
ips := make([]IPRecord, 0, 16)
|
||||
recType := req.recType
|
||||
|
||||
now := time.Now()
|
||||
ipRecord := &IPRecord{
|
||||
RCode: header.RCode,
|
||||
Expire: now.Add(time.Second * 600),
|
||||
}
|
||||
|
||||
for {
|
||||
header, err := parser.AnswerHeader()
|
||||
if err != nil {
|
||||
@ -145,6 +175,15 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
|
||||
if ttl == 0 {
|
||||
ttl = 600
|
||||
}
|
||||
expire := now.Add(time.Duration(ttl) * time.Second)
|
||||
if ipRecord.Expire.After(expire) {
|
||||
ipRecord.Expire = expire
|
||||
}
|
||||
|
||||
if header.Type != recType {
|
||||
continue
|
||||
}
|
||||
|
||||
switch header.Type {
|
||||
case dnsmessage.TypeA:
|
||||
ans, err := parser.AResource()
|
||||
@ -152,20 +191,14 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
|
||||
newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
|
||||
break
|
||||
}
|
||||
ips = append(ips, IPRecord{
|
||||
IP: net.IPAddress(ans.A[:]),
|
||||
Expire: now.Add(time.Duration(ttl) * time.Second),
|
||||
})
|
||||
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
|
||||
case dnsmessage.TypeAAAA:
|
||||
ans, err := parser.AAAAResource()
|
||||
if err != nil {
|
||||
newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
|
||||
break
|
||||
}
|
||||
ips = append(ips, IPRecord{
|
||||
IP: net.IPAddress(ans.AAAA[:]),
|
||||
Expire: now.Add(time.Duration(ttl) * time.Second),
|
||||
})
|
||||
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
|
||||
default:
|
||||
if err := parser.SkipAnswer(); err != nil {
|
||||
newError("failed to skip answer").Base(err).WriteToLog()
|
||||
@ -173,24 +206,49 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
|
||||
}
|
||||
}
|
||||
|
||||
if len(domain) > 0 && len(ips) > 0 {
|
||||
s.updateIP(domain, ips)
|
||||
var rec record
|
||||
switch recType {
|
||||
case dnsmessage.TypeA:
|
||||
rec.A = ipRecord
|
||||
case dnsmessage.TypeAAAA:
|
||||
rec.AAAA = ipRecord
|
||||
}
|
||||
|
||||
if len(domain) > 0 && (rec.A != nil || rec.AAAA != nil) {
|
||||
s.updateIP(domain, rec)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClassicNameServer) updateIP(domain string, ips []IPRecord) {
|
||||
func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
|
||||
if newRec == nil {
|
||||
return false
|
||||
}
|
||||
if baseRec == nil {
|
||||
return true
|
||||
}
|
||||
return baseRec.Expire.Before(newRec.Expire)
|
||||
}
|
||||
|
||||
func (s *ClassicNameServer) updateIP(domain string, newRec record) {
|
||||
s.Lock()
|
||||
|
||||
newError("updating IP records for domain:", domain).AtDebug().WriteToLog()
|
||||
now := time.Now()
|
||||
eips := s.ips[domain]
|
||||
for _, ip := range eips {
|
||||
if ip.Expire.After(now) {
|
||||
ips = append(ips, ip)
|
||||
rec := s.ips[domain]
|
||||
|
||||
updated := false
|
||||
if isNewer(rec.A, newRec.A) {
|
||||
rec.A = newRec.A
|
||||
updated = true
|
||||
}
|
||||
if isNewer(rec.AAAA, newRec.AAAA) {
|
||||
rec.AAAA = newRec.AAAA
|
||||
updated = true
|
||||
}
|
||||
s.ips[domain] = ips
|
||||
|
||||
if updated {
|
||||
s.ips[domain] = rec
|
||||
s.pub.Publish(domain, nil)
|
||||
}
|
||||
|
||||
s.Unlock()
|
||||
common.Must(s.cleanup.Start())
|
||||
@ -244,7 +302,7 @@ func (s *ClassicNameServer) getMsgOptions() *dnsmessage.Resource {
|
||||
return opt
|
||||
}
|
||||
|
||||
func (s *ClassicNameServer) addPendingRequest(domain string) uint16 {
|
||||
func (s *ClassicNameServer) addPendingRequest(domain string, recType dnsmessage.Type) uint16 {
|
||||
id := uint16(atomic.AddUint32(&s.reqID, 1))
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
@ -252,6 +310,7 @@ func (s *ClassicNameServer) addPendingRequest(domain string) uint16 {
|
||||
s.requests[id] = pendingRequest{
|
||||
domain: domain,
|
||||
expire: time.Now().Add(time.Second * 8),
|
||||
recType: recType,
|
||||
}
|
||||
|
||||
return id
|
||||
@ -274,7 +333,7 @@ func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmess
|
||||
|
||||
if option.IPv4Enable {
|
||||
msg := new(dnsmessage.Message)
|
||||
msg.Header.ID = s.addPendingRequest(domain)
|
||||
msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeA)
|
||||
msg.Header.RecursionDesired = true
|
||||
msg.Questions = []dnsmessage.Question{qA}
|
||||
if opt := s.getMsgOptions(); opt != nil {
|
||||
@ -285,7 +344,7 @@ func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmess
|
||||
|
||||
if option.IPv6Enable {
|
||||
msg := new(dnsmessage.Message)
|
||||
msg.Header.ID = s.addPendingRequest(domain)
|
||||
msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeAAAA)
|
||||
msg.Header.RecursionDesired = true
|
||||
msg.Questions = []dnsmessage.Question{qAAAA}
|
||||
if opt := s.getMsgOptions(); opt != nil {
|
||||
@ -313,22 +372,44 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, option
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) []net.IP {
|
||||
func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) ([]net.IP, error) {
|
||||
s.RLock()
|
||||
records, found := s.ips[domain]
|
||||
record, found := s.ips[domain]
|
||||
s.RUnlock()
|
||||
|
||||
if found && len(records) > 0 {
|
||||
if !found {
|
||||
return nil, errRecordNotFound
|
||||
}
|
||||
|
||||
var ips []net.Address
|
||||
now := time.Now()
|
||||
for _, rec := range records {
|
||||
if rec.Expire.After(now) {
|
||||
ips = append(ips, rec.IP)
|
||||
var lastErr error
|
||||
if option.IPv4Enable {
|
||||
a, err := record.A.getIPs()
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
}
|
||||
ips = append(ips, a...)
|
||||
}
|
||||
return toNetIP(filterIP(ips, option))
|
||||
|
||||
if option.IPv6Enable {
|
||||
aaaa, err := record.AAAA.getIPs()
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
}
|
||||
return nil
|
||||
ips = append(ips, aaaa...)
|
||||
}
|
||||
|
||||
fmt.Println("IPs for ", domain, ": ", ips)
|
||||
|
||||
if len(ips) > 0 {
|
||||
return toNetIP(ips), nil
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
return nil, dns_feature.ErrEmptyResponse
|
||||
}
|
||||
|
||||
func Fqdn(domain string) string {
|
||||
@ -341,9 +422,9 @@ func Fqdn(domain string) string {
|
||||
func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) {
|
||||
fqdn := Fqdn(domain)
|
||||
|
||||
ips := s.findIPsForDomain(fqdn, option)
|
||||
if len(ips) > 0 {
|
||||
return ips, nil
|
||||
ips, err := s.findIPsForDomain(fqdn, option)
|
||||
if err != errRecordNotFound {
|
||||
return ips, err
|
||||
}
|
||||
|
||||
sub := s.pub.Subscribe(fqdn)
|
||||
@ -352,9 +433,9 @@ func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option I
|
||||
s.sendQuery(ctx, fqdn, option)
|
||||
|
||||
for {
|
||||
ips := s.findIPsForDomain(fqdn, option)
|
||||
if len(ips) > 0 {
|
||||
return ips, nil
|
||||
ips, err := s.findIPsForDomain(fqdn, option)
|
||||
if err != errRecordNotFound {
|
||||
return ips, err
|
||||
}
|
||||
|
||||
select {
|
||||
|
@ -1,7 +1,9 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"v2ray.com/core/common/errors"
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/common/serial"
|
||||
"v2ray.com/core/features"
|
||||
)
|
||||
|
||||
@ -35,3 +37,23 @@ type IPv6Lookup interface {
|
||||
func ClientType() interface{} {
|
||||
return (*Client)(nil)
|
||||
}
|
||||
|
||||
// ErrEmptyResponse indicates that DNS query succeeded but no answer was returned.
|
||||
var ErrEmptyResponse = errors.New("empty response")
|
||||
|
||||
type RCodeError uint16
|
||||
|
||||
func (e RCodeError) Error() string {
|
||||
return serial.Concat("rcode: ", uint16(e))
|
||||
}
|
||||
|
||||
func RCodeFromError(err error) uint16 {
|
||||
if err == nil {
|
||||
return 0
|
||||
}
|
||||
cause := errors.Cause(err)
|
||||
if r, ok := cause.(RCodeError); ok {
|
||||
return uint16(r)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
@ -32,6 +32,9 @@ func (*Client) LookupIP(host string) ([]net.IP, error) {
|
||||
parsedIPs = append(parsedIPs, parsed.IP())
|
||||
}
|
||||
}
|
||||
if len(parsedIPs) == 0 {
|
||||
return nil, dns.ErrEmptyResponse
|
||||
}
|
||||
return parsedIPs, nil
|
||||
}
|
||||
|
||||
@ -47,6 +50,9 @@ func (c *Client) LookupIPv4(host string) ([]net.IP, error) {
|
||||
ipv4 = append(ipv4, ip)
|
||||
}
|
||||
}
|
||||
if len(ipv4) == 0 {
|
||||
return nil, dns.ErrEmptyResponse
|
||||
}
|
||||
return ipv4, nil
|
||||
}
|
||||
|
||||
@ -62,6 +68,9 @@ func (c *Client) LookupIPv6(host string) ([]net.IP, error) {
|
||||
ipv6 = append(ipv6, ip)
|
||||
}
|
||||
}
|
||||
if len(ipv6) == 0 {
|
||||
return nil, dns.ErrEmptyResponse
|
||||
}
|
||||
return ipv6, nil
|
||||
}
|
||||
|
||||
|
@ -218,20 +218,17 @@ func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string,
|
||||
ips, err = h.ipv6Lookup.LookupIPv6(domain)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
rcode := dns.RCodeFromError(err)
|
||||
if rcode == 0 && len(ips) == 0 && err != dns.ErrEmptyResponse {
|
||||
newError("ip query").Base(err).WriteToLog()
|
||||
return
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
b := buf.New()
|
||||
rawBytes := b.Extend(buf.Size)
|
||||
builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{
|
||||
ID: id,
|
||||
RCode: dnsmessage.RCodeSuccess,
|
||||
RCode: dnsmessage.RCode(rcode),
|
||||
RecursionAvailable: true,
|
||||
RecursionDesired: true,
|
||||
Response: true,
|
||||
|
@ -63,6 +63,8 @@ func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
rr, err := dns.NewRR("ipv6.google.com. IN AAAA 2001:4860:4860::8888")
|
||||
common.Must(err)
|
||||
ans.Answer = append(ans.Answer, rr)
|
||||
} else if q.Name == "notexist.google.com." && q.Qtype == dns.TypeAAAA {
|
||||
ans.MsgHdr.Rcode = dns.RcodeNameError
|
||||
}
|
||||
}
|
||||
w.WriteMsg(ans)
|
||||
@ -128,6 +130,7 @@ func TestUDPDNSTunnel(t *testing.T) {
|
||||
common.Must(v.Start())
|
||||
defer v.Close()
|
||||
|
||||
{
|
||||
m1 := new(dns.Msg)
|
||||
m1.Id = dns.Id()
|
||||
m1.RecursionDesired = true
|
||||
@ -149,6 +152,39 @@ func TestUDPDNSTunnel(t *testing.T) {
|
||||
if r := cmp.Diff(rr.A[:], net.IP{8, 8, 8, 8}); r != "" {
|
||||
t.Error(r)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
m1 := new(dns.Msg)
|
||||
m1.Id = dns.Id()
|
||||
m1.RecursionDesired = true
|
||||
m1.Question = make([]dns.Question, 1)
|
||||
m1.Question[0] = dns.Question{"ipv4only.google.com.", dns.TypeAAAA, dns.ClassINET}
|
||||
|
||||
c := new(dns.Client)
|
||||
in, _, err := c.Exchange(m1, "127.0.0.1:"+strconv.Itoa(int(serverPort)))
|
||||
common.Must(err)
|
||||
|
||||
if len(in.Answer) != 0 {
|
||||
t.Fatal("len(answer): ", len(in.Answer))
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
m1 := new(dns.Msg)
|
||||
m1.Id = dns.Id()
|
||||
m1.RecursionDesired = true
|
||||
m1.Question = make([]dns.Question, 1)
|
||||
m1.Question[0] = dns.Question{"notexist.google.com.", dns.TypeAAAA, dns.ClassINET}
|
||||
|
||||
c := new(dns.Client)
|
||||
in, _, err := c.Exchange(m1, "127.0.0.1:"+strconv.Itoa(int(serverPort)))
|
||||
common.Must(err)
|
||||
|
||||
if in.Rcode != dns.RcodeNameError {
|
||||
t.Error("expected NameError, but got ", in.Rcode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTCPDNSTunnel(t *testing.T) {
|
||||
|
Loading…
Reference in New Issue
Block a user