mirror of
https://github.com/v2fly/v2ray-core.git
synced 2025-01-18 07:17:32 -05:00
commit
d8a7763847
237
app/dns/dnscommon.go
Normal file
237
app/dns/dnscommon.go
Normal file
@ -0,0 +1,237 @@
|
|||||||
|
// +build !confonly
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/dns/dnsmessage"
|
||||||
|
"v2ray.com/core/common"
|
||||||
|
"v2ray.com/core/common/errors"
|
||||||
|
"v2ray.com/core/common/net"
|
||||||
|
dns_feature "v2ray.com/core/features/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Fqdn normalize domain make sure it ends with '.'
|
||||||
|
func Fqdn(domain string) string {
|
||||||
|
if len(domain) > 0 && domain[len(domain)-1] == '.' {
|
||||||
|
return domain
|
||||||
|
}
|
||||||
|
return domain + "."
|
||||||
|
}
|
||||||
|
|
||||||
|
type record struct {
|
||||||
|
A *IPRecord
|
||||||
|
AAAA *IPRecord
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPRecord is a cacheable item for a resolved domain
|
||||||
|
type IPRecord struct {
|
||||||
|
ReqID uint16
|
||||||
|
IP []net.Address
|
||||||
|
Expire time.Time
|
||||||
|
RCode dnsmessage.RCode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *IPRecord) getIPs() ([]net.Address, error) {
|
||||||
|
if r == nil || r.Expire.Before(time.Now()) {
|
||||||
|
return nil, errRecordNotFound
|
||||||
|
}
|
||||||
|
if r.RCode != dnsmessage.RCodeSuccess {
|
||||||
|
return nil, dns_feature.RCodeError(r.RCode)
|
||||||
|
}
|
||||||
|
return r.IP, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
|
||||||
|
if newRec == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if baseRec == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return baseRec.Expire.Before(newRec.Expire)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
errRecordNotFound = errors.New("record not found")
|
||||||
|
)
|
||||||
|
|
||||||
|
type dnsRequest struct {
|
||||||
|
reqType dnsmessage.Type
|
||||||
|
domain string
|
||||||
|
start time.Time
|
||||||
|
expire time.Time
|
||||||
|
msg *dnsmessage.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
func genEDNS0Options(clientIP net.IP) *dnsmessage.Resource {
|
||||||
|
if len(clientIP) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var netmask int
|
||||||
|
var family uint16
|
||||||
|
|
||||||
|
if len(clientIP) == 4 {
|
||||||
|
family = 1
|
||||||
|
netmask = 24 // 24 for IPV4, 96 for IPv6
|
||||||
|
} else {
|
||||||
|
family = 2
|
||||||
|
netmask = 96
|
||||||
|
}
|
||||||
|
|
||||||
|
b := make([]byte, 4)
|
||||||
|
binary.BigEndian.PutUint16(b[0:], family)
|
||||||
|
b[2] = byte(netmask)
|
||||||
|
b[3] = 0
|
||||||
|
switch family {
|
||||||
|
case 1:
|
||||||
|
ip := clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8))
|
||||||
|
needLength := (netmask + 8 - 1) / 8 // division rounding up
|
||||||
|
b = append(b, ip[:needLength]...)
|
||||||
|
case 2:
|
||||||
|
ip := clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8))
|
||||||
|
needLength := (netmask + 8 - 1) / 8 // division rounding up
|
||||||
|
b = append(b, ip[:needLength]...)
|
||||||
|
}
|
||||||
|
|
||||||
|
const EDNS0SUBNET = 0x08
|
||||||
|
|
||||||
|
opt := new(dnsmessage.Resource)
|
||||||
|
common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
|
||||||
|
|
||||||
|
opt.Body = &dnsmessage.OPTResource{
|
||||||
|
Options: []dnsmessage.Option{
|
||||||
|
{
|
||||||
|
Code: EDNS0SUBNET,
|
||||||
|
Data: b,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return opt
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildReqMsgs(domain string, option IPOption, reqIDGen func() uint16, reqOpts *dnsmessage.Resource) []*dnsRequest {
|
||||||
|
qA := dnsmessage.Question{
|
||||||
|
Name: dnsmessage.MustNewName(domain),
|
||||||
|
Type: dnsmessage.TypeA,
|
||||||
|
Class: dnsmessage.ClassINET,
|
||||||
|
}
|
||||||
|
|
||||||
|
qAAAA := dnsmessage.Question{
|
||||||
|
Name: dnsmessage.MustNewName(domain),
|
||||||
|
Type: dnsmessage.TypeAAAA,
|
||||||
|
Class: dnsmessage.ClassINET,
|
||||||
|
}
|
||||||
|
|
||||||
|
var reqs []*dnsRequest
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
if option.IPv4Enable {
|
||||||
|
msg := new(dnsmessage.Message)
|
||||||
|
msg.Header.ID = reqIDGen()
|
||||||
|
msg.Header.RecursionDesired = true
|
||||||
|
msg.Questions = []dnsmessage.Question{qA}
|
||||||
|
if reqOpts != nil {
|
||||||
|
msg.Additionals = append(msg.Additionals, *reqOpts)
|
||||||
|
}
|
||||||
|
reqs = append(reqs, &dnsRequest{
|
||||||
|
reqType: dnsmessage.TypeA,
|
||||||
|
domain: domain,
|
||||||
|
start: now,
|
||||||
|
msg: msg,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if option.IPv6Enable {
|
||||||
|
msg := new(dnsmessage.Message)
|
||||||
|
msg.Header.ID = reqIDGen()
|
||||||
|
msg.Header.RecursionDesired = true
|
||||||
|
msg.Questions = []dnsmessage.Question{qAAAA}
|
||||||
|
if reqOpts != nil {
|
||||||
|
msg.Additionals = append(msg.Additionals, *reqOpts)
|
||||||
|
}
|
||||||
|
reqs = append(reqs, &dnsRequest{
|
||||||
|
reqType: dnsmessage.TypeAAAA,
|
||||||
|
domain: domain,
|
||||||
|
start: now,
|
||||||
|
msg: msg,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return reqs
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseResponse parse DNS answers from the returned payload
|
||||||
|
func parseResponse(payload []byte) (*IPRecord, error) {
|
||||||
|
var parser dnsmessage.Parser
|
||||||
|
h, err := parser.Start(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, newError("failed to parse DNS response").Base(err).AtWarning()
|
||||||
|
}
|
||||||
|
if err := parser.SkipAllQuestions(); err != nil {
|
||||||
|
return nil, newError("failed to skip questions in DNS response").Base(err).AtWarning()
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
var ipRecExpire time.Time
|
||||||
|
if h.RCode != dnsmessage.RCodeSuccess {
|
||||||
|
// A default TTL, maybe a negtive cache
|
||||||
|
ipRecExpire = now.Add(time.Second * 120)
|
||||||
|
}
|
||||||
|
|
||||||
|
ipRecord := &IPRecord{
|
||||||
|
ReqID: h.ID,
|
||||||
|
RCode: h.RCode,
|
||||||
|
Expire: ipRecExpire,
|
||||||
|
}
|
||||||
|
|
||||||
|
L:
|
||||||
|
for {
|
||||||
|
ah, err := parser.AnswerHeader()
|
||||||
|
if err != nil {
|
||||||
|
if err != dnsmessage.ErrSectionDone {
|
||||||
|
newError("failed to parse answer section for domain: ", ah.Name.String()).Base(err).WriteToLog()
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
switch ah.Type {
|
||||||
|
case dnsmessage.TypeA:
|
||||||
|
ans, err := parser.AResource()
|
||||||
|
if err != nil {
|
||||||
|
newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog()
|
||||||
|
break L
|
||||||
|
}
|
||||||
|
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
|
||||||
|
case dnsmessage.TypeAAAA:
|
||||||
|
ans, err := parser.AAAAResource()
|
||||||
|
if err != nil {
|
||||||
|
newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog()
|
||||||
|
break L
|
||||||
|
}
|
||||||
|
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
|
||||||
|
default:
|
||||||
|
if err := parser.SkipAnswer(); err != nil {
|
||||||
|
newError("failed to skip answer").Base(err).WriteToLog()
|
||||||
|
break L
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if ipRecord.Expire.IsZero() {
|
||||||
|
ttl := ah.TTL
|
||||||
|
if ttl < 600 {
|
||||||
|
// at least 10 mins TTL
|
||||||
|
ipRecord.Expire = now.Add(time.Minute * 10)
|
||||||
|
} else {
|
||||||
|
ipRecord.Expire = now.Add(time.Duration(ttl) * time.Second)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ipRecord, nil
|
||||||
|
}
|
166
app/dns/dnscommon_test.go
Normal file
166
app/dns/dnscommon_test.go
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
// +build !confonly
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"golang.org/x/net/dns/dnsmessage"
|
||||||
|
"v2ray.com/core/common"
|
||||||
|
"v2ray.com/core/common/net"
|
||||||
|
v2net "v2ray.com/core/common/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_parseResponse(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
payload []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
var p [][]byte
|
||||||
|
|
||||||
|
ans := new(dns.Msg)
|
||||||
|
ans.Id = 0
|
||||||
|
p = append(p, common.Must2(ans.Pack()).([]byte))
|
||||||
|
|
||||||
|
p = append(p, []byte{})
|
||||||
|
|
||||||
|
ans = new(dns.Msg)
|
||||||
|
ans.Id = 1
|
||||||
|
ans.Answer = append(ans.Answer,
|
||||||
|
common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR),
|
||||||
|
common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")).(dns.RR),
|
||||||
|
common.Must2(dns.NewRR("google.com. IN A 8.8.8.8")).(dns.RR),
|
||||||
|
common.Must2(dns.NewRR("google.com. IN A 8.8.4.4")).(dns.RR),
|
||||||
|
)
|
||||||
|
p = append(p, common.Must2(ans.Pack()).([]byte))
|
||||||
|
|
||||||
|
ans = new(dns.Msg)
|
||||||
|
ans.Id = 2
|
||||||
|
ans.Answer = append(ans.Answer,
|
||||||
|
common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR),
|
||||||
|
common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")).(dns.RR),
|
||||||
|
common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR),
|
||||||
|
common.Must2(dns.NewRR("google.com. IN CNAME test.google.com")).(dns.RR),
|
||||||
|
common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8888")).(dns.RR),
|
||||||
|
common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8844")).(dns.RR),
|
||||||
|
)
|
||||||
|
p = append(p, common.Must2(ans.Pack()).([]byte))
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
want *IPRecord
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"empty",
|
||||||
|
&IPRecord{0, []v2net.Address(nil), time.Time{}, dnsmessage.RCodeSuccess},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{"error",
|
||||||
|
nil,
|
||||||
|
true,
|
||||||
|
},
|
||||||
|
{"a record",
|
||||||
|
&IPRecord{1, []v2net.Address{v2net.ParseAddress("8.8.8.8"), v2net.ParseAddress("8.8.4.4")},
|
||||||
|
time.Time{}, dnsmessage.RCodeSuccess},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
{"aaaa record",
|
||||||
|
&IPRecord{2, []v2net.Address{v2net.ParseAddress("2001::123:8888"), v2net.ParseAddress("2001::123:8844")}, time.Time{}, dnsmessage.RCodeSuccess},
|
||||||
|
false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for i, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := parseResponse(p[i])
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("handleResponse() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if got != nil {
|
||||||
|
// reset the time
|
||||||
|
got.Expire = time.Time{}
|
||||||
|
}
|
||||||
|
if cmp.Diff(got, tt.want) != "" {
|
||||||
|
t.Errorf(cmp.Diff(got, tt.want))
|
||||||
|
// t.Errorf("handleResponse() = %#v, want %#v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_buildReqMsgs(t *testing.T) {
|
||||||
|
|
||||||
|
stubID := func() uint16 {
|
||||||
|
return uint16(rand.Uint32())
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
domain string
|
||||||
|
option IPOption
|
||||||
|
reqOpts *dnsmessage.Resource
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{"dual stack", args{"test.com", IPOption{true, true}, nil}, 2},
|
||||||
|
{"ipv4 only", args{"test.com", IPOption{true, false}, nil}, 1},
|
||||||
|
{"ipv6 only", args{"test.com", IPOption{false, true}, nil}, 1},
|
||||||
|
{"none/error", args{"test.com", IPOption{false, false}, nil}, 0},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := buildReqMsgs(tt.args.domain, tt.args.option, stubID, tt.args.reqOpts); !(len(got) == tt.want) {
|
||||||
|
t.Errorf("buildReqMsgs() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_genEDNS0Options(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
clientIP net.IP
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want *dnsmessage.Resource
|
||||||
|
}{
|
||||||
|
// TODO: Add test cases.
|
||||||
|
{"ipv4", args{net.ParseIP("4.3.2.1")}, nil},
|
||||||
|
{"ipv6", args{net.ParseIP("2001::4321")}, nil},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := genEDNS0Options(tt.args.clientIP); got == nil {
|
||||||
|
t.Errorf("genEDNS0Options() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFqdn(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
domain string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"with fqdn", args{"www.v2ray.com."}, "www.v2ray.com."},
|
||||||
|
{"without fqdn", args{"www.v2ray.com"}, "www.v2ray.com."},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := Fqdn(tt.args.domain); got != tt.want {
|
||||||
|
t.Errorf("Fqdn() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
370
app/dns/dohdns.go
Normal file
370
app/dns/dohdns.go
Normal file
@ -0,0 +1,370 @@
|
|||||||
|
// +build !confonly
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/dns/dnsmessage"
|
||||||
|
"v2ray.com/core/common"
|
||||||
|
"v2ray.com/core/common/dice"
|
||||||
|
"v2ray.com/core/common/net"
|
||||||
|
"v2ray.com/core/common/protocol/dns"
|
||||||
|
"v2ray.com/core/common/session"
|
||||||
|
"v2ray.com/core/common/signal/pubsub"
|
||||||
|
"v2ray.com/core/common/task"
|
||||||
|
"v2ray.com/core/features/routing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DoHNameServer implimented DNS over HTTPS (RFC8484) Wire Format,
|
||||||
|
// which is compatiable with traditional dns over udp(RFC1035),
|
||||||
|
// thus most of the DOH implimentation is copied from udpns.go
|
||||||
|
type DoHNameServer struct {
|
||||||
|
sync.RWMutex
|
||||||
|
dispatcher routing.Dispatcher
|
||||||
|
dohDests []net.Destination
|
||||||
|
ips map[string]record
|
||||||
|
pub *pubsub.Service
|
||||||
|
cleanup *task.Periodic
|
||||||
|
reqID uint32
|
||||||
|
clientIP net.IP
|
||||||
|
httpClient *http.Client
|
||||||
|
dohURL string
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDoHNameServer creates DOH client object for remote resolving
|
||||||
|
func NewDoHNameServer(dohHost string, dohPort uint32, dispatcher routing.Dispatcher, clientIP net.IP) (*DoHNameServer, error) {
|
||||||
|
|
||||||
|
dohAddr := net.ParseAddress(dohHost)
|
||||||
|
var dests []net.Destination
|
||||||
|
|
||||||
|
if dohPort == 0 {
|
||||||
|
dohPort = 443
|
||||||
|
}
|
||||||
|
|
||||||
|
parseIPDest := func(ip net.IP, port uint32) net.Destination {
|
||||||
|
strIP := ip.String()
|
||||||
|
if len(ip) == net.IPv6len {
|
||||||
|
strIP = fmt.Sprintf("[%s]", strIP)
|
||||||
|
}
|
||||||
|
dest, err := net.ParseDestination(fmt.Sprintf("tcp:%s:%d", strIP, port))
|
||||||
|
common.Must(err)
|
||||||
|
return dest
|
||||||
|
}
|
||||||
|
|
||||||
|
if dohAddr.Family().IsDomain() {
|
||||||
|
// resolve DOH server in advance
|
||||||
|
ips, err := net.LookupIP(dohAddr.Domain())
|
||||||
|
if err != nil || len(ips) == 0 {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, ip := range ips {
|
||||||
|
dests = append(dests, parseIPDest(ip, dohPort))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ip := dohAddr.IP()
|
||||||
|
dests = append(dests, parseIPDest(ip, dohPort))
|
||||||
|
}
|
||||||
|
|
||||||
|
newError("DNS: created remote DOH client for https://", dohHost, ":", dohPort).AtInfo().WriteToLog()
|
||||||
|
s := baseDOHNameServer(dohHost, dohPort, "DOH", clientIP)
|
||||||
|
s.dispatcher = dispatcher
|
||||||
|
s.dohDests = dests
|
||||||
|
|
||||||
|
// Dispatched connection will be closed (interupted) after each request
|
||||||
|
// This makes DOH inefficient without a keeped-alive connection
|
||||||
|
// See: core/app/proxyman/outbound/handler.go:113
|
||||||
|
// Using mux (https request wrapped in a stream layer) improves the situation.
|
||||||
|
// Recommand to use NewDoHLocalNameServer (DOHL:) if v2ray instance is running on
|
||||||
|
// a normal network eg. the server side of v2ray
|
||||||
|
tr := &http.Transport{
|
||||||
|
MaxIdleConns: 10,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
DialContext: s.DialContext,
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatchedClient := &http.Client{
|
||||||
|
Transport: tr,
|
||||||
|
Timeout: 16 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
s.httpClient = dispatchedClient
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDoHLocalNameServer creates DOH client object for local resolving
|
||||||
|
func NewDoHLocalNameServer(dohHost string, dohPort uint32, clientIP net.IP) *DoHNameServer {
|
||||||
|
|
||||||
|
if dohPort == 0 {
|
||||||
|
dohPort = 443
|
||||||
|
}
|
||||||
|
|
||||||
|
s := baseDOHNameServer(dohHost, dohPort, "DOHL", clientIP)
|
||||||
|
s.httpClient = &http.Client{
|
||||||
|
Timeout: time.Second * 180,
|
||||||
|
}
|
||||||
|
newError("DNS: created local DOH client for https://", dohHost, ":", dohPort).AtInfo().WriteToLog()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func baseDOHNameServer(dohHost string, dohPort uint32, prefix string, clientIP net.IP) *DoHNameServer {
|
||||||
|
|
||||||
|
if dohPort == 0 {
|
||||||
|
dohPort = 443
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &DoHNameServer{
|
||||||
|
ips: make(map[string]record),
|
||||||
|
clientIP: clientIP,
|
||||||
|
pub: pubsub.NewService(),
|
||||||
|
name: fmt.Sprintf("%s:%s:%d", prefix, dohHost, dohPort),
|
||||||
|
dohURL: fmt.Sprintf("https://%s:%d/dns-query", dohHost, dohPort),
|
||||||
|
}
|
||||||
|
s.cleanup = &task.Periodic{
|
||||||
|
Interval: time.Minute,
|
||||||
|
Execute: s.Cleanup,
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns client name
|
||||||
|
func (s *DoHNameServer) Name() string {
|
||||||
|
return s.name
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialContext offer dispatched connection through core routing
|
||||||
|
func (s *DoHNameServer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
|
||||||
|
dest := s.dohDests[dice.Roll(len(s.dohDests))]
|
||||||
|
|
||||||
|
link, err := s.dispatcher.Dispatch(ctx, dest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return net.NewConnection(
|
||||||
|
net.ConnectionInputMulti(link.Writer),
|
||||||
|
net.ConnectionOutputMulti(link.Reader),
|
||||||
|
), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup clears expired items from cache
|
||||||
|
func (s *DoHNameServer) Cleanup() error {
|
||||||
|
now := time.Now()
|
||||||
|
s.Lock()
|
||||||
|
defer s.Unlock()
|
||||||
|
|
||||||
|
if len(s.ips) == 0 {
|
||||||
|
return newError("nothing to do. stopping...")
|
||||||
|
}
|
||||||
|
|
||||||
|
for domain, record := range s.ips {
|
||||||
|
if record.A != nil && record.A.Expire.Before(now) {
|
||||||
|
record.A = nil
|
||||||
|
}
|
||||||
|
if record.AAAA != nil && record.AAAA.Expire.Before(now) {
|
||||||
|
record.AAAA = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if record.A == nil && record.AAAA == nil {
|
||||||
|
newError(s.name, " cleanup ", domain).AtDebug().WriteToLog()
|
||||||
|
delete(s.ips, domain)
|
||||||
|
} else {
|
||||||
|
s.ips[domain] = record
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(s.ips) == 0 {
|
||||||
|
s.ips = make(map[string]record)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
|
||||||
|
elapsed := time.Since(req.start)
|
||||||
|
newError(s.name, " got answere: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog()
|
||||||
|
|
||||||
|
s.Lock()
|
||||||
|
rec := s.ips[req.domain]
|
||||||
|
updated := false
|
||||||
|
|
||||||
|
switch req.reqType {
|
||||||
|
case dnsmessage.TypeA:
|
||||||
|
if isNewer(rec.A, ipRec) {
|
||||||
|
rec.A = ipRec
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
case dnsmessage.TypeAAAA:
|
||||||
|
if isNewer(rec.AAAA, ipRec) {
|
||||||
|
rec.AAAA = ipRec
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if updated {
|
||||||
|
s.ips[req.domain] = rec
|
||||||
|
s.pub.Publish(req.domain, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Unlock()
|
||||||
|
common.Must(s.cleanup.Start())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DoHNameServer) newReqID() uint16 {
|
||||||
|
return uint16(atomic.AddUint32(&s.reqID, 1))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DoHNameServer) sendQuery(ctx context.Context, domain string, option IPOption) {
|
||||||
|
newError(s.name, " querying: ", domain).AtInfo().WriteToLog(session.ExportIDToError(ctx))
|
||||||
|
|
||||||
|
reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP))
|
||||||
|
|
||||||
|
var deadline time.Time
|
||||||
|
if d, ok := ctx.Deadline(); ok {
|
||||||
|
deadline = d
|
||||||
|
} else {
|
||||||
|
deadline = time.Now().Add(time.Second * 8)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, req := range reqs {
|
||||||
|
|
||||||
|
go func(r *dnsRequest) {
|
||||||
|
|
||||||
|
// generate new context for each req, using same context
|
||||||
|
// may cause reqs all aborted if any one encounter an error
|
||||||
|
dnsCtx := context.Background()
|
||||||
|
|
||||||
|
// reserve internal dns server requested Inbound
|
||||||
|
if inbound := session.InboundFromContext(ctx); inbound != nil {
|
||||||
|
dnsCtx = session.ContextWithInbound(dnsCtx, inbound)
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{
|
||||||
|
Protocol: "https",
|
||||||
|
})
|
||||||
|
|
||||||
|
// forced to use mux for DOH
|
||||||
|
dnsCtx = session.ContextWithMuxPrefered(dnsCtx, true)
|
||||||
|
|
||||||
|
dnsCtx, cancel := context.WithDeadline(dnsCtx, deadline)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
b, _ := dns.PackMessage(r.msg)
|
||||||
|
resp, err := s.dohHTTPSContext(dnsCtx, b.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
newError("failed to retrive response").Base(err).AtError().WriteToLog()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rec, err := parseResponse(resp)
|
||||||
|
if err != nil {
|
||||||
|
newError("failed to handle DOH response").Base(err).AtError().WriteToLog()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.updateIP(r, rec)
|
||||||
|
}(req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DoHNameServer) dohHTTPSContext(ctx context.Context, b []byte) ([]byte, error) {
|
||||||
|
|
||||||
|
body := bytes.NewBuffer(b)
|
||||||
|
req, err := http.NewRequest("POST", s.dohURL, body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Add("Accept", "application/dns-message")
|
||||||
|
req.Header.Add("Content-Type", "application/dns-message")
|
||||||
|
|
||||||
|
resp, err := s.httpClient.Do(req.WithContext(ctx))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
err = fmt.Errorf("DOH HTTPS server returned with non-OK code %d", resp.StatusCode)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ioutil.ReadAll(resp.Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DoHNameServer) findIPsForDomain(domain string, option IPOption) ([]net.IP, error) {
|
||||||
|
s.RLock()
|
||||||
|
record, found := s.ips[domain]
|
||||||
|
s.RUnlock()
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
return nil, errRecordNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
var ips []net.Address
|
||||||
|
var lastErr error
|
||||||
|
if option.IPv6Enable && record.AAAA != nil && record.AAAA.RCode == dnsmessage.RCodeSuccess {
|
||||||
|
aaaa, err := record.AAAA.getIPs()
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
}
|
||||||
|
ips = append(ips, aaaa...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if option.IPv4Enable && record.A != nil && record.A.RCode == dnsmessage.RCodeSuccess {
|
||||||
|
a, err := record.A.getIPs()
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
}
|
||||||
|
ips = append(ips, a...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ips) > 0 {
|
||||||
|
return toNetIP(ips), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastErr != nil {
|
||||||
|
return nil, lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errRecordNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryIP is called from dns.Server->queryIPTimeout
|
||||||
|
func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) {
|
||||||
|
|
||||||
|
fqdn := Fqdn(domain)
|
||||||
|
|
||||||
|
ips, err := s.findIPsForDomain(fqdn, option)
|
||||||
|
if err != errRecordNotFound {
|
||||||
|
newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog()
|
||||||
|
return ips, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sub := s.pub.Subscribe(fqdn)
|
||||||
|
defer sub.Close()
|
||||||
|
|
||||||
|
s.sendQuery(ctx, fqdn, option)
|
||||||
|
|
||||||
|
for {
|
||||||
|
ips, err := s.findIPsForDomain(fqdn, option)
|
||||||
|
if err != errRecordNotFound {
|
||||||
|
return ips, err
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-sub.Wait():
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -49,6 +49,7 @@ func (s *localNameServer) Name() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewLocalNameServer() *localNameServer {
|
func NewLocalNameServer() *localNameServer {
|
||||||
|
newError("DNS: created localhost client").AtInfo().WriteToLog()
|
||||||
return &localNameServer{
|
return &localNameServer{
|
||||||
client: localdns.New(),
|
client: localdns.New(),
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,8 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"log"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -39,7 +41,7 @@ type MultiGeoIPMatcher struct {
|
|||||||
matchers []*router.GeoIPMatcher
|
matchers []*router.GeoIPMatcher
|
||||||
}
|
}
|
||||||
|
|
||||||
var errExpectedIPNonMatch = errors.New("expected ip not match")
|
var errExpectedIPNonMatch = errors.New("expectIPs not match")
|
||||||
|
|
||||||
// Match check ip match
|
// Match check ip match
|
||||||
func (c *MultiGeoIPMatcher) Match(ip net.IP) bool {
|
func (c *MultiGeoIPMatcher) Match(ip net.IP) bool {
|
||||||
@ -71,7 +73,7 @@ func New(ctx context.Context, config *Config) (*Server, error) {
|
|||||||
server.tag = generateRandomTag()
|
server.tag = generateRandomTag()
|
||||||
}
|
}
|
||||||
if len(config.ClientIp) > 0 {
|
if len(config.ClientIp) > 0 {
|
||||||
if len(config.ClientIp) != 4 && len(config.ClientIp) != 16 {
|
if len(config.ClientIp) != net.IPv4len && len(config.ClientIp) != net.IPv6len {
|
||||||
return nil, newError("unexpected IP length", len(config.ClientIp))
|
return nil, newError("unexpected IP length", len(config.ClientIp))
|
||||||
}
|
}
|
||||||
server.clientIP = net.IP(config.ClientIp)
|
server.clientIP = net.IP(config.ClientIp)
|
||||||
@ -87,6 +89,23 @@ func New(ctx context.Context, config *Config) (*Server, error) {
|
|||||||
address := endpoint.Address.AsAddress()
|
address := endpoint.Address.AsAddress()
|
||||||
if address.Family().IsDomain() && address.Domain() == "localhost" {
|
if address.Family().IsDomain() && address.Domain() == "localhost" {
|
||||||
server.clients = append(server.clients, NewLocalNameServer())
|
server.clients = append(server.clients, NewLocalNameServer())
|
||||||
|
} else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "DOHL_") {
|
||||||
|
dohHost := address.Domain()[5:]
|
||||||
|
server.clients = append(server.clients, NewDoHLocalNameServer(dohHost, endpoint.Port, server.clientIP))
|
||||||
|
} else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "DOH_") {
|
||||||
|
// DOH_ prefix makes net.Address think it's a domain
|
||||||
|
dohHost := address.Domain()[4:]
|
||||||
|
idx := len(server.clients)
|
||||||
|
server.clients = append(server.clients, nil)
|
||||||
|
|
||||||
|
// need the core dispatcher, register DOHClient at callback
|
||||||
|
common.Must(core.RequireFeatures(ctx, func(d routing.Dispatcher) {
|
||||||
|
c, err := NewDoHNameServer(dohHost, endpoint.Port, d, server.clientIP)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln(newError("DNS config error").Base(err))
|
||||||
|
}
|
||||||
|
server.clients[idx] = c
|
||||||
|
}))
|
||||||
} else {
|
} else {
|
||||||
dest := endpoint.AsDestination()
|
dest := endpoint.AsDestination()
|
||||||
if dest.Network == net.Network_Unknown {
|
if dest.Network == net.Network_Unknown {
|
||||||
@ -129,6 +148,8 @@ func New(ctx context.Context, config *Config) (*Server, error) {
|
|||||||
domainIndexMap[midx] = uint32(idx)
|
domainIndexMap[midx] = uint32(idx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// only add to ipIndexMap if GeoIP is configured
|
||||||
|
if len(ns.Geoip) > 0 {
|
||||||
var matchers []*router.GeoIPMatcher
|
var matchers []*router.GeoIPMatcher
|
||||||
for _, geoip := range ns.Geoip {
|
for _, geoip := range ns.Geoip {
|
||||||
matcher, err := geoIPMatcherContainer.Add(geoip)
|
matcher, err := geoIPMatcherContainer.Add(geoip)
|
||||||
@ -140,6 +161,7 @@ func New(ctx context.Context, config *Config) (*Server, error) {
|
|||||||
matcher := &MultiGeoIPMatcher{matchers: matchers}
|
matcher := &MultiGeoIPMatcher{matchers: matchers}
|
||||||
ipIndexMap[uint32(idx)] = matcher
|
ipIndexMap[uint32(idx)] = matcher
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
server.domainMatcher = domainMatcher
|
server.domainMatcher = domainMatcher
|
||||||
server.domainIndexMap = domainIndexMap
|
server.domainIndexMap = domainIndexMap
|
||||||
@ -177,12 +199,11 @@ func (s *Server) IsOwnLink(ctx context.Context) bool {
|
|||||||
func (s *Server) Match(idx uint32, client Client, domain string, ips []net.IP) ([]net.IP, error) {
|
func (s *Server) Match(idx uint32, client Client, domain string, ips []net.IP) ([]net.IP, error) {
|
||||||
matcher, exist := s.ipIndexMap[idx]
|
matcher, exist := s.ipIndexMap[idx]
|
||||||
if !exist {
|
if !exist {
|
||||||
newError("domain ", domain, " server not in ipIndexMap: ", client.Name(), " idx:", idx, " just return").AtDebug().WriteToLog()
|
|
||||||
return ips, nil
|
return ips, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if !matcher.HasMatcher() {
|
if !matcher.HasMatcher() {
|
||||||
newError("domain ", domain, " server has not valid matcher: ", client.Name(), " idx:", idx, " just return").AtDebug().WriteToLog()
|
newError("domain ", domain, " server has no valid matcher: ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
|
||||||
return ips, nil
|
return ips, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -190,14 +211,12 @@ func (s *Server) Match(idx uint32, client Client, domain string, ips []net.IP) (
|
|||||||
for _, ip := range ips {
|
for _, ip := range ips {
|
||||||
if matcher.Match(ip) {
|
if matcher.Match(ip) {
|
||||||
newIps = append(newIps, ip)
|
newIps = append(newIps, ip)
|
||||||
newError("domain ", domain, " ip ", ip, " is match at server ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
|
|
||||||
} else {
|
|
||||||
newError("domain ", domain, " ip ", ip, " is not match at server ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(newIps) == 0 {
|
if len(newIps) == 0 {
|
||||||
return nil, errExpectedIPNonMatch
|
return nil, errExpectedIPNonMatch
|
||||||
}
|
}
|
||||||
|
newError("domain ", domain, " expectIPs ", newIps, " matched at server ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
|
||||||
return newIps, nil
|
return newIps, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -272,10 +291,16 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
|
|||||||
return nil, newError("empty domain name")
|
return nil, newError("empty domain name")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// normalize the FQDN form query
|
||||||
if domain[len(domain)-1] == '.' {
|
if domain[len(domain)-1] == '.' {
|
||||||
domain = domain[:len(domain)-1]
|
domain = domain[:len(domain)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// skip domain without any dot
|
||||||
|
if strings.Index(domain, ".") == -1 {
|
||||||
|
return nil, newError("invalid domain name").AtWarning()
|
||||||
|
}
|
||||||
|
|
||||||
ips := s.lookupStatic(domain, option, 0)
|
ips := s.lookupStatic(domain, option, 0)
|
||||||
if ips != nil && ips[0].Family().IsIP() {
|
if ips != nil && ips[0].Family().IsIP() {
|
||||||
newError("returning ", len(ips), " IPs for domain ", domain).WriteToLog()
|
newError("returning ", len(ips), " IPs for domain ", domain).WriteToLog()
|
||||||
@ -294,7 +319,6 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
|
|||||||
idx := s.domainMatcher.Match(domain)
|
idx := s.domainMatcher.Match(domain)
|
||||||
if idx > 0 {
|
if idx > 0 {
|
||||||
matchedClient = s.clients[s.domainIndexMap[idx]]
|
matchedClient = s.clients[s.domainIndexMap[idx]]
|
||||||
newError("domain matched, direct lookup ip for domain ", domain, " at ", matchedClient.Name()).WriteToLog()
|
|
||||||
ips, err := s.queryIPTimeout(s.domainIndexMap[idx], matchedClient, domain, option)
|
ips, err := s.queryIPTimeout(s.domainIndexMap[idx], matchedClient, domain, option)
|
||||||
if len(ips) > 0 {
|
if len(ips) > 0 {
|
||||||
return ips, nil
|
return ips, nil
|
||||||
@ -315,10 +339,8 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
newError("try to lookup ip for domain ", domain, " at server ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
|
|
||||||
ips, err := s.queryIPTimeout(uint32(idx), client, domain, option)
|
ips, err := s.queryIPTimeout(uint32(idx), client, domain, option)
|
||||||
if len(ips) > 0 {
|
if len(ips) > 0 {
|
||||||
newError("lookup ip for domain ", domain, " success: ", ips, " at server ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
|
|
||||||
return ips, nil
|
return ips, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -331,7 +353,7 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, newError("returning nil for domain ", domain).Base(lastErr)
|
return nil, dns.ErrEmptyResponse.Base(lastErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
276
app/dns/udpns.go
276
app/dns/udpns.go
@ -4,14 +4,13 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/net/dns/dnsmessage"
|
"golang.org/x/net/dns/dnsmessage"
|
||||||
"v2ray.com/core/common"
|
"v2ray.com/core/common"
|
||||||
"v2ray.com/core/common/errors"
|
|
||||||
"v2ray.com/core/common/net"
|
"v2ray.com/core/common/net"
|
||||||
"v2ray.com/core/common/protocol/dns"
|
"v2ray.com/core/common/protocol/dns"
|
||||||
udp_proto "v2ray.com/core/common/protocol/udp"
|
udp_proto "v2ray.com/core/common/protocol/udp"
|
||||||
@ -23,42 +22,12 @@ import (
|
|||||||
"v2ray.com/core/transport/internet/udp"
|
"v2ray.com/core/transport/internet/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
type record struct {
|
|
||||||
A *IPRecord
|
|
||||||
AAAA *IPRecord
|
|
||||||
}
|
|
||||||
|
|
||||||
type IPRecord struct {
|
|
||||||
IP []net.Address
|
|
||||||
Expire time.Time
|
|
||||||
RCode dnsmessage.RCode
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *IPRecord) getIPs() ([]net.Address, error) {
|
|
||||||
if r == nil || r.Expire.Before(time.Now()) {
|
|
||||||
return nil, errRecordNotFound
|
|
||||||
}
|
|
||||||
if r.RCode != dnsmessage.RCodeSuccess {
|
|
||||||
return nil, dns_feature.RCodeError(r.RCode)
|
|
||||||
}
|
|
||||||
return r.IP, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type pendingRequest struct {
|
|
||||||
domain string
|
|
||||||
expire time.Time
|
|
||||||
recType dnsmessage.Type
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
errRecordNotFound = errors.New("record not found")
|
|
||||||
)
|
|
||||||
|
|
||||||
type ClassicNameServer struct {
|
type ClassicNameServer struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
|
name string
|
||||||
address net.Destination
|
address net.Destination
|
||||||
ips map[string]record
|
ips map[string]record
|
||||||
requests map[uint16]pendingRequest
|
requests map[uint16]dnsRequest
|
||||||
pub *pubsub.Service
|
pub *pubsub.Service
|
||||||
udpServer *udp.Dispatcher
|
udpServer *udp.Dispatcher
|
||||||
cleanup *task.Periodic
|
cleanup *task.Periodic
|
||||||
@ -67,23 +36,31 @@ type ClassicNameServer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, clientIP net.IP) *ClassicNameServer {
|
func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, clientIP net.IP) *ClassicNameServer {
|
||||||
|
|
||||||
|
// default to 53 if unspecific
|
||||||
|
if address.Port == 0 {
|
||||||
|
address.Port = net.Port(53)
|
||||||
|
}
|
||||||
|
|
||||||
s := &ClassicNameServer{
|
s := &ClassicNameServer{
|
||||||
address: address,
|
address: address,
|
||||||
ips: make(map[string]record),
|
ips: make(map[string]record),
|
||||||
requests: make(map[uint16]pendingRequest),
|
requests: make(map[uint16]dnsRequest),
|
||||||
clientIP: clientIP,
|
clientIP: clientIP,
|
||||||
pub: pubsub.NewService(),
|
pub: pubsub.NewService(),
|
||||||
|
name: strings.ToUpper(address.String()),
|
||||||
}
|
}
|
||||||
s.cleanup = &task.Periodic{
|
s.cleanup = &task.Periodic{
|
||||||
Interval: time.Minute,
|
Interval: time.Minute,
|
||||||
Execute: s.Cleanup,
|
Execute: s.Cleanup,
|
||||||
}
|
}
|
||||||
s.udpServer = udp.NewDispatcher(dispatcher, s.HandleResponse)
|
s.udpServer = udp.NewDispatcher(dispatcher, s.HandleResponse)
|
||||||
|
newError("DNS: created udp client inited for ", address.NetAddr()).AtInfo().WriteToLog()
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ClassicNameServer) Name() string {
|
func (s *ClassicNameServer) Name() string {
|
||||||
return s.address.String()
|
return s.name
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ClassicNameServer) Cleanup() error {
|
func (s *ClassicNameServer) Cleanup() error {
|
||||||
@ -92,7 +69,7 @@ func (s *ClassicNameServer) Cleanup() error {
|
|||||||
defer s.Unlock()
|
defer s.Unlock()
|
||||||
|
|
||||||
if len(s.ips) == 0 && len(s.requests) == 0 {
|
if len(s.ips) == 0 && len(s.requests) == 0 {
|
||||||
return newError("nothing to do. stopping...")
|
return newError(s.name, " nothing to do. stopping...")
|
||||||
}
|
}
|
||||||
|
|
||||||
for domain, record := range s.ips {
|
for domain, record := range s.ips {
|
||||||
@ -121,123 +98,52 @@ func (s *ClassicNameServer) Cleanup() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(s.requests) == 0 {
|
if len(s.requests) == 0 {
|
||||||
s.requests = make(map[uint16]pendingRequest)
|
s.requests = make(map[uint16]dnsRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) {
|
func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) {
|
||||||
payload := packet.Payload
|
|
||||||
|
|
||||||
var parser dnsmessage.Parser
|
ipRec, err := parseResponse(packet.Payload.Bytes())
|
||||||
header, err := parser.Start(payload.Bytes())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
newError("failed to parse DNS response").Base(err).AtWarning().WriteToLog()
|
newError(s.name, " fail to parse responsed DNS udp").AtError().WriteToLog()
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := parser.SkipAllQuestions(); err != nil {
|
|
||||||
newError("failed to skip questions in DNS response").Base(err).AtWarning().WriteToLog()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
id := header.ID
|
|
||||||
s.Lock()
|
s.Lock()
|
||||||
req, f := s.requests[id]
|
id := ipRec.ReqID
|
||||||
if f {
|
req, ok := s.requests[id]
|
||||||
|
if ok {
|
||||||
|
// remove the pending request
|
||||||
delete(s.requests, id)
|
delete(s.requests, id)
|
||||||
}
|
}
|
||||||
s.Unlock()
|
s.Unlock()
|
||||||
|
if !ok {
|
||||||
if !f {
|
newError(s.name, " cannot find the pending request").AtError().WriteToLog()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
domain := req.domain
|
|
||||||
recType := req.recType
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
ipRecord := &IPRecord{
|
|
||||||
RCode: header.RCode,
|
|
||||||
Expire: now.Add(time.Second * 600),
|
|
||||||
}
|
|
||||||
|
|
||||||
L:
|
|
||||||
for {
|
|
||||||
header, err := parser.AnswerHeader()
|
|
||||||
if err != nil {
|
|
||||||
if err != dnsmessage.ErrSectionDone {
|
|
||||||
newError("failed to parse answer section for domain: ", domain).Base(err).WriteToLog()
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
ttl := header.TTL
|
|
||||||
if ttl == 0 {
|
|
||||||
ttl = 600
|
|
||||||
}
|
|
||||||
expire := now.Add(time.Duration(ttl) * time.Second)
|
|
||||||
if ipRecord.Expire.After(expire) {
|
|
||||||
ipRecord.Expire = expire
|
|
||||||
}
|
|
||||||
|
|
||||||
if header.Type != recType {
|
|
||||||
if err := parser.SkipAnswer(); err != nil {
|
|
||||||
newError("failed to skip answer").Base(err).WriteToLog()
|
|
||||||
break L
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
switch header.Type {
|
|
||||||
case dnsmessage.TypeA:
|
|
||||||
ans, err := parser.AResource()
|
|
||||||
if err != nil {
|
|
||||||
newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
|
|
||||||
break L
|
|
||||||
}
|
|
||||||
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
|
|
||||||
case dnsmessage.TypeAAAA:
|
|
||||||
ans, err := parser.AAAAResource()
|
|
||||||
if err != nil {
|
|
||||||
newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
|
|
||||||
break L
|
|
||||||
}
|
|
||||||
ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
|
|
||||||
default:
|
|
||||||
if err := parser.SkipAnswer(); err != nil {
|
|
||||||
newError("failed to skip answer").Base(err).WriteToLog()
|
|
||||||
break L
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var rec record
|
var rec record
|
||||||
switch recType {
|
switch req.reqType {
|
||||||
case dnsmessage.TypeA:
|
case dnsmessage.TypeA:
|
||||||
rec.A = ipRecord
|
rec.A = ipRec
|
||||||
case dnsmessage.TypeAAAA:
|
case dnsmessage.TypeAAAA:
|
||||||
rec.AAAA = ipRecord
|
rec.AAAA = ipRec
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(domain) > 0 && (rec.A != nil || rec.AAAA != nil) {
|
elapsed := time.Since(req.start)
|
||||||
s.updateIP(domain, rec)
|
newError(s.name, " got answere: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog()
|
||||||
|
if len(req.domain) > 0 && (rec.A != nil || rec.AAAA != nil) {
|
||||||
|
s.updateIP(req.domain, rec)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
|
|
||||||
if newRec == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if baseRec == nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return baseRec.Expire.Before(newRec.Expire)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ClassicNameServer) updateIP(domain string, newRec record) {
|
func (s *ClassicNameServer) updateIP(domain string, newRec record) {
|
||||||
s.Lock()
|
s.Lock()
|
||||||
|
|
||||||
newError("updating IP records for domain:", domain).AtDebug().WriteToLog()
|
newError(s.name, " updating IP records for domain:", domain).AtDebug().WriteToLog()
|
||||||
rec := s.ips[domain]
|
rec := s.ips[domain]
|
||||||
|
|
||||||
updated := false
|
updated := false
|
||||||
@ -259,116 +165,27 @@ func (s *ClassicNameServer) updateIP(domain string, newRec record) {
|
|||||||
common.Must(s.cleanup.Start())
|
common.Must(s.cleanup.Start())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ClassicNameServer) getMsgOptions() *dnsmessage.Resource {
|
func (s *ClassicNameServer) newReqID() uint16 {
|
||||||
if len(s.clientIP) == 0 {
|
return uint16(atomic.AddUint32(&s.reqID, 1))
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var netmask int
|
|
||||||
var family uint16
|
|
||||||
|
|
||||||
if len(s.clientIP) == 4 {
|
|
||||||
family = 1
|
|
||||||
netmask = 24 // 24 for IPV4, 96 for IPv6
|
|
||||||
} else {
|
|
||||||
family = 2
|
|
||||||
netmask = 96
|
|
||||||
}
|
|
||||||
|
|
||||||
b := make([]byte, 4)
|
|
||||||
binary.BigEndian.PutUint16(b[0:], family)
|
|
||||||
b[2] = byte(netmask)
|
|
||||||
b[3] = 0
|
|
||||||
switch family {
|
|
||||||
case 1:
|
|
||||||
ip := s.clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8))
|
|
||||||
needLength := (netmask + 8 - 1) / 8 // division rounding up
|
|
||||||
b = append(b, ip[:needLength]...)
|
|
||||||
case 2:
|
|
||||||
ip := s.clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8))
|
|
||||||
needLength := (netmask + 8 - 1) / 8 // division rounding up
|
|
||||||
b = append(b, ip[:needLength]...)
|
|
||||||
}
|
|
||||||
|
|
||||||
const EDNS0SUBNET = 0x08
|
|
||||||
|
|
||||||
opt := new(dnsmessage.Resource)
|
|
||||||
common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
|
|
||||||
|
|
||||||
opt.Body = &dnsmessage.OPTResource{
|
|
||||||
Options: []dnsmessage.Option{
|
|
||||||
{
|
|
||||||
Code: EDNS0SUBNET,
|
|
||||||
Data: b,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
return opt
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ClassicNameServer) addPendingRequest(domain string, recType dnsmessage.Type) uint16 {
|
func (s *ClassicNameServer) addPendingRequest(req *dnsRequest) {
|
||||||
id := uint16(atomic.AddUint32(&s.reqID, 1))
|
|
||||||
s.Lock()
|
s.Lock()
|
||||||
defer s.Unlock()
|
defer s.Unlock()
|
||||||
|
|
||||||
s.requests[id] = pendingRequest{
|
id := req.msg.ID
|
||||||
domain: domain,
|
req.expire = time.Now().Add(time.Second * 8)
|
||||||
expire: time.Now().Add(time.Second * 8),
|
s.requests[id] = *req
|
||||||
recType: recType,
|
|
||||||
}
|
|
||||||
|
|
||||||
return id
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmessage.Message {
|
|
||||||
qA := dnsmessage.Question{
|
|
||||||
Name: dnsmessage.MustNewName(domain),
|
|
||||||
Type: dnsmessage.TypeA,
|
|
||||||
Class: dnsmessage.ClassINET,
|
|
||||||
}
|
|
||||||
|
|
||||||
qAAAA := dnsmessage.Question{
|
|
||||||
Name: dnsmessage.MustNewName(domain),
|
|
||||||
Type: dnsmessage.TypeAAAA,
|
|
||||||
Class: dnsmessage.ClassINET,
|
|
||||||
}
|
|
||||||
|
|
||||||
var msgs []*dnsmessage.Message
|
|
||||||
|
|
||||||
if option.IPv4Enable {
|
|
||||||
msg := new(dnsmessage.Message)
|
|
||||||
msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeA)
|
|
||||||
msg.Header.RecursionDesired = true
|
|
||||||
msg.Questions = []dnsmessage.Question{qA}
|
|
||||||
if opt := s.getMsgOptions(); opt != nil {
|
|
||||||
msg.Additionals = append(msg.Additionals, *opt)
|
|
||||||
}
|
|
||||||
msgs = append(msgs, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
if option.IPv6Enable {
|
|
||||||
msg := new(dnsmessage.Message)
|
|
||||||
msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeAAAA)
|
|
||||||
msg.Header.RecursionDesired = true
|
|
||||||
msg.Questions = []dnsmessage.Question{qAAAA}
|
|
||||||
if opt := s.getMsgOptions(); opt != nil {
|
|
||||||
msg.Additionals = append(msg.Additionals, *opt)
|
|
||||||
}
|
|
||||||
msgs = append(msgs, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
return msgs
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, option IPOption) {
|
func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, option IPOption) {
|
||||||
newError("querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx))
|
newError(s.name, " querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx))
|
||||||
|
|
||||||
msgs := s.buildMsgs(domain, option)
|
reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP))
|
||||||
|
|
||||||
for _, msg := range msgs {
|
|
||||||
b, _ := dns.PackMessage(msg)
|
|
||||||
|
|
||||||
|
for _, req := range reqs {
|
||||||
|
s.addPendingRequest(req)
|
||||||
|
b, _ := dns.PackMessage(req.msg)
|
||||||
udpCtx := context.Background()
|
udpCtx := context.Background()
|
||||||
if inbound := session.InboundFromContext(ctx); inbound != nil {
|
if inbound := session.InboundFromContext(ctx); inbound != nil {
|
||||||
udpCtx = session.ContextWithInbound(udpCtx, inbound)
|
udpCtx = session.ContextWithInbound(udpCtx, inbound)
|
||||||
@ -418,18 +235,13 @@ func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) ([]
|
|||||||
return nil, dns_feature.ErrEmptyResponse
|
return nil, dns_feature.ErrEmptyResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
func Fqdn(domain string) string {
|
|
||||||
if len(domain) > 0 && domain[len(domain)-1] == '.' {
|
|
||||||
return domain
|
|
||||||
}
|
|
||||||
return domain + "."
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) {
|
func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) {
|
||||||
|
|
||||||
fqdn := Fqdn(domain)
|
fqdn := Fqdn(domain)
|
||||||
|
|
||||||
ips, err := s.findIPsForDomain(fqdn, option)
|
ips, err := s.findIPsForDomain(fqdn, option)
|
||||||
if err != errRecordNotFound {
|
if err != errRecordNotFound {
|
||||||
|
newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog()
|
||||||
return ips, err
|
return ips, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -68,12 +68,13 @@ func NewHandler(ctx context.Context, config *core.OutboundHandlerConfig) (outbou
|
|||||||
return nil, newError("not an outbound handler")
|
return nil, newError("not an outbound handler")
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.senderSettings != nil && h.senderSettings.MultiplexSettings != nil && h.senderSettings.MultiplexSettings.Enabled {
|
if h.senderSettings != nil && h.senderSettings.MultiplexSettings != nil {
|
||||||
config := h.senderSettings.MultiplexSettings
|
config := h.senderSettings.MultiplexSettings
|
||||||
if config.Concurrency < 1 || config.Concurrency > 1024 {
|
if config.Concurrency < 1 || config.Concurrency > 1024 {
|
||||||
return nil, newError("invalid mux concurrency: ", config.Concurrency).AtWarning()
|
return nil, newError("invalid mux concurrency: ", config.Concurrency).AtWarning()
|
||||||
}
|
}
|
||||||
h.mux = &mux.ClientManager{
|
h.mux = &mux.ClientManager{
|
||||||
|
Enabled: h.senderSettings.MultiplexSettings.Enabled,
|
||||||
Picker: &mux.IncrementalWorkerPicker{
|
Picker: &mux.IncrementalWorkerPicker{
|
||||||
Factory: &mux.DialingWorkerFactory{
|
Factory: &mux.DialingWorkerFactory{
|
||||||
Proxy: proxyHandler,
|
Proxy: proxyHandler,
|
||||||
@ -98,7 +99,7 @@ func (h *Handler) Tag() string {
|
|||||||
|
|
||||||
// Dispatch implements proxy.Outbound.Dispatch.
|
// Dispatch implements proxy.Outbound.Dispatch.
|
||||||
func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
|
func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
|
||||||
if h.mux != nil {
|
if h.mux != nil && (h.mux.Enabled || session.MuxPreferedFromContext(ctx)) {
|
||||||
if err := h.mux.Dispatch(ctx, link); err != nil {
|
if err := h.mux.Dispatch(ctx, link); err != nil {
|
||||||
newError("failed to process mux outbound traffic").Base(err).WriteToLog(session.ExportIDToError(ctx))
|
newError("failed to process mux outbound traffic").Base(err).WriteToLog(session.ExportIDToError(ctx))
|
||||||
common.Interrupt(link.Writer)
|
common.Interrupt(link.Writer)
|
||||||
|
@ -21,6 +21,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type ClientManager struct {
|
type ClientManager struct {
|
||||||
|
Enabled bool // wheather mux is enabled from user config
|
||||||
Picker WorkerPicker
|
Picker WorkerPicker
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@ const (
|
|||||||
inboundSessionKey
|
inboundSessionKey
|
||||||
outboundSessionKey
|
outboundSessionKey
|
||||||
contentSessionKey
|
contentSessionKey
|
||||||
|
muxPreferedSessionKey
|
||||||
)
|
)
|
||||||
|
|
||||||
// ContextWithID returns a new context with the given ID.
|
// ContextWithID returns a new context with the given ID.
|
||||||
@ -56,3 +57,16 @@ func ContentFromContext(ctx context.Context) *Content {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ContextWithMuxPrefered returns a new context with the given bool
|
||||||
|
func ContextWithMuxPrefered(ctx context.Context, forced bool) context.Context {
|
||||||
|
return context.WithValue(ctx, muxPreferedSessionKey, forced)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MuxPreferedFromContext returns value in this context, or false if not contained.
|
||||||
|
func MuxPreferedFromContext(ctx context.Context) bool {
|
||||||
|
if val, ok := ctx.Value(muxPreferedSessionKey).(bool); ok {
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
@ -21,7 +21,6 @@ func (c *NameServerConfig) UnmarshalJSON(data []byte) error {
|
|||||||
var address Address
|
var address Address
|
||||||
if err := json.Unmarshal(data, &address); err == nil {
|
if err := json.Unmarshal(data, &address); err == nil {
|
||||||
c.Address = &address
|
c.Address = &address
|
||||||
c.Port = 53
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,14 +76,24 @@ func (c *SniffingConfig) Build() (*proxyman.SniffingConfig, error) {
|
|||||||
|
|
||||||
type MuxConfig struct {
|
type MuxConfig struct {
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
Concurrency uint16 `json:"concurrency"`
|
Concurrency int16 `json:"concurrency"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *MuxConfig) GetConcurrency() uint16 {
|
// Build creates MultiplexingConfig, Concurrency < 0 completely disables mux.
|
||||||
if c.Concurrency == 0 {
|
func (m *MuxConfig) Build() *proxyman.MultiplexingConfig {
|
||||||
return 8
|
if m.Concurrency < 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var con uint32 = 8
|
||||||
|
if m.Concurrency > 0 {
|
||||||
|
con = uint32(m.Concurrency)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &proxyman.MultiplexingConfig{
|
||||||
|
Enabled: m.Enabled,
|
||||||
|
Concurrency: con,
|
||||||
}
|
}
|
||||||
return c.Concurrency
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type InboundDetourAllocationConfig struct {
|
type InboundDetourAllocationConfig struct {
|
||||||
@ -246,11 +256,8 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) {
|
|||||||
senderSettings.ProxySettings = ps
|
senderSettings.ProxySettings = ps
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.MuxSettings != nil && c.MuxSettings.Enabled {
|
if c.MuxSettings != nil {
|
||||||
senderSettings.MultiplexSettings = &proxyman.MultiplexingConfig{
|
senderSettings.MultiplexSettings = c.MuxSettings.Build()
|
||||||
Enabled: true,
|
|
||||||
Concurrency: uint32(c.MuxSettings.GetConcurrency()),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
settings := []byte("{}")
|
settings := []byte("{}")
|
||||||
|
@ -2,15 +2,16 @@ package conf_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
|
|
||||||
"v2ray.com/core"
|
"v2ray.com/core"
|
||||||
"v2ray.com/core/app/dispatcher"
|
"v2ray.com/core/app/dispatcher"
|
||||||
"v2ray.com/core/app/log"
|
"v2ray.com/core/app/log"
|
||||||
"v2ray.com/core/app/proxyman"
|
"v2ray.com/core/app/proxyman"
|
||||||
"v2ray.com/core/app/router"
|
"v2ray.com/core/app/router"
|
||||||
|
"v2ray.com/core/common"
|
||||||
clog "v2ray.com/core/common/log"
|
clog "v2ray.com/core/common/log"
|
||||||
"v2ray.com/core/common/net"
|
"v2ray.com/core/common/net"
|
||||||
"v2ray.com/core/common/protocol"
|
"v2ray.com/core/common/protocol"
|
||||||
@ -337,3 +338,34 @@ func TestV2RayConfig(t *testing.T) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMuxConfig_Build(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields string
|
||||||
|
want *proxyman.MultiplexingConfig
|
||||||
|
}{
|
||||||
|
{"default", `{"enabled": true, "concurrency": 16}`, &proxyman.MultiplexingConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Concurrency: 16,
|
||||||
|
}},
|
||||||
|
{"empty def", `{}`, &proxyman.MultiplexingConfig{
|
||||||
|
Enabled: false,
|
||||||
|
Concurrency: 8,
|
||||||
|
}},
|
||||||
|
{"not enable", `{"enabled": false, "concurrency": 4}`, &proxyman.MultiplexingConfig{
|
||||||
|
Enabled: false,
|
||||||
|
Concurrency: 4,
|
||||||
|
}},
|
||||||
|
{"forbidden", `{"enabled": false, "concurrency": -1}`, nil},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
m := &MuxConfig{}
|
||||||
|
common.Must(json.Unmarshal([]byte(tt.fields), m))
|
||||||
|
if got := m.Build(); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("MuxConfig.Build() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user