1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-06-10 09:50:43 +00:00

add DOH dns client

This commit is contained in:
vcptr 2019-06-29 23:43:30 +08:00
parent 7313089539
commit 6ef77246ab
10 changed files with 875 additions and 248 deletions

235
app/dns/dnscommon.go Normal file
View File

@ -0,0 +1,235 @@
// +build !confonly
package dns
import (
"encoding/binary"
"time"
"golang.org/x/net/dns/dnsmessage"
"v2ray.com/core/common"
"v2ray.com/core/common/errors"
"v2ray.com/core/common/net"
dns_feature "v2ray.com/core/features/dns"
)
func Fqdn(domain string) string {
if len(domain) > 0 && domain[len(domain)-1] == '.' {
return domain
}
return domain + "."
}
type record struct {
A *IPRecord
AAAA *IPRecord
}
type IPRecord struct {
ReqID uint16
IP []net.Address
Expire time.Time
RCode dnsmessage.RCode
}
func (r *IPRecord) getIPs() ([]net.Address, error) {
if r == nil || r.Expire.Before(time.Now()) {
return nil, errRecordNotFound
}
if r.RCode != dnsmessage.RCodeSuccess {
return nil, dns_feature.RCodeError(r.RCode)
}
return r.IP, nil
}
func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
if newRec == nil {
return false
}
if baseRec == nil {
return true
}
return baseRec.Expire.Before(newRec.Expire)
}
var (
errRecordNotFound = errors.New("record not found")
)
type dnsRequest struct {
reqType dnsmessage.Type
domain string
start time.Time
expire time.Time
msg *dnsmessage.Message
}
func genEDNS0Options(clientIP net.IP) *dnsmessage.Resource {
if len(clientIP) == 0 {
return nil
}
var netmask int
var family uint16
if len(clientIP) == 4 {
family = 1
netmask = 24 // 24 for IPV4, 96 for IPv6
} else {
family = 2
netmask = 96
}
b := make([]byte, 4)
binary.BigEndian.PutUint16(b[0:], family)
b[2] = byte(netmask)
b[3] = 0
switch family {
case 1:
ip := clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8))
needLength := (netmask + 8 - 1) / 8 // division rounding up
b = append(b, ip[:needLength]...)
case 2:
ip := clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8))
needLength := (netmask + 8 - 1) / 8 // division rounding up
b = append(b, ip[:needLength]...)
}
const EDNS0SUBNET = 0x08
opt := new(dnsmessage.Resource)
common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
opt.Body = &dnsmessage.OPTResource{
Options: []dnsmessage.Option{
{
Code: EDNS0SUBNET,
Data: b,
},
},
}
return opt
}
func buildReqMsgs(domain string, option IPOption, reqIDGen func() uint16, reqOpts *dnsmessage.Resource) []*dnsRequest {
qA := dnsmessage.Question{
Name: dnsmessage.MustNewName(domain),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
}
qAAAA := dnsmessage.Question{
Name: dnsmessage.MustNewName(domain),
Type: dnsmessage.TypeAAAA,
Class: dnsmessage.ClassINET,
}
var reqs []*dnsRequest
now := time.Now()
if option.IPv4Enable {
msg := new(dnsmessage.Message)
msg.Header.ID = reqIDGen()
msg.Header.RecursionDesired = true
msg.Questions = []dnsmessage.Question{qA}
if reqOpts != nil {
msg.Additionals = append(msg.Additionals, *reqOpts)
}
reqs = append(reqs, &dnsRequest{
reqType: dnsmessage.TypeA,
domain: domain,
start: now,
msg: msg,
})
}
if option.IPv6Enable {
msg := new(dnsmessage.Message)
msg.Header.ID = reqIDGen()
msg.Header.RecursionDesired = true
msg.Questions = []dnsmessage.Question{qAAAA}
if reqOpts != nil {
msg.Additionals = append(msg.Additionals, *reqOpts)
}
reqs = append(reqs, &dnsRequest{
reqType: dnsmessage.TypeAAAA,
domain: domain,
start: now,
msg: msg,
})
}
return reqs
}
// parseResponse parse DNS answers from the returned payload
func parseResponse(payload []byte) (*IPRecord, error) {
var parser dnsmessage.Parser
h, err := parser.Start(payload)
if err != nil {
return nil, newError("failed to parse DNS response").Base(err).AtWarning()
}
if err := parser.SkipAllQuestions(); err != nil {
return nil, newError("failed to skip questions in DNS response").Base(err).AtWarning()
}
now := time.Now()
var ipRecExpire time.Time
if h.RCode != dnsmessage.RCodeSuccess {
// A default TTL, maybe a negtive cache
ipRecExpire = now.Add(time.Second * 120)
}
ipRecord := &IPRecord{
ReqID: h.ID,
RCode: h.RCode,
Expire: ipRecExpire,
}
L:
for {
ah, err := parser.AnswerHeader()
if err != nil {
if err != dnsmessage.ErrSectionDone {
newError("failed to parse answer section for domain: ", ah.Name.String()).Base(err).WriteToLog()
}
break
}
switch ah.Type {
case dnsmessage.TypeA:
ans, err := parser.AResource()
if err != nil {
newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog()
break L
}
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
case dnsmessage.TypeAAAA:
ans, err := parser.AAAAResource()
if err != nil {
newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog()
break L
}
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
default:
if err := parser.SkipAnswer(); err != nil {
newError("failed to skip answer").Base(err).WriteToLog()
break L
}
continue
}
if ipRecord.Expire.IsZero() {
ttl := ah.TTL
if ttl < 600 {
// at least 10 mins TTL
ipRecord.Expire = now.Add(time.Minute * 10)
} else {
ipRecord.Expire = now.Add(time.Duration(ttl) * time.Second)
}
}
}
return ipRecord, nil
}

166
app/dns/dnscommon_test.go Normal file
View File

@ -0,0 +1,166 @@
// +build !confonly
package dns
import (
"math/rand"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/miekg/dns"
"golang.org/x/net/dns/dnsmessage"
"v2ray.com/core/common"
"v2ray.com/core/common/net"
v2net "v2ray.com/core/common/net"
)
func Test_parseResponse(t *testing.T) {
type args struct {
payload []byte
}
var p [][]byte
ans := new(dns.Msg)
ans.Id = 0
p = append(p, common.Must2(ans.Pack()).([]byte))
p = append(p, []byte{})
ans = new(dns.Msg)
ans.Id = 1
ans.Answer = append(ans.Answer,
common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR),
common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")).(dns.RR),
common.Must2(dns.NewRR("google.com. IN A 8.8.8.8")).(dns.RR),
common.Must2(dns.NewRR("google.com. IN A 8.8.4.4")).(dns.RR),
)
p = append(p, common.Must2(ans.Pack()).([]byte))
ans = new(dns.Msg)
ans.Id = 2
ans.Answer = append(ans.Answer,
common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR),
common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")).(dns.RR),
common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR),
common.Must2(dns.NewRR("google.com. IN CNAME test.google.com")).(dns.RR),
common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8888")).(dns.RR),
common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8844")).(dns.RR),
)
p = append(p, common.Must2(ans.Pack()).([]byte))
tests := []struct {
name string
want *IPRecord
wantErr bool
}{
{"empty",
&IPRecord{0, []v2net.Address(nil), time.Time{}, dnsmessage.RCodeSuccess},
false,
},
{"error",
nil,
true,
},
{"a record",
&IPRecord{1, []v2net.Address{v2net.ParseAddress("8.8.8.8"), v2net.ParseAddress("8.8.4.4")},
time.Time{}, dnsmessage.RCodeSuccess},
false,
},
{"aaaa record",
&IPRecord{2, []v2net.Address{v2net.ParseAddress("2001::123:8888"), v2net.ParseAddress("2001::123:8844")}, time.Time{}, dnsmessage.RCodeSuccess},
false,
},
}
for i, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := parseResponse(p[i])
if (err != nil) != tt.wantErr {
t.Errorf("handleResponse() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != nil {
// reset the time
got.Expire = time.Time{}
}
if cmp.Diff(got, tt.want) != "" {
t.Errorf(cmp.Diff(got, tt.want))
// t.Errorf("handleResponse() = %#v, want %#v", got, tt.want)
}
})
}
}
func Test_buildReqMsgs(t *testing.T) {
stubID := func() uint16 {
return uint16(rand.Uint32())
}
type args struct {
domain string
option IPOption
reqOpts *dnsmessage.Resource
}
tests := []struct {
name string
args args
want int
}{
{"dual stack", args{"test.com", IPOption{true, true}, nil}, 2},
{"ipv4 only", args{"test.com", IPOption{true, false}, nil}, 1},
{"ipv6 only", args{"test.com", IPOption{false, true}, nil}, 1},
{"none/error", args{"test.com", IPOption{false, false}, nil}, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := buildReqMsgs(tt.args.domain, tt.args.option, stubID, tt.args.reqOpts); !(len(got) == tt.want) {
t.Errorf("buildReqMsgs() = %v, want %v", got, tt.want)
}
})
}
}
func Test_genEDNS0Options(t *testing.T) {
type args struct {
clientIP net.IP
}
tests := []struct {
name string
args args
want *dnsmessage.Resource
}{
// TODO: Add test cases.
{"ipv4", args{net.ParseIP("4.3.2.1")}, nil},
{"ipv6", args{net.ParseIP("2001::4321")}, nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := genEDNS0Options(tt.args.clientIP); got == nil {
t.Errorf("genEDNS0Options() = %v, want %v", got, tt.want)
}
})
}
}
func TestFqdn(t *testing.T) {
type args struct {
domain string
}
tests := []struct {
name string
args args
want string
}{
{"with fqdn", args{"www.v2ray.com."}, "www.v2ray.com."},
{"without fqdn", args{"www.v2ray.com"}, "www.v2ray.com."},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := Fqdn(tt.args.domain); got != tt.want {
t.Errorf("Fqdn() = %v, want %v", got, tt.want)
}
})
}
}

315
app/dns/dohdns.go Normal file
View File

@ -0,0 +1,315 @@
// +build !confonly
package dns
import (
"bytes"
"context"
"fmt"
"io/ioutil"
"net/http"
"sync"
"sync/atomic"
"time"
"golang.org/x/net/dns/dnsmessage"
"v2ray.com/core/common"
"v2ray.com/core/common/dice"
"v2ray.com/core/common/net"
"v2ray.com/core/common/protocol/dns"
"v2ray.com/core/common/session"
"v2ray.com/core/common/signal/pubsub"
"v2ray.com/core/common/task"
"v2ray.com/core/features/routing"
)
// DoHNameServer implimented DNS over HTTPS (RFC8484) Wire Format,
// which is compatiable with traditional dns over udp(RFC1035),
// thus most of the DOH implimentation is copied from udpns.go
type DoHNameServer struct {
sync.RWMutex
dispatcher routing.Dispatcher
dohDests []net.Destination
ips map[string]record
pub *pubsub.Service
cleanup *task.Periodic
reqID uint32
clientIP net.IP
httpClient *http.Client
dohURL string
name string
}
func NewDoHNameServer(dests []net.Destination, dohHost string, dispatcher routing.Dispatcher, clientIP net.IP) *DoHNameServer {
s := NewDoHLocalNameServer(dohHost, clientIP)
s.name = "DOH:" + dohHost
s.dispatcher = dispatcher
s.dohDests = dests
// Dispatched connection will be closed (interupted) after each request
// This makes DOH inefficient without a keeped-alive connection
// See: core/app/proxyman/outbound/handler.go:113
// Using mux (https request wrapped in a stream layer) improves the situation.
// Recommand to use NewDoHLocalNameServer (DOHL:) if v2ray instance is running on
// a normal network eg. the server side of v2ray
tr := &http.Transport{
MaxIdleConns: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
DialContext: s.DialContext,
}
dispatchedClient := &http.Client{
Transport: tr,
Timeout: 16 * time.Second,
}
s.httpClient = dispatchedClient
return s
}
func NewDoHLocalNameServer(dohHost string, clientIP net.IP) *DoHNameServer {
s := &DoHNameServer{
httpClient: http.DefaultClient,
ips: make(map[string]record),
clientIP: clientIP,
pub: pubsub.NewService(),
name: "DOHL:" + dohHost,
dohURL: fmt.Sprintf("https://%s/dns-query", dohHost),
}
s.cleanup = &task.Periodic{
Interval: time.Minute,
Execute: s.Cleanup,
}
return s
}
func (s *DoHNameServer) Name() string {
return s.name
}
func (s *DoHNameServer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
dest := s.dohDests[dice.Roll(len(s.dohDests))]
link, err := s.dispatcher.Dispatch(ctx, dest)
if err != nil {
return nil, err
}
return net.NewConnection(
net.ConnectionInputMulti(link.Writer),
net.ConnectionOutputMulti(link.Reader),
), nil
}
func (s *DoHNameServer) Cleanup() error {
now := time.Now()
s.Lock()
defer s.Unlock()
if len(s.ips) == 0 {
return newError("nothing to do. stopping...")
}
for domain, record := range s.ips {
if record.A != nil && record.A.Expire.Before(now) {
record.A = nil
}
if record.AAAA != nil && record.AAAA.Expire.Before(now) {
record.AAAA = nil
}
if record.A == nil && record.AAAA == nil {
newError(s.name, " cleanup ", domain).AtDebug().WriteToLog()
delete(s.ips, domain)
} else {
s.ips[domain] = record
}
}
if len(s.ips) == 0 {
s.ips = make(map[string]record)
}
return nil
}
func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
elapsed := time.Since(req.start)
newError(s.name, " got answere: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog()
s.Lock()
rec := s.ips[req.domain]
updated := false
switch req.reqType {
case dnsmessage.TypeA:
if isNewer(rec.A, ipRec) {
rec.A = ipRec
updated = true
}
case dnsmessage.TypeAAAA:
if isNewer(rec.AAAA, ipRec) {
rec.AAAA = ipRec
updated = true
}
}
if updated {
s.ips[req.domain] = rec
s.pub.Publish(req.domain, nil)
}
s.Unlock()
common.Must(s.cleanup.Start())
}
func (s *DoHNameServer) newReqID() uint16 {
return uint16(atomic.AddUint32(&s.reqID, 1))
}
func (s *DoHNameServer) sendQuery(ctx context.Context, domain string, option IPOption) {
newError(s.name, " querying: ", domain).AtInfo().WriteToLog(session.ExportIDToError(ctx))
reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP))
var deadline time.Time
if d, ok := ctx.Deadline(); ok {
deadline = d
} else {
deadline = time.Now().Add(time.Second * 8)
}
for _, req := range reqs {
go func(r *dnsRequest) {
// generate new context for each req, using same context
// may cause reqs all aborted if any one encounter an error
dnsCtx := context.Background()
// reserve internal dns server requested Inbound
if inbound := session.InboundFromContext(ctx); inbound != nil {
dnsCtx = session.ContextWithInbound(dnsCtx, inbound)
}
dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{
Protocol: "https",
})
// forced to use mux for DOH
dnsCtx = session.ContextWithMuxPrefered(dnsCtx, true)
dnsCtx, cancel := context.WithDeadline(dnsCtx, deadline)
defer cancel()
b, _ := dns.PackMessage(r.msg)
resp, err := s.dohHTTPSContext(dnsCtx, b.Bytes())
if err != nil {
newError("failed to retrive response").Base(err).AtError().WriteToLog()
return
}
rec, err := parseResponse(resp)
if err != nil {
newError("failed to handle DOH response").Base(err).AtError().WriteToLog()
return
}
s.updateIP(r, rec)
}(req)
}
}
func (s *DoHNameServer) dohHTTPSContext(ctx context.Context, b []byte) ([]byte, error) {
body := bytes.NewBuffer(b)
req, err := http.NewRequest("POST", s.dohURL, body)
if err != nil {
return nil, err
}
req.Header.Add("Accept", "application/dns-message")
req.Header.Add("Content-Type", "application/dns-message")
resp, err := s.httpClient.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
err = fmt.Errorf("DOH HTTPS server returned with non-OK code %d", resp.StatusCode)
return nil, err
}
return ioutil.ReadAll(resp.Body)
}
func (s *DoHNameServer) findIPsForDomain(domain string, option IPOption) ([]net.IP, error) {
s.RLock()
record, found := s.ips[domain]
s.RUnlock()
if !found {
return nil, errRecordNotFound
}
var ips []net.Address
var lastErr error
if option.IPv6Enable && record.AAAA != nil && record.AAAA.RCode == dnsmessage.RCodeSuccess {
aaaa, err := record.AAAA.getIPs()
if err != nil {
lastErr = err
}
ips = append(ips, aaaa...)
}
if option.IPv4Enable && record.A != nil && record.A.RCode == dnsmessage.RCodeSuccess {
a, err := record.A.getIPs()
if err != nil {
lastErr = err
}
ips = append(ips, a...)
}
if len(ips) > 0 {
return toNetIP(ips), nil
}
if lastErr != nil {
return nil, lastErr
}
return nil, errRecordNotFound
}
// QueryIP is called from dns.Server->queryIPTimeout
func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) {
fqdn := Fqdn(domain)
ips, err := s.findIPsForDomain(fqdn, option)
if err != errRecordNotFound {
newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog()
return ips, err
}
sub := s.pub.Subscribe(fqdn)
defer sub.Close()
s.sendQuery(ctx, fqdn, option)
for {
ips, err := s.findIPsForDomain(fqdn, option)
if err != errRecordNotFound {
return ips, err
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-sub.Wait():
}
}
}

View File

@ -6,6 +6,8 @@ package dns
import (
"context"
"fmt"
"strings"
"sync"
"time"
@ -87,6 +89,49 @@ func New(ctx context.Context, config *Config) (*Server, error) {
address := endpoint.Address.AsAddress()
if address.Family().IsDomain() && address.Domain() == "localhost" {
server.clients = append(server.clients, NewLocalNameServer())
newError("DNS: localhost inited").AtInfo().WriteToLog()
} else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "DOHL_") {
dohHost := address.Domain()[5:]
server.clients = append(server.clients, NewDoHLocalNameServer(dohHost, server.clientIP))
newError("DNS: DOH - Local inited for https://", dohHost).AtInfo().WriteToLog()
} else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "DOH_") {
// DOH_ prefix makes net.Address think it's a domain
// need to process the real address here.
dohHost := address.Domain()[4:]
dohAddr := net.ParseAddress(dohHost)
dohIP := dohHost
var dests []net.Destination
if dohAddr.Family().IsDomain() {
// resolve DOH server in advance
ips, err := net.LookupIP(dohAddr.Domain())
if err != nil || len(ips) == 0 {
return 0
}
for _, ip := range ips {
dohIP := ip.String()
if len(ip) == net.IPv6len {
dohIP = fmt.Sprintf("[%s]", dohIP)
}
dohdest, _ := net.ParseDestination(fmt.Sprintf("tcp:%s:443", dohIP))
dests = append(dests, dohdest)
}
} else {
// rfc8484, DOH service only use port 443
dest, err := net.ParseDestination(fmt.Sprintf("tcp:%s:443", dohIP))
if err != nil {
return 0
}
dests = []net.Destination{dest}
}
// need the core dispatcher, register DOHClient at callback
idx := len(server.clients)
server.clients = append(server.clients, nil)
common.Must(core.RequireFeatures(ctx, func(d routing.Dispatcher) {
server.clients[idx] = NewDoHNameServer(dests, dohHost, d, server.clientIP)
newError("DNS: DOH - Remote client inited for https://", dohHost).AtInfo().WriteToLog()
}))
} else {
dest := endpoint.AsDestination()
if dest.Network == net.Network_Unknown {
@ -100,6 +145,7 @@ func New(ctx context.Context, config *Config) (*Server, error) {
server.clients[idx] = NewClassicNameServer(dest, d, server.clientIP)
}))
}
newError("DNS: UDP client inited for ", dest.NetAddr()).AtInfo().WriteToLog()
}
return len(server.clients) - 1
}
@ -272,10 +318,16 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
return nil, newError("empty domain name")
}
// normalize the FQDN form query
if domain[len(domain)-1] == '.' {
domain = domain[:len(domain)-1]
}
// skip domain without any dot
if strings.Index(domain, ".") == -1 {
return nil, newError("invalid domain name")
}
ips := s.lookupStatic(domain, option, 0)
if ips != nil && ips[0].Family().IsIP() {
newError("returning ", len(ips), " IPs for domain ", domain).WriteToLog()
@ -331,7 +383,7 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
}
}
return nil, newError("returning nil for domain ", domain).Base(lastErr)
return nil, dns.ErrEmptyResponse.Base(lastErr)
}
func init() {

View File

@ -4,14 +4,13 @@ package dns
import (
"context"
"encoding/binary"
"strings"
"sync"
"sync/atomic"
"time"
"golang.org/x/net/dns/dnsmessage"
"v2ray.com/core/common"
"v2ray.com/core/common/errors"
"v2ray.com/core/common/net"
"v2ray.com/core/common/protocol/dns"
udp_proto "v2ray.com/core/common/protocol/udp"
@ -23,42 +22,12 @@ import (
"v2ray.com/core/transport/internet/udp"
)
type record struct {
A *IPRecord
AAAA *IPRecord
}
type IPRecord struct {
IP []net.Address
Expire time.Time
RCode dnsmessage.RCode
}
func (r *IPRecord) getIPs() ([]net.Address, error) {
if r == nil || r.Expire.Before(time.Now()) {
return nil, errRecordNotFound
}
if r.RCode != dnsmessage.RCodeSuccess {
return nil, dns_feature.RCodeError(r.RCode)
}
return r.IP, nil
}
type pendingRequest struct {
domain string
expire time.Time
recType dnsmessage.Type
}
var (
errRecordNotFound = errors.New("record not found")
)
type ClassicNameServer struct {
sync.RWMutex
name string
address net.Destination
ips map[string]record
requests map[uint16]pendingRequest
requests map[uint16]dnsRequest
pub *pubsub.Service
udpServer *udp.Dispatcher
cleanup *task.Periodic
@ -70,9 +39,10 @@ func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher
s := &ClassicNameServer{
address: address,
ips: make(map[string]record),
requests: make(map[uint16]pendingRequest),
requests: make(map[uint16]dnsRequest),
clientIP: clientIP,
pub: pubsub.NewService(),
name: strings.ToUpper(address.String()),
}
s.cleanup = &task.Periodic{
Interval: time.Minute,
@ -83,7 +53,7 @@ func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher
}
func (s *ClassicNameServer) Name() string {
return s.address.String()
return s.name
}
func (s *ClassicNameServer) Cleanup() error {
@ -92,7 +62,7 @@ func (s *ClassicNameServer) Cleanup() error {
defer s.Unlock()
if len(s.ips) == 0 && len(s.requests) == 0 {
return newError("nothing to do. stopping...")
return newError(s.name, " nothing to do. stopping...")
}
for domain, record := range s.ips {
@ -121,123 +91,52 @@ func (s *ClassicNameServer) Cleanup() error {
}
if len(s.requests) == 0 {
s.requests = make(map[uint16]pendingRequest)
s.requests = make(map[uint16]dnsRequest)
}
return nil
}
func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) {
payload := packet.Payload
var parser dnsmessage.Parser
header, err := parser.Start(payload.Bytes())
ipRec, err := parseResponse(packet.Payload.Bytes())
if err != nil {
newError("failed to parse DNS response").Base(err).AtWarning().WriteToLog()
return
}
if err := parser.SkipAllQuestions(); err != nil {
newError("failed to skip questions in DNS response").Base(err).AtWarning().WriteToLog()
newError(s.name, " fail to parse responsed DNS udp").AtError().WriteToLog()
return
}
id := header.ID
s.Lock()
req, f := s.requests[id]
if f {
id := ipRec.ReqID
req, ok := s.requests[id]
if ok {
// remove the pending request
delete(s.requests, id)
}
s.Unlock()
if !f {
if !ok {
newError(s.name, " cannot find the pending request").AtError().WriteToLog()
return
}
domain := req.domain
recType := req.recType
now := time.Now()
ipRecord := &IPRecord{
RCode: header.RCode,
Expire: now.Add(time.Second * 600),
}
L:
for {
header, err := parser.AnswerHeader()
if err != nil {
if err != dnsmessage.ErrSectionDone {
newError("failed to parse answer section for domain: ", domain).Base(err).WriteToLog()
}
break
}
ttl := header.TTL
if ttl == 0 {
ttl = 600
}
expire := now.Add(time.Duration(ttl) * time.Second)
if ipRecord.Expire.After(expire) {
ipRecord.Expire = expire
}
if header.Type != recType {
if err := parser.SkipAnswer(); err != nil {
newError("failed to skip answer").Base(err).WriteToLog()
break L
}
continue
}
switch header.Type {
case dnsmessage.TypeA:
ans, err := parser.AResource()
if err != nil {
newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
break L
}
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
case dnsmessage.TypeAAAA:
ans, err := parser.AAAAResource()
if err != nil {
newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
break L
}
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
default:
if err := parser.SkipAnswer(); err != nil {
newError("failed to skip answer").Base(err).WriteToLog()
break L
}
}
}
var rec record
switch recType {
switch req.reqType {
case dnsmessage.TypeA:
rec.A = ipRecord
rec.A = ipRec
case dnsmessage.TypeAAAA:
rec.AAAA = ipRecord
rec.AAAA = ipRec
}
if len(domain) > 0 && (rec.A != nil || rec.AAAA != nil) {
s.updateIP(domain, rec)
elapsed := time.Since(req.start)
newError(s.name, " got answere: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog()
if len(req.domain) > 0 && (rec.A != nil || rec.AAAA != nil) {
s.updateIP(req.domain, rec)
}
}
func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
if newRec == nil {
return false
}
if baseRec == nil {
return true
}
return baseRec.Expire.Before(newRec.Expire)
}
func (s *ClassicNameServer) updateIP(domain string, newRec record) {
s.Lock()
newError("updating IP records for domain:", domain).AtDebug().WriteToLog()
newError(s.name, " updating IP records for domain:", domain).AtDebug().WriteToLog()
rec := s.ips[domain]
updated := false
@ -259,116 +158,27 @@ func (s *ClassicNameServer) updateIP(domain string, newRec record) {
common.Must(s.cleanup.Start())
}
func (s *ClassicNameServer) getMsgOptions() *dnsmessage.Resource {
if len(s.clientIP) == 0 {
return nil
}
var netmask int
var family uint16
if len(s.clientIP) == 4 {
family = 1
netmask = 24 // 24 for IPV4, 96 for IPv6
} else {
family = 2
netmask = 96
}
b := make([]byte, 4)
binary.BigEndian.PutUint16(b[0:], family)
b[2] = byte(netmask)
b[3] = 0
switch family {
case 1:
ip := s.clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8))
needLength := (netmask + 8 - 1) / 8 // division rounding up
b = append(b, ip[:needLength]...)
case 2:
ip := s.clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8))
needLength := (netmask + 8 - 1) / 8 // division rounding up
b = append(b, ip[:needLength]...)
}
const EDNS0SUBNET = 0x08
opt := new(dnsmessage.Resource)
common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
opt.Body = &dnsmessage.OPTResource{
Options: []dnsmessage.Option{
{
Code: EDNS0SUBNET,
Data: b,
},
},
}
return opt
func (s *ClassicNameServer) newReqID() uint16 {
return uint16(atomic.AddUint32(&s.reqID, 1))
}
func (s *ClassicNameServer) addPendingRequest(domain string, recType dnsmessage.Type) uint16 {
id := uint16(atomic.AddUint32(&s.reqID, 1))
func (s *ClassicNameServer) addPendingRequest(req *dnsRequest) {
s.Lock()
defer s.Unlock()
s.requests[id] = pendingRequest{
domain: domain,
expire: time.Now().Add(time.Second * 8),
recType: recType,
}
return id
}
func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmessage.Message {
qA := dnsmessage.Question{
Name: dnsmessage.MustNewName(domain),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
}
qAAAA := dnsmessage.Question{
Name: dnsmessage.MustNewName(domain),
Type: dnsmessage.TypeAAAA,
Class: dnsmessage.ClassINET,
}
var msgs []*dnsmessage.Message
if option.IPv4Enable {
msg := new(dnsmessage.Message)
msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeA)
msg.Header.RecursionDesired = true
msg.Questions = []dnsmessage.Question{qA}
if opt := s.getMsgOptions(); opt != nil {
msg.Additionals = append(msg.Additionals, *opt)
}
msgs = append(msgs, msg)
}
if option.IPv6Enable {
msg := new(dnsmessage.Message)
msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeAAAA)
msg.Header.RecursionDesired = true
msg.Questions = []dnsmessage.Question{qAAAA}
if opt := s.getMsgOptions(); opt != nil {
msg.Additionals = append(msg.Additionals, *opt)
}
msgs = append(msgs, msg)
}
return msgs
id := req.msg.ID
req.expire = time.Now().Add(time.Second * 8)
s.requests[id] = *req
}
func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, option IPOption) {
newError("querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx))
newError(s.name, " querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx))
msgs := s.buildMsgs(domain, option)
for _, msg := range msgs {
b, _ := dns.PackMessage(msg)
reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP))
for _, req := range reqs {
s.addPendingRequest(req)
b, _ := dns.PackMessage(req.msg)
udpCtx := context.Background()
if inbound := session.InboundFromContext(ctx); inbound != nil {
udpCtx = session.ContextWithInbound(udpCtx, inbound)
@ -418,18 +228,13 @@ func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) ([]
return nil, dns_feature.ErrEmptyResponse
}
func Fqdn(domain string) string {
if len(domain) > 0 && domain[len(domain)-1] == '.' {
return domain
}
return domain + "."
}
func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) {
fqdn := Fqdn(domain)
ips, err := s.findIPsForDomain(fqdn, option)
if err != errRecordNotFound {
newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog()
return ips, err
}

View File

@ -68,12 +68,13 @@ func NewHandler(ctx context.Context, config *core.OutboundHandlerConfig) (outbou
return nil, newError("not an outbound handler")
}
if h.senderSettings != nil && h.senderSettings.MultiplexSettings != nil && h.senderSettings.MultiplexSettings.Enabled {
if h.senderSettings != nil && h.senderSettings.MultiplexSettings != nil {
config := h.senderSettings.MultiplexSettings
if config.Concurrency < 1 || config.Concurrency > 1024 {
return nil, newError("invalid mux concurrency: ", config.Concurrency).AtWarning()
}
h.mux = &mux.ClientManager{
Enabled: h.senderSettings.MultiplexSettings.Enabled,
Picker: &mux.IncrementalWorkerPicker{
Factory: &mux.DialingWorkerFactory{
Proxy: proxyHandler,
@ -98,7 +99,7 @@ func (h *Handler) Tag() string {
// Dispatch implements proxy.Outbound.Dispatch.
func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
if h.mux != nil {
if h.mux != nil && (h.mux.Enabled || session.MuxPreferedFromContext(ctx)) {
if err := h.mux.Dispatch(ctx, link); err != nil {
newError("failed to process mux outbound traffic").Base(err).WriteToLog(session.ExportIDToError(ctx))
common.Interrupt(link.Writer)

View File

@ -21,7 +21,8 @@ import (
)
type ClientManager struct {
Picker WorkerPicker
Enabled bool // wheather mux is enabled from user config
Picker WorkerPicker
}
func (m *ClientManager) Dispatch(ctx context.Context, link *transport.Link) error {

View File

@ -9,6 +9,7 @@ const (
inboundSessionKey
outboundSessionKey
contentSessionKey
MuxPreferedSessionKey
)
// ContextWithID returns a new context with the given ID.
@ -56,3 +57,16 @@ func ContentFromContext(ctx context.Context) *Content {
}
return nil
}
// ContextWithMuxPrefered returns a new context with the given bool
func ContextWithMuxPrefered(ctx context.Context, forced bool) context.Context {
return context.WithValue(ctx, MuxPreferedSessionKey, forced)
}
// MuxPreferedFromContext returns value in this context, or false if not contained.
func MuxPreferedFromContext(ctx context.Context) bool {
if val, ok := ctx.Value(MuxPreferedSessionKey).(bool); ok {
return val
}
return false
}

View File

@ -75,15 +75,24 @@ func (c *SniffingConfig) Build() (*proxyman.SniffingConfig, error) {
}
type MuxConfig struct {
Enabled bool `json:"enabled"`
Concurrency uint16 `json:"concurrency"`
Enabled bool `json:"enabled"`
Concurrency int16 `json:"concurrency"`
}
func (c *MuxConfig) GetConcurrency() uint16 {
if c.Concurrency == 0 {
return 8
func (m *MuxConfig) Build() *proxyman.MultiplexingConfig {
if m.Concurrency < 0 {
return nil
}
var con uint32 = 8
if m.Concurrency > 0 {
con = uint32(m.Concurrency)
}
return &proxyman.MultiplexingConfig{
Enabled: m.Enabled,
Concurrency: con,
}
return c.Concurrency
}
type InboundDetourAllocationConfig struct {
@ -246,11 +255,8 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) {
senderSettings.ProxySettings = ps
}
if c.MuxSettings != nil && c.MuxSettings.Enabled {
senderSettings.MultiplexSettings = &proxyman.MultiplexingConfig{
Enabled: true,
Concurrency: uint32(c.MuxSettings.GetConcurrency()),
}
if c.MuxSettings != nil {
senderSettings.MultiplexSettings = c.MuxSettings.Build()
}
settings := []byte("{}")

View File

@ -2,15 +2,16 @@ package conf_test
import (
"encoding/json"
"reflect"
"testing"
"github.com/golang/protobuf/proto"
"v2ray.com/core"
"v2ray.com/core/app/dispatcher"
"v2ray.com/core/app/log"
"v2ray.com/core/app/proxyman"
"v2ray.com/core/app/router"
"v2ray.com/core/common"
clog "v2ray.com/core/common/log"
"v2ray.com/core/common/net"
"v2ray.com/core/common/protocol"
@ -337,3 +338,34 @@ func TestV2RayConfig(t *testing.T) {
},
})
}
func TestMuxConfig_Build(t *testing.T) {
tests := []struct {
name string
fields string
want *proxyman.MultiplexingConfig
}{
{"default", `{"enabled": true, "concurrency": 16}`, &proxyman.MultiplexingConfig{
Enabled: true,
Concurrency: 16,
}},
{"empty def", `{}`, &proxyman.MultiplexingConfig{
Enabled: false,
Concurrency: 8,
}},
{"not enable", `{"enabled": false, "concurrency": 4}`, &proxyman.MultiplexingConfig{
Enabled: false,
Concurrency: 4,
}},
{"forbidden", `{"enabled": false, "concurrency": -1}`, nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := &MuxConfig{}
common.Must(json.Unmarshal([]byte(tt.fields), m))
if got := m.Build(); !reflect.DeepEqual(got, tt.want) {
t.Errorf("MuxConfig.Build() = %v, want %v", got, tt.want)
}
})
}
}