mirror of
https://github.com/v2fly/v2ray-core.git
synced 2025-01-22 01:06:30 -05:00
Feat: DNS transport over TCP (#983)
* feat: DNS over TCP * fix: DNS over TCP misbehaving * fix: add a blank line after +build tag * style: rename NewTCPLNameServer to NewTCPLocalNameServer * style: add some comments * style: format Co-authored-by: Shelikhoo <xiaokangwang@outlook.com>
This commit is contained in:
parent
e98865a205
commit
f84a401704
@ -52,6 +52,10 @@ func NewServer(dest net.Destination, dispatcher routing.Dispatcher) (Server, err
|
||||
return NewDoHLocalNameServer(u), nil
|
||||
case strings.EqualFold(u.Scheme, "quic+local"): // DNS-over-QUIC Local mode
|
||||
return NewQUICNameServer(u)
|
||||
case strings.EqualFold(u.Scheme, "tcp"): // DNS-over-TCP Remote mode
|
||||
return NewTCPNameServer(u, dispatcher)
|
||||
case strings.EqualFold(u.Scheme, "tcp+local"): // DNS-over-TCP Local mode
|
||||
return NewTCPLocalNameServer(u)
|
||||
case strings.EqualFold(u.String(), "fakedns"):
|
||||
return NewFakeDNSServer(), nil
|
||||
}
|
||||
|
360
app/dns/nameserver_tcp.go
Normal file
360
app/dns/nameserver_tcp.go
Normal file
@ -0,0 +1,360 @@
|
||||
// +build !confonly
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"net/url"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
|
||||
"github.com/v2fly/v2ray-core/v4/common"
|
||||
"github.com/v2fly/v2ray-core/v4/common/buf"
|
||||
"github.com/v2fly/v2ray-core/v4/common/net"
|
||||
"github.com/v2fly/v2ray-core/v4/common/protocol/dns"
|
||||
"github.com/v2fly/v2ray-core/v4/common/session"
|
||||
"github.com/v2fly/v2ray-core/v4/common/signal/pubsub"
|
||||
"github.com/v2fly/v2ray-core/v4/common/task"
|
||||
dns_feature "github.com/v2fly/v2ray-core/v4/features/dns"
|
||||
"github.com/v2fly/v2ray-core/v4/features/routing"
|
||||
"github.com/v2fly/v2ray-core/v4/transport/internet"
|
||||
)
|
||||
|
||||
// TCPNameServer implemented DNS over TCP (RFC7766).
|
||||
type TCPNameServer struct {
|
||||
sync.RWMutex
|
||||
name string
|
||||
destination net.Destination
|
||||
ips map[string]record
|
||||
pub *pubsub.Service
|
||||
cleanup *task.Periodic
|
||||
reqID uint32
|
||||
dial func(context.Context) (net.Conn, error)
|
||||
}
|
||||
|
||||
// NewTCPNameServer creates DNS over TCP server object for remote resolving.
|
||||
func NewTCPNameServer(url *url.URL, dispatcher routing.Dispatcher) (*TCPNameServer, error) {
|
||||
s, err := baseTCPNameServer(url, "TCP")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.dial = func(ctx context.Context) (net.Conn, error) {
|
||||
link, err := dispatcher.Dispatch(ctx, s.destination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return net.NewConnection(
|
||||
net.ConnectionInputMulti(link.Writer),
|
||||
net.ConnectionOutputMulti(link.Reader),
|
||||
), nil
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// NewTCPLocalNameServer creates DNS over TCP client object for local resolving
|
||||
func NewTCPLocalNameServer(url *url.URL) (*TCPNameServer, error) {
|
||||
s, err := baseTCPNameServer(url, "TCPL")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.dial = func(ctx context.Context) (net.Conn, error) {
|
||||
return internet.DialSystem(ctx, s.destination, nil)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func baseTCPNameServer(url *url.URL, prefix string) (*TCPNameServer, error) {
|
||||
var err error
|
||||
port := net.Port(53)
|
||||
if url.Port() != "" {
|
||||
port, err = net.PortFromString(url.Port())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
dest := net.TCPDestination(net.DomainAddress(url.Hostname()), port)
|
||||
|
||||
s := &TCPNameServer{
|
||||
destination: dest,
|
||||
ips: make(map[string]record),
|
||||
pub: pubsub.NewService(),
|
||||
name: prefix + "//" + dest.NetAddr(),
|
||||
}
|
||||
s.cleanup = &task.Periodic{
|
||||
Interval: time.Minute,
|
||||
Execute: s.Cleanup,
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Name implements Server.
|
||||
func (s *TCPNameServer) Name() string {
|
||||
return s.name
|
||||
}
|
||||
|
||||
// Cleanup clears expired items from cache
|
||||
func (s *TCPNameServer) Cleanup() error {
|
||||
now := time.Now()
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
if len(s.ips) == 0 {
|
||||
return newError("nothing to do. stopping...")
|
||||
}
|
||||
|
||||
for domain, record := range s.ips {
|
||||
if record.A != nil && record.A.Expire.Before(now) {
|
||||
record.A = nil
|
||||
}
|
||||
if record.AAAA != nil && record.AAAA.Expire.Before(now) {
|
||||
record.AAAA = nil
|
||||
}
|
||||
|
||||
if record.A == nil && record.AAAA == nil {
|
||||
newError(s.name, " cleanup ", domain).AtDebug().WriteToLog()
|
||||
delete(s.ips, domain)
|
||||
} else {
|
||||
s.ips[domain] = record
|
||||
}
|
||||
}
|
||||
|
||||
if len(s.ips) == 0 {
|
||||
s.ips = make(map[string]record)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *TCPNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
|
||||
elapsed := time.Since(req.start)
|
||||
|
||||
s.Lock()
|
||||
rec := s.ips[req.domain]
|
||||
updated := false
|
||||
|
||||
switch req.reqType {
|
||||
case dnsmessage.TypeA:
|
||||
if isNewer(rec.A, ipRec) {
|
||||
rec.A = ipRec
|
||||
updated = true
|
||||
}
|
||||
case dnsmessage.TypeAAAA:
|
||||
addr := make([]net.Address, 0)
|
||||
for _, ip := range ipRec.IP {
|
||||
if len(ip.IP()) == net.IPv6len {
|
||||
addr = append(addr, ip)
|
||||
}
|
||||
}
|
||||
ipRec.IP = addr
|
||||
if isNewer(rec.AAAA, ipRec) {
|
||||
rec.AAAA = ipRec
|
||||
updated = true
|
||||
}
|
||||
}
|
||||
newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog()
|
||||
|
||||
if updated {
|
||||
s.ips[req.domain] = rec
|
||||
}
|
||||
switch req.reqType {
|
||||
case dnsmessage.TypeA:
|
||||
s.pub.Publish(req.domain+"4", nil)
|
||||
case dnsmessage.TypeAAAA:
|
||||
s.pub.Publish(req.domain+"6", nil)
|
||||
}
|
||||
s.Unlock()
|
||||
common.Must(s.cleanup.Start())
|
||||
}
|
||||
|
||||
func (s *TCPNameServer) newReqID() uint16 {
|
||||
return uint16(atomic.AddUint32(&s.reqID, 1))
|
||||
}
|
||||
|
||||
func (s *TCPNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) {
|
||||
newError(s.name, " querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx))
|
||||
|
||||
reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP))
|
||||
|
||||
var deadline time.Time
|
||||
if d, ok := ctx.Deadline(); ok {
|
||||
deadline = d
|
||||
} else {
|
||||
deadline = time.Now().Add(time.Second * 5)
|
||||
}
|
||||
|
||||
for _, req := range reqs {
|
||||
go func(r *dnsRequest) {
|
||||
dnsCtx := ctx
|
||||
|
||||
if inbound := session.InboundFromContext(ctx); inbound != nil {
|
||||
dnsCtx = session.ContextWithInbound(dnsCtx, inbound)
|
||||
}
|
||||
|
||||
dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{
|
||||
Protocol: "dns",
|
||||
SkipDNSResolve: true,
|
||||
})
|
||||
|
||||
var cancel context.CancelFunc
|
||||
dnsCtx, cancel = context.WithDeadline(dnsCtx, deadline)
|
||||
defer cancel()
|
||||
|
||||
b, err := dns.PackMessage(r.msg)
|
||||
if err != nil {
|
||||
newError("failed to pack dns query").Base(err).AtError().WriteToLog()
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := s.dial(dnsCtx)
|
||||
if err != nil {
|
||||
newError("failed to dial namesever").Base(err).AtError().WriteToLog()
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
dnsReqBuf := buf.New()
|
||||
binary.Write(dnsReqBuf, binary.BigEndian, uint16(b.Len()))
|
||||
dnsReqBuf.Write(b.Bytes())
|
||||
b.Release()
|
||||
|
||||
_, err = conn.Write(dnsReqBuf.Bytes())
|
||||
if err != nil {
|
||||
newError("failed to send query").Base(err).AtError().WriteToLog()
|
||||
return
|
||||
}
|
||||
dnsReqBuf.Release()
|
||||
|
||||
respBuf := buf.New()
|
||||
defer respBuf.Release()
|
||||
n, err := respBuf.ReadFullFrom(conn, 2)
|
||||
if err != nil && n == 0 {
|
||||
newError("failed to read response length").Base(err).AtError().WriteToLog()
|
||||
return
|
||||
}
|
||||
var length int16
|
||||
err = binary.Read(bytes.NewReader(respBuf.Bytes()), binary.BigEndian, &length)
|
||||
if err != nil {
|
||||
newError("failed to parse response length").Base(err).AtError().WriteToLog()
|
||||
return
|
||||
}
|
||||
respBuf.Clear()
|
||||
n, err = respBuf.ReadFullFrom(conn, int32(length))
|
||||
if err != nil && n == 0 {
|
||||
newError("failed to read response length").Base(err).AtError().WriteToLog()
|
||||
return
|
||||
}
|
||||
|
||||
rec, err := parseResponse(respBuf.Bytes())
|
||||
if err != nil {
|
||||
newError("failed to parse DNS over TCP response").Base(err).AtError().WriteToLog()
|
||||
return
|
||||
}
|
||||
|
||||
s.updateIP(r, rec)
|
||||
}(req)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TCPNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, error) {
|
||||
s.RLock()
|
||||
record, found := s.ips[domain]
|
||||
s.RUnlock()
|
||||
|
||||
if !found {
|
||||
return nil, errRecordNotFound
|
||||
}
|
||||
|
||||
var ips []net.Address
|
||||
var lastErr error
|
||||
if option.IPv4Enable {
|
||||
a, err := record.A.getIPs()
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
}
|
||||
ips = append(ips, a...)
|
||||
}
|
||||
|
||||
if option.IPv6Enable {
|
||||
aaaa, err := record.AAAA.getIPs()
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
}
|
||||
ips = append(ips, aaaa...)
|
||||
}
|
||||
|
||||
if len(ips) > 0 {
|
||||
return toNetIP(ips)
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
return nil, dns_feature.ErrEmptyResponse
|
||||
}
|
||||
|
||||
// QueryIP implements Server.
|
||||
func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) {
|
||||
fqdn := Fqdn(domain)
|
||||
|
||||
if disableCache {
|
||||
newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog()
|
||||
} else {
|
||||
ips, err := s.findIPsForDomain(fqdn, option)
|
||||
if err != errRecordNotFound {
|
||||
newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog()
|
||||
return ips, err
|
||||
}
|
||||
}
|
||||
|
||||
// ipv4 and ipv6 belong to different subscription groups
|
||||
var sub4, sub6 *pubsub.Subscriber
|
||||
if option.IPv4Enable {
|
||||
sub4 = s.pub.Subscribe(fqdn + "4")
|
||||
defer sub4.Close()
|
||||
}
|
||||
if option.IPv6Enable {
|
||||
sub6 = s.pub.Subscribe(fqdn + "6")
|
||||
defer sub6.Close()
|
||||
}
|
||||
done := make(chan interface{})
|
||||
go func() {
|
||||
if sub4 != nil {
|
||||
select {
|
||||
case <-sub4.Wait():
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
if sub6 != nil {
|
||||
select {
|
||||
case <-sub6.Wait():
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
s.sendQuery(ctx, fqdn, clientIP, option)
|
||||
|
||||
for {
|
||||
ips, err := s.findIPsForDomain(fqdn, option)
|
||||
if err != errRecordNotFound {
|
||||
return ips, err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-done:
|
||||
}
|
||||
}
|
||||
}
|
60
app/dns/nameserver_tcp_test.go
Normal file
60
app/dns/nameserver_tcp_test.go
Normal file
@ -0,0 +1,60 @@
|
||||
package dns_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
. "github.com/v2fly/v2ray-core/v4/app/dns"
|
||||
"github.com/v2fly/v2ray-core/v4/common"
|
||||
"github.com/v2fly/v2ray-core/v4/common/net"
|
||||
dns_feature "github.com/v2fly/v2ray-core/v4/features/dns"
|
||||
)
|
||||
|
||||
func TestTCPLocalNameServer(t *testing.T) {
|
||||
url, err := url.Parse("tcp+local://8.8.8.8")
|
||||
common.Must(err)
|
||||
s, err := NewTCPLocalNameServer(url)
|
||||
common.Must(err)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
ips, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
|
||||
IPv4Enable: true,
|
||||
IPv6Enable: true,
|
||||
}, false)
|
||||
cancel()
|
||||
common.Must(err)
|
||||
if len(ips) == 0 {
|
||||
t.Error("expect some ips, but got 0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTCPLocalNameServerWithCache(t *testing.T) {
|
||||
url, err := url.Parse("tcp+local://8.8.8.8")
|
||||
common.Must(err)
|
||||
s, err := NewTCPLocalNameServer(url)
|
||||
common.Must(err)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
ips, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
|
||||
IPv4Enable: true,
|
||||
IPv6Enable: true,
|
||||
}, false)
|
||||
cancel()
|
||||
common.Must(err)
|
||||
if len(ips) == 0 {
|
||||
t.Error("expect some ips, but got 0")
|
||||
}
|
||||
|
||||
ctx2, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
ips2, err := s.QueryIP(ctx2, "google.com", net.IP(nil), dns_feature.IPOption{
|
||||
IPv4Enable: true,
|
||||
IPv6Enable: true,
|
||||
}, true)
|
||||
cancel()
|
||||
common.Must(err)
|
||||
if r := cmp.Diff(ips2, ips); r != "" {
|
||||
t.Fatal(r)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user