diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index 7c22e05a7..c5aaf84f3 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -20,6 +20,7 @@ import ( "v2ray.com/core/features/outbound" "v2ray.com/core/features/policy" "v2ray.com/core/features/routing" + routing_session "v2ray.com/core/features/routing/session" "v2ray.com/core/features/stats" "v2ray.com/core/transport" "v2ray.com/core/transport/pipe" @@ -265,7 +266,7 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport. } if d.router != nil && !skipRoutePick { - if tag, err := d.router.PickRoute(ctx); err == nil { + if tag, err := d.router.PickRoute(routing_session.AsRoutingContext(ctx)); err == nil { if h := d.ohm.GetHandler(tag); h != nil { newError("taking detour [", tag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx)) handler = h diff --git a/app/router/condition.go b/app/router/condition.go index ffafdb3c3..5db1fa3c5 100644 --- a/app/router/condition.go +++ b/app/router/condition.go @@ -10,10 +10,11 @@ import ( "v2ray.com/core/common/net" "v2ray.com/core/common/strmatcher" + "v2ray.com/core/features/routing" ) type Condition interface { - Apply(ctx *Context) bool + Apply(ctx routing.Context) bool } type ConditionChan []Condition @@ -28,7 +29,8 @@ func (v *ConditionChan) Add(cond Condition) *ConditionChan { return v } -func (v *ConditionChan) Apply(ctx *Context) bool { +// Apply applies all conditions registered in this chan. +func (v *ConditionChan) Apply(ctx routing.Context) bool { for _, cond := range *v { if !cond.Apply(ctx) { return false @@ -85,36 +87,18 @@ func (m *DomainMatcher) ApplyDomain(domain string) bool { return len(m.matchers.Match(domain)) > 0 } -func (m *DomainMatcher) Apply(ctx *Context) bool { - if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() { +// Apply implements Condition. +func (m *DomainMatcher) Apply(ctx routing.Context) bool { + domain := ctx.GetTargetDomain() + if len(domain) == 0 { return false } - dest := ctx.Outbound.Target - if !dest.Address.Family().IsDomain() { - return false - } - return m.ApplyDomain(dest.Address.Domain()) -} - -func getIPsFromSource(ctx *Context) []net.IP { - if ctx.Inbound == nil || !ctx.Inbound.Source.IsValid() { - return nil - } - dest := ctx.Inbound.Source - if dest.Address.Family().IsDomain() { - return nil - } - - return []net.IP{dest.Address.IP()} -} - -func getIPsFromTarget(ctx *Context) []net.IP { - return ctx.GetTargetIPs() + return m.ApplyDomain(domain) } type MultiGeoIPMatcher struct { matchers []*GeoIPMatcher - ipFunc func(*Context) []net.IP + onSource bool } func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, error) { @@ -129,20 +113,20 @@ func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, e matcher := &MultiGeoIPMatcher{ matchers: matchers, - } - - if onSource { - matcher.ipFunc = getIPsFromSource - } else { - matcher.ipFunc = getIPsFromTarget + onSource: onSource, } return matcher, nil } -func (m *MultiGeoIPMatcher) Apply(ctx *Context) bool { - ips := m.ipFunc(ctx) - +// Apply implements Condition. +func (m *MultiGeoIPMatcher) Apply(ctx routing.Context) bool { + var ips []net.IP + if m.onSource { + ips = ctx.GetSourceIPs() + } else { + ips = ctx.GetTargetIPs() + } for _, ip := range ips { for _, matcher := range m.matchers { if matcher.Match(ip) { @@ -166,20 +150,13 @@ func NewPortMatcher(list *net.PortList, onSource bool) *PortMatcher { } } -func (v *PortMatcher) Apply(ctx *Context) bool { - var port net.Port +// Apply implements Condition. +func (v *PortMatcher) Apply(ctx routing.Context) bool { if v.onSource { - if ctx.Inbound == nil || !ctx.Inbound.Source.IsValid() { - return false - } - port = ctx.Inbound.Source.Port + return v.port.Contains(ctx.GetSourcePort()) } else { - if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() { - return false - } - port = ctx.Outbound.Target.Port + return v.port.Contains(ctx.GetTargetPort()) } - return v.port.Contains(port) } type NetworkMatcher struct { @@ -194,11 +171,9 @@ func NewNetworkMatcher(network []net.Network) NetworkMatcher { return matcher } -func (v NetworkMatcher) Apply(ctx *Context) bool { - if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() { - return false - } - return v.list[int(ctx.Outbound.Target.Network)] +// Apply implements Condition. +func (v NetworkMatcher) Apply(ctx routing.Context) bool { + return v.list[int(ctx.GetNetwork())] } type UserMatcher struct { @@ -217,17 +192,14 @@ func NewUserMatcher(users []string) *UserMatcher { } } -func (v *UserMatcher) Apply(ctx *Context) bool { - if ctx.Inbound == nil { - return false - } - - user := ctx.Inbound.User - if user == nil { +// Apply implements Condition. +func (v *UserMatcher) Apply(ctx routing.Context) bool { + user := ctx.GetUser() + if len(user) == 0 { return false } for _, u := range v.user { - if u == user.Email { + if u == user { return true } } @@ -250,11 +222,12 @@ func NewInboundTagMatcher(tags []string) *InboundTagMatcher { } } -func (v *InboundTagMatcher) Apply(ctx *Context) bool { - if ctx.Inbound == nil || len(ctx.Inbound.Tag) == 0 { +// Apply implements Condition. +func (v *InboundTagMatcher) Apply(ctx routing.Context) bool { + tag := ctx.GetInboundTag() + if len(tag) == 0 { return false } - tag := ctx.Inbound.Tag for _, t := range v.tags { if t == tag { return true @@ -281,18 +254,17 @@ func NewProtocolMatcher(protocols []string) *ProtocolMatcher { } } -func (m *ProtocolMatcher) Apply(ctx *Context) bool { - if ctx.Content == nil { +// Apply implements Condition. +func (m *ProtocolMatcher) Apply(ctx routing.Context) bool { + protocol := ctx.GetProtocol() + if len(protocol) == 0 { return false } - - protocol := ctx.Content.Protocol for _, p := range m.protocols { if strings.HasPrefix(protocol, p) { return true } } - return false } @@ -343,9 +315,11 @@ func (m *AttributeMatcher) Match(attrs map[string]interface{}) bool { return satisfied != nil && bool(satisfied.Truth()) } -func (m *AttributeMatcher) Apply(ctx *Context) bool { - if ctx.Content == nil { +// Apply implements Condition. +func (m *AttributeMatcher) Apply(ctx routing.Context) bool { + attributes := ctx.GetAttributes() + if attributes == nil { return false } - return m.Match(ctx.Content.Attributes) + return m.Match(attributes) } diff --git a/app/router/condition_test.go b/app/router/condition_test.go index 0c0f48f3c..caef73679 100644 --- a/app/router/condition_test.go +++ b/app/router/condition_test.go @@ -17,6 +17,8 @@ import ( "v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol/http" "v2ray.com/core/common/session" + "v2ray.com/core/features/routing" + routing_session "v2ray.com/core/features/routing/session" ) func init() { @@ -31,17 +33,25 @@ func init() { } } -func withOutbound(outbound *session.Outbound) *Context { - return &Context{Outbound: outbound} +func withBackground() routing.Context { + return &routing_session.Context{} } -func withInbound(inbound *session.Inbound) *Context { - return &Context{Inbound: inbound} +func withOutbound(outbound *session.Outbound) routing.Context { + return &routing_session.Context{Outbound: outbound} +} + +func withInbound(inbound *session.Inbound) routing.Context { + return &routing_session.Context{Inbound: inbound} +} + +func withContent(content *session.Content) routing.Context { + return &routing_session.Context{Content: content} } func TestRoutingRule(t *testing.T) { type ruleTest struct { - input *Context + input routing.Context output bool } @@ -92,7 +102,7 @@ func TestRoutingRule(t *testing.T) { output: false, }, { - input: &Context{}, + input: withBackground(), output: false, }, }, @@ -128,7 +138,7 @@ func TestRoutingRule(t *testing.T) { output: true, }, { - input: &Context{}, + input: withBackground(), output: false, }, }, @@ -168,7 +178,7 @@ func TestRoutingRule(t *testing.T) { output: true, }, { - input: &Context{}, + input: withBackground(), output: false, }, }, @@ -209,7 +219,7 @@ func TestRoutingRule(t *testing.T) { output: false, }, { - input: &Context{}, + input: withBackground(), output: false, }, }, @@ -220,7 +230,7 @@ func TestRoutingRule(t *testing.T) { }, test: []ruleTest{ { - input: &Context{Content: &session.Content{Protocol: (&http.SniffHeader{}).Protocol()}}, + input: withContent(&session.Content{Protocol: (&http.SniffHeader{}).Protocol()}), output: true, }, }, @@ -303,7 +313,7 @@ func TestRoutingRule(t *testing.T) { }, test: []ruleTest{ { - input: &Context{Content: &session.Content{Protocol: "http/1.1", Attributes: map[string]interface{}{":path": "/test/1"}}}, + input: withContent(&session.Content{Protocol: "http/1.1", Attributes: map[string]interface{}{":path": "/test/1"}}), output: true, }, }, diff --git a/app/router/config.go b/app/router/config.go index 932fcd85a..8eb9d5aa1 100644 --- a/app/router/config.go +++ b/app/router/config.go @@ -5,6 +5,7 @@ package router import ( "v2ray.com/core/common/net" "v2ray.com/core/features/outbound" + "v2ray.com/core/features/routing" ) // CIDRList is an alias of []*CIDR to provide sort.Interface. @@ -59,7 +60,8 @@ func (r *Rule) GetTag() (string, error) { return r.Tag, nil } -func (r *Rule) Apply(ctx *Context) bool { +// Apply checks rule matching of current routing context. +func (r *Rule) Apply(ctx routing.Context) bool { return r.Condition.Apply(ctx) } diff --git a/app/router/router.go b/app/router/router.go index 542ff3d7b..7e04c554d 100644 --- a/app/router/router.go +++ b/app/router/router.go @@ -10,7 +10,6 @@ import ( "v2ray.com/core" "v2ray.com/core/common" "v2ray.com/core/common/net" - "v2ray.com/core/common/session" "v2ray.com/core/features/dns" "v2ray.com/core/features/outbound" "v2ray.com/core/features/routing" @@ -74,7 +73,8 @@ func (r *Router) Init(config *Config, d dns.Client, ohm outbound.Manager) error return nil } -func (r *Router) PickRoute(ctx context.Context) (string, error) { +// PickRoute implements routing.Router. +func (r *Router) PickRoute(ctx routing.Context) (string, error) { rule, err := r.pickRouteInternal(ctx) if err != nil { return "", err @@ -82,37 +82,26 @@ func (r *Router) PickRoute(ctx context.Context) (string, error) { return rule.GetTag() } -func isDomainOutbound(outbound *session.Outbound) bool { - return outbound != nil && outbound.Target.IsValid() && outbound.Target.Address.Family().IsDomain() -} - -// PickRoute implements routing.Router. -func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) { - sessionContext := &Context{ - Inbound: session.InboundFromContext(ctx), - Outbound: session.OutboundFromContext(ctx), - Content: session.ContentFromContext(ctx), - } - +func (r *Router) pickRouteInternal(ctx routing.Context) (*Rule, error) { if r.domainStrategy == Config_IpOnDemand { - sessionContext.dnsClient = r.dns + ctx = ContextWithDNSClient(ctx, r.dns) } for _, rule := range r.rules { - if rule.Apply(sessionContext) { + if rule.Apply(ctx) { return rule, nil } } - if r.domainStrategy != Config_IpIfNonMatch || !isDomainOutbound(sessionContext.Outbound) { + if r.domainStrategy != Config_IpIfNonMatch || len(ctx.GetTargetDomain()) == 0 { return nil, common.ErrNoClue } - sessionContext.dnsClient = r.dns + ctx = ContextWithDNSClient(ctx, r.dns) // Try applying rules again if we have IPs. for _, rule := range r.rules { - if rule.Apply(sessionContext) { + if rule.Apply(ctx) { return rule, nil } } @@ -135,32 +124,30 @@ func (*Router) Type() interface{} { return routing.RouterType() } -type Context struct { - Inbound *session.Inbound - Outbound *session.Outbound - Content *session.Content - - dnsClient dns.Client +// ContextWithDNSClient creates a new routing context with domain resolving capability. Resolved domain IPs can be retrieved by GetTargetIPs(). +func ContextWithDNSClient(ctx routing.Context, client dns.Client) routing.Context { + return &resolvableContext{Context: ctx, dnsClient: client} } -func (c *Context) GetTargetIPs() []net.IP { - if c.Outbound == nil || !c.Outbound.Target.IsValid() { - return nil +type resolvableContext struct { + routing.Context + dnsClient dns.Client + resolvedIPs []net.IP +} + +func (ctx *resolvableContext) GetTargetIPs() []net.IP { + if ips := ctx.Context.GetTargetIPs(); len(ips) != 0 { + return ips } - if c.Outbound.Target.Address.Family().IsIP() { - return []net.IP{c.Outbound.Target.Address.IP()} + if len(ctx.resolvedIPs) > 0 { + return ctx.resolvedIPs } - if len(c.Outbound.ResolvedIPs) > 0 { - return c.Outbound.ResolvedIPs - } - - if c.dnsClient != nil { - domain := c.Outbound.Target.Address.Domain() - ips, err := c.dnsClient.LookupIP(domain) + if domain := ctx.GetTargetDomain(); len(domain) != 0 { + ips, err := ctx.dnsClient.LookupIP(domain) if err == nil { - c.Outbound.ResolvedIPs = ips + ctx.resolvedIPs = ips return ips } newError("resolve ip for ", domain).Base(err).WriteToLog() diff --git a/app/router/router_test.go b/app/router/router_test.go index 0992e1c9a..0ed5f033d 100644 --- a/app/router/router_test.go +++ b/app/router/router_test.go @@ -10,6 +10,7 @@ import ( "v2ray.com/core/common/net" "v2ray.com/core/common/session" "v2ray.com/core/features/outbound" + routing_session "v2ray.com/core/features/routing/session" "v2ray.com/core/testing/mocks" ) @@ -44,7 +45,7 @@ func TestSimpleRouter(t *testing.T) { })) ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}) - tag, err := r.PickRoute(ctx) + tag, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) common.Must(err) if tag != "test" { t.Error("expect tag 'test', bug actually ", tag) @@ -85,7 +86,7 @@ func TestSimpleBalancer(t *testing.T) { })) ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}) - tag, err := r.PickRoute(ctx) + tag, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) common.Must(err) if tag != "test" { t.Error("expect tag 'test', bug actually ", tag) @@ -120,7 +121,7 @@ func TestIPOnDemand(t *testing.T) { common.Must(r.Init(config, mockDns, nil)) ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}) - tag, err := r.PickRoute(ctx) + tag, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) common.Must(err) if tag != "test" { t.Error("expect tag 'test', bug actually ", tag) @@ -155,7 +156,7 @@ func TestIPIfNonMatchDomain(t *testing.T) { common.Must(r.Init(config, mockDns, nil)) ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}) - tag, err := r.PickRoute(ctx) + tag, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) common.Must(err) if tag != "test" { t.Error("expect tag 'test', bug actually ", tag) @@ -189,7 +190,7 @@ func TestIPIfNonMatchIP(t *testing.T) { common.Must(r.Init(config, mockDns, nil)) ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)}) - tag, err := r.PickRoute(ctx) + tag, err := r.PickRoute(routing_session.AsRoutingContext(ctx)) common.Must(err) if tag != "test" { t.Error("expect tag 'test', bug actually ", tag) diff --git a/common/session/session.go b/common/session/session.go index 86172c8a0..8d7b1ff6a 100644 --- a/common/session/session.go +++ b/common/session/session.go @@ -51,8 +51,6 @@ type Outbound struct { Target net.Destination // Gateway address Gateway net.Address - // ResolvedIPs is the resolved IP addresses, if the Targe is a domain address. - ResolvedIPs []net.IP } type SniffingRequest struct { diff --git a/features/routing/context.go b/features/routing/context.go new file mode 100644 index 000000000..b4adabafd --- /dev/null +++ b/features/routing/context.go @@ -0,0 +1,40 @@ +package routing + +import ( + "v2ray.com/core/common/net" +) + +// Context is a feature to store connection information for routing. +// +// v2ray:api:beta +type Context interface { + // GetInboundTag returns the tag of the inbound the connection was from. + GetInboundTag() string + + // GetSourcesIPs returns the source IPs bound to the connection. + GetSourceIPs() []net.IP + + // GetSourcePort returns the source port of the connection. + GetSourcePort() net.Port + + // GetTargetIPs returns the target IP of the connection or resolved IPs of target domain. + GetTargetIPs() []net.IP + + // GetTargetPort returns the target port of the connection. + GetTargetPort() net.Port + + // GetTargetDomain returns the target domain of the connection, if exists. + GetTargetDomain() string + + // GetNetwork returns the network type of the connection. + GetNetwork() net.Network + + // GetProtocol returns the protocol from the connection content, if sniffed out. + GetProtocol() string + + // GetUser returns the user email from the connection content, if exists. + GetUser() string + + // GetAttributes returns extra attributes from the conneciont content. + GetAttributes() map[string]interface{} +} diff --git a/features/routing/router.go b/features/routing/router.go index a71927b88..f473431ae 100644 --- a/features/routing/router.go +++ b/features/routing/router.go @@ -1,20 +1,18 @@ package routing import ( - "context" - "v2ray.com/core/common" "v2ray.com/core/features" ) // Router is a feature to choose an outbound tag for the given request. // -// v2ray:api:stable +// v2ray:api:beta type Router interface { features.Feature // PickRoute returns a tag of an OutboundHandler based on the given context. - PickRoute(ctx context.Context) (string, error) + PickRoute(ctx Context) (string, error) } // RouterType return the type of Router interface. Can be used to implement common.HasType. @@ -33,7 +31,7 @@ func (DefaultRouter) Type() interface{} { } // PickRoute implements Router. -func (DefaultRouter) PickRoute(ctx context.Context) (string, error) { +func (DefaultRouter) PickRoute(ctx Context) (string, error) { return "", common.ErrNoClue } diff --git a/features/routing/session/context.go b/features/routing/session/context.go new file mode 100644 index 000000000..6d61d4f94 --- /dev/null +++ b/features/routing/session/context.go @@ -0,0 +1,119 @@ +package session + +import ( + "context" + + "v2ray.com/core/common/net" + "v2ray.com/core/common/session" + "v2ray.com/core/features/routing" +) + +// Context is an implementation of routing.Context, which is a wrapper of context.context with session info. +type Context struct { + Inbound *session.Inbound + Outbound *session.Outbound + Content *session.Content +} + +// GetInboundTag implements routing.Context. +func (ctx *Context) GetInboundTag() string { + if ctx.Inbound == nil { + return "" + } + return ctx.Inbound.Tag +} + +// GetSourceIPs implements routing.Context. +func (ctx *Context) GetSourceIPs() []net.IP { + if ctx.Inbound == nil || !ctx.Inbound.Source.IsValid() { + return nil + } + dest := ctx.Inbound.Source + if dest.Address.Family().IsDomain() { + return nil + } + + return []net.IP{dest.Address.IP()} +} + +// GetSourcePort implements routing.Context. +func (ctx *Context) GetSourcePort() net.Port { + if ctx.Inbound == nil || !ctx.Inbound.Source.IsValid() { + return 0 + } + return ctx.Inbound.Source.Port +} + +// GetTargetIPs implements routing.Context. +func (ctx *Context) GetTargetIPs() []net.IP { + if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() { + return nil + } + + if ctx.Outbound.Target.Address.Family().IsIP() { + return []net.IP{ctx.Outbound.Target.Address.IP()} + } + + return nil +} + +// GetTargetPort implements routing.Context. +func (ctx *Context) GetTargetPort() net.Port { + if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() { + return 0 + } + return ctx.Outbound.Target.Port +} + +// GetTargetDomain implements routing.Context. +func (ctx *Context) GetTargetDomain() string { + if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() { + return "" + } + dest := ctx.Outbound.Target + if !dest.Address.Family().IsDomain() { + return "" + } + return dest.Address.Domain() +} + +// GetNetwork implements routing.Context. +func (ctx *Context) GetNetwork() net.Network { + if ctx.Outbound == nil { + return net.Network_Unknown + } + return ctx.Outbound.Target.Network +} + +// GetProtocol implements routing.Context. +func (ctx *Context) GetProtocol() string { + if ctx.Content == nil { + return "" + } + return ctx.Content.Protocol +} + +// GetUser implements routing.Context. +func (ctx *Context) GetUser() string { + if ctx.Inbound == nil { + return "" + } + return ctx.Inbound.User.Email +} + +// GetAttributes implements routing.Context. +func (ctx *Context) GetAttributes() map[string]interface{} { + if ctx.Content == nil { + return nil + } + return ctx.Content.Attributes +} + +// AsRoutingContext creates a context from context.context with session info. +func AsRoutingContext(ctx context.Context) routing.Context { + return &Context{ + Inbound: session.InboundFromContext(ctx), + Outbound: session.OutboundFromContext(ctx), + Content: session.ContentFromContext(ctx), + } +}