mirror of
https://github.com/v2fly/v2ray-core.git
synced 2025-01-02 23:47:07 -05:00
simplify classic dns server
This commit is contained in:
parent
4f33540b19
commit
9cfb2bfd51
@ -2,17 +2,9 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
"v2ray.com/core"
|
|
||||||
"v2ray.com/core/common"
|
|
||||||
"v2ray.com/core/common/buf"
|
|
||||||
"v2ray.com/core/common/dice"
|
|
||||||
"v2ray.com/core/common/net"
|
"v2ray.com/core/common/net"
|
||||||
"v2ray.com/core/common/task"
|
|
||||||
"v2ray.com/core/transport/internet/udp"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -29,203 +21,12 @@ type ARecord struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type NameServer interface {
|
type NameServer interface {
|
||||||
QueryA(domain string) <-chan *ARecord
|
QueryIP(ctx context.Context, domain string) ([]net.IP, error)
|
||||||
}
|
|
||||||
|
|
||||||
type PendingRequest struct {
|
|
||||||
expire time.Time
|
|
||||||
response chan<- *ARecord
|
|
||||||
}
|
|
||||||
|
|
||||||
type UDPNameServer struct {
|
|
||||||
sync.Mutex
|
|
||||||
address net.Destination
|
|
||||||
requests map[uint16]*PendingRequest
|
|
||||||
udpServer *udp.Dispatcher
|
|
||||||
cleanup *task.Periodic
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewUDPNameServer(address net.Destination, dispatcher core.Dispatcher) *UDPNameServer {
|
|
||||||
s := &UDPNameServer{
|
|
||||||
address: address,
|
|
||||||
requests: make(map[uint16]*PendingRequest),
|
|
||||||
udpServer: udp.NewDispatcher(dispatcher),
|
|
||||||
}
|
|
||||||
s.cleanup = &task.Periodic{
|
|
||||||
Interval: time.Minute,
|
|
||||||
Execute: s.Cleanup,
|
|
||||||
}
|
|
||||||
common.Must(s.cleanup.Start())
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *UDPNameServer) Cleanup() error {
|
|
||||||
now := time.Now()
|
|
||||||
s.Lock()
|
|
||||||
for id, r := range s.requests {
|
|
||||||
if r.expire.Before(now) {
|
|
||||||
close(r.response)
|
|
||||||
delete(s.requests, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *UDPNameServer) AssignUnusedID(response chan<- *ARecord) uint16 {
|
|
||||||
var id uint16
|
|
||||||
s.Lock()
|
|
||||||
|
|
||||||
for {
|
|
||||||
id = dice.RollUint16()
|
|
||||||
if _, found := s.requests[id]; found {
|
|
||||||
time.Sleep(time.Millisecond * 500)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
newError("add pending request id ", id).AtDebug().WriteToLog()
|
|
||||||
s.requests[id] = &PendingRequest{
|
|
||||||
expire: time.Now().Add(time.Second * 8),
|
|
||||||
response: response,
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
s.Unlock()
|
|
||||||
return id
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *UDPNameServer) HandleResponse(payload *buf.Buffer) {
|
|
||||||
msg := new(dns.Msg)
|
|
||||||
err := msg.Unpack(payload.Bytes())
|
|
||||||
if err == dns.ErrTruncated {
|
|
||||||
newError("truncated message received. DNS server should still work. If you see anything abnormal, please submit an issue to v2ray-core.").AtWarning().WriteToLog()
|
|
||||||
} else if err != nil {
|
|
||||||
newError("failed to parse DNS response").Base(err).AtWarning().WriteToLog()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
record := &ARecord{
|
|
||||||
IPs: make([]net.IP, 0, 16),
|
|
||||||
}
|
|
||||||
id := msg.Id
|
|
||||||
ttl := uint32(3600) // an hour
|
|
||||||
newError("handling response for id ", id, " content: ", msg).AtDebug().WriteToLog()
|
|
||||||
|
|
||||||
s.Lock()
|
|
||||||
request, found := s.requests[id]
|
|
||||||
if !found {
|
|
||||||
s.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
delete(s.requests, id)
|
|
||||||
s.Unlock()
|
|
||||||
|
|
||||||
for _, rr := range msg.Answer {
|
|
||||||
switch rr := rr.(type) {
|
|
||||||
case *dns.A:
|
|
||||||
record.IPs = append(record.IPs, rr.A)
|
|
||||||
if rr.Hdr.Ttl < ttl {
|
|
||||||
ttl = rr.Hdr.Ttl
|
|
||||||
}
|
|
||||||
case *dns.AAAA:
|
|
||||||
record.IPs = append(record.IPs, rr.AAAA)
|
|
||||||
if rr.Hdr.Ttl < ttl {
|
|
||||||
ttl = rr.Hdr.Ttl
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
record.Expire = time.Now().Add(time.Second * time.Duration(ttl))
|
|
||||||
|
|
||||||
request.response <- record
|
|
||||||
close(request.response)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *UDPNameServer) buildAMsg(domain string, id uint16) *dns.Msg {
|
|
||||||
msg := new(dns.Msg)
|
|
||||||
msg.Id = id
|
|
||||||
msg.RecursionDesired = true
|
|
||||||
msg.Question = []dns.Question{
|
|
||||||
{
|
|
||||||
Name: dns.Fqdn(domain),
|
|
||||||
Qtype: dns.TypeA,
|
|
||||||
Qclass: dns.ClassINET,
|
|
||||||
}}
|
|
||||||
if multiQuestionDNS[s.address.Address] {
|
|
||||||
msg.Question = append(msg.Question, dns.Question{
|
|
||||||
Name: dns.Fqdn(domain),
|
|
||||||
Qtype: dns.TypeAAAA,
|
|
||||||
Qclass: dns.ClassINET,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return msg
|
|
||||||
}
|
|
||||||
|
|
||||||
func msgToBuffer(msg *dns.Msg) (*buf.Buffer, error) {
|
|
||||||
buffer := buf.New()
|
|
||||||
if err := buffer.Reset(func(b []byte) (int, error) {
|
|
||||||
writtenBuffer, err := msg.PackBuffer(b)
|
|
||||||
return len(writtenBuffer), err
|
|
||||||
}); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return buffer, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *UDPNameServer) QueryA(domain string) <-chan *ARecord {
|
|
||||||
response := make(chan *ARecord, 1)
|
|
||||||
id := s.AssignUnusedID(response)
|
|
||||||
|
|
||||||
msg := s.buildAMsg(domain, id)
|
|
||||||
b, err := msgToBuffer(msg)
|
|
||||||
if err != nil {
|
|
||||||
newError("failed to build A query for domain ", domain).Base(err).WriteToLog()
|
|
||||||
s.Lock()
|
|
||||||
delete(s.requests, id)
|
|
||||||
s.Unlock()
|
|
||||||
close(response)
|
|
||||||
return response
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
s.udpServer.Dispatch(ctx, s.address, b, s.HandleResponse)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for i := 0; i < 2; i++ {
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
s.Lock()
|
|
||||||
_, found := s.requests[id]
|
|
||||||
s.Unlock()
|
|
||||||
if !found {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
b, _ := msgToBuffer(msg)
|
|
||||||
s.udpServer.Dispatch(ctx, s.address, b, s.HandleResponse)
|
|
||||||
}
|
|
||||||
cancel()
|
|
||||||
}()
|
|
||||||
|
|
||||||
return response
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type LocalNameServer struct {
|
type LocalNameServer struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*LocalNameServer) QueryA(domain string) <-chan *ARecord {
|
func (*LocalNameServer) QueryIP(ctx context.Context, domain string) ([]net.IP, error) {
|
||||||
response := make(chan *ARecord, 1)
|
return net.LookupIP(domain)
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer close(response)
|
|
||||||
|
|
||||||
ips, err := net.LookupIP(domain)
|
|
||||||
if err != nil {
|
|
||||||
newError("failed to lookup IPs for domain ", domain).Base(err).AtWarning().WriteToLog()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
response <- &ARecord{
|
|
||||||
IPs: ips,
|
|
||||||
Expire: time.Now().Add(time.Hour),
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return response
|
|
||||||
}
|
}
|
||||||
|
@ -7,48 +7,24 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
dnsmsg "github.com/miekg/dns"
|
|
||||||
"v2ray.com/core"
|
"v2ray.com/core"
|
||||||
"v2ray.com/core/common"
|
"v2ray.com/core/common"
|
||||||
"v2ray.com/core/common/net"
|
"v2ray.com/core/common/net"
|
||||||
"v2ray.com/core/common/task"
|
"v2ray.com/core/common/task"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
QueryTimeout = time.Second * 8
|
|
||||||
)
|
|
||||||
|
|
||||||
type DomainRecord struct {
|
|
||||||
IP []net.IP
|
|
||||||
Expire time.Time
|
|
||||||
LastAccess time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *DomainRecord) Expired() bool {
|
|
||||||
return r.Expire.Before(time.Now())
|
|
||||||
}
|
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
hosts map[string]net.IP
|
hosts map[string]net.IP
|
||||||
records map[string]*DomainRecord
|
|
||||||
servers []NameServer
|
servers []NameServer
|
||||||
task *task.Periodic
|
task *task.Periodic
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(ctx context.Context, config *Config) (*Server, error) {
|
func New(ctx context.Context, config *Config) (*Server, error) {
|
||||||
server := &Server{
|
server := &Server{
|
||||||
records: make(map[string]*DomainRecord),
|
|
||||||
servers: make([]NameServer, len(config.NameServers)),
|
servers: make([]NameServer, len(config.NameServers)),
|
||||||
hosts: config.GetInternalHosts(),
|
hosts: config.GetInternalHosts(),
|
||||||
}
|
}
|
||||||
server.task = &task.Periodic{
|
|
||||||
Interval: time.Minute * 10,
|
|
||||||
Execute: func() error {
|
|
||||||
server.cleanup()
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
v := core.MustFromContext(ctx)
|
v := core.MustFromContext(ctx)
|
||||||
if err := v.RegisterFeature((*core.DNSClient)(nil), server); err != nil {
|
if err := v.RegisterFeature((*core.DNSClient)(nil), server); err != nil {
|
||||||
return nil, newError("unable to register DNSClient.").Base(err)
|
return nil, newError("unable to register DNSClient.").Base(err)
|
||||||
@ -64,7 +40,7 @@ func New(ctx context.Context, config *Config) (*Server, error) {
|
|||||||
dest.Network = net.Network_UDP
|
dest.Network = net.Network_UDP
|
||||||
}
|
}
|
||||||
if dest.Network == net.Network_UDP {
|
if dest.Network == net.Network_UDP {
|
||||||
server.servers[idx] = NewUDPNameServer(dest, v.Dispatcher())
|
server.servers[idx] = NewClassicNameServer(dest, v.Dispatcher())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -85,64 +61,25 @@ func (s *Server) Close() error {
|
|||||||
return s.task.Close()
|
return s.task.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) GetCached(domain string) []net.IP {
|
|
||||||
s.Lock()
|
|
||||||
defer s.Unlock()
|
|
||||||
|
|
||||||
if record, found := s.records[domain]; found && !record.Expired() {
|
|
||||||
record.LastAccess = time.Now()
|
|
||||||
return record.IP
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) cleanup() {
|
|
||||||
s.Lock()
|
|
||||||
defer s.Unlock()
|
|
||||||
|
|
||||||
for d, r := range s.records {
|
|
||||||
if r.Expired() {
|
|
||||||
delete(s.records, d)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(s.records) == 0 {
|
|
||||||
s.records = make(map[string]*DomainRecord)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) LookupIP(domain string) ([]net.IP, error) {
|
func (s *Server) LookupIP(domain string) ([]net.IP, error) {
|
||||||
if ip, found := s.hosts[domain]; found {
|
if ip, found := s.hosts[domain]; found {
|
||||||
return []net.IP{ip}, nil
|
return []net.IP{ip}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
domain = dnsmsg.Fqdn(domain)
|
var lastErr error
|
||||||
ips := s.GetCached(domain)
|
|
||||||
if ips != nil {
|
|
||||||
return ips, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, server := range s.servers {
|
for _, server := range s.servers {
|
||||||
response := server.QueryA(domain)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*4)
|
||||||
select {
|
ips, err := server.QueryIP(ctx, domain)
|
||||||
case a, open := <-response:
|
cancel()
|
||||||
if !open || a == nil {
|
if err != nil {
|
||||||
continue
|
lastErr = err
|
||||||
}
|
}
|
||||||
s.Lock()
|
if len(ips) > 0 {
|
||||||
s.records[domain] = &DomainRecord{
|
return ips, nil
|
||||||
IP: a.IPs,
|
|
||||||
Expire: a.Expire,
|
|
||||||
LastAccess: time.Now(),
|
|
||||||
}
|
|
||||||
s.Unlock()
|
|
||||||
newError("returning ", len(a.IPs), " IPs for domain ", domain).AtDebug().WriteToLog()
|
|
||||||
return a.IPs, nil
|
|
||||||
case <-time.After(QueryTimeout):
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, newError("returning nil for domain ", domain)
|
return nil, newError("returning nil for domain ", domain).Base(lastErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
229
app/dns/udpns.go
Normal file
229
app/dns/udpns.go
Normal file
@ -0,0 +1,229 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"v2ray.com/core"
|
||||||
|
"v2ray.com/core/common"
|
||||||
|
"v2ray.com/core/common/buf"
|
||||||
|
"v2ray.com/core/common/net"
|
||||||
|
"v2ray.com/core/common/signal"
|
||||||
|
"v2ray.com/core/common/task"
|
||||||
|
"v2ray.com/core/transport/internet/udp"
|
||||||
|
)
|
||||||
|
|
||||||
|
type IPRecord struct {
|
||||||
|
IP net.IP
|
||||||
|
Expire time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClassicNameServer struct {
|
||||||
|
sync.RWMutex
|
||||||
|
address net.Destination
|
||||||
|
ips map[string][]IPRecord
|
||||||
|
updated signal.Notifier
|
||||||
|
udpServer *udp.Dispatcher
|
||||||
|
cleanup *task.Periodic
|
||||||
|
reqID uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClassicNameServer(address net.Destination, dispatcher core.Dispatcher) *ClassicNameServer {
|
||||||
|
s := &ClassicNameServer{
|
||||||
|
address: address,
|
||||||
|
ips: make(map[string][]IPRecord),
|
||||||
|
udpServer: udp.NewDispatcher(dispatcher),
|
||||||
|
}
|
||||||
|
s.cleanup = &task.Periodic{
|
||||||
|
Interval: time.Minute,
|
||||||
|
Execute: s.Cleanup,
|
||||||
|
}
|
||||||
|
common.Must(s.cleanup.Start())
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ClassicNameServer) Cleanup() error {
|
||||||
|
now := time.Now()
|
||||||
|
s.Lock()
|
||||||
|
for domain, ips := range s.ips {
|
||||||
|
newIPs := make([]IPRecord, 0, len(ips))
|
||||||
|
for _, ip := range ips {
|
||||||
|
if ip.Expire.After(now) {
|
||||||
|
newIPs = append(newIPs, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(newIPs) == 0 {
|
||||||
|
delete(s.ips, domain)
|
||||||
|
} else if len(newIPs) < len(ips) {
|
||||||
|
s.ips[domain] = newIPs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(s.ips) == 0 {
|
||||||
|
s.ips = make(map[string][]IPRecord)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ClassicNameServer) HandleResponse(payload *buf.Buffer) {
|
||||||
|
msg := new(dns.Msg)
|
||||||
|
err := msg.Unpack(payload.Bytes())
|
||||||
|
if err == dns.ErrTruncated {
|
||||||
|
newError("truncated message received. DNS server should still work. If you see anything abnormal, please submit an issue to v2ray-core.").AtWarning().WriteToLog()
|
||||||
|
} else if err != nil {
|
||||||
|
newError("failed to parse DNS response").Base(err).AtWarning().WriteToLog()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var domain string
|
||||||
|
ips := make([]IPRecord, 0, 16)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for _, rr := range msg.Answer {
|
||||||
|
var ip net.IP
|
||||||
|
domain = rr.Header().Name
|
||||||
|
ttl := rr.Header().Ttl
|
||||||
|
switch rr := rr.(type) {
|
||||||
|
case *dns.A:
|
||||||
|
ip = rr.A
|
||||||
|
case *dns.AAAA:
|
||||||
|
ip = rr.AAAA
|
||||||
|
}
|
||||||
|
if ttl == 0 {
|
||||||
|
ttl = 300
|
||||||
|
}
|
||||||
|
if len(ip) > 0 {
|
||||||
|
ips = append(ips, IPRecord{
|
||||||
|
IP: ip,
|
||||||
|
Expire: now.Add(time.Second * time.Duration(ttl)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(domain) > 0 && len(ips) > 0 {
|
||||||
|
s.updateIP(domain, ips)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ClassicNameServer) updateIP(domain string, ips []IPRecord) {
|
||||||
|
s.Lock()
|
||||||
|
defer s.Unlock()
|
||||||
|
|
||||||
|
newError("updating IP records for domain:", domain).AtDebug().WriteToLog()
|
||||||
|
now := time.Now()
|
||||||
|
eips := s.ips[domain]
|
||||||
|
for _, ip := range eips {
|
||||||
|
if ip.Expire.After(now) {
|
||||||
|
ips = append(ips, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.ips[domain] = ips
|
||||||
|
s.updated.Signal()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ClassicNameServer) buildMsgs(domain string) []*dns.Msg {
|
||||||
|
allowMulti := multiQuestionDNS[s.address.Address]
|
||||||
|
|
||||||
|
var msgs []*dns.Msg
|
||||||
|
|
||||||
|
{
|
||||||
|
msg := new(dns.Msg)
|
||||||
|
msg.Id = uint16(atomic.AddUint32(&s.reqID, 1))
|
||||||
|
msg.RecursionDesired = true
|
||||||
|
msg.Question = []dns.Question{
|
||||||
|
{
|
||||||
|
Name: domain,
|
||||||
|
Qtype: dns.TypeA,
|
||||||
|
Qclass: dns.ClassINET,
|
||||||
|
}}
|
||||||
|
if allowMulti {
|
||||||
|
msg.Question = append(msg.Question, dns.Question{
|
||||||
|
Name: domain,
|
||||||
|
Qtype: dns.TypeAAAA,
|
||||||
|
Qclass: dns.ClassINET,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
msgs = append(msgs, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !allowMulti {
|
||||||
|
msg := new(dns.Msg)
|
||||||
|
msg.Id = uint16(atomic.AddUint32(&s.reqID, 1))
|
||||||
|
msg.RecursionDesired = true
|
||||||
|
msg.Question = []dns.Question{
|
||||||
|
{
|
||||||
|
Name: domain,
|
||||||
|
Qtype: dns.TypeAAAA,
|
||||||
|
Qclass: dns.ClassINET,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
msgs = append(msgs, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return msgs
|
||||||
|
}
|
||||||
|
|
||||||
|
func msgToBuffer(msg *dns.Msg) (*buf.Buffer, error) {
|
||||||
|
buffer := buf.New()
|
||||||
|
if err := buffer.Reset(func(b []byte) (int, error) {
|
||||||
|
writtenBuffer, err := msg.PackBuffer(b)
|
||||||
|
return len(writtenBuffer), err
|
||||||
|
}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return buffer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string) {
|
||||||
|
msgs := s.buildMsgs(domain)
|
||||||
|
|
||||||
|
for _, msg := range msgs {
|
||||||
|
b, err := msgToBuffer(msg)
|
||||||
|
common.Must(err)
|
||||||
|
s.udpServer.Dispatch(ctx, s.address, b, s.HandleResponse)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ClassicNameServer) findIPsForDomain(domain string) []net.IP {
|
||||||
|
records, found := s.ips[domain]
|
||||||
|
if found && len(records) > 0 {
|
||||||
|
var ips []net.IP
|
||||||
|
now := time.Now()
|
||||||
|
for _, rec := range records {
|
||||||
|
if rec.Expire.After(now) {
|
||||||
|
ips = append(ips, rec.IP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ips
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string) ([]net.IP, error) {
|
||||||
|
fqdn := dns.Fqdn(domain)
|
||||||
|
|
||||||
|
ips := s.findIPsForDomain(fqdn)
|
||||||
|
if len(ips) > 0 {
|
||||||
|
return ips, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sendQuery(ctx, fqdn)
|
||||||
|
|
||||||
|
for {
|
||||||
|
ips := s.findIPsForDomain(fqdn)
|
||||||
|
if len(ips) > 0 {
|
||||||
|
return ips, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-s.updated.Wait():
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,26 +1,34 @@
|
|||||||
package signal
|
package signal
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
// Notifier is a utility for notifying changes. The change producer may notify changes multiple time, and the consumer may get notified asynchronously.
|
// Notifier is a utility for notifying changes. The change producer may notify changes multiple time, and the consumer may get notified asynchronously.
|
||||||
type Notifier struct {
|
type Notifier struct {
|
||||||
c chan struct{}
|
sync.Mutex
|
||||||
|
waiters []chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewNotifier creates a new Notifier.
|
// NewNotifier creates a new Notifier.
|
||||||
func NewNotifier() *Notifier {
|
func NewNotifier() *Notifier {
|
||||||
return &Notifier{
|
return &Notifier{}
|
||||||
c: make(chan struct{}, 1),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Signal signals a change, usually by producer. This method never blocks.
|
// Signal signals a change, usually by producer. This method never blocks.
|
||||||
func (n *Notifier) Signal() {
|
func (n *Notifier) Signal() {
|
||||||
select {
|
n.Lock()
|
||||||
case n.c <- struct{}{}:
|
for _, w := range n.waiters {
|
||||||
default:
|
close(w)
|
||||||
}
|
}
|
||||||
|
n.waiters = make([]chan struct{}, 0, 8)
|
||||||
|
n.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait returns a channel for waiting for changes. The returned channel never gets closed.
|
// Wait returns a channel for waiting for changes. The returned channel never gets closed.
|
||||||
func (n *Notifier) Wait() <-chan struct{} {
|
func (n *Notifier) Wait() <-chan struct{} {
|
||||||
return n.c
|
n.Lock()
|
||||||
|
defer n.Unlock()
|
||||||
|
|
||||||
|
w := make(chan struct{})
|
||||||
|
n.waiters = append(n.waiters, w)
|
||||||
|
return w
|
||||||
}
|
}
|
||||||
|
23
common/signal/notifier_test.go
Normal file
23
common/signal/notifier_test.go
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
package signal_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "v2ray.com/core/common/signal"
|
||||||
|
//. "v2ray.com/ext/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNotifierSignal(t *testing.T) {
|
||||||
|
//assert := With(t)
|
||||||
|
|
||||||
|
var n Notifier
|
||||||
|
|
||||||
|
w := n.Wait()
|
||||||
|
n.Signal()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-w:
|
||||||
|
default:
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user