diff --git a/app/dns/dohdns.go b/app/dns/dohdns.go index c7a3f5c25..61237ba8c 100644 --- a/app/dns/dohdns.go +++ b/app/dns/dohdns.go @@ -8,6 +8,7 @@ import ( "fmt" "io/ioutil" "net/http" + "net/url" "sync" "sync/atomic" "time" @@ -41,25 +42,25 @@ type DoHNameServer struct { } // NewDoHNameServer creates DOH client object for remote resolving -func NewDoHNameServer(dohHost string, dohPort uint32, dispatcher routing.Dispatcher, clientIP net.IP) (*DoHNameServer, error) { +func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher, clientIP net.IP) (*DoHNameServer, error) { - dohAddr := net.ParseAddress(dohHost) - var dests []net.Destination - - if dohPort == 0 { - dohPort = 443 + dohAddr := net.ParseAddress(url.Hostname()) + dohPort := "443" + if url.Port() != "" { + dohPort = url.Port() } - parseIPDest := func(ip net.IP, port uint32) net.Destination { + parseIPDest := func(ip net.IP, port string) net.Destination { strIP := ip.String() if len(ip) == net.IPv6len { strIP = fmt.Sprintf("[%s]", strIP) } - dest, err := net.ParseDestination(fmt.Sprintf("tcp:%s:%d", strIP, port)) + dest, err := net.ParseDestination(fmt.Sprintf("tcp:%s:%s", strIP, port)) common.Must(err) return dest } + var dests []net.Destination if dohAddr.Family().IsDomain() { // resolve DOH server in advance ips, err := net.LookupIP(dohAddr.Domain()) @@ -74,8 +75,8 @@ func NewDoHNameServer(dohHost string, dohPort uint32, dispatcher routing.Dispatc dests = append(dests, parseIPDest(ip, dohPort)) } - newError("DNS: created remote DOH client for https://", dohHost, ":", dohPort).AtInfo().WriteToLog() - s := baseDOHNameServer(dohHost, dohPort, "DOH", clientIP) + newError("DNS: created Remote DOH client for ", url.String(), ", preresolved Dests: ", dests).AtInfo().WriteToLog() + s := baseDOHNameServer(url, "DOH", clientIP) s.dispatcher = dispatcher s.dohDests = dests @@ -102,32 +103,24 @@ func NewDoHNameServer(dohHost string, dohPort uint32, dispatcher routing.Dispatc } // NewDoHLocalNameServer creates DOH client object for local resolving -func NewDoHLocalNameServer(dohHost string, dohPort uint32, clientIP net.IP) *DoHNameServer { - - if dohPort == 0 { - dohPort = 443 - } - - s := baseDOHNameServer(dohHost, dohPort, "DOHL", clientIP) +func NewDoHLocalNameServer(url *url.URL, clientIP net.IP) *DoHNameServer { + url.Scheme = "https" + s := baseDOHNameServer(url, "DOHL", clientIP) s.httpClient = &http.Client{ Timeout: time.Second * 180, } - newError("DNS: created local DOH client for https://", dohHost, ":", dohPort).AtInfo().WriteToLog() + newError("DNS: created Local DOH client for ", url.String()).AtInfo().WriteToLog() return s } -func baseDOHNameServer(dohHost string, dohPort uint32, prefix string, clientIP net.IP) *DoHNameServer { - - if dohPort == 0 { - dohPort = 443 - } +func baseDOHNameServer(url *url.URL, prefix string, clientIP net.IP) *DoHNameServer { s := &DoHNameServer{ ips: make(map[string]record), clientIP: clientIP, pub: pubsub.NewService(), - name: fmt.Sprintf("%s:%s:%d", prefix, dohHost, dohPort), - dohURL: fmt.Sprintf("https://%s:%d/dns-query", dohHost, dohPort), + name: fmt.Sprintf("%s//%s", prefix, url.Host), + dohURL: url.String(), } s.cleanup = &task.Periodic{ Interval: time.Minute, diff --git a/app/dns/server.go b/app/dns/server.go index 9f3f322ca..c2bf83f91 100644 --- a/app/dns/server.go +++ b/app/dns/server.go @@ -7,6 +7,7 @@ package dns import ( "context" "log" + "net/url" "strings" "sync" "time" @@ -89,24 +90,34 @@ func New(ctx context.Context, config *Config) (*Server, error) { address := endpoint.Address.AsAddress() if address.Family().IsDomain() && address.Domain() == "localhost" { server.clients = append(server.clients, NewLocalNameServer()) - } else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "DOHL_") { - dohHost := address.Domain()[5:] - server.clients = append(server.clients, NewDoHLocalNameServer(dohHost, endpoint.Port, server.clientIP)) - } else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "DOH_") { - // DOH_ prefix makes net.Address think it's a domain - dohHost := address.Domain()[4:] + } else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "https+local://") { + // URI schemed string treated as domain + // DOH Local mode + u, err := url.Parse(address.Domain()) + if err != nil { + log.Fatalln(newError("DNS config error").Base(err)) + } + server.clients = append(server.clients, NewDoHLocalNameServer(u, server.clientIP)) + } else if address.Family().IsDomain() && + strings.HasPrefix(address.Domain(), "https://") { + // DOH Remote mode + u, err := url.Parse(address.Domain()) + if err != nil { + log.Fatalln(newError("DNS config error").Base(err)) + } idx := len(server.clients) server.clients = append(server.clients, nil) // need the core dispatcher, register DOHClient at callback common.Must(core.RequireFeatures(ctx, func(d routing.Dispatcher) { - c, err := NewDoHNameServer(dohHost, endpoint.Port, d, server.clientIP) + c, err := NewDoHNameServer(u, d, server.clientIP) if err != nil { log.Fatalln(newError("DNS config error").Base(err)) } server.clients[idx] = c })) } else { + // UDP classic DNS mode dest := endpoint.AsDestination() if dest.Network == net.Network_Unknown { dest.Network = net.Network_UDP diff --git a/common/buf/data/test_ReadBuffer.dat b/common/buf/data/test_MultiBufferReadAllToByte.dat similarity index 100% rename from common/buf/data/test_ReadBuffer.dat rename to common/buf/data/test_MultiBufferReadAllToByte.dat diff --git a/common/buf/io.go b/common/buf/io.go index cf27ba892..2a4cd6705 100644 --- a/common/buf/io.go +++ b/common/buf/io.go @@ -3,6 +3,7 @@ package buf import ( "io" "net" + "os" "syscall" "time" ) @@ -57,19 +58,14 @@ func NewReader(reader io.Reader) Reader { } } - if useReadv { + _, isFile := reader.(*os.File) + if !isFile && useReadv { if sc, ok := reader.(syscall.Conn); ok { rawConn, err := sc.SyscallConn() if err != nil { newError("failed to get sysconn").Base(err).WriteToLog() } else { - /* - Check if ReadVReader Can be used on this reader first - Fix https://github.com/v2ray/v2ray-core/issues/1666 - */ - if ok, _ := checkReadVConstraint(rawConn); ok { - return NewReadVReader(reader, rawConn) - } + return NewReadVReader(reader, rawConn) } } } diff --git a/common/buf/multi_buffer_test.go b/common/buf/multi_buffer_test.go index 6d8387043..add7a1e19 100644 --- a/common/buf/multi_buffer_test.go +++ b/common/buf/multi_buffer_test.go @@ -7,6 +7,8 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "io/ioutil" + "os" "v2ray.com/core/common" . "v2ray.com/core/common/buf" @@ -98,14 +100,33 @@ func TestMultiBufferSplitFirst(t *testing.T) { } func TestMultiBufferReadAllToByte(t *testing.T) { - lb := make([]byte, 8*1024) - common.Must2(io.ReadFull(rand.Reader, lb)) - rd := bytes.NewBuffer(lb) - b, err := ReadAllToBytes(rd) - common.Must(err) + { + lb := make([]byte, 8*1024) + common.Must2(io.ReadFull(rand.Reader, lb)) + rd := bytes.NewBuffer(lb) + b, err := ReadAllToBytes(rd) + common.Must(err) - if l := len(b); l != 8*1024 { - t.Error("unexpceted length from ReadAllToBytes", l) + if l := len(b); l != 8*1024 { + t.Error("unexpceted length from ReadAllToBytes", l) + } + + } + { + const dat = "data/test_MultiBufferReadAllToByte.dat" + f, err := os.Open(dat) + common.Must(err) + + buf2, err := ReadAllToBytes(f) + common.Must(err) + f.Close() + + cnt, err := ioutil.ReadFile(dat) + common.Must(err) + + if d := cmp.Diff(buf2, cnt); d != "" { + t.Error("fail to read from file: ", d) + } } } diff --git a/common/buf/reader_test.go b/common/buf/reader_test.go index 74e4fd541..95622eefe 100644 --- a/common/buf/reader_test.go +++ b/common/buf/reader_test.go @@ -3,12 +3,9 @@ package buf_test import ( "bytes" "io" - "io/ioutil" - "os" "strings" "testing" - "github.com/google/go-cmp/cmp" "v2ray.com/core/common" . "v2ray.com/core/common/buf" "v2ray.com/core/transport/pipe" @@ -92,23 +89,6 @@ func TestReadBuffer(t *testing.T) { buf.Release() } - { - const dat = "data/test_ReadBuffer.dat" - f, err := os.Open(dat) - common.Must(err) - defer f.Close() - - buf2, err := ReadBuffer(f) - common.Must(err) - - cnt, err := ioutil.ReadFile(dat) - common.Must(err) - - if cmp.Diff(buf2.Bytes(), cnt) != "" { - t.Error("fail to read from file") - } - buf2.Release() - } } func TestReadAtMost(t *testing.T) { diff --git a/common/buf/readv_constraint_other.go b/common/buf/readv_constraint_other.go deleted file mode 100644 index 315ce61fd..000000000 --- a/common/buf/readv_constraint_other.go +++ /dev/null @@ -1,9 +0,0 @@ -// +build !windows - -package buf - -import "syscall" - -func checkReadVConstraint(conn syscall.RawConn) (bool, error) { - return true, nil -} diff --git a/common/buf/readv_constraint_windows.go b/common/buf/readv_constraint_windows.go deleted file mode 100644 index d78cbaa8d..000000000 --- a/common/buf/readv_constraint_windows.go +++ /dev/null @@ -1,37 +0,0 @@ -// +build windows -package buf - -import ( - "syscall" -) - -func checkReadVConstraint(conn syscall.RawConn) (bool, error) { - var isSocketReady = false - var reason error - /* - In Windows, WSARecv system call only support socket connection. - - It it required to check if the given fd is of a socket type - - Fix https://github.com/v2ray/v2ray-core/issues/1666 - - Additional Information: - https://docs.microsoft.com/en-us/windows/desktop/api/winsock2/nf-winsock2-wsarecv - https://docs.microsoft.com/en-us/windows/desktop/api/winsock/nf-winsock-getsockopt - https://docs.microsoft.com/en-us/windows/desktop/WinSock/sol-socket-socket-options - - */ - err := conn.Control(func(fd uintptr) { - var val [4]byte - var le = int32(len(val)) - err := syscall.Getsockopt(syscall.Handle(fd), syscall.SOL_SOCKET, syscall.SO_RCVBUF, &val[0], &le) - if err != nil { - isSocketReady = false - } else { - isSocketReady = true - } - reason = err - }) - - return isSocketReady, err -} diff --git a/infra/conf/common_test.go b/infra/conf/common_test.go index 1a4193e88..a8159c8c2 100644 --- a/infra/conf/common_test.go +++ b/infra/conf/common_test.go @@ -51,6 +51,25 @@ func TestDomainParsing(t *testing.T) { } } +func TestURLParsing(t *testing.T) { + { + rawJson := "\"https://dns.google/dns-query\"" + var address Address + common.Must(json.Unmarshal([]byte(rawJson), &address)) + if address.Domain() != "https://dns.google/dns-query" { + t.Error("URL: ", address.Domain()) + } + } + { + rawJson := "\"https+local://dns.google/dns-query\"" + var address Address + common.Must(json.Unmarshal([]byte(rawJson), &address)) + if address.Domain() != "https+local://dns.google/dns-query" { + t.Error("URL: ", address.Domain()) + } + } +} + func TestInvalidAddressJson(t *testing.T) { rawJson := "1234" var address Address diff --git a/proxy/dns/dns_test.go b/proxy/dns/dns_test.go index da9175968..441979433 100644 --- a/proxy/dns/dns_test.go +++ b/proxy/dns/dns_test.go @@ -162,6 +162,7 @@ func TestUDPDNSTunnel(t *testing.T) { m1.Question[0] = dns.Question{"ipv4only.google.com.", dns.TypeAAAA, dns.ClassINET} c := new(dns.Client) + c.Timeout = 10 * time.Second in, _, err := c.Exchange(m1, "127.0.0.1:"+strconv.Itoa(int(serverPort))) common.Must(err) diff --git a/testing/scenarios/vmess_test.go b/testing/scenarios/vmess_test.go index cb544337b..813121393 100644 --- a/testing/scenarios/vmess_test.go +++ b/testing/scenarios/vmess_test.go @@ -810,10 +810,10 @@ func TestVMessKCPLarge(t *testing.T) { Protocol: internet.TransportProtocol_MKCP, Settings: serial.ToTypedMessage(&kcp.Config{ ReadBuffer: &kcp.ReadBuffer{ - Size: 4096, + Size: 512 * 1024, }, WriteBuffer: &kcp.WriteBuffer{ - Size: 4096, + Size: 512 * 1024, }, UplinkCapacity: &kcp.UplinkCapacity{ Value: 20, @@ -897,10 +897,10 @@ func TestVMessKCPLarge(t *testing.T) { Protocol: internet.TransportProtocol_MKCP, Settings: serial.ToTypedMessage(&kcp.Config{ ReadBuffer: &kcp.ReadBuffer{ - Size: 4096, + Size: 512 * 1024, }, WriteBuffer: &kcp.WriteBuffer{ - Size: 4096, + Size: 512 * 1024, }, UplinkCapacity: &kcp.UplinkCapacity{ Value: 20,