diff --git a/app/dns/dohdns.go b/app/dns/dohdns.go index c7a3f5c25..61237ba8c 100644 --- a/app/dns/dohdns.go +++ b/app/dns/dohdns.go @@ -8,6 +8,7 @@ import ( "fmt" "io/ioutil" "net/http" + "net/url" "sync" "sync/atomic" "time" @@ -41,25 +42,25 @@ type DoHNameServer struct { } // NewDoHNameServer creates DOH client object for remote resolving -func NewDoHNameServer(dohHost string, dohPort uint32, dispatcher routing.Dispatcher, clientIP net.IP) (*DoHNameServer, error) { +func NewDoHNameServer(url *url.URL, dispatcher routing.Dispatcher, clientIP net.IP) (*DoHNameServer, error) { - dohAddr := net.ParseAddress(dohHost) - var dests []net.Destination - - if dohPort == 0 { - dohPort = 443 + dohAddr := net.ParseAddress(url.Hostname()) + dohPort := "443" + if url.Port() != "" { + dohPort = url.Port() } - parseIPDest := func(ip net.IP, port uint32) net.Destination { + parseIPDest := func(ip net.IP, port string) net.Destination { strIP := ip.String() if len(ip) == net.IPv6len { strIP = fmt.Sprintf("[%s]", strIP) } - dest, err := net.ParseDestination(fmt.Sprintf("tcp:%s:%d", strIP, port)) + dest, err := net.ParseDestination(fmt.Sprintf("tcp:%s:%s", strIP, port)) common.Must(err) return dest } + var dests []net.Destination if dohAddr.Family().IsDomain() { // resolve DOH server in advance ips, err := net.LookupIP(dohAddr.Domain()) @@ -74,8 +75,8 @@ func NewDoHNameServer(dohHost string, dohPort uint32, dispatcher routing.Dispatc dests = append(dests, parseIPDest(ip, dohPort)) } - newError("DNS: created remote DOH client for https://", dohHost, ":", dohPort).AtInfo().WriteToLog() - s := baseDOHNameServer(dohHost, dohPort, "DOH", clientIP) + newError("DNS: created Remote DOH client for ", url.String(), ", preresolved Dests: ", dests).AtInfo().WriteToLog() + s := baseDOHNameServer(url, "DOH", clientIP) s.dispatcher = dispatcher s.dohDests = dests @@ -102,32 +103,24 @@ func NewDoHNameServer(dohHost string, dohPort uint32, dispatcher routing.Dispatc } // NewDoHLocalNameServer creates DOH client object for local resolving -func NewDoHLocalNameServer(dohHost string, dohPort uint32, clientIP net.IP) *DoHNameServer { - - if dohPort == 0 { - dohPort = 443 - } - - s := baseDOHNameServer(dohHost, dohPort, "DOHL", clientIP) +func NewDoHLocalNameServer(url *url.URL, clientIP net.IP) *DoHNameServer { + url.Scheme = "https" + s := baseDOHNameServer(url, "DOHL", clientIP) s.httpClient = &http.Client{ Timeout: time.Second * 180, } - newError("DNS: created local DOH client for https://", dohHost, ":", dohPort).AtInfo().WriteToLog() + newError("DNS: created Local DOH client for ", url.String()).AtInfo().WriteToLog() return s } -func baseDOHNameServer(dohHost string, dohPort uint32, prefix string, clientIP net.IP) *DoHNameServer { - - if dohPort == 0 { - dohPort = 443 - } +func baseDOHNameServer(url *url.URL, prefix string, clientIP net.IP) *DoHNameServer { s := &DoHNameServer{ ips: make(map[string]record), clientIP: clientIP, pub: pubsub.NewService(), - name: fmt.Sprintf("%s:%s:%d", prefix, dohHost, dohPort), - dohURL: fmt.Sprintf("https://%s:%d/dns-query", dohHost, dohPort), + name: fmt.Sprintf("%s//%s", prefix, url.Host), + dohURL: url.String(), } s.cleanup = &task.Periodic{ Interval: time.Minute, diff --git a/app/dns/server.go b/app/dns/server.go index 9ba791abc..c2bf83f91 100644 --- a/app/dns/server.go +++ b/app/dns/server.go @@ -8,7 +8,6 @@ import ( "context" "log" "net/url" - "strconv" "strings" "sync" "time" @@ -87,40 +86,22 @@ func New(ctx context.Context, config *Config) (*Server, error) { } server.hosts = hosts - parseDOHURI := func(d string, endpoint *net.Endpoint) (host string, port uint32, err error) { - u, err := url.Parse(d) - if err != nil { - return "", 0, err - } - host = u.Hostname() - port = 443 - if u.Port() != "" { - p, err := strconv.ParseUint(u.Port(), 10, 16) - if err != nil { - return "", 0, err - } - port = uint32(p) - } - if endpoint.Port != 0 { - port = endpoint.Port - } - return - } - addNameServer := func(endpoint *net.Endpoint) int { address := endpoint.Address.AsAddress() if address.Family().IsDomain() && address.Domain() == "localhost" { server.clients = append(server.clients, NewLocalNameServer()) } else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "https+local://") { // URI schemed string treated as domain - dohlHost, dohlPort, err := parseDOHURI(address.Domain(), endpoint) + // DOH Local mode + u, err := url.Parse(address.Domain()) if err != nil { log.Fatalln(newError("DNS config error").Base(err)) } - server.clients = append(server.clients, NewDoHLocalNameServer(dohlHost, dohlPort, server.clientIP)) + server.clients = append(server.clients, NewDoHLocalNameServer(u, server.clientIP)) } else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "https://") { - dohHost, dohPort, err := parseDOHURI(address.Domain(), endpoint) + // DOH Remote mode + u, err := url.Parse(address.Domain()) if err != nil { log.Fatalln(newError("DNS config error").Base(err)) } @@ -129,13 +110,14 @@ func New(ctx context.Context, config *Config) (*Server, error) { // need the core dispatcher, register DOHClient at callback common.Must(core.RequireFeatures(ctx, func(d routing.Dispatcher) { - c, err := NewDoHNameServer(dohHost, dohPort, d, server.clientIP) + c, err := NewDoHNameServer(u, d, server.clientIP) if err != nil { log.Fatalln(newError("DNS config error").Base(err)) } server.clients[idx] = c })) } else { + // UDP classic DNS mode dest := endpoint.AsDestination() if dest.Network == net.Network_Unknown { dest.Network = net.Network_UDP