1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-12-21 17:46:58 -05:00

Revert "DNS: fix typo & refine code (#1183)"

This reverts commit 73470e8dd8.
This commit is contained in:
Shelikhoo 2021-09-04 11:49:10 +01:00
parent 9b0f8b7747
commit becbc3a3e2
No known key found for this signature in database
GPG Key ID: C4D5E79D22B25316
4 changed files with 101 additions and 106 deletions

View File

@ -33,7 +33,7 @@ import (
// thus most of the DOH implementation is copied from udpns.go // thus most of the DOH implementation is copied from udpns.go
type DoHNameServer struct { type DoHNameServer struct {
sync.RWMutex sync.RWMutex
ips map[string]*record ips map[string]record
pub *pubsub.Service pub *pubsub.Service
cleanup *task.Periodic cleanup *task.Periodic
reqID uint32 reqID uint32
@ -113,7 +113,7 @@ func NewDoHLocalNameServer(url *url.URL) *DoHNameServer {
func baseDOHNameServer(url *url.URL, prefix string) *DoHNameServer { func baseDOHNameServer(url *url.URL, prefix string) *DoHNameServer {
s := &DoHNameServer{ s := &DoHNameServer{
ips: make(map[string]*record), ips: make(map[string]record),
pub: pubsub.NewService(), pub: pubsub.NewService(),
name: prefix + "//" + url.Host, name: prefix + "//" + url.Host,
dohURL: url.String(), dohURL: url.String(),
@ -157,7 +157,7 @@ func (s *DoHNameServer) Cleanup() error {
} }
if len(s.ips) == 0 { if len(s.ips) == 0 {
s.ips = make(map[string]*record) s.ips = make(map[string]record)
} }
return nil return nil
@ -167,10 +167,7 @@ func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
elapsed := time.Since(req.start) elapsed := time.Since(req.start)
s.Lock() s.Lock()
rec, found := s.ips[req.domain] rec := s.ips[req.domain]
if !found {
rec = &record{}
}
updated := false updated := false
switch req.reqType { switch req.reqType {
@ -180,7 +177,7 @@ func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
updated = true updated = true
} }
case dnsmessage.TypeAAAA: case dnsmessage.TypeAAAA:
addr := make([]net.Address, 0, len(ipRec.IP)) addr := make([]net.Address, 0)
for _, ip := range ipRec.IP { for _, ip := range ipRec.IP {
if len(ip.IP()) == net.IPv6len { if len(ip.IP()) == net.IPv6len {
addr = append(addr, ip) addr = append(addr, ip)
@ -299,30 +296,34 @@ func (s *DoHNameServer) findIPsForDomain(domain string, option dns_feature.IPOpt
return nil, errRecordNotFound return nil, errRecordNotFound
} }
var err4 error
var err6 error
var ips []net.Address var ips []net.Address
var ip6 []net.Address var lastErr error
if option.IPv6Enable && record.AAAA != nil && record.AAAA.RCode == dnsmessage.RCodeSuccess {
aaaa, err := record.AAAA.getIPs()
if err != nil {
lastErr = err
}
ips = append(ips, aaaa...)
}
switch { if option.IPv4Enable && record.A != nil && record.A.RCode == dnsmessage.RCodeSuccess {
case option.IPv4Enable: a, err := record.A.getIPs()
ips, err4 = record.A.getIPs() if err != nil {
fallthrough lastErr = err
case option.IPv6Enable: }
ip6, err6 = record.AAAA.getIPs() ips = append(ips, a...)
ips = append(ips, ip6...)
} }
if len(ips) > 0 { if len(ips) > 0 {
return toNetIP(ips) return toNetIP(ips)
} }
if err4 != nil { if lastErr != nil {
return nil, err4 return nil, lastErr
} }
if err6 != nil { if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) {
return nil, err6 return nil, dns_feature.ErrEmptyResponse
} }
return nil, errRecordNotFound return nil, errRecordNotFound

View File

@ -31,12 +31,12 @@ const handshakeIdleTimeout = time.Second * 8
// QUICNameServer implemented DNS over QUIC // QUICNameServer implemented DNS over QUIC
type QUICNameServer struct { type QUICNameServer struct {
sync.RWMutex sync.RWMutex
ips map[string]*record ips map[string]record
pub *pubsub.Service pub *pubsub.Service
cleanup *task.Periodic cleanup *task.Periodic
reqID uint32 reqID uint32
name string name string
destination *net.Destination destination net.Destination
session quic.Session session quic.Session
} }
@ -55,10 +55,10 @@ func NewQUICNameServer(url *url.URL) (*QUICNameServer, error) {
dest := net.UDPDestination(net.ParseAddress(url.Hostname()), port) dest := net.UDPDestination(net.ParseAddress(url.Hostname()), port)
s := &QUICNameServer{ s := &QUICNameServer{
ips: make(map[string]*record), ips: make(map[string]record),
pub: pubsub.NewService(), pub: pubsub.NewService(),
name: url.String(), name: url.String(),
destination: &dest, destination: dest,
} }
s.cleanup = &task.Periodic{ s.cleanup = &task.Periodic{
Interval: time.Minute, Interval: time.Minute,
@ -100,7 +100,7 @@ func (s *QUICNameServer) Cleanup() error {
} }
if len(s.ips) == 0 { if len(s.ips) == 0 {
s.ips = make(map[string]*record) s.ips = make(map[string]record)
} }
return nil return nil
@ -110,10 +110,7 @@ func (s *QUICNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
elapsed := time.Since(req.start) elapsed := time.Since(req.start)
s.Lock() s.Lock()
rec, found := s.ips[req.domain] rec := s.ips[req.domain]
if !found {
rec = &record{}
}
updated := false updated := false
switch req.reqType { switch req.reqType {
@ -233,30 +230,34 @@ func (s *QUICNameServer) findIPsForDomain(domain string, option dns_feature.IPOp
return nil, errRecordNotFound return nil, errRecordNotFound
} }
var err4 error
var err6 error
var ips []net.Address var ips []net.Address
var ip6 []net.Address var lastErr error
if option.IPv6Enable && record.AAAA != nil && record.AAAA.RCode == dnsmessage.RCodeSuccess {
aaaa, err := record.AAAA.getIPs()
if err != nil {
lastErr = err
}
ips = append(ips, aaaa...)
}
switch { if option.IPv4Enable && record.A != nil && record.A.RCode == dnsmessage.RCodeSuccess {
case option.IPv4Enable: a, err := record.A.getIPs()
ips, err4 = record.A.getIPs() if err != nil {
fallthrough lastErr = err
case option.IPv6Enable: }
ip6, err6 = record.AAAA.getIPs() ips = append(ips, a...)
ips = append(ips, ip6...)
} }
if len(ips) > 0 { if len(ips) > 0 {
return toNetIP(ips) return toNetIP(ips)
} }
if err4 != nil { if lastErr != nil {
return nil, err4 return nil, lastErr
} }
if err6 != nil { if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) {
return nil, err6 return nil, dns_feature.ErrEmptyResponse
} }
return nil, errRecordNotFound return nil, errRecordNotFound

View File

@ -30,8 +30,8 @@ import (
type TCPNameServer struct { type TCPNameServer struct {
sync.RWMutex sync.RWMutex
name string name string
destination *net.Destination destination net.Destination
ips map[string]*record ips map[string]record
pub *pubsub.Service pub *pubsub.Service
cleanup *task.Periodic cleanup *task.Periodic
reqID uint32 reqID uint32
@ -46,7 +46,7 @@ func NewTCPNameServer(url *url.URL, dispatcher routing.Dispatcher) (*TCPNameServ
} }
s.dial = func(ctx context.Context) (net.Conn, error) { s.dial = func(ctx context.Context) (net.Conn, error) {
link, err := dispatcher.Dispatch(ctx, *s.destination) link, err := dispatcher.Dispatch(ctx, s.destination)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -68,7 +68,7 @@ func NewTCPLocalNameServer(url *url.URL) (*TCPNameServer, error) {
} }
s.dial = func(ctx context.Context) (net.Conn, error) { s.dial = func(ctx context.Context) (net.Conn, error) {
return internet.DialSystem(ctx, *s.destination, nil) return internet.DialSystem(ctx, s.destination, nil)
} }
return s, nil return s, nil
@ -86,8 +86,8 @@ func baseTCPNameServer(url *url.URL, prefix string) (*TCPNameServer, error) {
dest := net.TCPDestination(net.ParseAddress(url.Hostname()), port) dest := net.TCPDestination(net.ParseAddress(url.Hostname()), port)
s := &TCPNameServer{ s := &TCPNameServer{
destination: &dest, destination: dest,
ips: make(map[string]*record), ips: make(map[string]record),
pub: pubsub.NewService(), pub: pubsub.NewService(),
name: prefix + "//" + dest.NetAddr(), name: prefix + "//" + dest.NetAddr(),
} }
@ -131,7 +131,7 @@ func (s *TCPNameServer) Cleanup() error {
} }
if len(s.ips) == 0 { if len(s.ips) == 0 {
s.ips = make(map[string]*record) s.ips = make(map[string]record)
} }
return nil return nil
@ -141,10 +141,7 @@ func (s *TCPNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
elapsed := time.Since(req.start) elapsed := time.Since(req.start)
s.Lock() s.Lock()
rec, found := s.ips[req.domain] rec := s.ips[req.domain]
if !found {
rec = &record{}
}
updated := false updated := false
switch req.reqType { switch req.reqType {
@ -278,30 +275,30 @@ func (s *TCPNameServer) findIPsForDomain(domain string, option dns_feature.IPOpt
return nil, errRecordNotFound return nil, errRecordNotFound
} }
var err4 error
var err6 error
var ips []net.Address var ips []net.Address
var ip6 []net.Address var lastErr error
if option.IPv4Enable {
a, err := record.A.getIPs()
if err != nil {
lastErr = err
}
ips = append(ips, a...)
}
switch { if option.IPv6Enable {
case option.IPv4Enable: aaaa, err := record.AAAA.getIPs()
ips, err4 = record.A.getIPs() if err != nil {
fallthrough lastErr = err
case option.IPv6Enable: }
ip6, err6 = record.AAAA.getIPs() ips = append(ips, aaaa...)
ips = append(ips, ip6...)
} }
if len(ips) > 0 { if len(ips) > 0 {
return toNetIP(ips) return toNetIP(ips)
} }
if err4 != nil { if lastErr != nil {
return nil, err4 return nil, lastErr
}
if err6 != nil {
return nil, err6
} }
return nil, dns_feature.ErrEmptyResponse return nil, dns_feature.ErrEmptyResponse

View File

@ -29,9 +29,9 @@ import (
type ClassicNameServer struct { type ClassicNameServer struct {
sync.RWMutex sync.RWMutex
name string name string
address *net.Destination address net.Destination
ips map[string]*record ips map[string]record
requests map[uint16]*dnsRequest requests map[uint16]dnsRequest
pub *pubsub.Service pub *pubsub.Service
udpServer *udp.Dispatcher udpServer *udp.Dispatcher
cleanup *task.Periodic cleanup *task.Periodic
@ -46,9 +46,9 @@ func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher
} }
s := &ClassicNameServer{ s := &ClassicNameServer{
address: &address, address: address,
ips: make(map[string]*record), ips: make(map[string]record),
requests: make(map[uint16]*dnsRequest), requests: make(map[uint16]dnsRequest),
pub: pubsub.NewService(), pub: pubsub.NewService(),
name: strings.ToUpper(address.String()), name: strings.ToUpper(address.String()),
} }
@ -85,7 +85,6 @@ func (s *ClassicNameServer) Cleanup() error {
} }
if record.A == nil && record.AAAA == nil { if record.A == nil && record.AAAA == nil {
newError(s.name, " cleanup ", domain).AtDebug().WriteToLog()
delete(s.ips, domain) delete(s.ips, domain)
} else { } else {
s.ips[domain] = record s.ips[domain] = record
@ -93,7 +92,7 @@ func (s *ClassicNameServer) Cleanup() error {
} }
if len(s.ips) == 0 { if len(s.ips) == 0 {
s.ips = make(map[string]*record) s.ips = make(map[string]record)
} }
for id, req := range s.requests { for id, req := range s.requests {
@ -103,7 +102,7 @@ func (s *ClassicNameServer) Cleanup() error {
} }
if len(s.requests) == 0 { if len(s.requests) == 0 {
s.requests = make(map[uint16]*dnsRequest) s.requests = make(map[uint16]dnsRequest)
} }
return nil return nil
@ -141,17 +140,15 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
elapsed := time.Since(req.start) elapsed := time.Since(req.start)
newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog() newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog()
if len(req.domain) > 0 && (rec.A != nil || rec.AAAA != nil) { if len(req.domain) > 0 && (rec.A != nil || rec.AAAA != nil) {
s.updateIP(req.domain, &rec) s.updateIP(req.domain, rec)
} }
} }
func (s *ClassicNameServer) updateIP(domain string, newRec *record) { func (s *ClassicNameServer) updateIP(domain string, newRec record) {
s.Lock() s.Lock()
rec, found := s.ips[domain] newError(s.name, " updating IP records for domain:", domain).AtDebug().WriteToLog()
if !found { rec := s.ips[domain]
rec = &record{}
}
updated := false updated := false
if isNewer(rec.A, newRec.A) { if isNewer(rec.A, newRec.A) {
@ -164,7 +161,6 @@ func (s *ClassicNameServer) updateIP(domain string, newRec *record) {
} }
if updated { if updated {
newError(s.name, " updating IP records for domain:", domain).AtDebug().WriteToLog()
s.ips[domain] = rec s.ips[domain] = rec
} }
if newRec.A != nil { if newRec.A != nil {
@ -187,7 +183,7 @@ func (s *ClassicNameServer) addPendingRequest(req *dnsRequest) {
id := req.msg.ID id := req.msg.ID
req.expire = time.Now().Add(time.Second * 8) req.expire = time.Now().Add(time.Second * 8)
s.requests[id] = req s.requests[id] = *req
} }
func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) { func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) {
@ -205,7 +201,7 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, client
udpCtx = session.ContextWithContent(udpCtx, &session.Content{ udpCtx = session.ContextWithContent(udpCtx, &session.Content{
Protocol: "dns", Protocol: "dns",
}) })
s.udpServer.Dispatch(udpCtx, *s.address, b) s.udpServer.Dispatch(udpCtx, s.address, b)
} }
} }
@ -218,30 +214,30 @@ func (s *ClassicNameServer) findIPsForDomain(domain string, option dns_feature.I
return nil, errRecordNotFound return nil, errRecordNotFound
} }
var err4 error
var err6 error
var ips []net.Address var ips []net.Address
var ip6 []net.Address var lastErr error
if option.IPv4Enable {
a, err := record.A.getIPs()
if err != nil {
lastErr = err
}
ips = append(ips, a...)
}
switch { if option.IPv6Enable {
case option.IPv4Enable: aaaa, err := record.AAAA.getIPs()
ips, err4 = record.A.getIPs() if err != nil {
fallthrough lastErr = err
case option.IPv6Enable: }
ip6, err6 = record.AAAA.getIPs() ips = append(ips, aaaa...)
ips = append(ips, ip6...)
} }
if len(ips) > 0 { if len(ips) > 0 {
return toNetIP(ips) return toNetIP(ips)
} }
if err4 != nil { if lastErr != nil {
return nil, err4 return nil, lastErr
}
if err6 != nil {
return nil, err6
} }
return nil, dns_feature.ErrEmptyResponse return nil, dns_feature.ErrEmptyResponse