mirror of
https://github.com/v2fly/v2ray-core.git
synced 2024-11-16 09:26:21 -05:00
454 lines
9.5 KiB
Go
454 lines
9.5 KiB
Go
// +build !confonly
|
|
|
|
package dns
|
|
|
|
import (
|
|
"context"
|
|
"encoding/binary"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"golang.org/x/net/dns/dnsmessage"
|
|
"v2ray.com/core/common"
|
|
"v2ray.com/core/common/errors"
|
|
"v2ray.com/core/common/net"
|
|
"v2ray.com/core/common/protocol/dns"
|
|
udp_proto "v2ray.com/core/common/protocol/udp"
|
|
"v2ray.com/core/common/session"
|
|
"v2ray.com/core/common/signal/pubsub"
|
|
"v2ray.com/core/common/task"
|
|
dns_feature "v2ray.com/core/features/dns"
|
|
"v2ray.com/core/features/routing"
|
|
"v2ray.com/core/transport/internet/udp"
|
|
)
|
|
|
|
type record struct {
|
|
A *IPRecord
|
|
AAAA *IPRecord
|
|
}
|
|
|
|
type IPRecord struct {
|
|
IP []net.Address
|
|
Expire time.Time
|
|
RCode dnsmessage.RCode
|
|
}
|
|
|
|
func (r *IPRecord) getIPs() ([]net.Address, error) {
|
|
if r == nil || r.Expire.Before(time.Now()) {
|
|
return nil, errRecordNotFound
|
|
}
|
|
if r.RCode != dnsmessage.RCodeSuccess {
|
|
return nil, dns_feature.RCodeError(r.RCode)
|
|
}
|
|
return r.IP, nil
|
|
}
|
|
|
|
type pendingRequest struct {
|
|
domain string
|
|
expire time.Time
|
|
recType dnsmessage.Type
|
|
}
|
|
|
|
var (
|
|
errRecordNotFound = errors.New("record not found")
|
|
)
|
|
|
|
type ClassicNameServer struct {
|
|
sync.RWMutex
|
|
address net.Destination
|
|
ips map[string]record
|
|
requests map[uint16]pendingRequest
|
|
pub *pubsub.Service
|
|
udpServer *udp.Dispatcher
|
|
cleanup *task.Periodic
|
|
reqID uint32
|
|
clientIP net.IP
|
|
}
|
|
|
|
func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, clientIP net.IP) *ClassicNameServer {
|
|
s := &ClassicNameServer{
|
|
address: address,
|
|
ips: make(map[string]record),
|
|
requests: make(map[uint16]pendingRequest),
|
|
clientIP: clientIP,
|
|
pub: pubsub.NewService(),
|
|
}
|
|
s.cleanup = &task.Periodic{
|
|
Interval: time.Minute,
|
|
Execute: s.Cleanup,
|
|
}
|
|
s.udpServer = udp.NewDispatcher(dispatcher, s.HandleResponse)
|
|
return s
|
|
}
|
|
|
|
func (s *ClassicNameServer) Name() string {
|
|
return s.address.String()
|
|
}
|
|
|
|
func (s *ClassicNameServer) Cleanup() error {
|
|
now := time.Now()
|
|
s.Lock()
|
|
defer s.Unlock()
|
|
|
|
if len(s.ips) == 0 && len(s.requests) == 0 {
|
|
return newError("nothing to do. stopping...")
|
|
}
|
|
|
|
for domain, record := range s.ips {
|
|
if record.A != nil && record.A.Expire.Before(now) {
|
|
record.A = nil
|
|
}
|
|
if record.AAAA != nil && record.AAAA.Expire.Before(now) {
|
|
record.AAAA = nil
|
|
}
|
|
|
|
if record.A == nil && record.AAAA == nil {
|
|
delete(s.ips, domain)
|
|
} else {
|
|
s.ips[domain] = record
|
|
}
|
|
}
|
|
|
|
if len(s.ips) == 0 {
|
|
s.ips = make(map[string]record)
|
|
}
|
|
|
|
for id, req := range s.requests {
|
|
if req.expire.Before(now) {
|
|
delete(s.requests, id)
|
|
}
|
|
}
|
|
|
|
if len(s.requests) == 0 {
|
|
s.requests = make(map[uint16]pendingRequest)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) {
|
|
payload := packet.Payload
|
|
|
|
var parser dnsmessage.Parser
|
|
header, err := parser.Start(payload.Bytes())
|
|
if err != nil {
|
|
newError("failed to parse DNS response").Base(err).AtWarning().WriteToLog()
|
|
return
|
|
}
|
|
if err := parser.SkipAllQuestions(); err != nil {
|
|
newError("failed to skip questions in DNS response").Base(err).AtWarning().WriteToLog()
|
|
return
|
|
}
|
|
|
|
id := header.ID
|
|
s.Lock()
|
|
req, f := s.requests[id]
|
|
if f {
|
|
delete(s.requests, id)
|
|
}
|
|
s.Unlock()
|
|
|
|
if !f {
|
|
return
|
|
}
|
|
|
|
domain := req.domain
|
|
recType := req.recType
|
|
|
|
now := time.Now()
|
|
ipRecord := &IPRecord{
|
|
RCode: header.RCode,
|
|
Expire: now.Add(time.Second * 600),
|
|
}
|
|
|
|
L:
|
|
for {
|
|
header, err := parser.AnswerHeader()
|
|
if err != nil {
|
|
if err != dnsmessage.ErrSectionDone {
|
|
newError("failed to parse answer section for domain: ", domain).Base(err).WriteToLog()
|
|
}
|
|
break
|
|
}
|
|
ttl := header.TTL
|
|
if ttl == 0 {
|
|
ttl = 600
|
|
}
|
|
expire := now.Add(time.Duration(ttl) * time.Second)
|
|
if ipRecord.Expire.After(expire) {
|
|
ipRecord.Expire = expire
|
|
}
|
|
|
|
if header.Type != recType {
|
|
if err := parser.SkipAnswer(); err != nil {
|
|
newError("failed to skip answer").Base(err).WriteToLog()
|
|
break L
|
|
}
|
|
continue
|
|
}
|
|
|
|
switch header.Type {
|
|
case dnsmessage.TypeA:
|
|
ans, err := parser.AResource()
|
|
if err != nil {
|
|
newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
|
|
break L
|
|
}
|
|
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
|
|
case dnsmessage.TypeAAAA:
|
|
ans, err := parser.AAAAResource()
|
|
if err != nil {
|
|
newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
|
|
break L
|
|
}
|
|
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
|
|
default:
|
|
if err := parser.SkipAnswer(); err != nil {
|
|
newError("failed to skip answer").Base(err).WriteToLog()
|
|
break L
|
|
}
|
|
}
|
|
}
|
|
|
|
var rec record
|
|
switch recType {
|
|
case dnsmessage.TypeA:
|
|
rec.A = ipRecord
|
|
case dnsmessage.TypeAAAA:
|
|
rec.AAAA = ipRecord
|
|
}
|
|
|
|
if len(domain) > 0 && (rec.A != nil || rec.AAAA != nil) {
|
|
s.updateIP(domain, rec)
|
|
}
|
|
}
|
|
|
|
func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
|
|
if newRec == nil {
|
|
return false
|
|
}
|
|
if baseRec == nil {
|
|
return true
|
|
}
|
|
return baseRec.Expire.Before(newRec.Expire)
|
|
}
|
|
|
|
func (s *ClassicNameServer) updateIP(domain string, newRec record) {
|
|
s.Lock()
|
|
|
|
newError("updating IP records for domain:", domain).AtDebug().WriteToLog()
|
|
rec := s.ips[domain]
|
|
|
|
updated := false
|
|
if isNewer(rec.A, newRec.A) {
|
|
rec.A = newRec.A
|
|
updated = true
|
|
}
|
|
if isNewer(rec.AAAA, newRec.AAAA) {
|
|
rec.AAAA = newRec.AAAA
|
|
updated = true
|
|
}
|
|
|
|
if updated {
|
|
s.ips[domain] = rec
|
|
s.pub.Publish(domain, nil)
|
|
}
|
|
|
|
s.Unlock()
|
|
common.Must(s.cleanup.Start())
|
|
}
|
|
|
|
func (s *ClassicNameServer) getMsgOptions() *dnsmessage.Resource {
|
|
if len(s.clientIP) == 0 {
|
|
return nil
|
|
}
|
|
|
|
var netmask int
|
|
var family uint16
|
|
|
|
if len(s.clientIP) == 4 {
|
|
family = 1
|
|
netmask = 24 // 24 for IPV4, 96 for IPv6
|
|
} else {
|
|
family = 2
|
|
netmask = 96
|
|
}
|
|
|
|
b := make([]byte, 4)
|
|
binary.BigEndian.PutUint16(b[0:], family)
|
|
b[2] = byte(netmask)
|
|
b[3] = 0
|
|
switch family {
|
|
case 1:
|
|
ip := s.clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8))
|
|
needLength := (netmask + 8 - 1) / 8 // division rounding up
|
|
b = append(b, ip[:needLength]...)
|
|
case 2:
|
|
ip := s.clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8))
|
|
needLength := (netmask + 8 - 1) / 8 // division rounding up
|
|
b = append(b, ip[:needLength]...)
|
|
}
|
|
|
|
const EDNS0SUBNET = 0x08
|
|
|
|
opt := new(dnsmessage.Resource)
|
|
common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
|
|
|
|
opt.Body = &dnsmessage.OPTResource{
|
|
Options: []dnsmessage.Option{
|
|
{
|
|
Code: EDNS0SUBNET,
|
|
Data: b,
|
|
},
|
|
},
|
|
}
|
|
|
|
return opt
|
|
}
|
|
|
|
func (s *ClassicNameServer) addPendingRequest(domain string, recType dnsmessage.Type) uint16 {
|
|
id := uint16(atomic.AddUint32(&s.reqID, 1))
|
|
s.Lock()
|
|
defer s.Unlock()
|
|
|
|
s.requests[id] = pendingRequest{
|
|
domain: domain,
|
|
expire: time.Now().Add(time.Second * 8),
|
|
recType: recType,
|
|
}
|
|
|
|
return id
|
|
}
|
|
|
|
func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmessage.Message {
|
|
qA := dnsmessage.Question{
|
|
Name: dnsmessage.MustNewName(domain),
|
|
Type: dnsmessage.TypeA,
|
|
Class: dnsmessage.ClassINET,
|
|
}
|
|
|
|
qAAAA := dnsmessage.Question{
|
|
Name: dnsmessage.MustNewName(domain),
|
|
Type: dnsmessage.TypeAAAA,
|
|
Class: dnsmessage.ClassINET,
|
|
}
|
|
|
|
var msgs []*dnsmessage.Message
|
|
|
|
if option.IPv4Enable {
|
|
msg := new(dnsmessage.Message)
|
|
msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeA)
|
|
msg.Header.RecursionDesired = true
|
|
msg.Questions = []dnsmessage.Question{qA}
|
|
if opt := s.getMsgOptions(); opt != nil {
|
|
msg.Additionals = append(msg.Additionals, *opt)
|
|
}
|
|
msgs = append(msgs, msg)
|
|
}
|
|
|
|
if option.IPv6Enable {
|
|
msg := new(dnsmessage.Message)
|
|
msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeAAAA)
|
|
msg.Header.RecursionDesired = true
|
|
msg.Questions = []dnsmessage.Question{qAAAA}
|
|
if opt := s.getMsgOptions(); opt != nil {
|
|
msg.Additionals = append(msg.Additionals, *opt)
|
|
}
|
|
msgs = append(msgs, msg)
|
|
}
|
|
|
|
return msgs
|
|
}
|
|
|
|
func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, option IPOption) {
|
|
newError("querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx))
|
|
|
|
msgs := s.buildMsgs(domain, option)
|
|
|
|
for _, msg := range msgs {
|
|
b, _ := dns.PackMessage(msg)
|
|
|
|
udpCtx := context.Background()
|
|
if inbound := session.InboundFromContext(ctx); inbound != nil {
|
|
udpCtx = session.ContextWithInbound(udpCtx, inbound)
|
|
}
|
|
udpCtx = session.ContextWithContent(udpCtx, &session.Content{
|
|
Protocol: "dns",
|
|
})
|
|
s.udpServer.Dispatch(udpCtx, s.address, b)
|
|
}
|
|
}
|
|
|
|
func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) ([]net.IP, error) {
|
|
s.RLock()
|
|
record, found := s.ips[domain]
|
|
s.RUnlock()
|
|
|
|
if !found {
|
|
return nil, errRecordNotFound
|
|
}
|
|
|
|
var ips []net.Address
|
|
var lastErr error
|
|
if option.IPv4Enable {
|
|
a, err := record.A.getIPs()
|
|
if err != nil {
|
|
lastErr = err
|
|
}
|
|
ips = append(ips, a...)
|
|
}
|
|
|
|
if option.IPv6Enable {
|
|
aaaa, err := record.AAAA.getIPs()
|
|
if err != nil {
|
|
lastErr = err
|
|
}
|
|
ips = append(ips, aaaa...)
|
|
}
|
|
|
|
if len(ips) > 0 {
|
|
return toNetIP(ips), nil
|
|
}
|
|
|
|
if lastErr != nil {
|
|
return nil, lastErr
|
|
}
|
|
|
|
return nil, dns_feature.ErrEmptyResponse
|
|
}
|
|
|
|
func Fqdn(domain string) string {
|
|
if len(domain) > 0 && domain[len(domain)-1] == '.' {
|
|
return domain
|
|
}
|
|
return domain + "."
|
|
}
|
|
|
|
func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) {
|
|
fqdn := Fqdn(domain)
|
|
|
|
ips, err := s.findIPsForDomain(fqdn, option)
|
|
if err != errRecordNotFound {
|
|
return ips, err
|
|
}
|
|
|
|
sub := s.pub.Subscribe(fqdn)
|
|
defer sub.Close()
|
|
|
|
s.sendQuery(ctx, fqdn, option)
|
|
|
|
for {
|
|
ips, err := s.findIPsForDomain(fqdn, option)
|
|
if err != errRecordNotFound {
|
|
return ips, err
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-sub.Wait():
|
|
}
|
|
}
|
|
}
|