diff --git a/app/dns/server.go b/app/dns/server.go index 6598008fb..33394f693 100644 --- a/app/dns/server.go +++ b/app/dns/server.go @@ -86,10 +86,18 @@ func New(ctx context.Context, config *Config) (*Server, error) { } server.hosts = hosts - addNameServer := func(endpoint *net.Endpoint) int { + addNameServer := func(ns *NameServer) int { + endpoint := ns.Address address := endpoint.Address.AsAddress() if address.Family().IsDomain() && address.Domain() == "localhost" { server.clients = append(server.clients, NewLocalNameServer()) + if len(ns.PrioritizedDomain) == 0 { // Priotize local domain with .local domain or without any dot to local DNS + ns.PrioritizedDomain = []*NameServer_PriorityDomain{ + {Type: DomainMatchingType_Regex, Domain: "^[^.]*$"}, // This will only match domain without any dot + {Type: DomainMatchingType_Subdomain, Domain: "local"}, + {Type: DomainMatchingType_Subdomain, Domain: "localdomain"}, + } + } } else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "https+local://") { // URI schemed string treated as domain // DOH Local mode @@ -137,7 +145,7 @@ func New(ctx context.Context, config *Config) (*Server, error) { if len(config.NameServers) > 0 { features.PrintDeprecatedFeatureWarning("simple DNS server") for _, destPB := range config.NameServers { - addNameServer(destPB) + addNameServer(&NameServer{Address: destPB}) } } @@ -148,7 +156,7 @@ func New(ctx context.Context, config *Config) (*Server, error) { var geoIPMatcherContainer router.GeoIPMatcherContainer for _, ns := range config.NameServer { - idx := addNameServer(ns.Address) + idx := addNameServer(ns) for _, domain := range ns.PrioritizedDomain { matcher, err := toStrMatcher(domain.Type, domain.Domain) @@ -307,11 +315,6 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err domain = domain[:len(domain)-1] } - // skip domain without any dot - if !strings.Contains(domain, ".") { - return nil, newError("invalid domain name").AtWarning() - } - ips := s.lookupStatic(domain, option, 0) if ips != nil && ips[0].Family().IsIP() { newError("returning ", len(ips), " IPs for domain ", domain).WriteToLog() diff --git a/app/dns/server_test.go b/app/dns/server_test.go index 056174114..c340b41b1 100644 --- a/app/dns/server_test.go +++ b/app/dns/server_test.go @@ -63,6 +63,27 @@ func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { ans.Answer = append(ans.Answer, rr) } else if q.Name == "notexist.google.com." && q.Qtype == dns.TypeAAAA { ans.MsgHdr.Rcode = dns.RcodeNameError + } else if q.Name == "hostname." && q.Qtype == dns.TypeA { + rr, _ := dns.NewRR("hostname. IN A 127.0.0.1") + ans.Answer = append(ans.Answer, rr) + } else if q.Name == "hostname.local." && q.Qtype == dns.TypeA { + rr, _ := dns.NewRR("hostname.local. IN A 127.0.0.1") + ans.Answer = append(ans.Answer, rr) + } else if q.Name == "hostname.localdomain." && q.Qtype == dns.TypeA { + rr, _ := dns.NewRR("hostname.localdomain. IN A 127.0.0.1") + ans.Answer = append(ans.Answer, rr) + } else if q.Name == "localhost." && q.Qtype == dns.TypeA { + rr, _ := dns.NewRR("localhost. IN A 127.0.0.2") + ans.Answer = append(ans.Answer, rr) + } else if q.Name == "localhost-a." && q.Qtype == dns.TypeA { + rr, _ := dns.NewRR("localhost-a. IN A 127.0.0.3") + ans.Answer = append(ans.Answer, rr) + } else if q.Name == "localhost-b." && q.Qtype == dns.TypeA { + rr, _ := dns.NewRR("localhost-b. IN A 127.0.0.4") + ans.Answer = append(ans.Answer, rr) + } else if q.Name == "Mijia\\ Cloud." && q.Qtype == dns.TypeA { + rr, _ := dns.NewRR("Mijia\\ Cloud. IN A 127.0.0.1") + ans.Answer = append(ans.Answer, rr) } } w.WriteMsg(ans) @@ -537,3 +558,199 @@ func TestIPMatch(t *testing.T) { t.Error("DNS query doesn't finish in 2 seconds.") } } + +func TestLocalDomain(t *testing.T) { + port := udp.PickPort() + + dnsServer := dns.Server{ + Addr: "127.0.0.1:" + port.String(), + Net: "udp", + Handler: &staticHandler{}, + UDPSize: 1200, + } + + go dnsServer.ListenAndServe() + time.Sleep(time.Second) + + config := &core.Config{ + App: []*serial.TypedMessage{ + serial.ToTypedMessage(&Config{ + NameServers: []*net.Endpoint{ + { + Network: net.Network_UDP, + Address: &net.IPOrDomain{ + Address: &net.IPOrDomain_Ip{ + Ip: []byte{127, 0, 0, 1}, + }, + }, + Port: 9999, /* unreachable */ + }, + }, + NameServer: []*NameServer{ + { + Address: &net.Endpoint{ + Network: net.Network_UDP, + Address: &net.IPOrDomain{ + Address: &net.IPOrDomain_Ip{ + Ip: []byte{127, 0, 0, 1}, + }, + }, + Port: uint32(port), + }, + PrioritizedDomain: []*NameServer_PriorityDomain{ + // Equivalent of dotless:localhost + {Type: DomainMatchingType_Regex, Domain: "^[^.]*localhost[^.]*$"}, + }, + Geoip: []*router.GeoIP{ + { // Will match localhost, localhost-a and localhost-b, + CountryCode: "local", + Cidr: []*router.CIDR{ + {Ip: []byte{127, 0, 0, 2}, Prefix: 32}, + {Ip: []byte{127, 0, 0, 3}, Prefix: 32}, + {Ip: []byte{127, 0, 0, 4}, Prefix: 32}, + }, + }, + }, + }, + { + Address: &net.Endpoint{ + Network: net.Network_UDP, + Address: &net.IPOrDomain{ + Address: &net.IPOrDomain_Ip{ + Ip: []byte{127, 0, 0, 1}, + }, + }, + Port: uint32(port), + }, + PrioritizedDomain: []*NameServer_PriorityDomain{ + // Equivalent of dotless: and domain:local + {Type: DomainMatchingType_Regex, Domain: "^[^.]*$"}, + {Type: DomainMatchingType_Subdomain, Domain: "local"}, + {Type: DomainMatchingType_Subdomain, Domain: "localdomain"}, + }, + }, + }, + StaticHosts: []*Config_HostMapping{ + { + Type: DomainMatchingType_Full, + Domain: "hostnamestatic", + Ip: [][]byte{{127, 0, 0, 53}}, + }, + { + Type: DomainMatchingType_Full, + Domain: "hostnamealias", + ProxiedDomain: "hostname.localdomain", + }, + }, + }), + serial.ToTypedMessage(&dispatcher.Config{}), + serial.ToTypedMessage(&proxyman.OutboundConfig{}), + serial.ToTypedMessage(&policy.Config{}), + }, + Outbound: []*core.OutboundHandlerConfig{ + { + ProxySettings: serial.ToTypedMessage(&freedom.Config{}), + }, + }, + } + + v, err := core.New(config) + common.Must(err) + + client := v.GetFeature(feature_dns.ClientType()).(feature_dns.Client) + + startTime := time.Now() + + { // Will match dotless: + ips, err := client.LookupIP("hostname") + if err != nil { + t.Fatal("unexpected error: ", err) + } + + if r := cmp.Diff(ips, []net.IP{{127, 0, 0, 1}}); r != "" { + t.Fatal(r) + } + } + + { // Will match domain:local + ips, err := client.LookupIP("hostname.local") + if err != nil { + t.Fatal("unexpected error: ", err) + } + + if r := cmp.Diff(ips, []net.IP{{127, 0, 0, 1}}); r != "" { + t.Fatal(r) + } + } + + { // Will match static ip + ips, err := client.LookupIP("hostnamestatic") + if err != nil { + t.Fatal("unexpected error: ", err) + } + + if r := cmp.Diff(ips, []net.IP{{127, 0, 0, 53}}); r != "" { + t.Fatal(r) + } + } + + { // Will match domain replacing + ips, err := client.LookupIP("hostnamealias") + if err != nil { + t.Fatal("unexpected error: ", err) + } + + if r := cmp.Diff(ips, []net.IP{{127, 0, 0, 1}}); r != "" { + t.Fatal(r) + } + } + + { // Will match dotless:localhost, but not expectIPs: 127.0.0.2, 127.0.0.3, then matches at dotless: + ips, err := client.LookupIP("localhost") + if err != nil { + t.Fatal("unexpected error: ", err) + } + + if r := cmp.Diff(ips, []net.IP{{127, 0, 0, 2}}); r != "" { + t.Fatal(r) + } + } + + { // Will match dotless:localhost, and expectIPs: 127.0.0.2, 127.0.0.3 + ips, err := client.LookupIP("localhost-a") + if err != nil { + t.Fatal("unexpected error: ", err) + } + + if r := cmp.Diff(ips, []net.IP{{127, 0, 0, 3}}); r != "" { + t.Fatal(r) + } + } + + { // Will match dotless:localhost, and expectIPs: 127.0.0.2, 127.0.0.3 + ips, err := client.LookupIP("localhost-b") + if err != nil { + t.Fatal("unexpected error: ", err) + } + + if r := cmp.Diff(ips, []net.IP{{127, 0, 0, 4}}); r != "" { + t.Fatal(r) + } + } + + { // Will match dotless: + ips, err := client.LookupIP("Mijia Cloud") + if err != nil { + t.Fatal("unexpected error: ", err) + } + + if r := cmp.Diff(ips, []net.IP{{127, 0, 0, 1}}); r != "" { + t.Fatal(r) + } + } + + endTime := time.Now() + if startTime.After(endTime.Add(time.Second * 2)) { + t.Error("DNS query doesn't finish in 2 seconds.") + } +} diff --git a/app/dns/udpns.go b/app/dns/udpns.go index f1f60aca6..a362c1dff 100644 --- a/app/dns/udpns.go +++ b/app/dns/udpns.go @@ -134,7 +134,7 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot } elapsed := time.Since(req.start) - newError(s.name, " got answere: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog() + newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog() if len(req.domain) > 0 && (rec.A != nil || rec.AAAA != nil) { s.updateIP(req.domain, rec) } diff --git a/infra/conf/router.go b/infra/conf/router.go index 03044de67..77549481e 100644 --- a/infra/conf/router.go +++ b/infra/conf/router.go @@ -299,6 +299,16 @@ func parseDomainRule(domain string) ([]*router.Domain, error) { case strings.HasPrefix(domain, "keyword:"): domainRule.Type = router.Domain_Plain domainRule.Value = domain[8:] + case strings.HasPrefix(domain, "dotless:"): + domainRule.Type = router.Domain_Regex + switch substr := domain[8:]; { + case substr == "": + domainRule.Value = "^[^.]*$" + case !strings.Contains(substr, "."): + domainRule.Value = "^[^.]*" + substr + "[^.]*$" + default: + return nil, newError("Substr in dotless rule should not contain a dot: ", substr) + } default: domainRule.Type = router.Domain_Plain domainRule.Value = domain