From b4b4b3d032bbd441922822e59feb241dc68cac74 Mon Sep 17 00:00:00 2001 From: vcptr <51714622+vcptr@users.noreply.github.com> Date: Fri, 6 Dec 2019 12:55:14 +0800 Subject: [PATCH] doh config use RFC8484 url format --- app/dns/server.go | 43 ++++++++++++++++++++++++++++++++------- infra/conf/common_test.go | 19 +++++++++++++++++ 2 files changed, 55 insertions(+), 7 deletions(-) diff --git a/app/dns/server.go b/app/dns/server.go index 9f3f322ca..9ba791abc 100644 --- a/app/dns/server.go +++ b/app/dns/server.go @@ -7,6 +7,8 @@ package dns import ( "context" "log" + "net/url" + "strconv" "strings" "sync" "time" @@ -85,22 +87,49 @@ func New(ctx context.Context, config *Config) (*Server, error) { } server.hosts = hosts + parseDOHURI := func(d string, endpoint *net.Endpoint) (host string, port uint32, err error) { + u, err := url.Parse(d) + if err != nil { + return "", 0, err + } + host = u.Hostname() + port = 443 + if u.Port() != "" { + p, err := strconv.ParseUint(u.Port(), 10, 16) + if err != nil { + return "", 0, err + } + port = uint32(p) + } + if endpoint.Port != 0 { + port = endpoint.Port + } + return + } + addNameServer := func(endpoint *net.Endpoint) int { address := endpoint.Address.AsAddress() if address.Family().IsDomain() && address.Domain() == "localhost" { server.clients = append(server.clients, NewLocalNameServer()) - } else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "DOHL_") { - dohHost := address.Domain()[5:] - server.clients = append(server.clients, NewDoHLocalNameServer(dohHost, endpoint.Port, server.clientIP)) - } else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "DOH_") { - // DOH_ prefix makes net.Address think it's a domain - dohHost := address.Domain()[4:] + } else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "https+local://") { + // URI schemed string treated as domain + dohlHost, dohlPort, err := parseDOHURI(address.Domain(), endpoint) + if err != nil { + log.Fatalln(newError("DNS config error").Base(err)) + } + server.clients = append(server.clients, NewDoHLocalNameServer(dohlHost, dohlPort, server.clientIP)) + } else if address.Family().IsDomain() && + strings.HasPrefix(address.Domain(), "https://") { + dohHost, dohPort, err := parseDOHURI(address.Domain(), endpoint) + if err != nil { + log.Fatalln(newError("DNS config error").Base(err)) + } idx := len(server.clients) server.clients = append(server.clients, nil) // need the core dispatcher, register DOHClient at callback common.Must(core.RequireFeatures(ctx, func(d routing.Dispatcher) { - c, err := NewDoHNameServer(dohHost, endpoint.Port, d, server.clientIP) + c, err := NewDoHNameServer(dohHost, dohPort, d, server.clientIP) if err != nil { log.Fatalln(newError("DNS config error").Base(err)) } diff --git a/infra/conf/common_test.go b/infra/conf/common_test.go index 1a4193e88..a8159c8c2 100644 --- a/infra/conf/common_test.go +++ b/infra/conf/common_test.go @@ -51,6 +51,25 @@ func TestDomainParsing(t *testing.T) { } } +func TestURLParsing(t *testing.T) { + { + rawJson := "\"https://dns.google/dns-query\"" + var address Address + common.Must(json.Unmarshal([]byte(rawJson), &address)) + if address.Domain() != "https://dns.google/dns-query" { + t.Error("URL: ", address.Domain()) + } + } + { + rawJson := "\"https+local://dns.google/dns-query\"" + var address Address + common.Must(json.Unmarshal([]byte(rawJson), &address)) + if address.Domain() != "https+local://dns.google/dns-query" { + t.Error("URL: ", address.Domain()) + } + } +} + func TestInvalidAddressJson(t *testing.T) { rawJson := "1234" var address Address