From 484dc4e4880d35bf8696b7250fe9d00f7ece4631 Mon Sep 17 00:00:00 2001 From: Shelikhoo Date: Sat, 4 Sep 2021 11:46:14 +0100 Subject: [PATCH] reverting commit 50bcb683 --- app/dns/nameserver_doh.go | 48 ++++++++++++++------------- app/dns/nameserver_quic.go | 8 ++--- app/dns/nameserver_tcp.go | 8 ++--- app/dns/nameserver_udp.go | 67 +++++++++++++++++++++----------------- 4 files changed, 70 insertions(+), 61 deletions(-) diff --git a/app/dns/nameserver_doh.go b/app/dns/nameserver_doh.go index e24f69180..88350f2ee 100644 --- a/app/dns/nameserver_doh.go +++ b/app/dns/nameserver_doh.go @@ -1,3 +1,6 @@ +//go:build !confonly +// +build !confonly + package dns import ( @@ -30,7 +33,7 @@ import ( // thus most of the DOH implementation is copied from udpns.go type DoHNameServer struct { sync.RWMutex - ips map[string]record + ips map[string]*record pub *pubsub.Service cleanup *task.Periodic reqID uint32 @@ -110,7 +113,7 @@ func NewDoHLocalNameServer(url *url.URL) *DoHNameServer { func baseDOHNameServer(url *url.URL, prefix string) *DoHNameServer { s := &DoHNameServer{ - ips: make(map[string]record), + ips: make(map[string]*record), pub: pubsub.NewService(), name: prefix + "//" + url.Host, dohURL: url.String(), @@ -154,7 +157,7 @@ func (s *DoHNameServer) Cleanup() error { } if len(s.ips) == 0 { - s.ips = make(map[string]record) + s.ips = make(map[string]*record) } return nil @@ -164,7 +167,10 @@ func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) { elapsed := time.Since(req.start) s.Lock() - rec := s.ips[req.domain] + rec, found := s.ips[req.domain] + if !found { + rec = &record{} + } updated := false switch req.reqType { @@ -174,7 +180,7 @@ func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) { updated = true } case dnsmessage.TypeAAAA: - addr := make([]net.Address, 0) + addr := make([]net.Address, 0, len(ipRec.IP)) for _, ip := range ipRec.IP { if len(ip.IP()) == net.IPv6len { addr = append(addr, ip) @@ -293,34 +299,30 @@ func (s *DoHNameServer) findIPsForDomain(domain string, option dns_feature.IPOpt return nil, errRecordNotFound } + var err4 error + var err6 error var ips []net.Address - var lastErr error - if option.IPv6Enable && record.AAAA != nil && record.AAAA.RCode == dnsmessage.RCodeSuccess { - aaaa, err := record.AAAA.getIPs() - if err != nil { - lastErr = err - } - ips = append(ips, aaaa...) - } + var ip6 []net.Address - if option.IPv4Enable && record.A != nil && record.A.RCode == dnsmessage.RCodeSuccess { - a, err := record.A.getIPs() - if err != nil { - lastErr = err - } - ips = append(ips, a...) + switch { + case option.IPv4Enable: + ips, err4 = record.A.getIPs() + fallthrough + case option.IPv6Enable: + ip6, err6 = record.AAAA.getIPs() + ips = append(ips, ip6...) } if len(ips) > 0 { return toNetIP(ips) } - if lastErr != nil { - return nil, lastErr + if err4 != nil { + return nil, err4 } - if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) { - return nil, dns_feature.ErrEmptyResponse + if err6 != nil { + return nil, err6 } return nil, errRecordNotFound diff --git a/app/dns/nameserver_quic.go b/app/dns/nameserver_quic.go index e5ed77edb..ce1a1bc89 100644 --- a/app/dns/nameserver_quic.go +++ b/app/dns/nameserver_quic.go @@ -238,11 +238,11 @@ func (s *QUICNameServer) findIPsForDomain(domain string, option dns_feature.IPOp var ips []net.Address var ip6 []net.Address - if option.IPv4Enable { + switch { + case option.IPv4Enable: ips, err4 = record.A.getIPs() - } - - if option.IPv6Enable { + fallthrough + case option.IPv6Enable: ip6, err6 = record.AAAA.getIPs() ips = append(ips, ip6...) } diff --git a/app/dns/nameserver_tcp.go b/app/dns/nameserver_tcp.go index c114d6152..be51ffb57 100644 --- a/app/dns/nameserver_tcp.go +++ b/app/dns/nameserver_tcp.go @@ -283,11 +283,11 @@ func (s *TCPNameServer) findIPsForDomain(domain string, option dns_feature.IPOpt var ips []net.Address var ip6 []net.Address - if option.IPv4Enable { + switch { + case option.IPv4Enable: ips, err4 = record.A.getIPs() - } - - if option.IPv6Enable { + fallthrough + case option.IPv6Enable: ip6, err6 = record.AAAA.getIPs() ips = append(ips, ip6...) } diff --git a/app/dns/nameserver_udp.go b/app/dns/nameserver_udp.go index 3aa4976c1..5d88da148 100644 --- a/app/dns/nameserver_udp.go +++ b/app/dns/nameserver_udp.go @@ -1,3 +1,6 @@ +//go:build !confonly +// +build !confonly + package dns import ( @@ -26,9 +29,9 @@ import ( type ClassicNameServer struct { sync.RWMutex name string - address net.Destination - ips map[string]record - requests map[uint16]dnsRequest + address *net.Destination + ips map[string]*record + requests map[uint16]*dnsRequest pub *pubsub.Service udpServer *udp.Dispatcher cleanup *task.Periodic @@ -43,9 +46,9 @@ func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher } s := &ClassicNameServer{ - address: address, - ips: make(map[string]record), - requests: make(map[uint16]dnsRequest), + address: &address, + ips: make(map[string]*record), + requests: make(map[uint16]*dnsRequest), pub: pubsub.NewService(), name: strings.ToUpper(address.String()), } @@ -82,6 +85,7 @@ func (s *ClassicNameServer) Cleanup() error { } if record.A == nil && record.AAAA == nil { + newError(s.name, " cleanup ", domain).AtDebug().WriteToLog() delete(s.ips, domain) } else { s.ips[domain] = record @@ -89,7 +93,7 @@ func (s *ClassicNameServer) Cleanup() error { } if len(s.ips) == 0 { - s.ips = make(map[string]record) + s.ips = make(map[string]*record) } for id, req := range s.requests { @@ -99,7 +103,7 @@ func (s *ClassicNameServer) Cleanup() error { } if len(s.requests) == 0 { - s.requests = make(map[uint16]dnsRequest) + s.requests = make(map[uint16]*dnsRequest) } return nil @@ -137,15 +141,17 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot elapsed := time.Since(req.start) newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog() if len(req.domain) > 0 && (rec.A != nil || rec.AAAA != nil) { - s.updateIP(req.domain, rec) + s.updateIP(req.domain, &rec) } } -func (s *ClassicNameServer) updateIP(domain string, newRec record) { +func (s *ClassicNameServer) updateIP(domain string, newRec *record) { s.Lock() - newError(s.name, " updating IP records for domain:", domain).AtDebug().WriteToLog() - rec := s.ips[domain] + rec, found := s.ips[domain] + if !found { + rec = &record{} + } updated := false if isNewer(rec.A, newRec.A) { @@ -158,6 +164,7 @@ func (s *ClassicNameServer) updateIP(domain string, newRec record) { } if updated { + newError(s.name, " updating IP records for domain:", domain).AtDebug().WriteToLog() s.ips[domain] = rec } if newRec.A != nil { @@ -180,7 +187,7 @@ func (s *ClassicNameServer) addPendingRequest(req *dnsRequest) { id := req.msg.ID req.expire = time.Now().Add(time.Second * 8) - s.requests[id] = *req + s.requests[id] = req } func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) { @@ -198,7 +205,7 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, client udpCtx = session.ContextWithContent(udpCtx, &session.Content{ Protocol: "dns", }) - s.udpServer.Dispatch(udpCtx, s.address, b) + s.udpServer.Dispatch(udpCtx, *s.address, b) } } @@ -211,30 +218,30 @@ func (s *ClassicNameServer) findIPsForDomain(domain string, option dns_feature.I return nil, errRecordNotFound } + var err4 error + var err6 error var ips []net.Address - var lastErr error - if option.IPv4Enable { - a, err := record.A.getIPs() - if err != nil { - lastErr = err - } - ips = append(ips, a...) - } + var ip6 []net.Address - if option.IPv6Enable { - aaaa, err := record.AAAA.getIPs() - if err != nil { - lastErr = err - } - ips = append(ips, aaaa...) + switch { + case option.IPv4Enable: + ips, err4 = record.A.getIPs() + fallthrough + case option.IPv6Enable: + ip6, err6 = record.AAAA.getIPs() + ips = append(ips, ip6...) } if len(ips) > 0 { return toNetIP(ips) } - if lastErr != nil { - return nil, lastErr + if err4 != nil { + return nil, err4 + } + + if err6 != nil { + return nil, err6 } return nil, dns_feature.ErrEmptyResponse