1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2025-01-05 00:47:51 -05:00

extract all session context before checking conditions

This commit is contained in:
Darien Raymond 2019-02-28 09:28:55 +01:00
parent cc513c1002
commit 0d31a68694
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
5 changed files with 103 additions and 113 deletions

View File

@ -3,16 +3,14 @@
package router package router
import ( import (
"context"
"strings" "strings"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/session"
"v2ray.com/core/common/strmatcher" "v2ray.com/core/common/strmatcher"
) )
type Condition interface { type Condition interface {
Apply(ctx context.Context) bool Apply(ctx *Context) bool
} }
type ConditionChan []Condition type ConditionChan []Condition
@ -27,7 +25,7 @@ func (v *ConditionChan) Add(cond Condition) *ConditionChan {
return v return v
} }
func (v *ConditionChan) Apply(ctx context.Context) bool { func (v *ConditionChan) Apply(ctx *Context) bool {
for _, cond := range *v { for _, cond := range *v {
if !cond.Apply(ctx) { if !cond.Apply(ctx) {
return false return false
@ -84,46 +82,36 @@ func (m *DomainMatcher) ApplyDomain(domain string) bool {
return m.matchers.Match(domain) > 0 return m.matchers.Match(domain) > 0
} }
func (m *DomainMatcher) Apply(ctx context.Context) bool { func (m *DomainMatcher) Apply(ctx *Context) bool {
outbound := session.OutboundFromContext(ctx) if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() {
if outbound == nil || !outbound.Target.IsValid() {
return false return false
} }
dest := outbound.Target dest := ctx.Outbound.Target
if !dest.Address.Family().IsDomain() { if !dest.Address.Family().IsDomain() {
return false return false
} }
return m.ApplyDomain(dest.Address.Domain()) return m.ApplyDomain(dest.Address.Domain())
} }
func sourceFromContext(ctx context.Context) net.Destination { func getIPsFromSource(ctx *Context) []net.IP {
inbound := session.InboundFromContext(ctx) if ctx.Inbound == nil || !ctx.Inbound.Source.IsValid() {
if inbound == nil {
return net.Destination{}
}
return inbound.Source
}
func targetFromContent(ctx context.Context) net.Destination {
outbound := session.OutboundFromContext(ctx)
if outbound == nil {
return net.Destination{}
}
return outbound.Target
}
func resolvedIPFromContext(ctx context.Context) []net.IP {
outbound := session.OutboundFromContext(ctx)
if outbound == nil {
return nil return nil
} }
return outbound.ResolvedIPs dest := ctx.Inbound.Source
if dest.Address.Family().IsDomain() {
return nil
}
return []net.IP{dest.Address.IP()}
}
func getIPsFromTarget(ctx *Context) []net.IP {
return ctx.GetTargetIPs()
} }
type MultiGeoIPMatcher struct { type MultiGeoIPMatcher struct {
matchers []*GeoIPMatcher matchers []*GeoIPMatcher
destFunc func(context.Context) net.Destination ipFunc func(*Context) []net.IP
resolvedIPFunc func(context.Context) []net.IP
} }
func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, error) { func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, error) {
@ -141,30 +129,16 @@ func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, e
} }
if onSource { if onSource {
matcher.destFunc = sourceFromContext matcher.ipFunc = getIPsFromSource
} else { } else {
matcher.destFunc = targetFromContent matcher.ipFunc = getIPsFromTarget
matcher.resolvedIPFunc = resolvedIPFromContext
} }
return matcher, nil return matcher, nil
} }
func (m *MultiGeoIPMatcher) Apply(ctx context.Context) bool { func (m *MultiGeoIPMatcher) Apply(ctx *Context) bool {
ips := make([]net.IP, 0, 4) ips := m.ipFunc(ctx)
dest := m.destFunc(ctx)
if dest.IsValid() && dest.Address.Family().IsIP() {
ips = append(ips, dest.Address.IP())
}
if m.resolvedIPFunc != nil {
rips := m.resolvedIPFunc(ctx)
if len(rips) > 0 {
ips = append(ips, rips...)
}
}
for _, ip := range ips { for _, ip := range ips {
for _, matcher := range m.matchers { for _, matcher := range m.matchers {
@ -186,12 +160,11 @@ func NewPortMatcher(list *net.PortList) *PortMatcher {
} }
} }
func (v *PortMatcher) Apply(ctx context.Context) bool { func (v *PortMatcher) Apply(ctx *Context) bool {
outbound := session.OutboundFromContext(ctx) if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() {
if outbound == nil || !outbound.Target.IsValid() {
return false return false
} }
return v.port.Contains(outbound.Target.Port) return v.port.Contains(ctx.Outbound.Target.Port)
} }
type NetworkMatcher struct { type NetworkMatcher struct {
@ -206,12 +179,11 @@ func NewNetworkMatcher(network []net.Network) NetworkMatcher {
return matcher return matcher
} }
func (v NetworkMatcher) Apply(ctx context.Context) bool { func (v NetworkMatcher) Apply(ctx *Context) bool {
outbound := session.OutboundFromContext(ctx) if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() {
if outbound == nil || !outbound.Target.IsValid() {
return false return false
} }
return v.list[int(outbound.Target.Network)] return v.list[int(ctx.Outbound.Target.Network)]
} }
type UserMatcher struct { type UserMatcher struct {
@ -230,13 +202,12 @@ func NewUserMatcher(users []string) *UserMatcher {
} }
} }
func (v *UserMatcher) Apply(ctx context.Context) bool { func (v *UserMatcher) Apply(ctx *Context) bool {
inbound := session.InboundFromContext(ctx) if ctx.Inbound == nil {
if inbound == nil {
return false return false
} }
user := inbound.User user := ctx.Inbound.User
if user == nil { if user == nil {
return false return false
} }
@ -264,12 +235,11 @@ func NewInboundTagMatcher(tags []string) *InboundTagMatcher {
} }
} }
func (v *InboundTagMatcher) Apply(ctx context.Context) bool { func (v *InboundTagMatcher) Apply(ctx *Context) bool {
inbound := session.InboundFromContext(ctx) if ctx.Inbound == nil || len(ctx.Inbound.Tag) == 0 {
if inbound == nil || len(inbound.Tag) == 0 {
return false return false
} }
tag := inbound.Tag tag := ctx.Inbound.Tag
for _, t := range v.tags { for _, t := range v.tags {
if t == tag { if t == tag {
return true return true
@ -296,14 +266,12 @@ func NewProtocolMatcher(protocols []string) *ProtocolMatcher {
} }
} }
func (m *ProtocolMatcher) Apply(ctx context.Context) bool { func (m *ProtocolMatcher) Apply(ctx *Context) bool {
content := session.ContentFromContext(ctx) if ctx.Content == nil {
if content == nil {
return false return false
} }
protocol := content.Protocol protocol := ctx.Content.Protocol
for _, p := range m.protocols { for _, p := range m.protocols {
if strings.HasPrefix(protocol, p) { if strings.HasPrefix(protocol, p) {
return true return true

View File

@ -1,7 +1,6 @@
package router_test package router_test
import ( import (
"context"
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
@ -28,17 +27,17 @@ func init() {
common.Must(filesystem.CopyFile(platform.GetAssetLocation("geosite.dat"), filepath.Join(wd, "..", "..", "release", "config", "geosite.dat"))) common.Must(filesystem.CopyFile(platform.GetAssetLocation("geosite.dat"), filepath.Join(wd, "..", "..", "release", "config", "geosite.dat")))
} }
func withOutbound(outbound *session.Outbound) context.Context { func withOutbound(outbound *session.Outbound) *Context {
return session.ContextWithOutbound(context.Background(), outbound) return &Context{Outbound: outbound}
} }
func withInbound(inbound *session.Inbound) context.Context { func withInbound(inbound *session.Inbound) *Context {
return session.ContextWithInbound(context.Background(), inbound) return &Context{Inbound: inbound}
} }
func TestRoutingRule(t *testing.T) { func TestRoutingRule(t *testing.T) {
type ruleTest struct { type ruleTest struct {
input context.Context input *Context
output bool output bool
} }
@ -89,7 +88,7 @@ func TestRoutingRule(t *testing.T) {
output: false, output: false,
}, },
{ {
input: context.Background(), input: &Context{},
output: false, output: false,
}, },
}, },
@ -125,7 +124,7 @@ func TestRoutingRule(t *testing.T) {
output: true, output: true,
}, },
{ {
input: context.Background(), input: &Context{},
output: false, output: false,
}, },
}, },
@ -165,7 +164,7 @@ func TestRoutingRule(t *testing.T) {
output: true, output: true,
}, },
{ {
input: context.Background(), input: &Context{},
output: false, output: false,
}, },
}, },
@ -206,7 +205,7 @@ func TestRoutingRule(t *testing.T) {
output: false, output: false,
}, },
{ {
input: context.Background(), input: &Context{},
output: false, output: false,
}, },
}, },
@ -217,7 +216,7 @@ func TestRoutingRule(t *testing.T) {
}, },
test: []ruleTest{ test: []ruleTest{
{ {
input: session.ContextWithContent(context.Background(), &session.Content{Protocol: (&http.SniffHeader{}).Protocol()}), input: &Context{Content: &session.Content{Protocol: (&http.SniffHeader{}).Protocol()}},
output: true, output: true,
}, },
}, },

View File

@ -3,9 +3,7 @@
package router package router
import ( import (
"context" "v2ray.com/core/common/net"
net "v2ray.com/core/common/net"
"v2ray.com/core/features/outbound" "v2ray.com/core/features/outbound"
) )
@ -61,7 +59,7 @@ func (r *Rule) GetTag() (string, error) {
return r.Tag, nil return r.Tag, nil
} }
func (r *Rule) Apply(ctx context.Context) bool { func (r *Rule) Apply(ctx *Context) bool {
return r.Condition.Apply(ctx) return r.Condition.Apply(ctx)
} }

View File

@ -9,6 +9,7 @@ import (
"v2ray.com/core" "v2ray.com/core"
"v2ray.com/core/common" "v2ray.com/core/common"
"v2ray.com/core/common/net"
"v2ray.com/core/common/session" "v2ray.com/core/common/session"
"v2ray.com/core/features/dns" "v2ray.com/core/features/dns"
"v2ray.com/core/features/outbound" "v2ray.com/core/features/outbound"
@ -85,44 +86,33 @@ func isDomainOutbound(outbound *session.Outbound) bool {
return outbound != nil && outbound.Target.IsValid() && outbound.Target.Address.Family().IsDomain() return outbound != nil && outbound.Target.IsValid() && outbound.Target.Address.Family().IsDomain()
} }
func (r *Router) resolveIP(outbound *session.Outbound) error {
domain := outbound.Target.Address.Domain()
ips, err := r.dns.LookupIP(domain)
if err != nil {
return err
}
outbound.ResolvedIPs = ips
return nil
}
// PickRoute implements routing.Router. // PickRoute implements routing.Router.
func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) { func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) {
outbound := session.OutboundFromContext(ctx) sessionContext := &Context{
if r.domainStrategy == Config_IpOnDemand && isDomainOutbound(outbound) { Inbound: session.InboundFromContext(ctx),
if err := r.resolveIP(outbound); err != nil { Outbound: session.OutboundFromContext(ctx),
newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx)) Content: session.ContentFromContext(ctx),
} }
if r.domainStrategy == Config_IpOnDemand {
sessionContext.dnsClient = r.dns
} }
for _, rule := range r.rules { for _, rule := range r.rules {
if rule.Apply(ctx) { if rule.Apply(sessionContext) {
return rule, nil return rule, nil
} }
} }
if r.domainStrategy != Config_IpIfNonMatch || !isDomainOutbound(outbound) { if r.domainStrategy != Config_IpIfNonMatch || !isDomainOutbound(sessionContext.Outbound) {
return nil, common.ErrNoClue return nil, common.ErrNoClue
} }
if err := r.resolveIP(outbound); err != nil { sessionContext.dnsClient = r.dns
newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx))
return nil, common.ErrNoClue
}
// Try applying rules again if we have IPs. // Try applying rules again if we have IPs.
for _, rule := range r.rules { for _, rule := range r.rules {
if rule.Apply(ctx) { if rule.Apply(sessionContext) {
return rule, nil return rule, nil
} }
} }
@ -144,3 +134,37 @@ func (*Router) Close() error {
func (*Router) Type() interface{} { func (*Router) Type() interface{} {
return routing.RouterType() return routing.RouterType()
} }
type Context struct {
Inbound *session.Inbound
Outbound *session.Outbound
Content *session.Content
dnsClient dns.Client
}
func (c *Context) GetTargetIPs() []net.IP {
if c.Outbound == nil || !c.Outbound.Target.IsValid() {
return nil
}
if c.Outbound.Target.Address.Family().IsIP() {
return []net.IP{c.Outbound.Target.Address.IP()}
}
if len(c.Outbound.ResolvedIPs) > 0 {
return c.Outbound.ResolvedIPs
}
if c.dnsClient != nil {
domain := c.Outbound.Target.Address.Domain()
ips, err := c.dnsClient.LookupIP(domain)
if err == nil {
c.Outbound.ResolvedIPs = ips
return ips
}
newError("resolve ip for ", domain).Base(err).WriteToLog()
}
return nil
}

View File

@ -1,6 +1,7 @@
package router_test package router_test
import ( import (
"context"
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
@ -42,7 +43,7 @@ func TestSimpleRouter(t *testing.T) {
HandlerSelector: mockHs, HandlerSelector: mockHs,
})) }))
ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}) ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
tag, err := r.PickRoute(ctx) tag, err := r.PickRoute(ctx)
common.Must(err) common.Must(err)
if tag != "test" { if tag != "test" {
@ -83,7 +84,7 @@ func TestSimpleBalancer(t *testing.T) {
HandlerSelector: mockHs, HandlerSelector: mockHs,
})) }))
ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}) ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
tag, err := r.PickRoute(ctx) tag, err := r.PickRoute(ctx)
common.Must(err) common.Must(err)
if tag != "test" { if tag != "test" {
@ -118,7 +119,7 @@ func TestIPOnDemand(t *testing.T) {
r := new(Router) r := new(Router)
common.Must(r.Init(config, mockDns, nil)) common.Must(r.Init(config, mockDns, nil))
ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}) ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
tag, err := r.PickRoute(ctx) tag, err := r.PickRoute(ctx)
common.Must(err) common.Must(err)
if tag != "test" { if tag != "test" {
@ -153,7 +154,7 @@ func TestIPIfNonMatchDomain(t *testing.T) {
r := new(Router) r := new(Router)
common.Must(r.Init(config, mockDns, nil)) common.Must(r.Init(config, mockDns, nil))
ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)}) ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
tag, err := r.PickRoute(ctx) tag, err := r.PickRoute(ctx)
common.Must(err) common.Must(err)
if tag != "test" { if tag != "test" {
@ -187,7 +188,7 @@ func TestIPIfNonMatchIP(t *testing.T) {
r := new(Router) r := new(Router)
common.Must(r.Init(config, mockDns, nil)) common.Must(r.Init(config, mockDns, nil))
ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)}) ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)})
tag, err := r.PickRoute(ctx) tag, err := r.PickRoute(ctx)
common.Must(err) common.Must(err)
if tag != "test" { if tag != "test" {