diff --git a/app/router/condition.go b/app/router/condition.go index 7ff0db274..7d961bf52 100644 --- a/app/router/condition.go +++ b/app/router/condition.go @@ -3,16 +3,14 @@ package router import ( - "context" "strings" "v2ray.com/core/common/net" - "v2ray.com/core/common/session" "v2ray.com/core/common/strmatcher" ) type Condition interface { - Apply(ctx context.Context) bool + Apply(ctx *Context) bool } type ConditionChan []Condition @@ -27,7 +25,7 @@ func (v *ConditionChan) Add(cond Condition) *ConditionChan { return v } -func (v *ConditionChan) Apply(ctx context.Context) bool { +func (v *ConditionChan) Apply(ctx *Context) bool { for _, cond := range *v { if !cond.Apply(ctx) { return false @@ -84,46 +82,36 @@ func (m *DomainMatcher) ApplyDomain(domain string) bool { return m.matchers.Match(domain) > 0 } -func (m *DomainMatcher) Apply(ctx context.Context) bool { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { +func (m *DomainMatcher) Apply(ctx *Context) bool { + if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() { return false } - dest := outbound.Target + dest := ctx.Outbound.Target if !dest.Address.Family().IsDomain() { return false } return m.ApplyDomain(dest.Address.Domain()) } -func sourceFromContext(ctx context.Context) net.Destination { - inbound := session.InboundFromContext(ctx) - if inbound == nil { - return net.Destination{} - } - return inbound.Source -} - -func targetFromContent(ctx context.Context) net.Destination { - outbound := session.OutboundFromContext(ctx) - if outbound == nil { - return net.Destination{} - } - return outbound.Target -} - -func resolvedIPFromContext(ctx context.Context) []net.IP { - outbound := session.OutboundFromContext(ctx) - if outbound == nil { +func getIPsFromSource(ctx *Context) []net.IP { + if ctx.Inbound == nil || !ctx.Inbound.Source.IsValid() { return nil } - return outbound.ResolvedIPs + dest := ctx.Inbound.Source + if dest.Address.Family().IsDomain() { + return nil + } + + return []net.IP{dest.Address.IP()} +} + +func getIPsFromTarget(ctx *Context) []net.IP { + return ctx.GetTargetIPs() } type MultiGeoIPMatcher struct { - matchers []*GeoIPMatcher - destFunc func(context.Context) net.Destination - resolvedIPFunc func(context.Context) []net.IP + matchers []*GeoIPMatcher + ipFunc func(*Context) []net.IP } func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, error) { @@ -141,30 +129,16 @@ func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, e } if onSource { - matcher.destFunc = sourceFromContext + matcher.ipFunc = getIPsFromSource } else { - matcher.destFunc = targetFromContent - matcher.resolvedIPFunc = resolvedIPFromContext + matcher.ipFunc = getIPsFromTarget } return matcher, nil } -func (m *MultiGeoIPMatcher) Apply(ctx context.Context) bool { - ips := make([]net.IP, 0, 4) - - dest := m.destFunc(ctx) - - if dest.IsValid() && dest.Address.Family().IsIP() { - ips = append(ips, dest.Address.IP()) - } - - if m.resolvedIPFunc != nil { - rips := m.resolvedIPFunc(ctx) - if len(rips) > 0 { - ips = append(ips, rips...) - } - } +func (m *MultiGeoIPMatcher) Apply(ctx *Context) bool { + ips := m.ipFunc(ctx) for _, ip := range ips { for _, matcher := range m.matchers { @@ -186,12 +160,11 @@ func NewPortMatcher(list *net.PortList) *PortMatcher { } } -func (v *PortMatcher) Apply(ctx context.Context) bool { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { +func (v *PortMatcher) Apply(ctx *Context) bool { + if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() { return false } - return v.port.Contains(outbound.Target.Port) + return v.port.Contains(ctx.Outbound.Target.Port) } type NetworkMatcher struct { @@ -206,12 +179,11 @@ func NewNetworkMatcher(network []net.Network) NetworkMatcher { return matcher } -func (v NetworkMatcher) Apply(ctx context.Context) bool { - outbound := session.OutboundFromContext(ctx) - if outbound == nil || !outbound.Target.IsValid() { +func (v NetworkMatcher) Apply(ctx *Context) bool { + if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() { return false } - return v.list[int(outbound.Target.Network)] + return v.list[int(ctx.Outbound.Target.Network)] } type UserMatcher struct { @@ -230,13 +202,12 @@ func NewUserMatcher(users []string) *UserMatcher { } } -func (v *UserMatcher) Apply(ctx context.Context) bool { - inbound := session.InboundFromContext(ctx) - if inbound == nil { +func (v *UserMatcher) Apply(ctx *Context) bool { + if ctx.Inbound == nil { return false } - user := inbound.User + user := ctx.Inbound.User if user == nil { return false } @@ -264,12 +235,11 @@ func NewInboundTagMatcher(tags []string) *InboundTagMatcher { } } -func (v *InboundTagMatcher) Apply(ctx context.Context) bool { - inbound := session.InboundFromContext(ctx) - if inbound == nil || len(inbound.Tag) == 0 { +func (v *InboundTagMatcher) Apply(ctx *Context) bool { + if ctx.Inbound == nil || len(ctx.Inbound.Tag) == 0 { return false } - tag := inbound.Tag + tag := ctx.Inbound.Tag for _, t := range v.tags { if t == tag { return true @@ -296,14 +266,12 @@ func NewProtocolMatcher(protocols []string) *ProtocolMatcher { } } -func (m *ProtocolMatcher) Apply(ctx context.Context) bool { - content := session.ContentFromContext(ctx) - - if content == nil { +func (m *ProtocolMatcher) Apply(ctx *Context) bool { + if ctx.Content == nil { return false } - protocol := content.Protocol + protocol := ctx.Content.Protocol for _, p := range m.protocols { if strings.HasPrefix(protocol, p) { return true diff --git a/app/router/condition_test.go b/app/router/condition_test.go index bb193e004..9d49428f2 100644 --- a/app/router/condition_test.go +++ b/app/router/condition_test.go @@ -1,7 +1,6 @@ package router_test import ( - "context" "os" "path/filepath" "strconv" @@ -28,17 +27,17 @@ func init() { common.Must(filesystem.CopyFile(platform.GetAssetLocation("geosite.dat"), filepath.Join(wd, "..", "..", "release", "config", "geosite.dat"))) } -func withOutbound(outbound *session.Outbound) context.Context { - return session.ContextWithOutbound(context.Background(), outbound) +func withOutbound(outbound *session.Outbound) *Context { + return &Context{Outbound: outbound} } -func withInbound(inbound *session.Inbound) context.Context { - return session.ContextWithInbound(context.Background(), inbound) +func withInbound(inbound *session.Inbound) *Context { + return &Context{Inbound: inbound} } func TestRoutingRule(t *testing.T) { type ruleTest struct { - input context.Context + input *Context output bool } @@ -89,7 +88,7 @@ func TestRoutingRule(t *testing.T) { output: false, }, { - input: context.Background(), + input: &Context{}, output: false, }, }, @@ -125,7 +124,7 @@ func TestRoutingRule(t *testing.T) { output: true, }, { - input: context.Background(), + input: &Context{}, output: false, }, }, @@ -165,7 +164,7 @@ func TestRoutingRule(t *testing.T) { output: true, }, { - input: context.Background(), + input: &Context{}, output: false, }, }, @@ -206,7 +205,7 @@ func TestRoutingRule(t *testing.T) { output: false, }, { - input: context.Background(), + input: &Context{}, output: false, }, }, @@ -217,7 +216,7 @@ func TestRoutingRule(t *testing.T) { }, test: []ruleTest{ { - input: session.ContextWithContent(context.Background(), &session.Content{Protocol: (&http.SniffHeader{}).Protocol()}), + input: &Context{Content: &session.Content{Protocol: (&http.SniffHeader{}).Protocol()}}, output: true, }, }, diff --git a/app/router/config.go b/app/router/config.go index 7cb163a82..8a5141840 100644 --- a/app/router/config.go +++ b/app/router/config.go @@ -3,9 +3,7 @@ package router import ( - "context" - - net "v2ray.com/core/common/net" + "v2ray.com/core/common/net" "v2ray.com/core/features/outbound" ) @@ -61,7 +59,7 @@ func (r *Rule) GetTag() (string, error) { return r.Tag, nil } -func (r *Rule) Apply(ctx context.Context) bool { +func (r *Rule) Apply(ctx *Context) bool { return r.Condition.Apply(ctx) } diff --git a/app/router/router.go b/app/router/router.go index dad69728b..542ff3d7b 100644 --- a/app/router/router.go +++ b/app/router/router.go @@ -9,6 +9,7 @@ import ( "v2ray.com/core" "v2ray.com/core/common" + "v2ray.com/core/common/net" "v2ray.com/core/common/session" "v2ray.com/core/features/dns" "v2ray.com/core/features/outbound" @@ -85,44 +86,33 @@ func isDomainOutbound(outbound *session.Outbound) bool { return outbound != nil && outbound.Target.IsValid() && outbound.Target.Address.Family().IsDomain() } -func (r *Router) resolveIP(outbound *session.Outbound) error { - domain := outbound.Target.Address.Domain() - ips, err := r.dns.LookupIP(domain) - if err != nil { - return err - } - - outbound.ResolvedIPs = ips - return nil -} - // PickRoute implements routing.Router. func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) { - outbound := session.OutboundFromContext(ctx) - if r.domainStrategy == Config_IpOnDemand && isDomainOutbound(outbound) { - if err := r.resolveIP(outbound); err != nil { - newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx)) - } + sessionContext := &Context{ + Inbound: session.InboundFromContext(ctx), + Outbound: session.OutboundFromContext(ctx), + Content: session.ContentFromContext(ctx), + } + + if r.domainStrategy == Config_IpOnDemand { + sessionContext.dnsClient = r.dns } for _, rule := range r.rules { - if rule.Apply(ctx) { + if rule.Apply(sessionContext) { return rule, nil } } - if r.domainStrategy != Config_IpIfNonMatch || !isDomainOutbound(outbound) { + if r.domainStrategy != Config_IpIfNonMatch || !isDomainOutbound(sessionContext.Outbound) { return nil, common.ErrNoClue } - if err := r.resolveIP(outbound); err != nil { - newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx)) - return nil, common.ErrNoClue - } + sessionContext.dnsClient = r.dns // Try applying rules again if we have IPs. for _, rule := range r.rules { - if rule.Apply(ctx) { + if rule.Apply(sessionContext) { return rule, nil } } @@ -144,3 +134,37 @@ func (*Router) Close() error { func (*Router) Type() interface{} { return routing.RouterType() } + +type Context struct { + Inbound *session.Inbound + Outbound *session.Outbound + Content *session.Content + + dnsClient dns.Client +} + +func (c *Context) GetTargetIPs() []net.IP { + if c.Outbound == nil || !c.Outbound.Target.IsValid() { + return nil + } + + if c.Outbound.Target.Address.Family().IsIP() { + return []net.IP{c.Outbound.Target.Address.IP()} + } + + if len(c.Outbound.ResolvedIPs) > 0 { + return c.Outbound.ResolvedIPs + } + + if c.dnsClient != nil { + domain := c.Outbound.Target.Address.Domain() + ips, err := c.dnsClient.LookupIP(domain) + if err == nil { + c.Outbound.ResolvedIPs = ips + return ips + } + newError("resolve ip for ", domain).Base(err).WriteToLog() + } + + return nil +} diff --git a/app/router/router_test.go b/app/router/router_test.go index ccb135cce..0992e1c9a 100644 --- a/app/router/router_test.go +++ b/app/router/router_test.go @@ -1,6 +1,7 @@ package router_test import ( + "context" "testing" "github.com/golang/mock/gomock" @@ -42,7 +43,7 @@ func TestSimpleRouter(t *testing.T) { HandlerSelector: mockHs, })) - ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}) + ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}) tag, err := r.PickRoute(ctx) common.Must(err) if tag != "test" { @@ -83,7 +84,7 @@ func TestSimpleBalancer(t *testing.T) { HandlerSelector: mockHs, })) - ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}) + ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}) tag, err := r.PickRoute(ctx) common.Must(err) if tag != "test" { @@ -118,7 +119,7 @@ func TestIPOnDemand(t *testing.T) { r := new(Router) common.Must(r.Init(config, mockDns, nil)) - ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}) + ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}) tag, err := r.PickRoute(ctx) common.Must(err) if tag != "test" { @@ -153,7 +154,7 @@ func TestIPIfNonMatchDomain(t *testing.T) { r := new(Router) common.Must(r.Init(config, mockDns, nil)) - ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}) + ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}) tag, err := r.PickRoute(ctx) common.Must(err) if tag != "test" { @@ -187,7 +188,7 @@ func TestIPIfNonMatchIP(t *testing.T) { r := new(Router) common.Must(r.Init(config, mockDns, nil)) - ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)}) + ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)}) tag, err := r.PickRoute(ctx) common.Must(err) if tag != "test" {