2018-06-26 09:04:47 -04:00
package dns
import (
"context"
"sync"
"sync/atomic"
"time"
"github.com/miekg/dns"
"v2ray.com/core"
"v2ray.com/core/common"
"v2ray.com/core/common/buf"
"v2ray.com/core/common/net"
2018-07-01 06:38:40 -04:00
"v2ray.com/core/common/signal/pubsub"
2018-06-26 09:04:47 -04:00
"v2ray.com/core/common/task"
"v2ray.com/core/transport/internet/udp"
)
2018-06-26 09:16:45 -04:00
var (
multiQuestionDNS = map [ net . Address ] bool {
net . IPAddress ( [ ] byte { 8 , 8 , 8 , 8 } ) : true ,
net . IPAddress ( [ ] byte { 8 , 8 , 4 , 4 } ) : true ,
net . IPAddress ( [ ] byte { 9 , 9 , 9 , 9 } ) : true ,
}
)
2018-06-26 09:04:47 -04:00
type IPRecord struct {
IP net . IP
Expire time . Time
}
2018-07-01 11:15:29 -04:00
type pendingRequest struct {
domain string
expire time . Time
}
2018-06-26 09:04:47 -04:00
type ClassicNameServer struct {
sync . RWMutex
address net . Destination
ips map [ string ] [ ] IPRecord
2018-07-01 11:15:29 -04:00
requests map [ uint16 ] pendingRequest
2018-07-01 06:38:40 -04:00
pub * pubsub . Service
2018-06-26 09:04:47 -04:00
udpServer * udp . Dispatcher
cleanup * task . Periodic
reqID uint32
2018-06-26 17:23:59 -04:00
clientIP net . IP
2018-06-26 09:04:47 -04:00
}
2018-06-26 17:23:59 -04:00
func NewClassicNameServer ( address net . Destination , dispatcher core . Dispatcher , clientIP net . IP ) * ClassicNameServer {
2018-06-26 09:04:47 -04:00
s := & ClassicNameServer {
2018-07-03 15:38:02 -04:00
address : address ,
ips : make ( map [ string ] [ ] IPRecord ) ,
requests : make ( map [ uint16 ] pendingRequest ) ,
clientIP : clientIP ,
pub : pubsub . NewService ( ) ,
2018-06-26 09:04:47 -04:00
}
s . cleanup = & task . Periodic {
Interval : time . Minute ,
Execute : s . Cleanup ,
}
2018-07-03 15:38:02 -04:00
s . udpServer = udp . NewDispatcher ( dispatcher , s . HandleResponse )
2018-06-26 09:04:47 -04:00
common . Must ( s . cleanup . Start ( ) )
return s
}
func ( s * ClassicNameServer ) Cleanup ( ) error {
now := time . Now ( )
s . Lock ( )
for domain , ips := range s . ips {
newIPs := make ( [ ] IPRecord , 0 , len ( ips ) )
for _ , ip := range ips {
if ip . Expire . After ( now ) {
newIPs = append ( newIPs , ip )
}
}
if len ( newIPs ) == 0 {
delete ( s . ips , domain )
} else if len ( newIPs ) < len ( ips ) {
s . ips [ domain ] = newIPs
}
}
if len ( s . ips ) == 0 {
s . ips = make ( map [ string ] [ ] IPRecord )
}
2018-07-01 11:15:29 -04:00
for id , req := range s . requests {
if req . expire . Before ( now ) {
delete ( s . requests , id )
}
}
if len ( s . requests ) == 0 {
s . requests = make ( map [ uint16 ] pendingRequest )
}
2018-06-26 09:04:47 -04:00
s . Unlock ( )
return nil
}
2018-07-03 15:38:02 -04:00
func ( s * ClassicNameServer ) HandleResponse ( ctx context . Context , payload * buf . Buffer ) {
2018-06-26 09:04:47 -04:00
msg := new ( dns . Msg )
err := msg . Unpack ( payload . Bytes ( ) )
if err == dns . ErrTruncated {
newError ( "truncated message received. DNS server should still work. If you see anything abnormal, please submit an issue to v2ray-core." ) . AtWarning ( ) . WriteToLog ( )
} else if err != nil {
newError ( "failed to parse DNS response" ) . Base ( err ) . AtWarning ( ) . WriteToLog ( )
return
}
2018-07-01 11:15:29 -04:00
id := msg . Id
s . Lock ( )
req , f := s . requests [ id ]
if f {
delete ( s . requests , id )
}
s . Unlock ( )
if ! f {
return
}
domain := req . domain
2018-06-26 09:04:47 -04:00
ips := make ( [ ] IPRecord , 0 , 16 )
now := time . Now ( )
for _ , rr := range msg . Answer {
var ip net . IP
ttl := rr . Header ( ) . Ttl
switch rr := rr . ( type ) {
case * dns . A :
ip = rr . A
case * dns . AAAA :
ip = rr . AAAA
}
if ttl == 0 {
2018-07-01 06:38:40 -04:00
ttl = 600
2018-06-26 09:04:47 -04:00
}
if len ( ip ) > 0 {
ips = append ( ips , IPRecord {
IP : ip ,
Expire : now . Add ( time . Second * time . Duration ( ttl ) ) ,
} )
}
}
if len ( domain ) > 0 && len ( ips ) > 0 {
s . updateIP ( domain , ips )
}
}
func ( s * ClassicNameServer ) updateIP ( domain string , ips [ ] IPRecord ) {
s . Lock ( )
defer s . Unlock ( )
newError ( "updating IP records for domain:" , domain ) . AtDebug ( ) . WriteToLog ( )
now := time . Now ( )
eips := s . ips [ domain ]
for _ , ip := range eips {
if ip . Expire . After ( now ) {
ips = append ( ips , ip )
}
}
s . ips [ domain ] = ips
2018-07-01 06:38:40 -04:00
s . pub . Publish ( domain , nil )
2018-06-26 09:04:47 -04:00
}
2018-06-26 11:14:51 -04:00
func ( s * ClassicNameServer ) getMsgOptions ( ) * dns . OPT {
2018-06-26 17:23:59 -04:00
if len ( s . clientIP ) == 0 {
2018-06-26 11:14:51 -04:00
return nil
}
o := new ( dns . OPT )
o . Hdr . Name = "."
o . Hdr . Rrtype = dns . TypeOPT
2018-07-02 16:22:04 -04:00
o . SetUDPSize ( 1350 )
2018-06-26 11:14:51 -04:00
2018-06-26 17:23:59 -04:00
e := new ( dns . EDNS0_SUBNET )
e . Code = dns . EDNS0SUBNET
if len ( s . clientIP ) == 4 {
2018-07-02 16:22:04 -04:00
e . Family = 1 // 1 for IPv4 source address, 2 for IPv6
e . SourceNetmask = 24 // 32 for IPV4, 128 for IPv6
2018-06-26 17:23:59 -04:00
} else {
e . Family = 2
2018-07-02 16:22:04 -04:00
e . SourceNetmask = 96
2018-06-26 11:14:51 -04:00
}
2018-06-26 17:23:59 -04:00
e . SourceScope = 0
e . Address = s . clientIP
o . Option = append ( o . Option , e )
2018-06-26 11:14:51 -04:00
return o
2018-07-01 11:15:29 -04:00
}
func ( s * ClassicNameServer ) addPendingRequest ( domain string ) uint16 {
id := uint16 ( atomic . AddUint32 ( & s . reqID , 1 ) )
s . Lock ( )
defer s . Unlock ( )
s . requests [ id ] = pendingRequest {
domain : domain ,
expire : time . Now ( ) . Add ( time . Second * 8 ) ,
}
2018-06-26 11:14:51 -04:00
2018-07-01 11:15:29 -04:00
return id
2018-06-26 11:14:51 -04:00
}
2018-06-26 09:04:47 -04:00
func ( s * ClassicNameServer ) buildMsgs ( domain string ) [ ] * dns . Msg {
allowMulti := multiQuestionDNS [ s . address . Address ]
2018-06-26 11:14:51 -04:00
qA := dns . Question {
Name : domain ,
Qtype : dns . TypeA ,
Qclass : dns . ClassINET ,
}
qAAAA := dns . Question {
Name : domain ,
Qtype : dns . TypeAAAA ,
Qclass : dns . ClassINET ,
}
2018-06-26 09:04:47 -04:00
var msgs [ ] * dns . Msg
{
msg := new ( dns . Msg )
2018-07-01 11:15:29 -04:00
msg . Id = s . addPendingRequest ( domain )
2018-06-26 09:04:47 -04:00
msg . RecursionDesired = true
2018-06-26 11:14:51 -04:00
msg . Question = [ ] dns . Question { qA }
2018-06-26 09:04:47 -04:00
if allowMulti {
2018-06-26 11:14:51 -04:00
msg . Question = append ( msg . Question , qAAAA )
}
if opt := s . getMsgOptions ( ) ; opt != nil {
msg . Extra = append ( msg . Extra , opt )
2018-06-26 09:04:47 -04:00
}
msgs = append ( msgs , msg )
}
if ! allowMulti {
msg := new ( dns . Msg )
2018-07-01 11:15:29 -04:00
msg . Id = s . addPendingRequest ( domain )
2018-06-26 09:04:47 -04:00
msg . RecursionDesired = true
2018-06-26 11:14:51 -04:00
msg . Question = [ ] dns . Question { qAAAA }
if opt := s . getMsgOptions ( ) ; opt != nil {
msg . Extra = append ( msg . Extra , opt )
2018-06-26 09:04:47 -04:00
}
msgs = append ( msgs , msg )
}
return msgs
}
func msgToBuffer ( msg * dns . Msg ) ( * buf . Buffer , error ) {
buffer := buf . New ( )
if err := buffer . Reset ( func ( b [ ] byte ) ( int , error ) {
writtenBuffer , err := msg . PackBuffer ( b )
return len ( writtenBuffer ) , err
} ) ; err != nil {
return nil , err
}
return buffer , nil
}
func ( s * ClassicNameServer ) sendQuery ( ctx context . Context , domain string ) {
msgs := s . buildMsgs ( domain )
for _ , msg := range msgs {
b , err := msgToBuffer ( msg )
common . Must ( err )
2018-07-03 15:38:02 -04:00
s . udpServer . Dispatch ( context . Background ( ) , s . address , b )
2018-06-26 09:04:47 -04:00
}
}
func ( s * ClassicNameServer ) findIPsForDomain ( domain string ) [ ] net . IP {
2018-06-27 03:12:55 -04:00
s . RLock ( )
2018-06-26 09:04:47 -04:00
records , found := s . ips [ domain ]
2018-06-27 03:12:55 -04:00
s . RUnlock ( )
2018-06-26 09:04:47 -04:00
if found && len ( records ) > 0 {
var ips [ ] net . IP
now := time . Now ( )
for _ , rec := range records {
if rec . Expire . After ( now ) {
ips = append ( ips , rec . IP )
}
}
return ips
}
return nil
}
func ( s * ClassicNameServer ) QueryIP ( ctx context . Context , domain string ) ( [ ] net . IP , error ) {
fqdn := dns . Fqdn ( domain )
ips := s . findIPsForDomain ( fqdn )
if len ( ips ) > 0 {
return ips , nil
}
2018-07-01 06:38:40 -04:00
sub := s . pub . Subscribe ( fqdn )
defer sub . Close ( )
2018-06-26 09:04:47 -04:00
s . sendQuery ( ctx , fqdn )
for {
ips := s . findIPsForDomain ( fqdn )
if len ( ips ) > 0 {
return ips , nil
}
select {
case <- ctx . Done ( ) :
return nil , ctx . Err ( )
2018-07-01 06:38:40 -04:00
case <- sub . Wait ( ) :
2018-06-26 09:04:47 -04:00
}
}
}