diff --git a/app/router/condition.go b/app/router/condition.go index aac5c960e..4face51d1 100644 --- a/app/router/condition.go +++ b/app/router/condition.go @@ -111,9 +111,18 @@ func targetFromContent(ctx context.Context) net.Destination { return outbound.Target } +func resolvedIPFromContext(ctx context.Context) []net.IP { + outbound := session.OutboundFromContext(ctx) + if outbound == nil { + return nil + } + return outbound.ResolvedIPs +} + type MultiGeoIPMatcher struct { - matchers []*GeoIPMatcher - destFunc func(context.Context) net.Destination + matchers []*GeoIPMatcher + destFunc func(context.Context) net.Destination + resolvedIPFunc func(context.Context) []net.IP } func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, error) { @@ -126,17 +135,18 @@ func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, e matchers = append(matchers, matcher) } - var destFunc func(context.Context) net.Destination - if onSource { - destFunc = sourceFromContext - } else { - destFunc = targetFromContent + matcher := &MultiGeoIPMatcher{ + matchers: matchers, } - return &MultiGeoIPMatcher{ - matchers: matchers, - destFunc: destFunc, - }, nil + if onSource { + matcher.destFunc = sourceFromContext + } else { + matcher.destFunc = targetFromContent + matcher.resolvedIPFunc = resolvedIPFromContext + } + + return matcher, nil } func (m *MultiGeoIPMatcher) Apply(ctx context.Context) bool { @@ -146,10 +156,12 @@ func (m *MultiGeoIPMatcher) Apply(ctx context.Context) bool { if dest.IsValid() && dest.Address.Family().IsIP() { ips = append(ips, dest.Address.IP()) - } else if resolver, ok := ResolvedIPsFromContext(ctx); ok { - resolvedIPs := resolver.Resolve() - for _, rip := range resolvedIPs { - ips = append(ips, rip.IP()) + } + + if m.resolvedIPFunc != nil { + rips := m.resolvedIPFunc(ctx) + if len(rips) > 0 { + ips = append(ips, rips...) } } diff --git a/app/router/router.go b/app/router/router.go index 6269117cf..910dd06d8 100644 --- a/app/router/router.go +++ b/app/router/router.go @@ -7,32 +7,12 @@ import ( "v2ray.com/core" "v2ray.com/core/common" - "v2ray.com/core/common/net" "v2ray.com/core/common/session" "v2ray.com/core/features/dns" "v2ray.com/core/features/outbound" "v2ray.com/core/features/routing" ) -type key uint32 - -const ( - resolvedIPsKey key = iota -) - -type IPResolver interface { - Resolve() []net.Address -} - -func ContextWithResolveIPs(ctx context.Context, f IPResolver) context.Context { - return context.WithValue(ctx, resolvedIPsKey, f) -} - -func ResolvedIPsFromContext(ctx context.Context) (IPResolver, bool) { - ips, ok := ctx.Value(resolvedIPsKey).(IPResolver) - return ips, ok -} - func init() { common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { r := new(Router) @@ -91,34 +71,6 @@ func (r *Router) Init(config *Config, d dns.Client, ohm outbound.Manager) error return nil } -type ipResolver struct { - dns dns.Client - ip []net.Address - domain string - resolved bool -} - -func (r *ipResolver) Resolve() []net.Address { - if r.resolved { - return r.ip - } - - newError("looking for IP for domain: ", r.domain).WriteToLog() - r.resolved = true - ips, err := r.dns.LookupIP(r.domain) - if err != nil { - newError("failed to get IP address").Base(err).WriteToLog() - } - if len(ips) == 0 { - return nil - } - r.ip = make([]net.Address, len(ips)) - for i, ip := range ips { - r.ip[i] = net.IPAddress(ip) - } - return r.ip -} - func (r *Router) PickRoute(ctx context.Context) (string, error) { rule, err := r.pickRouteInternal(ctx) if err != nil { @@ -127,17 +79,27 @@ func (r *Router) PickRoute(ctx context.Context) (string, error) { return rule.GetTag() } -// PickRoute implements routing.Router. -func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) { - resolver := &ipResolver{ - dns: r.dns, +func isDomainOutbound(outbound *session.Outbound) bool { + return outbound != nil && outbound.Target.IsValid() && outbound.Target.Address.Family().IsDomain() +} + +func (r *Router) resolveIP(outbound *session.Outbound) error { + domain := outbound.Target.Address.Domain() + ips, err := r.dns.LookupIP(domain) + if err != nil { + return err } + outbound.ResolvedIPs = ips + return nil +} + +// PickRoute implements routing.Router. +func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) { outbound := session.OutboundFromContext(ctx) - if r.domainStrategy == Config_IpOnDemand { - if outbound != nil && outbound.Target.IsValid() && outbound.Target.Address.Family().IsDomain() { - resolver.domain = outbound.Target.Address.Domain() - ctx = ContextWithResolveIPs(ctx, resolver) + if r.domainStrategy == Config_IpOnDemand && isDomainOutbound(outbound) { + if err := r.resolveIP(outbound); err != nil { + newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx)) } } @@ -147,21 +109,19 @@ func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) { } } - if outbound == nil || !outbound.Target.IsValid() { + if r.domainStrategy != Config_IpIfNonMatch || !isDomainOutbound(outbound) { return nil, common.ErrNoClue } - dest := outbound.Target - if r.domainStrategy == Config_IpIfNonMatch && dest.Address.Family().IsDomain() { - resolver.domain = dest.Address.Domain() - ips := resolver.Resolve() - if len(ips) > 0 { - ctx = ContextWithResolveIPs(ctx, resolver) - for _, rule := range r.rules { - if rule.Apply(ctx) { - return rule, nil - } - } + if err := r.resolveIP(outbound); err != nil { + newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx)) + return nil, common.ErrNoClue + } + + // Try applying rules again if we have IPs. + for _, rule := range r.rules { + if rule.Apply(ctx) { + return rule, nil } } diff --git a/app/router/router_test.go b/app/router/router_test.go index 565f4038e..ccb135cce 100644 --- a/app/router/router_test.go +++ b/app/router/router_test.go @@ -125,3 +125,72 @@ func TestIPOnDemand(t *testing.T) { t.Error("expect tag 'test', bug actually ", tag) } } + +func TestIPIfNonMatchDomain(t *testing.T) { + config := &Config{ + DomainStrategy: Config_IpIfNonMatch, + Rule: []*RoutingRule{ + { + TargetTag: &RoutingRule_Tag{ + Tag: "test", + }, + Cidr: []*CIDR{ + { + Ip: []byte{192, 168, 0, 0}, + Prefix: 16, + }, + }, + }, + }, + } + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockDns := mocks.NewDNSClient(mockCtl) + mockDns.EXPECT().LookupIP(gomock.Eq("v2ray.com")).Return([]net.IP{{192, 168, 0, 1}}, nil).AnyTimes() + + r := new(Router) + common.Must(r.Init(config, mockDns, nil)) + + ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}) + tag, err := r.PickRoute(ctx) + common.Must(err) + if tag != "test" { + t.Error("expect tag 'test', bug actually ", tag) + } +} + +func TestIPIfNonMatchIP(t *testing.T) { + config := &Config{ + DomainStrategy: Config_IpIfNonMatch, + Rule: []*RoutingRule{ + { + TargetTag: &RoutingRule_Tag{ + Tag: "test", + }, + Cidr: []*CIDR{ + { + Ip: []byte{127, 0, 0, 0}, + Prefix: 8, + }, + }, + }, + }, + } + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockDns := mocks.NewDNSClient(mockCtl) + + r := new(Router) + common.Must(r.Init(config, mockDns, nil)) + + ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)}) + tag, err := r.PickRoute(ctx) + common.Must(err) + if tag != "test" { + t.Error("expect tag 'test', bug actually ", tag) + } +}