diff --git a/app/dns/nameserver.go b/app/dns/nameserver.go index da7a11ac4..43b8c39e9 100644 --- a/app/dns/nameserver.go +++ b/app/dns/nameserver.go @@ -2,17 +2,9 @@ package dns import ( "context" - "sync" "time" - "github.com/miekg/dns" - "v2ray.com/core" - "v2ray.com/core/common" - "v2ray.com/core/common/buf" - "v2ray.com/core/common/dice" "v2ray.com/core/common/net" - "v2ray.com/core/common/task" - "v2ray.com/core/transport/internet/udp" ) var ( @@ -29,203 +21,12 @@ type ARecord struct { } type NameServer interface { - QueryA(domain string) <-chan *ARecord -} - -type PendingRequest struct { - expire time.Time - response chan<- *ARecord -} - -type UDPNameServer struct { - sync.Mutex - address net.Destination - requests map[uint16]*PendingRequest - udpServer *udp.Dispatcher - cleanup *task.Periodic -} - -func NewUDPNameServer(address net.Destination, dispatcher core.Dispatcher) *UDPNameServer { - s := &UDPNameServer{ - address: address, - requests: make(map[uint16]*PendingRequest), - udpServer: udp.NewDispatcher(dispatcher), - } - s.cleanup = &task.Periodic{ - Interval: time.Minute, - Execute: s.Cleanup, - } - common.Must(s.cleanup.Start()) - return s -} - -func (s *UDPNameServer) Cleanup() error { - now := time.Now() - s.Lock() - for id, r := range s.requests { - if r.expire.Before(now) { - close(r.response) - delete(s.requests, id) - } - } - s.Unlock() - return nil -} - -func (s *UDPNameServer) AssignUnusedID(response chan<- *ARecord) uint16 { - var id uint16 - s.Lock() - - for { - id = dice.RollUint16() - if _, found := s.requests[id]; found { - time.Sleep(time.Millisecond * 500) - continue - } - newError("add pending request id ", id).AtDebug().WriteToLog() - s.requests[id] = &PendingRequest{ - expire: time.Now().Add(time.Second * 8), - response: response, - } - break - } - s.Unlock() - return id -} - -func (s *UDPNameServer) HandleResponse(payload *buf.Buffer) { - msg := new(dns.Msg) - err := msg.Unpack(payload.Bytes()) - if err == dns.ErrTruncated { - newError("truncated message received. DNS server should still work. If you see anything abnormal, please submit an issue to v2ray-core.").AtWarning().WriteToLog() - } else if err != nil { - newError("failed to parse DNS response").Base(err).AtWarning().WriteToLog() - return - } - record := &ARecord{ - IPs: make([]net.IP, 0, 16), - } - id := msg.Id - ttl := uint32(3600) // an hour - newError("handling response for id ", id, " content: ", msg).AtDebug().WriteToLog() - - s.Lock() - request, found := s.requests[id] - if !found { - s.Unlock() - return - } - delete(s.requests, id) - s.Unlock() - - for _, rr := range msg.Answer { - switch rr := rr.(type) { - case *dns.A: - record.IPs = append(record.IPs, rr.A) - if rr.Hdr.Ttl < ttl { - ttl = rr.Hdr.Ttl - } - case *dns.AAAA: - record.IPs = append(record.IPs, rr.AAAA) - if rr.Hdr.Ttl < ttl { - ttl = rr.Hdr.Ttl - } - } - } - record.Expire = time.Now().Add(time.Second * time.Duration(ttl)) - - request.response <- record - close(request.response) -} - -func (s *UDPNameServer) buildAMsg(domain string, id uint16) *dns.Msg { - msg := new(dns.Msg) - msg.Id = id - msg.RecursionDesired = true - msg.Question = []dns.Question{ - { - Name: dns.Fqdn(domain), - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - }} - if multiQuestionDNS[s.address.Address] { - msg.Question = append(msg.Question, dns.Question{ - Name: dns.Fqdn(domain), - Qtype: dns.TypeAAAA, - Qclass: dns.ClassINET, - }) - } - - return msg -} - -func msgToBuffer(msg *dns.Msg) (*buf.Buffer, error) { - buffer := buf.New() - if err := buffer.Reset(func(b []byte) (int, error) { - writtenBuffer, err := msg.PackBuffer(b) - return len(writtenBuffer), err - }); err != nil { - return nil, err - } - return buffer, nil -} - -func (s *UDPNameServer) QueryA(domain string) <-chan *ARecord { - response := make(chan *ARecord, 1) - id := s.AssignUnusedID(response) - - msg := s.buildAMsg(domain, id) - b, err := msgToBuffer(msg) - if err != nil { - newError("failed to build A query for domain ", domain).Base(err).WriteToLog() - s.Lock() - delete(s.requests, id) - s.Unlock() - close(response) - return response - } - - ctx, cancel := context.WithCancel(context.Background()) - s.udpServer.Dispatch(ctx, s.address, b, s.HandleResponse) - - go func() { - for i := 0; i < 2; i++ { - time.Sleep(time.Second) - s.Lock() - _, found := s.requests[id] - s.Unlock() - if !found { - break - } - b, _ := msgToBuffer(msg) - s.udpServer.Dispatch(ctx, s.address, b, s.HandleResponse) - } - cancel() - }() - - return response + QueryIP(ctx context.Context, domain string) ([]net.IP, error) } type LocalNameServer struct { } -func (*LocalNameServer) QueryA(domain string) <-chan *ARecord { - response := make(chan *ARecord, 1) - - go func() { - defer close(response) - - ips, err := net.LookupIP(domain) - if err != nil { - newError("failed to lookup IPs for domain ", domain).Base(err).AtWarning().WriteToLog() - return - } - - response <- &ARecord{ - IPs: ips, - Expire: time.Now().Add(time.Hour), - } - }() - - return response +func (*LocalNameServer) QueryIP(ctx context.Context, domain string) ([]net.IP, error) { + return net.LookupIP(domain) } diff --git a/app/dns/server.go b/app/dns/server.go index fe55c7f7a..6e4a66cea 100644 --- a/app/dns/server.go +++ b/app/dns/server.go @@ -7,48 +7,24 @@ import ( "sync" "time" - dnsmsg "github.com/miekg/dns" "v2ray.com/core" "v2ray.com/core/common" "v2ray.com/core/common/net" "v2ray.com/core/common/task" ) -const ( - QueryTimeout = time.Second * 8 -) - -type DomainRecord struct { - IP []net.IP - Expire time.Time - LastAccess time.Time -} - -func (r *DomainRecord) Expired() bool { - return r.Expire.Before(time.Now()) -} - type Server struct { sync.Mutex hosts map[string]net.IP - records map[string]*DomainRecord servers []NameServer task *task.Periodic } func New(ctx context.Context, config *Config) (*Server, error) { server := &Server{ - records: make(map[string]*DomainRecord), servers: make([]NameServer, len(config.NameServers)), hosts: config.GetInternalHosts(), } - server.task = &task.Periodic{ - Interval: time.Minute * 10, - Execute: func() error { - server.cleanup() - return nil - }, - } v := core.MustFromContext(ctx) if err := v.RegisterFeature((*core.DNSClient)(nil), server); err != nil { return nil, newError("unable to register DNSClient.").Base(err) @@ -64,7 +40,7 @@ func New(ctx context.Context, config *Config) (*Server, error) { dest.Network = net.Network_UDP } if dest.Network == net.Network_UDP { - server.servers[idx] = NewUDPNameServer(dest, v.Dispatcher()) + server.servers[idx] = NewClassicNameServer(dest, v.Dispatcher()) } } } @@ -85,64 +61,25 @@ func (s *Server) Close() error { return s.task.Close() } -func (s *Server) GetCached(domain string) []net.IP { - s.Lock() - defer s.Unlock() - - if record, found := s.records[domain]; found && !record.Expired() { - record.LastAccess = time.Now() - return record.IP - } - return nil -} - -func (s *Server) cleanup() { - s.Lock() - defer s.Unlock() - - for d, r := range s.records { - if r.Expired() { - delete(s.records, d) - } - } - - if len(s.records) == 0 { - s.records = make(map[string]*DomainRecord) - } -} - func (s *Server) LookupIP(domain string) ([]net.IP, error) { if ip, found := s.hosts[domain]; found { return []net.IP{ip}, nil } - domain = dnsmsg.Fqdn(domain) - ips := s.GetCached(domain) - if ips != nil { - return ips, nil - } - + var lastErr error for _, server := range s.servers { - response := server.QueryA(domain) - select { - case a, open := <-response: - if !open || a == nil { - continue - } - s.Lock() - s.records[domain] = &DomainRecord{ - IP: a.IPs, - Expire: a.Expire, - LastAccess: time.Now(), - } - s.Unlock() - newError("returning ", len(a.IPs), " IPs for domain ", domain).AtDebug().WriteToLog() - return a.IPs, nil - case <-time.After(QueryTimeout): + ctx, cancel := context.WithTimeout(context.Background(), time.Second*4) + ips, err := server.QueryIP(ctx, domain) + cancel() + if err != nil { + lastErr = err + } + if len(ips) > 0 { + return ips, nil } } - return nil, newError("returning nil for domain ", domain) + return nil, newError("returning nil for domain ", domain).Base(lastErr) } func init() { diff --git a/app/dns/udpns.go b/app/dns/udpns.go new file mode 100644 index 000000000..9b025784d --- /dev/null +++ b/app/dns/udpns.go @@ -0,0 +1,229 @@ +package dns + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "github.com/miekg/dns" + "v2ray.com/core" + "v2ray.com/core/common" + "v2ray.com/core/common/buf" + "v2ray.com/core/common/net" + "v2ray.com/core/common/signal" + "v2ray.com/core/common/task" + "v2ray.com/core/transport/internet/udp" +) + +type IPRecord struct { + IP net.IP + Expire time.Time +} + +type ClassicNameServer struct { + sync.RWMutex + address net.Destination + ips map[string][]IPRecord + updated signal.Notifier + udpServer *udp.Dispatcher + cleanup *task.Periodic + reqID uint32 +} + +func NewClassicNameServer(address net.Destination, dispatcher core.Dispatcher) *ClassicNameServer { + s := &ClassicNameServer{ + address: address, + ips: make(map[string][]IPRecord), + udpServer: udp.NewDispatcher(dispatcher), + } + s.cleanup = &task.Periodic{ + Interval: time.Minute, + Execute: s.Cleanup, + } + common.Must(s.cleanup.Start()) + return s +} + +func (s *ClassicNameServer) Cleanup() error { + now := time.Now() + s.Lock() + for domain, ips := range s.ips { + newIPs := make([]IPRecord, 0, len(ips)) + for _, ip := range ips { + if ip.Expire.After(now) { + newIPs = append(newIPs, ip) + } + } + if len(newIPs) == 0 { + delete(s.ips, domain) + } else if len(newIPs) < len(ips) { + s.ips[domain] = newIPs + } + } + + if len(s.ips) == 0 { + s.ips = make(map[string][]IPRecord) + } + + s.Unlock() + return nil +} + +func (s *ClassicNameServer) HandleResponse(payload *buf.Buffer) { + msg := new(dns.Msg) + err := msg.Unpack(payload.Bytes()) + if err == dns.ErrTruncated { + newError("truncated message received. DNS server should still work. If you see anything abnormal, please submit an issue to v2ray-core.").AtWarning().WriteToLog() + } else if err != nil { + newError("failed to parse DNS response").Base(err).AtWarning().WriteToLog() + return + } + + var domain string + ips := make([]IPRecord, 0, 16) + + now := time.Now() + for _, rr := range msg.Answer { + var ip net.IP + domain = rr.Header().Name + ttl := rr.Header().Ttl + switch rr := rr.(type) { + case *dns.A: + ip = rr.A + case *dns.AAAA: + ip = rr.AAAA + } + if ttl == 0 { + ttl = 300 + } + if len(ip) > 0 { + ips = append(ips, IPRecord{ + IP: ip, + Expire: now.Add(time.Second * time.Duration(ttl)), + }) + } + } + + if len(domain) > 0 && len(ips) > 0 { + s.updateIP(domain, ips) + } +} + +func (s *ClassicNameServer) updateIP(domain string, ips []IPRecord) { + s.Lock() + defer s.Unlock() + + newError("updating IP records for domain:", domain).AtDebug().WriteToLog() + now := time.Now() + eips := s.ips[domain] + for _, ip := range eips { + if ip.Expire.After(now) { + ips = append(ips, ip) + } + } + s.ips[domain] = ips + s.updated.Signal() +} + +func (s *ClassicNameServer) buildMsgs(domain string) []*dns.Msg { + allowMulti := multiQuestionDNS[s.address.Address] + + var msgs []*dns.Msg + + { + msg := new(dns.Msg) + msg.Id = uint16(atomic.AddUint32(&s.reqID, 1)) + msg.RecursionDesired = true + msg.Question = []dns.Question{ + { + Name: domain, + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }} + if allowMulti { + msg.Question = append(msg.Question, dns.Question{ + Name: domain, + Qtype: dns.TypeAAAA, + Qclass: dns.ClassINET, + }) + } + msgs = append(msgs, msg) + } + + if !allowMulti { + msg := new(dns.Msg) + msg.Id = uint16(atomic.AddUint32(&s.reqID, 1)) + msg.RecursionDesired = true + msg.Question = []dns.Question{ + { + Name: domain, + Qtype: dns.TypeAAAA, + Qclass: dns.ClassINET, + }, + } + msgs = append(msgs, msg) + } + + return msgs +} + +func msgToBuffer(msg *dns.Msg) (*buf.Buffer, error) { + buffer := buf.New() + if err := buffer.Reset(func(b []byte) (int, error) { + writtenBuffer, err := msg.PackBuffer(b) + return len(writtenBuffer), err + }); err != nil { + return nil, err + } + return buffer, nil +} + +func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string) { + msgs := s.buildMsgs(domain) + + for _, msg := range msgs { + b, err := msgToBuffer(msg) + common.Must(err) + s.udpServer.Dispatch(ctx, s.address, b, s.HandleResponse) + } +} + +func (s *ClassicNameServer) findIPsForDomain(domain string) []net.IP { + records, found := s.ips[domain] + if found && len(records) > 0 { + var ips []net.IP + now := time.Now() + for _, rec := range records { + if rec.Expire.After(now) { + ips = append(ips, rec.IP) + } + } + return ips + } + return nil +} + +func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string) ([]net.IP, error) { + fqdn := dns.Fqdn(domain) + + ips := s.findIPsForDomain(fqdn) + if len(ips) > 0 { + return ips, nil + } + + s.sendQuery(ctx, fqdn) + + for { + ips := s.findIPsForDomain(fqdn) + if len(ips) > 0 { + return ips, nil + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-s.updated.Wait(): + } + } +} diff --git a/common/signal/notifier.go b/common/signal/notifier.go index 19836e54f..f850016bf 100755 --- a/common/signal/notifier.go +++ b/common/signal/notifier.go @@ -1,26 +1,34 @@ package signal +import "sync" + // Notifier is a utility for notifying changes. The change producer may notify changes multiple time, and the consumer may get notified asynchronously. type Notifier struct { - c chan struct{} + sync.Mutex + waiters []chan struct{} } // NewNotifier creates a new Notifier. func NewNotifier() *Notifier { - return &Notifier{ - c: make(chan struct{}, 1), - } + return &Notifier{} } // Signal signals a change, usually by producer. This method never blocks. func (n *Notifier) Signal() { - select { - case n.c <- struct{}{}: - default: + n.Lock() + for _, w := range n.waiters { + close(w) } + n.waiters = make([]chan struct{}, 0, 8) + n.Unlock() } // Wait returns a channel for waiting for changes. The returned channel never gets closed. func (n *Notifier) Wait() <-chan struct{} { - return n.c + n.Lock() + defer n.Unlock() + + w := make(chan struct{}) + n.waiters = append(n.waiters, w) + return w } diff --git a/common/signal/notifier_test.go b/common/signal/notifier_test.go new file mode 100644 index 000000000..ca6a77003 --- /dev/null +++ b/common/signal/notifier_test.go @@ -0,0 +1,23 @@ +package signal_test + +import ( + "testing" + + . "v2ray.com/core/common/signal" + //. "v2ray.com/ext/assert" +) + +func TestNotifierSignal(t *testing.T) { + //assert := With(t) + + var n Notifier + + w := n.Wait() + n.Signal() + + select { + case <-w: + default: + t.Fail() + } +}